from EIVData.csv_dataset import CSVData
from torch.utils.data import random_split

def load_data(seed=0, splitting_part=0.8, normalize=True):
    """
    Loads the naval propulsion dataset
    :param seed: Seed for splitting and shuffling the data.
    Defaults to 0.
    :param splitting_part: Which fraction of the data to use as training
    data. Defaults to 0.8.
    :normalize: Whether to normalize the data, defaults to True.
    :returns: naval_trainset, naval_testset
    """
    naval_dataset = CSVData('~/SharedData/AI/datasets/naval_propulsion/navalplantmaintenance.csv',
            class_name=[16,17],
            shuffle_seed=seed,
            normalize=normalize,
            header=None,
            delimiter=r"\s+")
    dataset_len = len(naval_dataset)
    train_len = int(dataset_len*splitting_part)
    test_len = dataset_len - train_len
    naval_trainset, naval_testset = random_split(naval_dataset , lengths=[train_len, test_len])
    return naval_trainset, naval_testset