#!/usr/bin/env python3

"""Plots outermost contours and max-speed contours.

Red for anti-cyclones, blue for cyclones. 

Without the --light option. Empty circles for extrema
with a valid outermost contour, empty squares for other
extrema. Squares on outermost contour for a well-defined but invalid
outermost contour. Crosses on outermost contour for a valid outermost
contour but with no distinct max-speed contour. Filled circles on
outermost contour and max-speed contour for a valid outermost contour
with a distinct max-speed contour.

With the --light option. Crosses for extrema with a valid outermost
contour.

This script takes about 30 s of CPU for a 90° by 90° window, about 4
mn for a 180° by 180° window.

"""

import shapefile
import numpy as np
import cartopy.crs as ccrs
from os import path
from matplotlib import animation, patches
import f90nml
import sys
sys.path.append(path.join(sys.path[0], "../../Common"))
import util_detect_eddies

def select_ishapes(d, d_init, ishape_last, reader_extr, window = None):
    """Select ishapes at date d and in window.

    ax should be an instance of GeoAxes."""
    
    assert d >= d_init
    
    if d == d_init:
        ishape_first = 0
    else:
        # d > d_init
        ishape_first = ishape_last[d - d_init - 1] + 1

    if window is None:
        ishape_list = list(range(ishape_first, ishape_last[d - d_init] + 1))
    else:
        ishape_list = []

        for ishape in range(ishape_first, ishape_last[d - d_init] + 1):
            shape_rec_extr = reader_extr.shapeRecord(ishape)
            points = shape_rec_extr.shape.points[0]

            if util_detect_eddies.in_window(points, window):
                ishape_list.append(ishape)

    return ishape_list

def snapshot(ax, ishape_list, readers, *, dashed = False, light = False,
             src_crs = ccrs.PlateCarree()):
    """Plots extrema, outermost contours and max-speed contours.

    dashed: boolean."""
    
    for ishape in ishape_list:
        shape_rec_extr = readers["extremum"].shapeRecord(ishape)
        shape_outer = readers["outermost_contour"].shape(ishape)

        if "max_speed_contour" in readers:
            shape_m_s = readers["max_speed_contour"].shape(ishape)
        else:
            shape_m_s = None

        points = shape_rec_extr.shape.points[0]

        try:
            if shape_rec_extr.record.cyclone == 0:
                    # Anti-cyclone
                color = "red"
            else:
                color = "blue"
        except AttributeError:
            color = "green"

        lines = ax.plot(points[0], points[1], markersize = 10,
                        color = color, fillstyle = "none",
                        transform = src_crs)

        if not hasattr(shape_rec_extr.record, "valid")  \
           or shape_rec_extr.record.valid == 1:
            if light:
                lines[0].set_marker("+")
            else:
                lines[0].set_marker("o")
        elif not light:
            # Invalid outermost contour
            lines[0].set_marker("s")

        if not light:
            ax.annotate(str(shape_rec_extr.record.eddy_index),
                        ax.projection.transform_point(points[0], points[1],
                                                      src_crs),
                        xytext = (3, 3), textcoords = "offset points")

        if shape_outer.shapeType != shapefile.NULL and not light \
           or shape_rec_extr.record.valid:
            points = np.array(shape_outer.points)
            lines = ax.plot(points[:, 0], points[:, 1], color = color,
                            transform = src_crs)

            if not light:
                if hasattr(shape_rec_extr.record, "valid")  \
                   and shape_rec_extr.record.valid == 0:
                    # Invalid outermost contour
                    lines[0].set_marker("s")
                    lines[0].set_fillstyle("none")
                elif shape_m_s == None \
                     or shape_m_s.shapeType == shapefile.NULL:
                    lines[0].set_marker("x")
                else:
                    lines[0].set_marker("o")

            if dashed: lines[0].set_linestyle("dashed")

            if shape_m_s != None and shape_m_s.shapeType != shapefile.NULL:
                points = np.array(shape_m_s.points)

                try:
                    if shape_rec_extr.record.cyclone == 0:
                        # Anti-cyclone
                        color = "magenta"
                    else:
                        color = "cyan"
                except AttributeError:
                    color = "green"

                lines = ax.plot(points[:, 0], points[:, 1], color = color,
                                transform = src_crs)
                if not light: lines[0].set_marker("o")
                if dashed: lines[0].set_linestyle("dashed")

