import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch.distributions


class CSVData(Dataset):
    """
    Wraps a pandas.DataFrame into a torch.utils.data.Dataset
    :filename: File to csv file
    :class_name: String or List. Name(s) of column(s) that will be
    interpreted as class(es)
    :shuffle_seed: If not None, this will be used as a seed for
    shuffling the data when reading in. Caution: Take the same seed for
    training and testing.
    :condition: If not None, this will be used to cut the dataset.
    Should take in a dataframe and return a list of True/False of the
    same length than the dataset.
    """
    def __init__(self, filename, class_name,
                 shuffle_seed=None, condition=None, header=0,delimiter=',',
                 normalize=True):
        if condition is None:
            full_df = pd.read_csv(filename, header=header, delimiter=delimiter)
        else:
            full_df = pd.read_csv(filename, header=header, delimiter=delimiter)
            indices = np.where(condition(full_df))
            full_df = full_df.iloc[indices]
        if shuffle_seed is not None:
            full_df = full_df.sample(frac=1, random_state=shuffle_seed)
        # check whether NaNs exist
        if np.array(full_df.isnull()).any():
            old_size = len(full_df)
            full_df = full_df.dropna()
            print(f'Removed {old_size-len(full_df)} rows due to missing data')
        self.full_df = full_df
        self.data_df = self.full_df.drop(columns = class_name)
        self.labels_df = self.full_df[class_name]
        self.full_columns = self.full_df.columns
        self.data_columns = self.data_df.columns
        self.normalize = normalize
        if self.normalize:
            self.save_mean_and_std()

    def __len__(self):
        return len(self.full_df)

    def save_mean_and_std(self):
        """
        Saves the mean and standard deviation of `self.data_df` and
        `self.labels_df` in `self.mean_features`, `self.std_features`,
        `self.mean_labels`, `self.std_labels`.
        """
        features_array = np.array(self.data_df)
        labels_array = np.array(self.labels_df)
        self.mean_features = torch.tensor(np.mean(features_array, axis=0), dtype=torch.float32)
        self.std_features = torch.tensor(np.std(features_array, axis=0), dtype=torch.float32)
        self.mean_labels = torch.tensor(np.mean(labels_array, axis=0), dtype=torch.float32)
        self.std_labels = torch.tensor(np.std(labels_array, axis=0), dtype=torch.float32)

    def normalize_sample(self, sample):
        """
        Applies `normalize_array` to the tuple `sample=(features, labels)`
        using `self.mean_features`, `self.std_features`, `self.mean_labels`,
        `self.std_labels`.
        """
        features, labels = sample
        return self.normalize_array(features, self.mean_features,
                self.std_features),\
                   self.normalize_array(labels, self.mean_labels,
                           self.std_labels)
              

    @staticmethod
    def normalize_array(array, mean, std):
        """
        Normalizes a one or two-dimensional array by `mean` and `std`, 
        where `mean` and `std` are supposed to be the mean and standard
        deviation of `array` w.r.t. its first dimension.
        **Note**: Whenever std is 0, the corresponding array component is set
        to 0.
        :param array: A `torch.tensor` of either one or two dimensions
        :param mean: A `torch.tensor` of either zero or one dimension:
        :param std: A `torch.tensor` of either zero or one dimension:
        :returns: A `torch.tensor`, the normalized array
        """
        assert len(array.shape) <= 2
        assert len(mean.shape) <= 1 and len(std.shape) <= 1
        if torch.numel(std) == 1:
            if std.item() == 0:
                normalized_array = array * 0.0
            else:
                normalized_array = (array - mean)/std.item()
        else:
            normalized_array = array * 0.0
            idx = std > 0
            normalized_array[idx] = (array[idx] - mean[idx])/std[idx]
        return normalized_array



    def __getitem__(self, i):
        # returns a tuple of a tensor and the corresponding label
        assert 0 <= i and i<self.__len__()
        sample = (torch.tensor(np.array(self.data_df.iloc[i]), dtype=torch.float32),
            torch.tensor(np.array(self.labels_df.iloc[i]), dtype=torch.float32))
        if self.normalize:
            return self.normalize_sample(sample)
        else:
            return sample