Skip to content
Snippets Groups Projects
plot_traj.py 5.67 KiB
#!/usr/bin/env python3

"""This script plots on a map lines linking extrema in trajectories
and markers for the extrema. The color changes with the trajectory.

"""

import random

import numpy as np

import util_eddies


def get_extr_coord(traj, e_overestim, SHPC, orientation):
    """Get the coordinates of the extrema of instantaneous eddies along a
    given trajectory.

    """

    x = []
    y = []

    for node in traj:
        i_slice, ishape = SHPC.comp_ishape_n(node, e_overestim, orientation)
        shape = SHPC.get_reader(i_slice, orientation, layer="extremum").shape(
            ishape
        )
        x.append(shape.points[0][0])
        y.append(shape.points[0][1])

    return x, y


def plot_single_traj(
    traj, e_overestim, SHPC, orientation, ax, src_crs, annotate_flag, color
):
    x, y = get_extr_coord(traj, e_overestim, SHPC, orientation)
    ax.plot(x, y, color, linewidth=0.5, transform=src_crs)
    ax.plot(
        x[0], y[0], marker="s", markersize=2, color="green", transform=src_crs
    )

    if annotate_flag:
        ax.annotate(
            str(traj[0]),
            ax.projection.transform_point(x[0], y[0], src_crs),
            xytext=(3 * random.random(), 3 * random.random()),
            textcoords="offset points",
        )


def get_duration(traj_list, e_overestim):
    duration_list = []

    for traj in traj_list:
        init_date = util_eddies.node_to_date_eddy(
            traj[0], e_overestim, only_date=True
        )
        final_date = util_eddies.node_to_date_eddy(
            traj[-1], e_overestim, only_date=True
        )
        duration_list.append(final_date - init_date)

    return np.array(duration_list)


if __name__ == "__main__":
    import json
    import argparse
    import sys

    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    import util_eddies

    parser = argparse.ArgumentParser()
    parser.add_argument("expanded_traj", help="JSon file")
    parser.add_argument("SHPC", help="directory")
    parser.add_argument(
        "--save", metavar="format", help="Save file to specified format"
    )
    parser.add_argument(
        "--annotate",
        action="store_true",
        help="annotate the first point of trajectory with node number",
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--min_duration",
        type=int,
        default=1,
        help="minimum duration of plotted trajectories (in time steps), >= 1",
    )
    group.add_argument(
        "--first_node",
        help="only plot the trajectory beginning with given node",
        type=int,
    )
    group.add_argument(
        "--traj_index",
        help="only plot the trajectory with given index",
        type=int,
    )
    args = parser.parse_args()

    with open(args.expanded_traj) as f:
        expanded_traj = json.load(f)

    print("Number of trajectories:", len(expanded_traj["traj"]))
    SHPC = util_eddies.SHPC_class(
        args.SHPC, def_orient=expanded_traj["orientation"]
    )
    src_crs = ccrs.Geodetic()
    projection = ccrs.Mercator(central_longitude=110)
    fig, ax = plt.subplots(subplot_kw={"projection": projection})
    random.seed(0)

    if args.first_node:
        for traj in expanded_traj["traj"]:
            if traj[0] == args.first_node:
                plot_single_traj(
                    traj,
                    expanded_traj["e_overestim"],
                    SHPC,
                    expanded_traj["orientation"],
                    ax,
                    src_crs,
                    args.annotate,
                    color="red",
                )
                break
        else:
            sys.exit("No trajectory found with this first node")
    elif args.traj_index:
        plot_single_traj(
            expanded_traj["traj"][args.traj_index],
            expanded_traj["e_overestim"],
            SHPC,
            expanded_traj["orientation"],
            ax,
            src_crs,
            args.annotate,
            color="red",
        )
    else:
        if args.min_duration == 1:
            for traj in expanded_traj["traj"]:
                plot_single_traj(
                    traj,
                    expanded_traj["e_overestim"],
                    SHPC,
                    expanded_traj["orientation"],
                    ax,
                    src_crs,
                    args.annotate,
                    color="red",
                )
        else:
            # args.min_duration > 1
            n_long_traj = 0
            duration_array = get_duration(
                expanded_traj["traj"], expanded_traj["e_overestim"]
            )

            for traj, duration in zip(expanded_traj["traj"], duration_array):
                if duration >= args.min_duration:
                    n_long_traj += 1
                    plot_single_traj(
                        traj,
                        expanded_traj["e_overestim"],
                        SHPC,
                        expanded_traj["orientation"],
                        ax,
                        src_crs,
                        args.annotate,
                        color="red",
                    )

            print(
                "Number of trajectories with sufficient duration:", n_long_traj
            )
            ax.set_title(
                rf"lifetime $\geq$ {args.min_duration} time steps"
                f"\nnumber of trajectories: {n_long_traj}"
            )

    ax.set_title(expanded_traj["orientation"], loc="left")
    ax.add_feature(cfeature.LAND, edgecolor="black")
    ax.gridlines(draw_labels=True)

    if args.save:
        plt.savefig(f"plot_traj.{args.save}")
        print(f'Created "plot_traj.{args.save}".')
    else:
        plt.show()