#!/usr/bin/env python3

"""A script that takes the graph of segments without cost functions
and computes the non-local cost functions applied to edges.

Input:

- "node_id_param.json", expected to be in the current directory

- the graph of segments without cost functions, "segments.gt" or
  "segments.graphml", expected to be in the current directory

- shapefiles, specified as command line arguments

Output: the graph of segments with cost functions,
'segments_cost_functions.gt'

"""

import graph_tool
import time
import json
import math
from os import path
import shapefile
import datetime
from numpy import loadtxt
import report_graph
import util_eddies
import bisect
import argparse

def calculate_radii_rossby(list_eddies, e_overestim, handlers, array_d_init):
    """Compute average on list_eddies of Rossby number and radius of
    maximum speed contour.

    """

    radii = 0 # in m
    rossby = 0
    days_modifier = 0
    Omega = 2 * math.pi / 86164.
    n_eddies = len(list_eddies)

    for n in list_eddies:
        current_eddy = report_graph.node_to_date_eddy(n, e_overestim)
        i_SHPC = get_SHPC(array_d_init, current_eddy['date_index'])

        # calculate the location in the shapefile
        location = util_eddies.comp_ishape(handlers[i_SHPC],
                                           current_eddy['date_index'],
                                           current_eddy['eddy_index'])

        # now that we have the location in the shapefiles, we need to
        # get the radius and the rossby number
        shapeRec = handlers[i_SHPC]["readers"]["extremum"].shapeRecord(location)
        lat_in_deg = shapeRec.shape.points[0][1] # in degrees
        f = 2 * Omega * math.sin(math.radians(lat_in_deg)) # in s-1
        V_max = shapeRec.record[4] # in m/s
        R_Vmax = handlers[i_SHPC]["readers"]["max_speed_contour"]\
            .record(location)['r_eq_area'] * 1000 # in m

        if (V_max < 100):
            # calculate Ro and Delta_Ro
            rossby += V_max / (f * R_Vmax)
        else:
            days_modifier += 1

        radii += R_Vmax # in m

    radii /= n_eddies

    if n_eddies > days_modifier:
        rossby /= n_eddies - days_modifier
    else:
        rossby = None

    return {"radii": radii, "rossby": rossby}

def get_SHPC(array_d_ini, date_index):
    i_SHPC = bisect.bisect(array_d_init, date_index)
    assert i_SHPC >= 1
    return i_SHPC - 1

t0 = time.perf_counter()
timings = open("timings.txt", "w")
parser = argparse.ArgumentParser()
parser.add_argument("SHPC_dir", nargs='+')
parser.add_argument("--graphml", action = "store_true",
                    help = "save to graphml format")
args = parser.parse_args()

# Grab e_overestim:
with open("node_id_param.json") as f: node_id_param = json.load(f)
e_overestim = node_id_param["e_overestim"]

# Set some values needed for the cost function:

delta_cent_mean = 3.8481 # [in km]
delta_cent_std = 8.0388

delta_ro_mean = -0.0025965
delta_ro_std = 5.2168

delta_r_mean = -0.0094709 * 1000 # [in m]
delta_r_std = 8.6953 * 1000

# Load the graph_tool file:

print('Loading graph...')
g = graph_tool.Graph()

try:
    g.load('segments.gt')
except FileNotFoundError:
    g.load('segments.graphml')

print('Loading done...')
print("Input graph:")
print("Number of vertices:", g.num_vertices())
print("Number of edges:", g.num_edges())
print("Internal properties:")
g.list_properties()
t1 = time.perf_counter()
timings.write(f"loading: {t1 - t0:.0f} s\n")
t0 = t1

g.vp['pos_first'] = g.new_vp('object')
g.vp['pos_last'] = g.new_vp('object')
g.vp['first_av_rad'] = g.new_vp('float')
g.vp['first_av_ros'] = g.new_vp('float')
g.vp['last_av_rad'] = g.new_vp('float')
g.vp['last_av_ros'] = g.new_vp('float')
g.ep['cost_function'] = g.new_ep('float')

# Set up the list of SHPC:

handlers = [util_eddies.open_shpc(shpc_dir) for shpc_dir in args.SHPC_dir]

array_d_init = [handler["d_init"] for handler in handlers]
# (create the list once and for all)

num_of_days_to_avg = 7 # number of days to average
print("Iterating on vertices...")

