Skip to content
Snippets Groups Projects
quadratic.py 2.88 KiB
Newer Older
Jörg Martin's avatar
Jörg Martin committed
import torch
import sys
from torch.utils.data import TensorDataset, random_split

total_number_of_datapoints = 2000
input_range = [-1,1]
slope = 1.0
intercept = 0.0
x_noise_strength = 0.05
y_noise_strength = 0.1

def get_normalization(*args):
    """
    Returns the mean and standard deviations (in tuples) of the tensors in *args.
    """
    normalization_collection = []
    for t in args:
        t_mean = torch.mean(t, dim=0, keepdim=True)
        t_std = torch.std(t, dim=0, keepdim=True)
        normalization_collection.append((t_mean, t_std))
    return tuple(normalization_collection)

def load_data(seed=0, splitting_part=0.8, normalize=True,
        return_ground_truth=False):
    """
    Loads one-dimensional data
    :param seed: Seed for drawing and splitting the data.
    :param splitting_part: Which fraction of the data to use as training
    data. Defaults to 0.8.
    :param normalize: Whether to normalize the data, defaults to True.
    :param return_ground_truth: Boolean. If True, the unnoisy ground truth will
    also be returned. Defaults to False.
    :returns: linear_trainset, linear_testset if return_ground_truth is False,
    else linear_trainset, linear_testset, (true_x, true_y)
    """
    random_generator = torch.Generator().manual_seed(seed)
    # draw different seeds for noise and splitting
    seeds = torch.randint(0,sys.maxsize,(4,), generator=random_generator)
    # create new generators from tensor seeds
    create_generator = lambda tensor_seed:\
            torch.Generator().manual_seed(tensor_seed.item())
    true_x = input_range[0] + (input_range[1]-input_range[0])\
                  * torch.rand((total_number_of_datapoints,1),
                          generator=create_generator(seeds[0]))
    true_y = slope * true_x**2 + intercept 
    noisy_x = true_x + x_noise_strength * \
            torch.randn((total_number_of_datapoints,1),
            generator=create_generator(seeds[1]))
    noisy_y = true_y + y_noise_strength * \
            torch.randn((total_number_of_datapoints,1),
            generator=create_generator(seeds[2]))
    if normalize:
        normalization_x, normalization_y = get_normalization(noisy_x, noisy_y)
        noisy_x = (noisy_x-normalization_x[0])/normalization_x[1]
        true_x = (true_x-normalization_x[0])/normalization_x[1]
        noisy_y = (noisy_y-normalization_y[0])/normalization_y[1]
        true_y = (true_y-normalization_y[0])/normalization_y[1]
    linear_dataset = TensorDataset(noisy_x, noisy_y)
    dataset_len = len(linear_dataset)
    train_len = int(dataset_len*splitting_part)
    test_len = dataset_len - train_len
    linear_trainset, linear_testset = random_split(linear_dataset,
            lengths=[train_len, test_len],
            generator=create_generator(seeds[3]))
    if not return_ground_truth:
        return linear_trainset, linear_testset
    else:
        return linear_trainset, linear_testset, (true_x, true_y)