#!/usr/bin/env python3

from os import path

import shapefile
import numpy as np

import jumble_matplotlib


def read(dirname):
    """Read the three dbf files in dirname and return speed, radius and
    amplitude of outermost contour, radius and amplitude of maximum
    speed contour, as Numpy arrays.

    Select valid speed values.

    """
    extr_file = path.join(dirname, "extremum")
    outer_file = path.join(dirname, "outermost_contour")
    max_speed_file = path.join(dirname, "max_speed_contour")

    with shapefile.Reader(extr_file) as extremum, shapefile.Reader(
        outer_file
    ) as outerm_cont, shapefile.Reader(max_speed_file) as max_speed_cont:
        speed = []
        rad_outer = []
        rad_speed = []
        amp_outer = []
        amp_speed = []

        for rec_extr, rec_outer, rec_max in zip(
            extremum.iterRecords(),
            outerm_cont.iterRecords(),
            max_speed_cont.iterRecords(),
        ):
            if rec_extr.speed != 1e4:
                speed.append(rec_extr.speed)

            rad_outer.append(rec_outer.r_eq_area)
            amp_outer.append(rec_extr.ssh - rec_outer.ssh)

            if rec_max.r_eq_area != -100:
                rad_speed.append(rec_max.r_eq_area)
                amp_speed.append(rec_extr.ssh - rec_max.ssh)

    return {
        "speed": np.array(speed),
        "rad_outer": np.array(rad_outer),
        "rad_speed": np.array(rad_speed),
        "amp_outer": np.abs(np.array(amp_outer)),
        "amp_speed": np.abs(np.array(amp_speed)),
    }


def plot_all(dict_list, label=None):
    """dict_list: list of dictionaries. label: list of labels, one label
    for each dictionary.

    """

    fig_list = []
    fig = jumble_matplotlib.fig_hist(
        "speed, in m s-1", [d["speed"] for d in dict_list], label
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "speed, when positive, in m s-1",
        [d["speed"][d["speed"] >= 0] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "speed, when negative, in m s-1",
        [-d["speed"][d["speed"] < 0] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_hist(
        "equivalent radius of outermost contour, in km",
        [d["rad_outer"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "equivalent radius of outermost contour, in km",
        [d["rad_outer"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_hist(
        "equivalent radius of max-speed contour, in km",
        [d["rad_speed"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "equivalent radius of max-speed contour, in km",
        [d["rad_speed"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_hist(
        "amplitude of outermost contour, in m",
        [d["amp_outer"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "amplitude of outermost contour, in m",
        [d["amp_outer"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_hist(
        "amplitude of max-speed contour, in m",
        [d["amp_speed"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    fig = jumble_matplotlib.fig_distr_funct(
        "amplitude of max-speed contour, in m",
        [d["amp_speed"] for d in dict_list],
        label,
    )
    fig_list.append(fig)
    return fig_list


if __name__ == "__main__":
    import argparse

    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dir", help="directory containing collection of shapefiles"
    )
    parser.add_argument("--save", action="store_true")
    args = parser.parse_args()
    d = read(args.dir)
    fig_list = plot_all([d])

    if args.save:
        for i, fig in enumerate(fig_list):
            fig.savefig(f"figure_{i}.png")

        print(f"Created figure_[0-{len(fig_list)-1}].png")
    else:
        plt.show()