From 159f8cdb5547d4cfc454725d4cab1bfc0be467bf Mon Sep 17 00:00:00 2001
From: Lionel GUEZ <guez@lmd.ens.fr>
Date: Wed, 21 Feb 2024 20:25:52 +0100
Subject: [PATCH] Blacken

---
 Overlap/Analysis/plot_components.py | 158 ++++++++++++++++++----------
 1 file changed, 105 insertions(+), 53 deletions(-)

diff --git a/Overlap/Analysis/plot_components.py b/Overlap/Analysis/plot_components.py
index 7cdc1185..ce680efa 100755
--- a/Overlap/Analysis/plot_components.py
+++ b/Overlap/Analysis/plot_components.py
@@ -13,19 +13,39 @@ import report_graph
 import util_eddies
 import extract_component_nx
 
-color_iter = itertools.cycle(('#1f77b4', '#aec7e8', '#ff7f0e',
-                              '#ffbb78', '#2ca02c', '#98df8a',
-                              '#d62728', '#ff9896', '#9467bd',
-                              '#c5b0d5', '#8c564b', '#c49c94',
-                              '#e377c2', '#f7b6d2', '#7f7f7f',
-                              '#c7c7c7', '#bcbd22', '#dbdb8d',
-                              '#17becf', '#9edae5'))
+color_iter = itertools.cycle(
+    (
+        "#1f77b4",
+        "#aec7e8",
+        "#ff7f0e",
+        "#ffbb78",
+        "#2ca02c",
+        "#98df8a",
+        "#d62728",
+        "#ff9896",
+        "#9467bd",
+        "#c5b0d5",
+        "#8c564b",
+        "#c49c94",
+        "#e377c2",
+        "#f7b6d2",
+        "#7f7f7f",
+        "#c7c7c7",
+        "#bcbd22",
+        "#dbdb8d",
+        "#17becf",
+        "#9edae5",
+    )
+)
+
 
 def assign_all_components(G):
     G.graph["component_list"] = list(nx.weakly_connected_components(G))
 
     for component in G.graph["component_list"]:
-        for n in component: G.nodes[n]["component"] = component
+        for n in component:
+            G.nodes[n]["component"] = component
+
 
 def is_node_type(G, n, label):
     if label == "root":
@@ -41,20 +61,26 @@ def is_node_type(G, n, label):
     elif label == "important":
         return G.in_degree[n] == 0 or G.out_degree[n] == 0 or G.degree[n] >= 3
     elif label[0] == "date":
-        return util_eddies.node_to_date_eddy(n, G.graph["e_overestim"],
-                                             only_date = True) == label[1]
+        return (
+            util_eddies.node_to_date_eddy(
+                n, G.graph["e_overestim"], only_date=True
+            )
+            == label[1]
+        )
     elif label[0] == "node":
         return n in label[1]
     else:
         sys.exit(f"is_node_type: bad label: {label}")
 
-def plot_nbunch(G, nbunch, color = '#1f78b4', label = None, ax = None):
+
+def plot_nbunch(G, nbunch, color="#1f78b4", label=None, ax=None):
     """This function plots all nodes in nbunch and all edges with source
     or target in nbunch.
 
     """
 
-    if ax is None: ax = plt.gca()
+    if ax is None:
+        ax = plt.gca()
     pos = G.nodes.data("coordinates")
     nbunch_plot = nbunch.copy()
 
@@ -65,32 +91,44 @@ def plot_nbunch(G, nbunch, color = '#1f78b4', label = None, ax = None):
 
     xy = np.asarray([pos[n] for n in nbunch_plot])
     src_crs = ccrs.PlateCarree()
-    ax.scatter(xy[:, 0], xy[:, 1], s = 10, c = color, marker = 'o',
-               transform = src_crs)
+    ax.scatter(xy[:, 0], xy[:, 1], s=10, c=color, marker="o", transform=src_crs)
 
     for e in G.edges(nbunch_plot):
         if pos[e[0]] and pos[e[1]]:
-            arrow = patches.FancyArrowPatch(pos[e[0]], pos[e[1]],
-                                            arrowstyle = '-|>', color = color,
-                                            mutation_scale = 10,
-                                            transform = src_crs)
+            arrow = patches.FancyArrowPatch(
+                pos[e[0]],
+                pos[e[1]],
+                arrowstyle="-|>",
+                color=color,
+                mutation_scale=10,
+                transform=src_crs,
+            )
             ax.add_patch(arrow)
 
     if label is not None:
         for n in nbunch_plot:
             if is_node_type(G, n, label):
                 xy = ax.projection.transform_point(*pos[n], src_crs)
-                ax.annotate(str(n), xy, color = color, xytext = (2, 2),
-                            textcoords = 'offset points',
-                            backgroundcolor = "white", fontsize = "xx-small")
+                ax.annotate(
+                    str(n),
+                    xy,
+                    color=color,
+                    xytext=(2, 2),
+                    textcoords="offset points",
+                    backgroundcolor="white",
+                    fontsize="xx-small",
+                )
+
 
 def plot_all_components(G, label):
     for component, color in zip(G.graph["component_list"], color_iter):
         plot_nbunch(G, component, color, label)
 
+
 def plot_descendants(G, n, label):
     nbunch = nx.descendants(G, n) | {n}
-    plot_nbunch(G, nbunch, label = label)
+    plot_nbunch(G, nbunch, label=label)
+
 
 def animate_nbunch(G, nbunch):
     sorted_nbunch = sorted(nbunch)
