#!/usr/bin/env python3

"""This script just plots a velocity field. There is not much in it
that is special to surface ocean current coming from AVISO ADT files.

"""

import sys
import argparse

import numpy as np
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import netCDF4

import wind_cartopy

def plot_velocity():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "-s",
        "--scale",
        default=20,
        type=float,
        help="scale of arrows for the velocity field",
    )
    parser.add_argument(
        "-w",
        "--window",
        help="choose a limited plot window",
        type=float,
        nargs=4,
        metavar=("llcrnrlon", "llcrnrlat", "urcrnrlon", "urcrnrlat"),
    )
    parser.add_argument(
        "--save", metavar="format", help="Save file to specified format"
    )
    parser.add_argument(
        "-u",
        "--undefined",
        action="store_true",
        help="plot points where velocity is not defined",
    )
    parser.add_argument("input_file", help="NetCDF file containing velocity")
    args = parser.parse_args()

    with netCDF4.Dataset(args.input_file) as f:
        if "lon" in f.variables:
            lon = "lon"
            lat = "lat"
        else:
            lon = "longitude"
            lat = "latitude"

        longitude = f[lon][:]
        latitude = f[lat][:]

        if "time" in f["ugos"].dimensions:
            ugos = f["ugos"][0]
            vgos = f["vgos"][0]
        else:
            ugos = f["ugos"][:]
            vgos = f["vgos"][:]

    if args.window is None:
        lon_mask = np.ones(len(longitude), dtype=bool)
        lat_mask = np.ones(len(latitude), dtype=bool)
    else:
        llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = args.window

        if urcrnrlon - llcrnrlon > 360:
            sys.exit("bad values of urcrnrlon and llcrnrlon")

        longitude += np.ceil((llcrnrlon - longitude) / 360) * 360
        # (in [llcrnrlon, llcrnrlon + 2 pi[)

        lon_mask = longitude <= urcrnrlon
        lat_mask = np.logical_and(latitude >= llcrnrlat, latitude <= urcrnrlat)

    longitude = longitude[lon_mask]
    latitude = latitude[lat_mask]
    src_crs = ccrs.PlateCarree()

    # Use a conformal projection for quiver:
    projection = ccrs.Stereographic(
        central_latitude=latitude.mean(), central_longitude=longitude.mean()
    )
    ##projection = ccrs.NorthPolarStereo()

    fig = plt.figure()
    ax = plt.axes(projection=projection)

    if args.undefined:
        undef_velocity = np.logical_or(ugos.mask, vgos.mask)
        lon_2d, lat_2d = np.meshgrid(longitude, latitude)
        ax.plot(
            lon_2d[undef_velocity].reshape(-1),
            lat_2d[undef_velocity].reshape(-1),
            transform=src_crs,
            marker="*",
            color="violet",
            linestyle="None",
        )
    else:
        quiver_return = wind_cartopy.plot(
            ax,
            longitude,
            latitude,
            ugos[lat_mask][:, lon_mask],
            vgos[lat_mask][:, lon_mask],
            scale=args.scale,
            scale_units="width",
        )
        ax.quiverkey(
            quiver_return, 0.9, 0.9, 1, r"1 m s$^{-1}$", coordinates="figure"
        )

    ax.gridlines(draw_labels=True)
    ax.coastlines()

    if args.save:
        fig.savefig(f"plot_velocity.{args.save}")
        print(f'Created "plot_velocity.{args.save}".')
    else:
        plt.show()

if __name__ == "__main__":
    plot_velocity()