diff --git a/Analysis/plot_traj.py b/Analysis/plot_traj.py index 927ea1d536c0aa4bc35e742f857afe2d596b67cf..e93e18b666bc5328b54a61c037d1caa3bd4c3083 100755 --- a/Analysis/plot_traj.py +++ b/Analysis/plot_traj.py @@ -8,6 +8,9 @@ directory for interpolated eddies. import matplotlib.pyplot as plt import networkx as nx import itertools +import numpy as np +from matplotlib import patches +import cartopy.crs as ccrs color_iter = itertools.cycle(('#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', @@ -31,25 +34,31 @@ def is_node_type(G, n, label): elif isinstance(label, int): return n[0] == label -def plot_nbunch(G, nbunch, color = None, label = None, reset = True): - nx.draw_networkx(G, G.nodes.data("coordinates"), - edgelist = G.edges(nbunch), nodelist = nbunch, - edge_color=color, node_size=10, node_color=color, - with_labels=False) +def plot_nbunch(G, nbunch, color = '#1f78b4', label = None, ax = None): + if ax is None: ax = plt.gca() + pos = G.nodes.data("coordinates") + xy = np.asarray([pos[n] for n in nbunch]) + src_crs = ccrs.PlateCarree() + ax.scatter(xy[:, 0], xy[:, 1], s = 10, c = color, marker='o', + transform = src_crs) + + for e in G.edges(nbunch): + arrow = patches.FancyArrowPatch(pos[e[0]], pos[e[1]], + arrowstyle = '-|>', color = color, + mutation_scale = 10, + transform = src_crs) + ax.add_patch(arrow) if label is not None: for n in nbunch: if is_node_type(G, n, label): - plt.annotate(str(n), G.nodes[n]["coordinates"], color = color, - xytext = (2, 2), textcoords = 'offset points') - - if reset: plt.tick_params(reset = True) + xy = ccrs.PlateCarree().transform_point(*pos[n], src_crs) + ax.annotate(str(n), xy, color = color, xytext = (2, 2), + textcoords = 'offset points') def plot_all_traj(G, label): for component, color in zip(nx.weakly_connected_components(G), color_iter): - plot_nbunch(G, component, color, label, reset = False) - - plt.tick_params(reset = True) + plot_nbunch(G, component, color, label) def plot_descendants(G, n, label): nbunch = nx.descendants(G, n) | {n} @@ -91,6 +100,7 @@ if __name__ == "__main__": print("Reading edgelist.csv in current directory...") G = report_graph.read_eddy_graph("edgelist.csv", args.shp_tr_dir) plt.figure() + ax = plt.axes(projection=ccrs.PlateCarree()) if args.window is not None: if args.window[2] - args.window[0] > 360: @@ -111,4 +121,6 @@ if __name__ == "__main__": else: plot_all_traj(G, args.label) + ax.coastlines() + ax.gridlines(draw_labels=True) plt.show()