import pandas as pd
import datetime
import os
import re
import xmltodict
import xml.etree.cElementTree as ET
from io import StringIO

import utils
from dataprovider.picarroprovider import PicarroProvider


class ExploProvider:

    def __init__(self, picarro_prvd: PicarroProvider):
        self.datasets_root_directory = ""
        self.datasets = {}
        self.picarro_prvd = picarro_prvd

    def explore_root_directory(self, root_directory: str) -> list:
        """Get the names of the datasets directories.

        Parameters
        ----------
        root_directory: str
            Full path of the directory containing the datasets directories.

        Returns
        -------
        list
            List of dataset directories name (without full path)

        """
        directories = []

        # Find all directories in datasets root directory (not recursive)
        for element in os.listdir(root_directory):
            if os.path.isdir(root_directory + "/" + element):
                directories.append(element)

        # Keep only datasets directories (ignore others like pump_calibration, conduct_calib, old, etc.)
        regex = re.compile(r'[0-9]{8}_.*')
        dataset_directories = list(filter(regex.search, directories))

        # Sort list in alphabetical order (in this case, by ascending date)
        dataset_directories.sort()

        self.datasets_root_directory = root_directory
        self.datasets_dirs = dataset_directories
        for directory in dataset_directories:
            dataset = Dataset(root_directory, directory, self.picarro_prvd)
            self.datasets[directory] = dataset

        return dataset_directories


