From f9f0119d8c480608c90350807a185bfddcb8c3f6 Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Fri, 28 Jan 2022 09:32:04 +0100 Subject: [PATCH] repeated sampling updated --- EIVPackage/EIVData/repeated_sampling.py | 53 +++++++++++++++++++++---- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/EIVPackage/EIVData/repeated_sampling.py b/EIVPackage/EIVData/repeated_sampling.py index bff0c9f..f6339df 100644 --- a/EIVPackage/EIVData/repeated_sampling.py +++ b/EIVPackage/EIVData/repeated_sampling.py @@ -7,7 +7,8 @@ import sys import torch from torch.utils.data import TensorDataset -from EIVGeneral.manipulate_tensors import add_noise +from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ + unnormalize_tensor class repeated_sampling(): """ @@ -25,17 +26,35 @@ class repeated_sampling(): """ def __init__(self, dataclass, fixed_seed=0): self.dataclass = dataclass + self.func = dataclass.func self.fixed_seed = fixed_seed self.x_noise_strength = dataclass.x_noise_strength self.y_noise_strength = dataclass.y_noise_strength def __call__(self,seed=0, splitting_part=0.8, normalize=True, - return_ground_truth=False): + return_ground_truth=False, + return_normalized_func=False): _, _, true_trainset, true_testset\ = self.dataclass.load_data( seed=self.fixed_seed, splitting_part=splitting_part, normalize=False, return_ground_truth=True) + """ + Loads repeated sampling data + :param seed: Seed for the used noise + :param splitting_part: Which fraction of the data to use as training + data. Defaults to 0.8. + :param normalize: Whether to normalize the data, defaults to True. + :param return_ground_truth: Boolean. If True, the unnoisy ground truth will + also be returned. Defaults to False. + :param return_normalized_func: Boolean (default False). If True, the + normalized version of the used function is returned as a last element. + :returns: trainset, testset, (, normalized_func) if + return_ground_truth is False, + else trainset, testset, true_trainset, + true_testset, (, normalized_func). The "true" datasets each return + **four tensors**: The true x,y and their noisy counterparts. + """ full_noisy_x = torch.concat((true_trainset.tensors[2], true_testset.tensors[2]), dim=0) full_noisy_y = torch.concat((true_trainset.tensors[3], @@ -46,23 +65,41 @@ class repeated_sampling(): # draw different seeds for noise and splitting seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\ generator=random_generator)] - (noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\ - add_noise((true_train_x, true_train_y), + # use same normalization for train and test + (noisy_train_x, noisy_train_y), (true_train_x, true_train_y),\ + normalization_list= add_noise((true_train_x, true_train_y), (self.x_noise_strength, self.y_noise_strength), seeds, normalize=normalize, - normalization_list=[full_noisy_x, full_noisy_y]) + normalization_list=[full_noisy_x, full_noisy_y], + return_normalization=True) (noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\ add_noise((true_test_x, true_test_y), (self.x_noise_strength, self.y_noise_strength), seeds, normalize=normalize, - normalization_list=[full_noisy_x, full_noisy_y]) + normalization_list=[full_noisy_x, full_noisy_y], + return_normalization=False) # same normalization + def normalized_func(x): + unnormalized_x = unnormalize_tensor(x, normalization_list[0]) + y = self.func(unnormalized_x) + normalized_y = normalize_tensor(y, normalization_list[1]) + return normalized_y trainset = TensorDataset(noisy_train_x, noisy_train_y) testset = TensorDataset(noisy_test_x, noisy_test_y) true_trainset = TensorDataset(true_train_x, true_train_y, noisy_train_x, noisy_train_y) true_testset = TensorDataset(true_test_x, true_test_y, noisy_test_x, noisy_test_y) + # return different objects, depending on Booleans if not return_ground_truth: - return trainset, testset + if not return_normalized_func: + return trainset, testset + else: + return trainset, testset, normalized_func else: - return trainset, testset, true_trainset, true_testset + if not return_normalized_func: + return trainset, testset, true_trainset,\ + true_testset + else: + return trainset, testset, true_trainset,\ + true_testset, normalized_func + -- GitLab