#!/usr/bin/env python3

"""This script takes the graph of segments without cost and computes
the cost applied to edges.

Input:

-- the graph of segments without cost;

-- the SHPC.

Output: the graph of segments with cost.

The inst_eddies property of vertices is not modified by this
script. All the content of the input graph is part of the output graph
so the input file may be removed, if desired, after running this
script.

"""

import time
import math
import argparse

import graph_tool

import util_eddies

Omega = 2 * math.pi / 86164.0  # in s-1
r_Earth = 6371  # radius of the Earth, in km


def calculate_radii_rossby(properties):
    """Compute average on some instantaneous eddies of Rossby number and
    radius of maximum speed contour. The required properties for each
    eddy are position, radius and speed. "properties" is a list of
    dictionaries. Each dictionary in the list contains the three
    properties. If the speed is not defined for any eddy then the
    returned value of avg_Rossby is 0.

    """

    avg_rad = 0  # in m
    avg_Rossby = 0
    n_valid_Rossby = 0

    for prop in properties:
        f = 2 * Omega * math.sin(prop["pos"][1])  # in s-1
        radius = prop["radius"] * 1000  # in m

        if abs(prop["speed"]) < 100:
            avg_Rossby += prop["speed"] / (f * radius)
            n_valid_Rossby += 1

        avg_rad += radius  # in m

    avg_rad /= len(properties)

    if n_valid_Rossby != 0:
        avg_Rossby /= n_valid_Rossby

    return avg_rad, avg_Rossby


def node_to_prop(node_list, e_overestim, SHPC, orientation):
    """node_list is a list of node identification numbers for
    instantaneous eddies. This function returns some properties of the
    eddies, read from shapefiles: position of extremum, radius of
    outermost contour or maximum speed contour, and speed. The three
    properties are in a dictionary, for each eddy. So the function
    returns a list of dictionaries: one dictionary for each element of
    node_list.

    """

    properties = []

    for n in node_list:
        date_index, eddy_index = util_eddies.node_to_date_eddy(n, e_overestim)
        i_slice = SHPC.get_slice(date_index)
        i_slice, ishape = SHPC.comp_ishape(
            date_index, eddy_index, orientation, i_slice
        )
        shapeRec = SHPC.get_reader(
            i_slice, orientation, "extremum"
        ).shapeRecord(ishape)
        prop = {
            "pos": [math.radians(x) for x in shapeRec.shape.points[0]],
            "speed": shapeRec.record.speed,
        }
        prop["radius"] = (
            SHPC.get_reader(i_slice, orientation, "max_speed_contour")
            .record(ishape)
            .r_eq_area
        )

        if prop["radius"] < 0:
            prop["radius"] = (
                SHPC.get_reader(i_slice, orientation, "outermost_contour")
                .record(ishape)
                .r_eq_area
            )

        properties.append(prop)

    return properties


t0 = time.perf_counter()
timings = open("timings_cost.txt", "w")
parser = argparse.ArgumentParser()
parser.add_argument("SHPC_dir")
parser.add_argument("orientation", choices=["Anticyclones", "Cyclones"])
parser.add_argument(
    "input_segments",
    help="input graph of segments without cost, suffix .gt (graph-tool) or "
    ".graphml",
)
parser.add_argument(
    "output_segments",
    help="output graph of segments with cost, suffix .gt (graph-tool) or "
    ".graphml",
)
parser.add_argument(
    "--debug", help="save properties to output file", action="store_true"
)
args = parser.parse_args()

# Set some values needed for the cost function:
delta_cent_mean = 3.8481  # in km
delta_cent_std = 8.0388  # in km
delta_ro_mean = -0.0025965
delta_ro_std = 5.2168
delta_r_mean = -9.4709  # in m
delta_r_std = 8.6953e3  # in m

# Load the graph_tool file:

print("Loading graph...")
g = graph_tool.load_graph(args.input_segments)
print("Loading done...")
print("Input graph:")
print("Number of vertices:", g.num_vertices())
print("Number of edges:", g.num_edges())
print("Internal properties:")
g.list_properties()
t1 = time.perf_counter()
timings.write(f"loading: {t1 - t0:.0f} s\n")
t0 = t1

# It is useful to save the orientation to the output graph of this
# script for further processing of the output graph by other scripts:
g.graph_properties["orientation"] = g.new_graph_property("string")
g.graph_properties["orientation"] = args.orientation