class Dataset:

    def __init__(self, root_directory: str, directory_name: str, picarro_prvd: PicarroProvider):
        self.root_directory = root_directory
        self.directory_name = directory_name
        self.full_directory_name = root_directory + "/" + directory_name

        self.picarro_prvd = picarro_prvd

        # Get dataset name
        self.dataset_text = directory_name[-9:]
        self.first_data_datetime = datetime.datetime.now(tz=datetime.timezone.utc)
        self.last_data_datetime = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)

        self.instlogs = {}
        self.manual_event_log = None

        # Setup save/load
        self.saved_setup_dir = self.full_directory_name + "/saved_setups/"
        self.saved_setup_ext = ".xml"

        self.explore_dataset()

    def explore_dataset(self) -> None:
        filenames = os.listdir(self.full_directory_name)

        for filename in filenames:
            try:
                inst_and_type = re.search("^" + self.directory_name + '_(.+?).log$', filename).group(1)
            except AttributeError:
                # The found file does not match normal instrument's log file pattern
                print("File [" + filename + "] does not appear to be a valid CFA log file")
                continue

            instrument_name = inst_and_type.split("_")[0]

            if len(inst_and_type.split("_")) == 2:
                log_type = inst_and_type.split("_")[1]
                if log_type == "instant":
                    if instrument_name == "ICBKCTRL":
                        instrument_log = IceblockInstantLog(self.full_directory_name, filename, instrument_name)
                    else:
                        instrument_log = InstrumentInstantLog(self.full_directory_name, filename, instrument_name)
                elif log_type == "periodic":
                    instrument_log = InstrumentPeriodicLog(self.full_directory_name, filename, instrument_name)
                    self.first_data_datetime = min(self.first_data_datetime, instrument_log.df["datetime"].min())
                    self.last_data_datetime = max(self.last_data_datetime, instrument_log.df["datetime"].max())
                else:
                    raise ValueError("Unknown log type: [" + log_type + "]")
                self.instlogs[inst_and_type] = instrument_log
            elif instrument_name == "manual-event":
                self.manual_event_log = ManualEventLog(self.full_directory_name, filename, instrument_name)

        # Picarro data are not logged the same way as the others, it is logged directly in the Picarro instrument.
        # In order to have comparable data files, create "artificial" PICARRO_periodic log file from the Picarro log
        # files.
        picarro_filename = self.directory_name + "_PICARRO_periodic.log"
        if picarro_filename not in filenames:
            try:
                picarro_df = self.picarro_prvd.get_df(self.first_data_datetime,
                                                      self.last_data_datetime,
                                                      ["H2O", "Delta_D_H", "Delta_18_16"])
            except ValueError as e:
                print("Failed to get Picarro data: " + str(e))
                return

            picarro_df.to_csv(path_or_buf=self.full_directory_name + "/" + picarro_filename,
                              sep="\t",
                              index=False,
                              mode='w',  # Always override file content
                              date_format=utils.datetime_format
                              )
            picarro_log = InstrumentPeriodicLog(self.full_directory_name, picarro_filename, "PICARRO")
            self.instlogs["PICARRO_periodic"] = picarro_log

    def save_setup(self, setup_name: str, variable_df: pd.DataFrame, view_range: list) -> None:
        # Build 'saved setup' full file name
        if not os.path.exists(self.saved_setup_dir):
            os.mkdir(self.saved_setup_dir)
        filename = self.saved_setup_dir + setup_name + self.saved_setup_ext

        # Variables table
        variables_str = variable_df.to_csv(sep=";",
                                           index=False,
                                           mode='w')

        # Create XML file
        root_elmt = ET.Element("save")
        ET.SubElement(root_elmt, "variables").text = variables_str
        view_range_elmt = ET.SubElement(root_elmt, "view_range")
        ET.SubElement(view_range_elmt, "xmin").text = "{:.2f}".format(view_range[0][0])
        ET.SubElement(view_range_elmt, "xmax").text = "{:.2f}".format(view_range[0][1])
        ET.SubElement(view_range_elmt, "ymin").text = "{:.4f}".format(view_range[1][0])
        ET.SubElement(view_range_elmt, "ymax").text = "{:.4f}".format(view_range[1][1])

        tree = ET.ElementTree(root_elmt)
        tree.write(filename)

    def load_setup(self, filename: str) -> tuple:
        full_filename = self.saved_setup_dir + filename + self.saved_setup_ext

        # Open XML file
        tree = ET.parse(full_filename)
        root = tree.getroot()

        # Variable CSV table as pd.Dataframe
        variables_str = root.findall("variables")[0].text
        variable_io = StringIO(variables_str)
        variable_df = pd.read_csv(variable_io, sep=";")

        # View range
        view_range_elmt = root.findall("view_range")[0]
        view_range_dict = {"xmin": float(view_range_elmt.findall("xmin")[0].text),
                           "xmax": float(view_range_elmt.findall("xmax")[0].text),
                           "ymin": float(view_range_elmt.findall("ymin")[0].text),
                           "ymax": float(view_range_elmt.findall("ymax")[0].text)}

        return variable_df, view_range_dict

    def setup_filename_is_valid(self, filename: str) -> tuple:
        """Check if the file name is valid: no special characters, file does not already exists.

        Parameters
        ----------
        filename: str
            filename (without extension) to be tested.

        Returns
        -------
        bool:
            True if the file name is valid, False otherwise
        str:
            The error message explaining why the file name is not valid ; an empty string if file name is valid.
        """
        if not re.match("^[A-Za-z0-9_-]*$", filename):
            error_msg = "File name can only contain letters, digits and '-' or '_'. File extension is automatically set."
            return False, error_msg
        elif filename in self.get_setup_saved_files():
            error_msg = "File already exists."
            return False, error_msg
        else:
            return True, ""

    def get_setup_saved_files(self) -> list:
        """Get a list of the 'setup' file names (without extension) existing in the 'saved_setups' directory."""
        if not os.path.exists(self.saved_setup_dir):
            return []

        filenames = os.listdir(self.saved_setup_dir)

        files_without_ext = [os.path.splitext(filename)[0] for filename in filenames]
        return files_without_ext


class InstrumentLog:

    def __init__(self, full_directory_name: str, filename: str, instrument_name: str):
        self.full_directory_name = full_directory_name
        self.filename = filename
        self.full_file_name = full_directory_name + "/" + filename
        self.instrument_name = instrument_name

        self.df = self.__get_df__()

    def get_variables(self):
        raise NotImplementedError("Subclasses should implement this.")

    def get_timeseries(self, variable: str) -> pd.DataFrame:
        raise NotImplementedError("Subclasses should implement this.")

    def __get_df__(self) -> pd.DataFrame:
        raise NotImplementedError("Subclasses should implement this.")


class InstrumentInstantLog(InstrumentLog):

    def __init__(self, full_directory_name: str, filename: str, instrument_name: str):
        InstrumentLog.__init__(self, full_directory_name, filename, instrument_name)

    def __get_df__(self) -> pd.DataFrame:
        df = pd.read_csv(self.full_file_name, sep=",", parse_dates=["datetime"])
        df["datetime"] = df["datetime"].dt.tz_localize('UTC')
        return df

    def get_variables(self):
        return self.df.name.unique()

    def get_timeseries(self, variable: str) -> pd.DataFrame:
        timeseries_df = self.df[self.df["name"] == variable]
        timeseries_df = timeseries_df.drop(columns=['name'])

        try:
            timeseries_df["value"] = timeseries_df["value"].astype(float)
        except ValueError:
            timeseries_df["value_int"] = timeseries_df["value"].astype("category").cat.codes
        return timeseries_df


