Skip to content
Snippets Groups Projects
plot_components.py 8.11 KiB
#!/usr/bin/env python3

import itertools
import sys

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib import patches
import cartopy.crs as ccrs

import report_graph
import util_eddies

color_iter = itertools.cycle(('#1f77b4', '#aec7e8', '#ff7f0e',
                              '#ffbb78', '#2ca02c', '#98df8a',
                              '#d62728', '#ff9896', '#9467bd',
                              '#c5b0d5', '#8c564b', '#c49c94',
                              '#e377c2', '#f7b6d2', '#7f7f7f',
                              '#c7c7c7', '#bcbd22', '#dbdb8d',
                              '#17becf', '#9edae5'))

def assign_all_components(G):
    G.graph["component_list"] = list(nx.weakly_connected_components(G))

    for component in G.graph["component_list"]:
        for n in component: G.nodes[n]["component"] = component

def is_node_type(G, n, label):
    if label == "root":
        return G.in_degree[n] == 0
    elif label == "leaf":
        return G.out_degree[n] == 0
    elif label == "merge":
        return G.in_degree[n] >= 2
    elif label == "split":
        return G.out_degree[n] >= 2
    elif label == "all":
        return True
    elif label == "important":
        return G.in_degree[n] == 0 or G.out_degree[n] == 0 or G.degree[n] >= 3
    elif label[0] == "date":
        return util_eddies.node_to_date_eddy(n, G.graph["e_overestim"],
                                             only_date = True) == label[1]
    elif label[0] == "node":
        return n in label[1]
    else:
        sys.exit(f"is_node_type: bad label: {label}")

def plot_nbunch(G, nbunch, color = '#1f78b4', label = None, ax = None):
    if ax is None: ax = plt.gca()
    pos = G.nodes.data("coordinates")
    nbunch_plot = nbunch.copy()

    for n in nbunch:
        if pos[n] is None:
            print("plot_nbunch: missing coordinates for node ", n)
            nbunch_plot.remove(n)
        
    xy = np.asarray([pos[n] for n in nbunch_plot])
    src_crs = ccrs.PlateCarree()
    ax.scatter(xy[:, 0], xy[:, 1], s = 10, c = color, marker='o',
               transform = src_crs)
    
    for e in G.edges(nbunch_plot):
        if pos[e[0]] and pos[e[1]]:
            arrow = patches.FancyArrowPatch(pos[e[0]], pos[e[1]],
                                            arrowstyle = '-|>', color = color,
                                            mutation_scale = 10,
                                            transform = src_crs)
            ax.add_patch(arrow)

    if label is not None:
        for n in nbunch_plot:
            if is_node_type(G, n, label):
                xy = ax.projection.transform_point(*pos[n], src_crs)
                ax.annotate(str(n), xy, color = color, xytext = (2, 2),
                            textcoords = 'offset points',
                            backgroundcolor = "white", fontsize = "xx-small")

def plot_all_components(G, label):
    for component, color in zip(G.graph["component_list"], color_iter):
        plot_nbunch(G, component, color, label)
    
def plot_descendants(G, n, label):
    nbunch = nx.descendants(G, n) | {n}
    plot_nbunch(G, nbunch, label = label)

def animate_nbunch(G, nbunch):
    sorted_nbunch = sorted(nbunch)
    j = 0
    date = sorted_nbunch[0][0]
    
    while j < len(sorted_nbunch):
        # {sorted_nbunch[j][0] == date}
        j += 1
        while j < len(sorted_nbunch) and sorted_nbunch[j][0] == date: j += 1
        plt.clf()
        plot_nbunch(G, sorted_nbunch[:j], label = date)
        plt.waitforbuttonpress()
        date += 1
    
if __name__ == "__main__":
    import argparse
    from os import path
    import time
    
    parser = argparse.ArgumentParser(description = __doc__)
    parser.add_argument("edgelist", help = "path to input CSV file")
    parser.add_argument("shpc_dir", help = "directory containing SHPC, with "
                        "visible eddies at all dates")
    parser.add_argument("orientation", choices = ["Anticyclones", "Cyclones"])

    # Label group:
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-l", "--label_type",
                        choices = ["root", "leaf", "split", "merge", "all",
                                   "important"])
    group.add_argument("--label_date", type = int, metavar = "DATE_INDEX")
    group.add_argument("--label_node", metavar = 'NODE', type = int,
                       nargs = "+")

    parser.add_argument("-s", "--save", metavar = "FORMAT",
                        help = "Save file to specified format")
    parser.add_argument("-t", "--time", action = "store_true",
                        help = "Report elapsed time")

    # Selection group:
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-n", "--node", help = "Select component containing "
                       "node", type = int)
    group.add_argument("-w", "--window", help = "choose a limited plot window",
                       type = float, nargs = 4,
                       metavar = ("LLLON", "LLLAT", "URLON", "URLAT"))

    args = parser.parse_args()

    if args.label_type:
        label = args.label_type
    elif args.label_date:
        label = ("date", args.label_date)
    elif args.label_node:
        label = ("node", args.label_node)
    else:
        label = None

    plt.figure()

    if args.window is not None:
        if args.window[2] - args.window[0] > 360:
            sys.exit("bad values of urlon and lllon")

        if args.time:
            print("Reading edge list and SHPC...")
            t0 = time.perf_counter()
            
        G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
                                         args.orientation)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            t0 = t1
            print("Finding components...")

        for n, d in G.nodes.items():
            if util_eddies.in_window(d["coordinates"], args.window):
                if "component" not in d: report_graph.add_component(G, n)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            t0 = t1
            print("Plotting...")
            
        dest_crs = ccrs.PlateCarree((args.window[0] + args.window[2]) / 2)
        ax = plt.axes(projection = dest_crs)
        plot_all_components(G, label)
    elif args.node is not None:
        if args.time:
            print("Reading edge list in current directory...")
            t0 = time.perf_counter()
            
        G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
                                         args.orientation)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            t0 = t1
            print("Finding component...")
            
        report_graph.add_component(G, args.node)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            print("Plotting...")
            
        dest_crs = ccrs.PlateCarree(G.nodes[args.node]["coordinates"][0])
        ax = plt.axes(projection = dest_crs)
        plot_nbunch(G, G.nodes[args.node]["component"], label = label)
    else:
        if args.time:
            print("Reading edge lists in current directory...")
            t0 = time.perf_counter()
            
        G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
                                         args.orientation)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            t0 = t1
            print("Finding components...")
            
        assign_all_components(G)

        if args.time:
            t1 = time.perf_counter()
            print("Elapsed time:", t1 - t0, "s")
            print("Plotting...")
            
        ax = plt.axes(projection = ccrs.PlateCarree())
        plot_all_components(G, label)

    ax.coastlines()
    ax.gridlines(draw_labels = True)
    if args.time: print("Elapsed time:", time.perf_counter() - t1, "s")

    if args.save:
        plt.savefig(f"plot_traj.{args.save}")
        print(f'Created "plot_traj.{args.save}".')
    else:
        plt.show()