Skip to content
Snippets Groups Projects
plot_velocity.py 3.06 KiB
#!/usr/bin/env python3

import numpy as np
import cartopy.crs as ccrs
import sys
import matplotlib.pyplot as plt
import argparse
import netCDF4

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--scale", default = 20, type = float,
                    help = "scale of arrows for the velocity field")
parser.add_argument("-w", "--window", help = "choose a limited plot window",
                    type = float, nargs = 4,
                    metavar = ("llcrnrlon", "llcrnrlat", "urcrnrlon",
                               "urcrnrlat"))
parser.add_argument("--save", metavar = "format",
                    help = "Save file to specified format")
parser.add_argument("-u", "--undefined", action = "store_true",
                    help = "plot points where velocity is not defined")
parser.add_argument("input_file", help = "NetCDF file containing velocity")
args = parser.parse_args()

with netCDF4.Dataset(args.input_file) as f:
    if "lon" in f.variables:
        lon = "lon"
        lat = "lat"
    else:
        lon = "longitude"
        lat = "latitude"

    longitude = f[lon][:]
    latitude = f[lat][:]

if args.window is None:
    lon_mask = np.ones(len(longitude), dtype = bool)
    lat_mask = np.ones(len(latitude), dtype = bool)
else:
    llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = args.window

    if urcrnrlon - llcrnrlon > 360:
        sys.exit("bad values of urcrnrlon and llcrnrlon")

    longitude += np.ceil((llcrnrlon - longitude) / 360) * 360
    # (in [llcrnrlon, llcrnrlon + 2 pi[)

    lon_mask = longitude <= urcrnrlon
    lat_mask = np.logical_and(latitude >= llcrnrlat,
                              latitude <= urcrnrlat)

fig = plt.figure()
src_crs = ccrs.PlateCarree()
projection = ccrs.PlateCarree()
##projection = ccrs.NorthPolarStereo()
ax = plt.axes(projection = projection)

with netCDF4.Dataset(args.input_file) as f:
    if "time" in f["ugos"].dimensions:
        ugos = f["ugos"][0]
        vgos = f["vgos"][0]
    else:
        ugos = f["ugos"][:]
        vgos = f["vgos"][:]

    if args.undefined:
        undef_velocity = np.logical_or(ugos.mask, vgos.mask)
        lon_2d, lat_2d = np.meshgrid(longitude[lon_mask],
                                     latitude[lat_mask])
        ax.plot(lon_2d[undef_velocity].reshape(-1),
                lat_2d[undef_velocity].reshape(-1), transform = src_crs,
                marker = "*", color = "violet", linestyle = "None")
    else:
        quiver_return = ax.quiver(longitude[lon_mask],
                                  latitude[lat_mask],
                                  ugos[lat_mask][:, lon_mask],
                                  vgos[lat_mask][:, lon_mask],
                                  scale = args.scale,
                                  scale_units = "width",
                                  transform = src_crs)
        plt.quiverkey(quiver_return, 0.9, 0.9, 1, r"1 m s$^{-1}$",
                      coordinates = "figure")

ax.gridlines(draw_labels = True)
ax.coastlines()

if args.save:
    plt.savefig(f"plot_velocity.{args.save}")
    print(f'Created "plot_velocity.{args.save}".')
else:
    plt.show()