def plot_grid_bb(shpc_dir, ax):
    """Grid bounding box."""
    
    file = path.join(shpc_dir, "grid_nml.txt")
    try:
        grid_nml = f90nml.read(file)["grid_nml"]
    except FileNotFoundError:
        print("grid_nml.txt not found. Will not plot bounding box.")
    else:
        step = grid_nml["STEP_DEG"]
        rect = patches.Rectangle(grid_nml["corner_deg"],
                                 (grid_nml["nlon"] - 1) * step[0],
                                 (grid_nml["nlat"] - 1) * step[1],
                                 edgecolor="black", fill=False)
        ax.add_patch(rect)

def func(d, ax, ishape_lists, readers, bbox, light):
    """To be passed as argument to the animation function."""
    
    ax.cla()
    ax.set_xlim(bbox[0], bbox[2])
    ax.set_ylim(bbox[1], bbox[3])
    snapshot(ax, ishape_lists[d], readers, light = light)
    ax.gridlines(draw_labels = True) 
    ax.coastlines()
    ax.set_title(f"d = {d}", y = 1.05)

def bbox_union(b1, b2):
    if b1 is None:
        return b2
    else:
        return [min(b1[0], b2[0]), min(b1[1], b2[1]), max(b1[2], b2[2]),
                max(b1[3], b2[3])]

def compute_bbox(ishape_lists, reader_outer):
    bbox = None

    for ishape_list in ishape_lists.values():
        for ishape in ishape_list:
            try:
                bbox = bbox_union(bbox, reader_outer.shape(ishape).bbox)
            except AttributeError:
                pass

    return bbox

