-
Lionel GUEZ authoredLionel GUEZ authored
plot_components.py 8.11 KiB
#!/usr/bin/env python3
import itertools
import sys
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib import patches
import cartopy.crs as ccrs
import report_graph
import util_eddies
color_iter = itertools.cycle(('#1f77b4', '#aec7e8', '#ff7f0e',
'#ffbb78', '#2ca02c', '#98df8a',
'#d62728', '#ff9896', '#9467bd',
'#c5b0d5', '#8c564b', '#c49c94',
'#e377c2', '#f7b6d2', '#7f7f7f',
'#c7c7c7', '#bcbd22', '#dbdb8d',
'#17becf', '#9edae5'))
def assign_all_components(G):
G.graph["component_list"] = list(nx.weakly_connected_components(G))
for component in G.graph["component_list"]:
for n in component: G.nodes[n]["component"] = component
def is_node_type(G, n, label):
if label == "root":
return G.in_degree[n] == 0
elif label == "leaf":
return G.out_degree[n] == 0
elif label == "merge":
return G.in_degree[n] >= 2
elif label == "split":
return G.out_degree[n] >= 2
elif label == "all":
return True
elif label == "important":
return G.in_degree[n] == 0 or G.out_degree[n] == 0 or G.degree[n] >= 3
elif label[0] == "date":
return util_eddies.node_to_date_eddy(n, G.graph["e_overestim"],
only_date = True) == label[1]
elif label[0] == "node":
return n in label[1]
else:
sys.exit(f"is_node_type: bad label: {label}")
def plot_nbunch(G, nbunch, color = '#1f78b4', label = None, ax = None):
if ax is None: ax = plt.gca()
pos = G.nodes.data("coordinates")
nbunch_plot = nbunch.copy()
for n in nbunch:
if pos[n] is None:
print("plot_nbunch: missing coordinates for node ", n)
nbunch_plot.remove(n)
xy = np.asarray([pos[n] for n in nbunch_plot])
src_crs = ccrs.PlateCarree()
ax.scatter(xy[:, 0], xy[:, 1], s = 10, c = color, marker='o',
transform = src_crs)
for e in G.edges(nbunch_plot):
if pos[e[0]] and pos[e[1]]:
arrow = patches.FancyArrowPatch(pos[e[0]], pos[e[1]],
arrowstyle = '-|>', color = color,
mutation_scale = 10,
transform = src_crs)
ax.add_patch(arrow)
if label is not None:
for n in nbunch_plot:
if is_node_type(G, n, label):
xy = ax.projection.transform_point(*pos[n], src_crs)
ax.annotate(str(n), xy, color = color, xytext = (2, 2),
textcoords = 'offset points',
backgroundcolor = "white", fontsize = "xx-small")
def plot_all_components(G, label):
for component, color in zip(G.graph["component_list"], color_iter):
plot_nbunch(G, component, color, label)
def plot_descendants(G, n, label):
nbunch = nx.descendants(G, n) | {n}
plot_nbunch(G, nbunch, label = label)
def animate_nbunch(G, nbunch):
sorted_nbunch = sorted(nbunch)
j = 0
date = sorted_nbunch[0][0]
while j < len(sorted_nbunch):
# {sorted_nbunch[j][0] == date}
j += 1
while j < len(sorted_nbunch) and sorted_nbunch[j][0] == date: j += 1
plt.clf()
plot_nbunch(G, sorted_nbunch[:j], label = date)
plt.waitforbuttonpress()
date += 1
if __name__ == "__main__":
import argparse
from os import path
import time
parser = argparse.ArgumentParser(description = __doc__)
parser.add_argument("edgelist", help = "path to input CSV file")
parser.add_argument("shpc_dir", help = "directory containing SHPC, with "
"visible eddies at all dates")
parser.add_argument("orientation", choices = ["Anticyclones", "Cyclones"])
# Label group:
group = parser.add_mutually_exclusive_group()
group.add_argument("-l", "--label_type",
choices = ["root", "leaf", "split", "merge", "all",
"important"])
group.add_argument("--label_date", type = int, metavar = "DATE_INDEX")
group.add_argument("--label_node", metavar = 'NODE', type = int,
nargs = "+")
parser.add_argument("-s", "--save", metavar = "FORMAT",
help = "Save file to specified format")
parser.add_argument("-t", "--time", action = "store_true",
help = "Report elapsed time")
# Selection group:
group = parser.add_mutually_exclusive_group()
group.add_argument("-n", "--node", help = "Select component containing "
"node", type = int)
group.add_argument("-w", "--window", help = "choose a limited plot window",
type = float, nargs = 4,
metavar = ("LLLON", "LLLAT", "URLON", "URLAT"))
args = parser.parse_args()
if args.label_type:
label = args.label_type
elif args.label_date:
label = ("date", args.label_date)
elif args.label_node:
label = ("node", args.label_node)
else:
label = None
plt.figure()
if args.window is not None:
if args.window[2] - args.window[0] > 360:
sys.exit("bad values of urlon and lllon")
if args.time:
print("Reading edge list and SHPC...")
t0 = time.perf_counter()
G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
args.orientation)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
t0 = t1
print("Finding components...")
for n, d in G.nodes.items():
if util_eddies.in_window(d["coordinates"], args.window):
if "component" not in d: report_graph.add_component(G, n)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
t0 = t1
print("Plotting...")
dest_crs = ccrs.PlateCarree((args.window[0] + args.window[2]) / 2)
ax = plt.axes(projection = dest_crs)
plot_all_components(G, label)
elif args.node is not None:
if args.time:
print("Reading edge list in current directory...")
t0 = time.perf_counter()
G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
args.orientation)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
t0 = t1
print("Finding component...")
report_graph.add_component(G, args.node)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
print("Plotting...")
dest_crs = ccrs.PlateCarree(G.nodes[args.node]["coordinates"][0])
ax = plt.axes(projection = dest_crs)
plot_nbunch(G, G.nodes[args.node]["component"], label = label)
else:
if args.time:
print("Reading edge lists in current directory...")
t0 = time.perf_counter()
G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
args.orientation)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
t0 = t1
print("Finding components...")
assign_all_components(G)
if args.time:
t1 = time.perf_counter()
print("Elapsed time:", t1 - t0, "s")
print("Plotting...")
ax = plt.axes(projection = ccrs.PlateCarree())
plot_all_components(G, label)
ax.coastlines()
ax.gridlines(draw_labels = True)
if args.time: print("Elapsed time:", time.perf_counter() - t1, "s")
if args.save:
plt.savefig(f"plot_traj.{args.save}")
print(f'Created "plot_traj.{args.save}".')
else:
plt.show()