class IceblockInstantLog(InstrumentLog):

    def __init__(self, full_directory_name: str, filename: str, instrument_name: str):
        InstrumentLog.__init__(self, full_directory_name, filename, instrument_name)

    def __get_df__(self) -> pd.DataFrame:
        df = pd.read_csv(self.full_file_name, sep=",", parse_dates=["datetime"])
        df["datetime"] = df["datetime"].dt.tz_localize('UTC')
        return df

    def get_variables(self):
        return ["melting"]

    def get_timeseries(self, variable: str) -> pd.DataFrame:
        if variable == "melting":
            timeseries_df = self.__get_melting_timeseries__()
        else:
            raise ValueError("Variable name [" + variable + "] not yet managed.")

        return timeseries_df

    def __get_melting_timeseries__(self) -> pd.DataFrame:
        # Get the mapping between iceblock id and iceblock name (assuming that the last name's modification is the
        # good one.
        mapping_df = self.df[["datetime", "id", "name"]].copy()
        mapping_df = mapping_df.groupby("id")["id", "name"].tail(1)
        mapping_df = mapping_df.append({"id": 0, "name": "None"}, ignore_index=True)
        mapping_df = mapping_df.set_index("id")
        mapping_dict = mapping_df["name"].to_dict()

        # Get the datetime of the beginning of each iceblock's melting
        melting_df = self.df[["datetime", "id", "status"]].copy()
        start_df = melting_df[melting_df["status"] == "Melting"].groupby("id")["datetime", "id"].head(1)

        # Get the end of the last iceblock's melting, and set that after that the current melting block is 0/None.
        end_df = melting_df[melting_df["status"] == "Done"].groupby("id").head(1)
        melting_df = start_df.append({"datetime": end_df.iloc[-1]["datetime"], "id": 0},
                                     ignore_index=True)

        # Get the value (iceblocks name) and value_int (coded value, iceblock id in this case).
        melting_df.rename(columns={"id": 'value_int'}, inplace=True)
        melting_df["value"] = melting_df["value_int"].map(mapping_dict)

        return melting_df


class ManualEventLog(InstrumentLog):

    def __init__(self, full_directory_name: str, filename: str, instrument_name: str):
        InstrumentLog.__init__(self, full_directory_name, filename, instrument_name)

    def __get_df__(self) -> pd.DataFrame:
        # The manual-event log file is not a valid XML file: the root tage is missing. So open the content of the file,
        # and add the root tags
        with open(self.full_file_name) as f:
            xml_str = f.read()
        xml_str = "<root>" + xml_str + "</root>"

        # Convert the XML to dict, then convert the dict to pd.Dataframe
        xml_dict = xmltodict.parse(xml_str)
        if "datetime" in xml_dict["root"]["event"]:  # Only 1 event -> one less level in dict tree
            df = pd.DataFrame([xml_dict["root"]["event"]])
        else:
            df = pd.DataFrame.from_dict(xml_dict["root"]["event"])

        # Rename "description" column
        df.rename(columns={"description": 'event'}, inplace=True)

        # Format datetime column.
        df["datetime"] = pd.to_datetime(df["datetime"]).dt.tz_localize('UTC')
        return df

    def get_variables(self):
        return ["event"]

    def get_timeseries(self, variable: str) -> pd.DataFrame:
        timeseries_df = self.df[["datetime", variable]]
        timeseries_df.rename(columns={variable: 'value'}, inplace=True)
        return timeseries_df


class InstrumentPeriodicLog(InstrumentLog):

    def __init__(self, full_directory_name: str, filename: str, instrument_name: str):
        InstrumentLog.__init__(self, full_directory_name, filename, instrument_name)

    def __get_df__(self) -> pd.DataFrame:
        df = pd.read_csv(self.full_file_name, sep="\t", parse_dates=["datetime"])
        if not df.empty:
            df["datetime"] = df["datetime"].dt.tz_localize('UTC')
        return df

    def get_variables(self):
        all_cols = list(self.df)
        variable_cols = [colname for colname in all_cols if colname != "datetime"]
        return variable_cols

    def get_timeseries(self, variable: str) -> pd.DataFrame:
        timeseries_df = self.df[["datetime", variable]]
        timeseries_df.rename(columns={variable: 'value'}, inplace=True)
        return timeseries_df