-
Jörg Martin authoredJörg Martin authored
kin8nm.py 1010 B
import torch
from EIVData.csv_dataset import CSVData
from torch.utils.data import random_split
def load_data(seed=0, splitting_part=0.8, normalize=True):
"""
Loads the kin8nm dataset
:param seed: Seed for splitting and shuffling the data.
Defaults to 0.
:param splitting_part: Which fraction of the data to use as training
data. Defaults to 0.8.
:normalize: Whether to normalize the data, defaults to True.
:returns: kin8nm_trainset, kin8nm_testset
"""
kin8nm_dataset = CSVData('~/SharedData/AI/datasets/kin8nm/dataset_2175_kin8nm.csv',
class_name='y',
shuffle_seed=seed,
normalize=normalize)
dataset_len = len(kin8nm_dataset)
train_len = int(dataset_len*splitting_part)
test_len = dataset_len - train_len
kin8nm_trainset, kin8nm_testset = random_split(kin8nm_dataset,
lengths=[train_len, test_len],
generator=torch.Generator().manual_seed(seed))
return kin8nm_trainset, kin8nm_testset