def make_animation(fig, ax, d_init, ishape_last, readers, window, d_min, d_max,
                   light):
    ishape_lists = {d: select_ishapes(d, d_init, ishape_last,
                                      readers["extremum"], window)
                    for d in range(d_min, d_max + 1)}

    if window is None:
        bbox = compute_bbox(ishape_lists, readers["outermost_contour"])
    else:
        bbox = window

    ani = animation.FuncAnimation(fig, func, range(d_min, d_max + 1),
                                  fargs = (ax, ishape_lists, readers, bbox,
                                           light),
                                  interval = 500)
    return ani

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import argparse
    import netCDF4
    import pygifsicle

    parser = argparse.ArgumentParser()
    parser.add_argument("-v", "--velocity", help = "plot velocity field",
                        action = "store_true")
    parser.add_argument("-s", "--scale", default = 20, type = float,
                        help = "scale of arrows for the velocity field")
    parser.add_argument("-d", "--date", type = int,
                        help = "date, as days since 1950-1-1")
    parser.add_argument("-g", "--grid", help = "plot grid",
                        action = "store_true")
    parser.add_argument("-w", "--window", help = "choose a limited plot window",
                        type = float, nargs = 4,
                        metavar = ("llcrnrlon", "llcrnrlat", "urcrnrlon",
                                   "urcrnrlat"))
    parser.add_argument("-l", "--light", help = "lighter plot",
                        action = "store_true")
    parser.add_argument("--dashed", action = "store_true",
                        help = "dashed linestyle, useful for a second snapshot")
    parser.add_argument("-a", "--anim", type = int, nargs = 2,
                        metavar = ("d_min", "d_max"), help = "make animation")
    parser.add_argument("shpc_dir", help = "directory containing the "
                        "collection of shapefiles")
    parser.add_argument("--save", metavar = "format",
                        help = "Save file to specified format")
    args = parser.parse_args()

    if args.grid or args.velocity:
        with netCDF4.Dataset("h.nc") as f:
            if "lon" in f.variables:
                lon = "lon"
                lat = "lat"
            else:
                lon = "longitude"
                lat = "latitude"
                
            longitude = f[lon][:]
            latitude = f[lat][:]
    
    if args.window is not None:
        llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = args.window
        
        if urcrnrlon - llcrnrlon > 360:
            sys.exit("bad values of urcrnrlon and llcrnrlon")

        if args.grid or args.velocity:
            longitude += np.ceil((llcrnrlon - longitude) / 360) * 360
            # (in [llcrnrlon, llcrnrlon + 2 pi[)

            lon_mask = longitude <= urcrnrlon
            lat_mask = np.logical_and(latitude >= llcrnrlat,
                                      latitude <= urcrnrlat)
    elif args.grid or args.velocity:
        lon_mask = np.ones(len(longitude), dtype = bool)
        lat_mask = np.ones(len(latitude), dtype = bool)

    readers, d_init, ishape_last \
        = util_detect_eddies.open_shpc(args.shpc_dir)

    if ishape_last is None:
        print("We will use all the shapes.")
        ishape_last = [len(readers["extremum"]) - 1]
        
    fig = plt.figure()
    src_crs = ccrs.PlateCarree()
    projection = ccrs.PlateCarree()
    ##projection = ccrs.NorthPolarStereo()
    ax = plt.axes(projection = projection)

    if args.anim is None:
        if len(ishape_last) == 1:
            if args.date is not None and args.date != d_init:
                sys.exit("Bad value of option k")
            d = d_init
        else:
            if args.date is None:
                print("No option k, plotting first date:", d_init)
                d = d_init
            else:
                d = args.date

        if args.grid:
            lon_2d, lat_2d = np.meshgrid(longitude[lon_mask],
                                         latitude[lat_mask])
            ax.plot(lon_2d.reshape(-1), lat_2d.reshape(-1), transform = src_crs,
                    marker = "+", color = "gray", linestyle = "None")

        if args.window is None: plot_grid_bb(args.shpc_dir, ax)
        ishape_list = select_ishapes(d, d_init, ishape_last,
                                     readers["extremum"], args.window)
        if len(ishape_list) == 0: print("No eddy found")
        ax.set_title(f"d = {d}", y = 1.05)
        snapshot(ax, ishape_list, readers, dashed = args.dashed,
                 light = args.light)
        ax.gridlines(draw_labels = True)
        ax.coastlines()
        
        if args.velocity:
            with netCDF4.Dataset("uv.nc") as f:
                quiver_return = ax.quiver(longitude[lon_mask],
                                          latitude[lat_mask],
                                          f["ugos"][0, lat_mask][:, lon_mask],
                                          f["vgos"][0, lat_mask][:, lon_mask],
                                          scale = args.scale,
                                          scale_units = "width",
                                          transform = src_crs)
            plt.quiverkey(quiver_return, 0.9, 0.9, 1, r"1 m s$^{-1}$",
                          coordinates = "figure")

        if args.save:
            plt.savefig(f"plot_eddy_contours.{args.save}")
            print(f'Created "plot_eddy_contours.{args.save}".')
        else:
            plt.show()
    else:
        if not d_init <= args.anim[0] < args.anim[1] <= d_init \
           + len(ishape_last) - 1:
            sys.exit("Bad dates specified in option anim")
            
        ani = make_animation(fig, ax, d_init, ishape_last, readers,
                             window = args.window, d_min = args.anim[0],
                             d_max = args.anim[1], light = args.light)
        ani.save("eddy_contours.gif", writer  = "imagemagick")

        pygifsicle.gifsicle("eddy_contours.gif",  options = ["--no-loopcount"])
        # (The repeat = False option of the save method of
        # FuncAnimation does not work.)

        print('Created file "eddy_contours.gif".')