for n in g.vertices():
    segment = g.vp.inst_eddies[n]
    num_of_days = len(segment)

    # Calculate the date index, the eddy index and the SHPC index of
    # the first and last instantaneous eddies in the segment:

    first = report_graph.node_to_date_eddy(segment[0], e_overestim)
    first_SHPC = get_SHPC(array_d_init, first['date_index'])

    last = report_graph.node_to_date_eddy(segment[-1], e_overestim)
    last_SHPC = get_SHPC(array_d_init, last['date_index'])

    # Calculate the location in the shapefiles:
    first_loc = util_eddies.comp_ishape(handlers[first_SHPC],
                                        first['date_index'],
                                        first['eddy_index'])
    last_loc = util_eddies.comp_ishape(handlers[last_SHPC],
                                       last['date_index'],
                                       last['eddy_index'])

    # Grab the positions of the extrema and store them in the vertex
    # properties:
    g.vp.pos_first[n] = handlers[first_SHPC]["readers"]["extremum"]\
        .shape(first_loc).points[0] # [in degrees]
    g.vp.pos_last[n] = handlers[last_SHPC]["readers"]["extremum"]\
        .shape(last_loc).points[0] # [in degrees]

    if (num_of_days > num_of_days_to_avg):
        # The segment is longer than the number of days over which to average

        # First 7 days calculation
        first_res = calculate_radii_rossby(segment[:num_of_days_to_avg],
                                           e_overestim, handlers, array_d_init)

        # Average and assign the first radii:
        g.vp.first_av_rad[n] = first_res['radii']

        if first_res['rossby'] is not None:
            # Average and assign the rossbies:
            g.vp.first_av_ros[n] = first_res['rossby']

        # Last 7 days calculation:
        last_res = calculate_radii_rossby(segment[- num_of_days_to_avg:],
                                          e_overestim, handlers, array_d_init)

        # Average and assign the last radii
        g.vp.last_av_rad[n] = last_res['radii']

        if last_res['rossby'] is not None:
            # Average and assign the rossbies:
            g.vp.last_av_ros[n] = last_res['rossby']
    else:
        # The number of eddies in a segment is lower than the number
        # of days over which to average. The values will be the same
        # except for the positions.
        res = calculate_radii_rossby(segment, e_overestim, handlers,
                                     array_d_init)

        if res['rossby'] is not None:
            # Average and assign the rossbies:
            rossby = res['rossby']
            g.vp.first_av_ros[n] = rossby
            g.vp.last_av_ros[n] = rossby

        # Average and assign the radii
        radii = res['radii']
        g.vp.first_av_rad[n] = radii
        g.vp.last_av_rad[n] = radii

t1 = time.perf_counter()
timings.write(f"iterating on vertices: {t1 - t0:.0f} s\n")
t0 = t1
print("Iterating on edges...")

for edge in g.edges():
    source_node = edge.source()
    target_node = edge.target()

    cf = -10000

    lat_for_conv = (g.vp.pos_last[source_node][1] +
                    g.vp.pos_first[target_node][1]) / 2
    # (latitude needed for conversion of degrees to kilometers)
    lat_for_conv = math.radians(lat_for_conv) # need to convert to radians

    # because of the wrapping issue (360° wrapping incorrectly to 0°),
    # we check for that here
    lon_diff = abs(g.vp.pos_last[source_node][0] \
                   - g.vp.pos_first[target_node][0])
    if (lon_diff > 300):
        lon_diff = 360 - lon_diff

    Delta_Cent = math.sqrt((lon_diff * 111.32 * math.cos(lat_for_conv))**2
                           + ((g.vp.pos_last[source_node][1]
                               - g.vp.pos_first[target_node][1]) * 110.574)**2)

    # calculate the first term
    first_term = ((Delta_Cent - delta_cent_mean)/delta_cent_std) ** 2

    # Rossbies:
    if (g.vp.first_av_ros[target_node] and g.vp.last_av_ros[source_node]):
        Delta_Ro = g.vp.last_av_ros[source_node] \
            - g.vp.first_av_ros[target_node]
    else:
        # At least one of the rossbies is invalid.
        # Delta_Ro = delta_ro_mean
        Delta_Ro = 0

    # Calculate the second term
    second_term = ((Delta_Ro - delta_ro_mean)/delta_ro_std ) ** 2

    # R_Vmax 1 and 2 already exist, just get the delta

    Delta_R_Vmax = g.vp.last_av_rad[source_node] \
        - g.vp.first_av_rad[target_node]

    # Calculate the third term
    third_term = ((Delta_R_Vmax - delta_r_mean)/delta_r_std) ** 2

    #############################
    # calculate the cost function
    #############################

    cf = math.sqrt(first_term + second_term + third_term)

    # assign as weight to the edge
    g.ep.cost_function[edge] = cf

t1 = time.perf_counter()
timings.write(f"iterating on edges: {t1 - t0:.0f} s\n")
t0 = t1
print("Saving...")

if args.graphml:
    g.save('segments_cost_functions.graphml')
else:
    g.save('segments_cost_functions.gt')

print('All done')
t1 = time.perf_counter()
timings.write(f"saving: {t1 - t0:.0f} s\n")
timings.close()