pos_first = g.new_vp("vector<double>")
pos_last = g.new_vp("vector<double>")
first_av_rad = g.new_vp("float")
first_av_ros = g.new_vp("float")
last_av_rad = g.new_vp("float")
last_av_ros = g.new_vp("float")

if args.debug:
    # Make the properties internal to the graph:
    g.vp["pos_first"] = pos_first
    g.vp["pos_last"] = pos_last
    g.vp["first_av_rad"] = first_av_rad
    g.vp["first_av_ros"] = first_av_ros
    g.vp["last_av_rad"] = last_av_rad
    g.vp["last_av_ros"] = last_av_ros

g.ep["cost_function"] = g.new_ep("float")
SHPC = util_eddies.SHPC_class(args.SHPC_dir, args.orientation)
n_days_avg = 7  # number of days to average
print("Iterating on vertices...")

for n in g.vertices():
    if n.in_degree() != 0:
        # Define properties for beginning of the segment:
        properties = node_to_prop(
            g.vp.inst_eddies[n][:n_days_avg],
            g.gp.e_overestim,
            SHPC,
            args.orientation,
        )
        first_av_rad[n], first_av_ros[n] = calculate_radii_rossby(properties)
        pos_first[n] = properties[0]["pos"]  # in rad

    if n.out_degree() != 0:
        # Define properties for end of the segment:

        len_seg = len(g.vp.inst_eddies[n])

        if n.in_degree() == 0 or len_seg > n_days_avg:
            # We have to read more from the shapefiles and redefine
            # properties.

            if n.in_degree() == 0 or len_seg >= 2 * n_days_avg:
                # We cannot use part of properties from the beginning
                # of the segment.
                properties = node_to_prop(
                    g.vp.inst_eddies[n][-n_days_avg:],
                    g.gp.e_overestim,
                    SHPC,
                    args.orientation,
                )
            else:
                # assertion: n.in_degree() != 0 and n_days_avg <
                # len_seg < 2 * n_days_avg

                # We can use part of the properties from the beginning
                # of the segment.
                properties = properties[len_seg - n_days_avg :] + node_to_prop(
                    g.vp.inst_eddies[n][n_days_avg:],
                    g.gp.e_overestim,
                    SHPC,
                    args.orientation,
                )

            last_av_rad[n], last_av_ros[n] = calculate_radii_rossby(properties)
        else:
            # The number of eddies in the segment is lower than or
            # equal to the number of days over which to average. The
            # values for the end of the segment will be the same as
            # for the begining, except for the position.
            last_av_rad[n] = first_av_rad[n]
            last_av_ros[n] = first_av_ros[n]

        pos_last[n] = properties[-1]["pos"]  # in rad

t1 = time.perf_counter()
timings.write(f"iterating on vertices: {t1 - t0:.0f} s\n")
t0 = t1
print("Iterating on edges...")

for edge in g.edges():
    source_node = edge.source()
    target_node = edge.target()
    latitude = (pos_last[source_node][1] + pos_first[target_node][1]) / 2
    lon_diff = abs(pos_last[source_node][0] - pos_first[target_node][0])

    if lon_diff > math.radians(300):
        lon_diff = 2 * math.pi - lon_diff

    Delta_Cent = r_Earth * math.sqrt(
        (lon_diff * math.cos(latitude)) ** 2
        + (pos_last[source_node][1] - pos_first[target_node][1]) ** 2
    )

    # Rossby numbers:
    if first_av_ros[target_node] and last_av_ros[source_node]:
        Delta_Ro = first_av_ros[target_node] - last_av_ros[source_node]
    else:
        # At least one of the rossbies is invalid.
        Delta_Ro = 0

    # R_Vmax 1 and 2 already exist, just get the delta
    Delta_R_Vmax = first_av_rad[target_node] - last_av_rad[source_node]

    # Calculate the cost and assign to the edge:
    g.ep.cost_function[edge] = math.sqrt(
        ((Delta_Cent - delta_cent_mean) / delta_cent_std) ** 2
        + ((Delta_Ro - delta_ro_mean) / delta_ro_std) ** 2
        + ((Delta_R_Vmax - delta_r_mean) / delta_r_std) ** 2
    )

t1 = time.perf_counter()
timings.write(f"iterating on edges: {t1 - t0:.0f} s\n")
t0 = t1
print("Saving...")
g.save(args.output_segments)
print("All done")
t1 = time.perf_counter()
timings.write(f"saving: {t1 - t0:.0f} s\n")
timings.close()