@@ -100,43 +138,54 @@ def animate_nbunch(G, nbunch):
     while j < len(sorted_nbunch):
         # {sorted_nbunch[j][0] == date}
         j += 1
-        while j < len(sorted_nbunch) and sorted_nbunch[j][0] == date: j += 1
+        while j < len(sorted_nbunch) and sorted_nbunch[j][0] == date:
+            j += 1
         plt.clf()
-        plot_nbunch(G, sorted_nbunch[:j], label = date)
+        plot_nbunch(G, sorted_nbunch[:j], label=date)
         plt.waitforbuttonpress()
         date += 1
 
+
 if __name__ == "__main__":
     import argparse
     from os import path
     import time
 
-    parser = argparse.ArgumentParser(description = __doc__)
-    parser.add_argument("edgelist", help = "path to input CSV file")
-    parser.add_argument("shpc_dir", help = "directory containing SHPC")
-    parser.add_argument("orientation", choices = ["Anticyclones", "Cyclones"])
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument("edgelist", help="path to input CSV file")
+    parser.add_argument("shpc_dir", help="directory containing SHPC")
+    parser.add_argument("orientation", choices=["Anticyclones", "Cyclones"])
 
     # Label group:
     group = parser.add_mutually_exclusive_group()
-    group.add_argument("-l", "--label_type",
-                        choices = ["root", "leaf", "split", "merge", "all",
-                                   "important"])
-    group.add_argument("--label_date", type = int, metavar = "DATE_INDEX")
-    group.add_argument("--label_node", metavar = 'NODE', type = int,
-                       nargs = "+")
-
-    parser.add_argument("-s", "--save", metavar = "FORMAT",
-                        help = "Save file to specified format")
-    parser.add_argument("-t", "--time", action = "store_true",
-                        help = "Report elapsed time")
+    group.add_argument(
+        "-l",
+        "--label_type",
+        choices=["root", "leaf", "split", "merge", "all", "important"],
+    )
+    group.add_argument("--label_date", type=int, metavar="DATE_INDEX")
+    group.add_argument("--label_node", metavar="NODE", type=int, nargs="+")
+
+    parser.add_argument(
+        "-s", "--save", metavar="FORMAT", help="Save file to specified format"
+    )
+    parser.add_argument(
+        "-t", "--time", action="store_true", help="Report elapsed time"
+    )
 
     # Selection group:
     group = parser.add_mutually_exclusive_group()
-    group.add_argument("-n", "--node", help = "Select component containing "
-                       "node", type = int)
-    group.add_argument("-w", "--window", help = "choose a limited plot window",
-                       type = float, nargs = 4,
-                       metavar = ("LLLON", "LLLAT", "URLON", "URLAT"))
+    group.add_argument(
+        "-n", "--node", help="Select component containing " "node", type=int
+    )
+    group.add_argument(
+        "-w",
+        "--window",
+        help="choose a limited plot window",
+        type=float,
+        nargs=4,
+        metavar=("LLLON", "LLLAT", "URLON", "URLAT"),
+    )
 
     args = parser.parse_args()
 
@@ -159,8 +208,9 @@ if __name__ == "__main__":
         print("Reading edge list and SHPC...")
         t0 = time.perf_counter()
 
-    G = report_graph.read_eddy_graph(args.edgelist, args.shpc_dir,
-                                     args.orientation)
+    G = report_graph.read_eddy_graph(
+        args.edgelist, args.shpc_dir, args.orientation
+    )
 
     if args.time:
         t1 = time.perf_counter()
@@ -171,7 +221,8 @@ if __name__ == "__main__":
     if args.window is not None:
         for n, d in G.nodes.items():
             if util_eddies.in_window(d["coordinates"], args.window):
-                if "component" not in d: extract_component_nx.add_component(G, n)
+                if "component" not in d:
+                    extract_component_nx.add_component(G, n)
 
         if args.time:
             t1 = time.perf_counter()
@@ -180,7 +231,7 @@ if __name__ == "__main__":
             print("Plotting...")
 
         dest_crs = ccrs.PlateCarree((args.window[0] + args.window[2]) / 2)
-        ax = plt.axes(projection = dest_crs)
+        ax = plt.axes(projection=dest_crs)
         plot_all_components(G, label)
     elif args.node is not None:
         extract_component_nx.add_component(G, args.node)
@@ -191,8 +242,8 @@ if __name__ == "__main__":
             print("Plotting...")
 
         dest_crs = ccrs.PlateCarree(G.nodes[args.node]["coordinates"][0])
-        ax = plt.axes(projection = dest_crs)
-        plot_nbunch(G, G.nodes[args.node]["component"], label = label)
+        ax = plt.axes(projection=dest_crs)
+        plot_nbunch(G, G.nodes[args.node]["component"], label=label)
     else:
         assign_all_components(G)
 
@@ -201,12 +252,13 @@ if __name__ == "__main__":
             print("Elapsed time:", t1 - t0, "s")
             print("Plotting...")
 
-        ax = plt.axes(projection = ccrs.PlateCarree())
+        ax = plt.axes(projection=ccrs.PlateCarree())
         plot_all_components(G, label)
 
     ax.coastlines()
-    ax.gridlines(draw_labels = True)
-    if args.time: print("Elapsed time:", time.perf_counter() - t1, "s")
+    ax.gridlines(draw_labels=True)
+    if args.time:
+        print("Elapsed time:", time.perf_counter() - t1, "s")
 
     if args.save:
         plt.savefig(f"plot_components.{args.save}")
-- 
GitLab