diff --git a/EIVPackage/EIVData/repeated_sampling.py b/EIVPackage/EIVData/repeated_sampling.py index 360cd508b065b4ddb6c475d7c4e8142369ad4884..771577de48c49bef96b06e7d05d72e52ceaefb57 100644 --- a/EIVPackage/EIVData/repeated_sampling.py +++ b/EIVPackage/EIVData/repeated_sampling.py @@ -1,3 +1,10 @@ +import sys + +import torch +from torch.utils.data import TensorDataset + +from EIVGeneral.manipulate_tensors import add_noise + class repeated_sampling(): def __init__(self, dataclass, fixed_seed=0): self.dataclass = dataclass @@ -7,12 +14,36 @@ class repeated_sampling(): def __call__(self,seed=0, splitting_part=0.8, normalize=True, return_ground_truth=False): - _, _, _, true_testset = self.dataclass.load_data( + _, _, true_trainset, true_testset\ + = self.dataclass.load_data( seed=self.fixed_seed, splitting_part=splitting_part, normalize=normalize, - return_ground_truth=return_ground_truth) - true_x, true_y = true_testset.tensors[:2] - - - + return_ground_truth=True) + true_train_x, true_train_y = true_trainset.tensors[:2] + true_test_x, true_test_y = true_testset.tensors[:2] + random_generator = torch.Generator().manual_seed(seed) + # draw different seeds for noise and splitting + seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\ + generator=random_generator)] + (noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\ + add_noise((true_train_x, true_train_y), + (self.x_noise_strength, self.y_noise_strength), seeds, + normalize=True, + normalization_list=true_trainset.tensors[2:]) + (noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\ + add_noise((true_test_x, true_test_y), + (self.x_noise_strength, self.y_noise_strength), seeds, + normalize=True, + normalization_list=true_testset.tensors[2:]) + trainset = TensorDataset(noisy_train_x, noisy_train_y) + testset = TensorDataset(noisy_test_x, noisy_test_y) + true_trainset = TensorDataset(true_train_x, true_train_y, + noisy_train_x, noisy_train_y) + true_testset = TensorDataset(true_test_x, true_test_y, + noisy_test_x, noisy_test_y) + if not return_ground_truth: + return trainset, testset + else: + return trainset, testset, true_trainset, true_testset + diff --git a/EIVPackage/EIVGeneral/manipulate_tensors.py b/EIVPackage/EIVGeneral/manipulate_tensors.py index 8fa4fd8fee483f8ca250f55e067ff3a2a397e17a..de166822e45905f24ffe8c46082cc81965040f4f 100644 --- a/EIVPackage/EIVGeneral/manipulate_tensors.py +++ b/EIVPackage/EIVGeneral/manipulate_tensors.py @@ -19,7 +19,8 @@ def normalize_tensor(t, mean_std): return (t-mean_std[0])/mean_std[1] -def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True): +def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True, + normalization_list = None): """ Takes the tensors in `tensor_list`, adds random noise using the standard deviations in `noise_strength_list` and the seeds in `seed_list`, then, if @@ -31,15 +32,26 @@ def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True): :param noise_strength_list: A list of positive floats :param seed_list: A list of integers. :param normalize: A Boolean, defaults to True. + :param normalization_list: Either None (default) or a list of tensors. + If the latter, these tensors will be used for normalization and `normalize` + is assumed to be True. :returns: noisy_tensor_list, unnoisy_tensor_list, both normalized """ noisy_t_list = [] unnoisy_t_list = [] - for t,noise,seed in zip(tensor_list, noise_strength_list, seed_list): + if normalization_list is not None: + assert normalize + assert len(normalization_list) == len(tensor_list) + for i, (t,noise,seed) in enumerate(zip(tensor_list, noise_strength_list,\ + seed_list)): noisy_t = t + noise * torch.randn(t.shape, generator=torch.Generator().manual_seed(seed)) if normalize: - noisy_t_normalization = get_normalization(noisy_t) + if normalization_list is not None: + noisy_t_normalization =\ + get_normalization(normalization_list[i]) + else: + noisy_t_normalization = get_normalization(noisy_t) noisy_t = normalize_tensor(noisy_t, noisy_t_normalization) t = normalize_tensor(t, noisy_t_normalization) noisy_t_list.append(noisy_t)