Skip to content
Snippets Groups Projects
plot_traj.py 4.33 KiB
#!/usr/bin/env python3

import random

import numpy as np

import util_eddies

def get_extr_coord(traj, e_overestim, SHPC, orientation):
    """Get the coordinates of the extrema of instantaneous eddies along a
    given trajectory.

    """

    x = []
    y = []

    for node in traj:
        date_index, eddy_index = util_eddies.node_to_date_eddy(
            node, e_overestim
        )
        i_slice = SHPC.get_slice(date_index)
        ishape = SHPC.comp_ishape(date_index, eddy_index, i_slice, orientation)
        shape = SHPC.get_reader(i_slice, orientation, layer="extremum").shape(
            ishape
        )
        x.append(shape.points[0][0])
        y.append(shape.points[0][1])

    return x, y


def plot_single_traj(
    traj, e_overestim, SHPC, orientation, ax, src_crs, annotate_flag, color
):
    x, y = get_extr_coord(traj, e_overestim, SHPC, orientation)
    ax.plot(x, y, color, linewidth=0.5, transform=src_crs)
    ax.plot(
        x[0], y[0], marker="s", markersize=2, color="green", transform=src_crs
    )

    if annotate_flag:
        ax.annotate(
            str(traj[0]),
            ax.projection.transform_point(x[0], y[0], src_crs),
            xytext=(3 * random.random(), 3 * random.random()),
            textcoords="offset points",
        )


def get_duration(expanded_traj):
    duration_list = []

    for traj in expanded_traj["traj"]:
        init_date = util_eddies.node_to_date_eddy(
            traj[0], expanded_traj["e_overestim"], only_date=True
        )
        final_date = util_eddies.node_to_date_eddy(
            traj[-1], expanded_traj["e_overestim"], only_date=True
        )
        duration_list.append(final_date - init_date)

    return np.array(duration_list)


if __name__ == "__main__":
    import json
    import argparse

    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    import util_eddies

    parser = argparse.ArgumentParser()
    parser.add_argument("expanded_traj", help="JSon file")
    parser.add_argument("SHPC", help="directory")
    parser.add_argument(
        "--save", metavar="format", help="Save file to specified format"
    )
    parser.add_argument(
        "--annotate",
        action="store_true",
        help="annotate the first point of trajectory with node number",
    )
    parser.add_argument(
        "--min_duration",
        type=int,
        default=1,
        help="minimum duration of plotted trajectories (in time steps), >= 1",
    )
    args = parser.parse_args()

    with open(args.expanded_traj) as f:
        expanded_traj = json.load(f)
    print("Number of trajectories:", len(expanded_traj["traj"]))
    SHPC = util_eddies.SHPC_class(
        args.SHPC, def_orient=expanded_traj["orientation"]
    )
    src_crs = ccrs.Geodetic()
    projection = ccrs.Mercator(central_longitude=110)
    fig, ax = plt.subplots(subplot_kw={"projection": projection})
    random.seed(0)

    if args.min_duration == 1:
        for traj in expanded_traj["traj"]:
            plot_single_traj(
                traj,
                expanded_traj["e_overestim"],
                SHPC,
                expanded_traj["orientation"],
                ax,
                src_crs,
                args.annotate,
                color="red",
            )
    else:
        # args.min_duration > 1
        n_long_traj = 0
        duration_array = get_duration(expanded_traj)

        for traj, duration in zip(expanded_traj["traj"], duration_array):
            if duration >= args.min_duration:
                n_long_traj += 1
                plot_single_traj(
                    traj,
                    expanded_traj["e_overestim"],
                    SHPC,
                    expanded_traj["orientation"],
                    ax,
                    src_crs,
                    args.annotate,
                    color="red",
                )

        print("Number of trajectories with sufficient duration:", n_long_traj)
        ax.set_title(
            rf"lifetime $\geq$ {args.min_duration} time steps"
            f"\nnumber of trajectories: {n_long_traj}"
        )

    ax.set_title(expanded_traj["orientation"], loc="left")
    ax.add_feature(cfeature.LAND, edgecolor="black")
    ax.gridlines(draw_labels=True)

    if args.save:
        plt.savefig(f"plot_traj.{args.save}")
        print(f'Created "plot_traj.{args.save}".')
    else:
        plt.show()