Skip to content
Snippets Groups Projects
Commit 947c258d authored by Jörg Martin's avatar Jörg Martin
Browse files

Corrected normalization in repeated_sampling

parent 6e80eaff
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,12 @@ class repeated_sampling():
_, _, true_trainset, true_testset\
= self.dataclass.load_data(
seed=self.fixed_seed, splitting_part=splitting_part,
normalize=False,
return_ground_truth=True)
full_noisy_x = torch.concat((true_trainset.tensors[2],
true_testset.tensors[2]), dim=0)
full_noisy_y = torch.concat((true_trainset.tensors[3],
true_testset.tensors[3]), dim=0)
true_train_x, true_train_y = true_trainset.tensors[:2]
true_test_x, true_test_y = true_testset.tensors[:2]
random_generator = torch.Generator().manual_seed(seed)
......@@ -45,13 +50,12 @@ class repeated_sampling():
add_noise((true_train_x, true_train_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize,
normalization_list=true_trainset.tensors[2:])
normalization_list=[full_noisy_x, full_noisy_y])
(noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
add_noise((true_test_x, true_test_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize,
# normalize both datasets with train set
normalization_list=true_trainset.tensors[2:])
normalization_list=[full_noisy_x, full_noisy_y])
trainset = TensorDataset(noisy_train_x, noisy_train_y)
testset = TensorDataset(noisy_test_x, noisy_test_y)
true_trainset = TensorDataset(true_train_x, true_train_y,
......@@ -62,5 +66,3 @@ class repeated_sampling():
return trainset, testset
else:
return trainset, testset, true_trainset, true_testset
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment