import pandas as pd
import numpy as np
import torch
import torch.nn
from torch.utils.data import Dataset, DataLoader

class CSVData(Dataset):
    """
    A dataset to load csv ("comma seperated values") data. 
    To change for a different delimiter than `,` use the argument `delimiter`.
    :param data_path: Path to CSV File
    :param load_into_memory: Boolean. If True (default), whole dataset
    will be loaded into memory.
    :param y_columns: If not None should be a numpy array (or similar) and 
    list the columns which will be returned as second item of __getitem__
    :param delimiter: Will be used for reading CSV, defaults to ','
    :param header: Which row to take as header, defaults to None (=no header)
    :param chunksize: This argument is only used once and only if 
    `load_into_memory` is false to scan through the dataset. 
    Defaults to 10,000.
    """
    def __init__(self, data_path, load_into_memory=True,
            y_columns=None, delimiter=",", header=None, chunksize=10000):
        self.data_path = data_path
        self.load_into_memory = load_into_memory
        self.y_indices = y_columns
        self.delimiter = delimiter
        self.header = header
        if self.load_into_memory:
            # load dataset as array and determine length of data 
            self.array_dataset =  np.array(pd.read_csv(self.data_path,
                                                delimiter=delimiter,
                                                header=self.header))
            self.len = self.array_dataset.shape[0]
        else:
            # determine only self.len using chunksize
            with pd.read_csv(self.data_path,
                    delimiter=delimiter,
                    header=self.header,
                    chunksize=chunksize) as reader:
                for i, chunk in enumerate(reader):
                    pass
            self.len = i*chunksize+chunk.shape[0]
            

    def __getitem__(self, i):
        assert i < self.len
        while i<0:
            i = i + self.len
        if self.load_into_memory:
            ith_row = self.array_dataset[i,:]
        else:
            ith_row = np.array(pd.read_csv(self.data_path,
                                            delimiter=self.delimiter,
                                            header=self.header,
                                            skiprows=max(0,i),
                                            nrows=1))[0,:]
        if self.y_indices is None:
            return torch.tensor(ith_row, dtype=torch.float32)
        else:
            y_index_array = np.array(self.y_indices)
            total_number_of_columns = len(ith_row)
            x_index_array = np.setdiff1d(
                        np.arange(0,total_number_of_columns),
                        y_index_array
                        )
            x = torch.tensor(ith_row[x_index_array], dtype=torch.float32)
            y = torch.tensor(ith_row[y_index_array], dtype=torch.float32)
            return x,y

    def __len__(self):
        return self.len