-
Jörg Martin authoredJörg Martin authored
repeated_sampling.py 3.24 KiB
"""
Contains the class `repeated_sampling` that can be used to generate
datasets for repeated sampling from datasets with a ground truth.
"""
import sys
import torch
from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise
class repeated_sampling():
"""
A class for repeated sampling from datasets with a known ground truth and
known input and output noise. The class `dataclass` should contain a
`load_data` routine that returns a ground truth and two positive floats
`x_noise_strength` and `y_noise_strength` that will be used as the standard
deviation of input and output noise.
:param dataclass: A module that contains a routine `load_data`, which
accepts the keyword `return_ground_truth` and returns the noisy and true
train and test datasets, and two positive floats `x_noise_strength` and
`y_noise_strength`.
:param fixed_seed: Integer. The seed to load the unnoisy ground truth,
defaults to 0.
"""
def __init__(self, dataclass, fixed_seed=0):
self.dataclass = dataclass
self.fixed_seed = fixed_seed
self.x_noise_strength = dataclass.x_noise_strength
self.y_noise_strength = dataclass.y_noise_strength
def __call__(self,seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False):
_, _, true_trainset, true_testset\
= self.dataclass.load_data(
seed=self.fixed_seed, splitting_part=splitting_part,
normalize=False,
return_ground_truth=True)
full_noisy_x = torch.concat((true_trainset.tensors[2],
true_testset.tensors[2]), dim=0)
full_noisy_y = torch.concat((true_trainset.tensors[3],
true_testset.tensors[3]), dim=0)
true_train_x, true_train_y = true_trainset.tensors[:2]
true_test_x, true_test_y = true_testset.tensors[:2]
random_generator = torch.Generator().manual_seed(seed)
# draw different seeds for noise and splitting
seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\
generator=random_generator)]
(noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\
add_noise((true_train_x, true_train_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize,
normalization_list=[full_noisy_x, full_noisy_y])
(noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
add_noise((true_test_x, true_test_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize,
normalization_list=[full_noisy_x, full_noisy_y])
trainset = TensorDataset(noisy_train_x, noisy_train_y)
testset = TensorDataset(noisy_test_x, noisy_test_y)
true_trainset = TensorDataset(true_train_x, true_train_y,
noisy_train_x, noisy_train_y)
true_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y)
if not return_ground_truth:
return trainset, testset
else:
return trainset, testset, true_trainset, true_testset