Skip to content
Snippets Groups Projects
repeated_sampling.py 3.24 KiB
"""
Contains the class `repeated_sampling` that can be used to generate
datasets for repeated sampling from datasets with a ground truth.
"""
import sys

import torch
from torch.utils.data import TensorDataset

from EIVGeneral.manipulate_tensors import add_noise

class repeated_sampling():
    """
    A class for repeated sampling from datasets with a known ground truth and
    known input and output noise. The class `dataclass` should contain a
    `load_data` routine that returns a ground truth and two positive floats
    `x_noise_strength` and `y_noise_strength` that will be used as the standard
    deviation of input and output noise.
    :param dataclass: A module that contains a routine `load_data`, which
    accepts the keyword `return_ground_truth` and returns the noisy and true
    train and test datasets, and two positive floats `x_noise_strength` and
    `y_noise_strength`.
    :param fixed_seed: Integer. The seed to load the unnoisy ground truth,
    defaults to 0. 
    """
    def __init__(self, dataclass, fixed_seed=0):
        self.dataclass = dataclass
        self.fixed_seed = fixed_seed
        self.x_noise_strength = dataclass.x_noise_strength
        self.y_noise_strength = dataclass.y_noise_strength

    def __call__(self,seed=0, splitting_part=0.8, normalize=True,
            return_ground_truth=False):
        _, _, true_trainset, true_testset\
                = self.dataclass.load_data(
                        seed=self.fixed_seed, splitting_part=splitting_part,
                        normalize=False,
                        return_ground_truth=True)
        full_noisy_x = torch.concat((true_trainset.tensors[2],
                true_testset.tensors[2]), dim=0)
        full_noisy_y = torch.concat((true_trainset.tensors[3],
                true_testset.tensors[3]), dim=0)
        true_train_x, true_train_y = true_trainset.tensors[:2]
        true_test_x, true_test_y = true_testset.tensors[:2]
        random_generator = torch.Generator().manual_seed(seed)
        # draw different seeds for noise and splitting
        seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\
                generator=random_generator)]
        (noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\
                add_noise((true_train_x, true_train_y),
                        (self.x_noise_strength, self.y_noise_strength), seeds,
                        normalize=normalize,
                        normalization_list=[full_noisy_x, full_noisy_y])
        (noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
                add_noise((true_test_x, true_test_y),
                        (self.x_noise_strength, self.y_noise_strength), seeds,
                        normalize=normalize,
                        normalization_list=[full_noisy_x, full_noisy_y])
        trainset = TensorDataset(noisy_train_x, noisy_train_y)
        testset = TensorDataset(noisy_test_x, noisy_test_y)
        true_trainset = TensorDataset(true_train_x, true_train_y,
                noisy_train_x, noisy_train_y)
        true_testset = TensorDataset(true_test_x, true_test_y,
                noisy_test_x, noisy_test_y)
        if not return_ground_truth:
            return trainset, testset
        else:
            return trainset, testset, true_trainset, true_testset