Skip to content
Snippets Groups Projects
power_plant.py 1.03 KiB
Newer Older
import torch
Jörg Martin's avatar
Jörg Martin committed
from EIVData.csv_dataset import CSVData
from torch.utils.data import random_split

def load_data(seed=0, splitting_part=0.8, normalize=True):
    """
Jörg Martin's avatar
Jörg Martin committed
    Loads the power plant dataset
Jörg Martin's avatar
Jörg Martin committed
    :param seed: Seed for splitting and shuffling the data.
    Defaults to 0.
    :param splitting_part: Which fraction of the data to use as training
    data. Defaults to 0.8.
    :normalize: Whether to normalize the data, defaults to True.
Jörg Martin's avatar
Jörg Martin committed
    :returns: power_trainset, power_testset
Jörg Martin's avatar
Jörg Martin committed
    """
Jörg Martin's avatar
Jörg Martin committed
    power_dataset = CSVData('~/SharedData/AI/datasets/combined_cycle_power_plant/Folds5x2_pp_single_sheet.csv',
Jörg Martin's avatar
Jörg Martin committed
            class_name="PE",
            shuffle_seed=seed,
            normalize=normalize,
            delimiter=",")
Jörg Martin's avatar
Jörg Martin committed
    dataset_len = len(power_dataset)
Jörg Martin's avatar
Jörg Martin committed
    train_len = int(dataset_len*splitting_part)
    test_len = dataset_len - train_len
Jörg Martin's avatar
Jörg Martin committed
    power_trainset, power_testset = random_split(power_dataset,
            lengths=[train_len, test_len],
            generator=torch.Generator().manual_seed(seed))
Jörg Martin's avatar
Jörg Martin committed
    return power_trainset, power_testset