Skip to content
Snippets Groups Projects
plot_traj.py 4.12 KiB
#!/usr/bin/env python3

import report_graph
import random
import numpy as np

def get_extr_coord(traj, e_overestim, SHPC, orientation):
    """Get the coordinates of the extrema of instantaneous eddies along a
    given trajectory.

    """

    x = []
    y = []

    for node in traj:
        date_index, eddy_index = report_graph.node_to_date_eddy(node,
                                                                e_overestim)
        i_slice = SHPC.get_slice(date_index)
        ishape = SHPC.comp_ishape(date_index, eddy_index, i_slice, orientation)
        shape = SHPC.get_reader(i_slice, orientation, layer = "extremum")\
                    .shape(ishape)
        x.append(shape.points[0][0])
        y.append(shape.points[0][1])

    return x, y

def plot_single_traj(traj, e_overestim, SHPC, orientation, ax, src_crs,
                     annotate_flag, color):
    x, y = get_extr_coord(traj, e_overestim, SHPC, orientation)
    ax.plot(x, y, color, linewidth = 0.5, transform = src_crs)
    ax.plot(x[0], y[0], marker = "s", markersize = 2, color = "green",
            transform = src_crs)

    if annotate_flag:
        ax.annotate(str(traj[0]),
                    ax.projection.transform_point(x[0], y[0], src_crs),
                    xytext = (3 * random.random(), 3 * random.random()),
                    textcoords = "offset points")

def get_duration(expanded_traj, e_overestim):
    duration_list = []

    for traj in expanded_traj:
        init_date = report_graph.node_to_date_eddy(traj[0], e_overestim,
                                                   only_date = True)
        final_date = report_graph.node_to_date_eddy(traj[- 1], e_overestim,
                                                    only_date = True)
        duration_list.append(final_date - init_date)

    return np.array(duration_list)

if __name__ == "__main__":
    import util_eddies
    import matplotlib.pyplot as plt
    import json
    import cartopy.crs as ccrs
    import argparse
    import cartopy.feature as cfeature

    parser = argparse.ArgumentParser()
    parser.add_argument("expanded_traj", help = "JSon file")
    parser.add_argument("e_overestim", help = "text file")
    parser.add_argument("SHPC", help = "directory")
    parser.add_argument("orientation", choices = ["Anticyclones", "Cyclones"])
    parser.add_argument("--save", metavar = "format",
                        help = "Save file to specified format")
    parser.add_argument("--annotate", action = "store_true", help = "annotate "
                        "the first point of trajectory with node number")
    parser.add_argument("--min_duration", type = int, default = 1,
                        help = "minimum duration of plotted trajectories (in "
                        "time steps), >= 1")
    args = parser.parse_args()

    with open(args.expanded_traj) as f: expanded_traj = json.load(f)
    print("Number of trajectories:", len(expanded_traj))
    with open(args.e_overestim) as f: e_overestim = int(f.read())
    SHPC = util_eddies.SHPC_class(args.SHPC, def_orient = args.orientation)
    src_crs = ccrs.Geodetic()
    projection = ccrs.PlateCarree(central_longitude = 110)
    fig, ax = plt.subplots(subplot_kw = {"projection": projection})
    random.seed(0)

    if args.min_duration == 1:
        for traj in expanded_traj:
            plot_single_traj(traj, e_overestim, SHPC, args.orientation, ax,
                             src_crs, args.annotate, color = "red")
    else:
        # args.min_duration > 1
        n_long_traj = 0
        duration_array = get_duration(expanded_traj, e_overestim)

        for traj, duration in zip(expanded_traj, duration_array):
            if duration >= args.min_duration:
                n_long_traj += 1
                plot_single_traj(traj, e_overestim, SHPC, args.orientation, ax,
                                 src_crs, args.annotate, color = "red")

        print("Number of trajectories with sufficient duration:", n_long_traj)

    ax.add_feature(cfeature.LAND, edgecolor = "black")
    ax.gridlines(draw_labels = True)

    if args.save:
        plt.savefig(f"plot_traj.{args.save}")
        print(f'Created "plot_traj.{args.save}".')
    else:
        plt.show()