Skip to content
Snippets Groups Projects
Commit 947c258d authored by Jörg Martin's avatar Jörg Martin
Browse files

Corrected normalization in repeated_sampling

parent 6e80eaff
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,12 @@ class repeated_sampling(): ...@@ -34,7 +34,12 @@ class repeated_sampling():
_, _, true_trainset, true_testset\ _, _, true_trainset, true_testset\
= self.dataclass.load_data( = self.dataclass.load_data(
seed=self.fixed_seed, splitting_part=splitting_part, seed=self.fixed_seed, splitting_part=splitting_part,
normalize=False,
return_ground_truth=True) return_ground_truth=True)
full_noisy_x = torch.concat((true_trainset.tensors[2],
true_testset.tensors[2]), dim=0)
full_noisy_y = torch.concat((true_trainset.tensors[3],
true_testset.tensors[3]), dim=0)
true_train_x, true_train_y = true_trainset.tensors[:2] true_train_x, true_train_y = true_trainset.tensors[:2]
true_test_x, true_test_y = true_testset.tensors[:2] true_test_x, true_test_y = true_testset.tensors[:2]
random_generator = torch.Generator().manual_seed(seed) random_generator = torch.Generator().manual_seed(seed)
...@@ -45,13 +50,12 @@ class repeated_sampling(): ...@@ -45,13 +50,12 @@ class repeated_sampling():
add_noise((true_train_x, true_train_y), add_noise((true_train_x, true_train_y),
(self.x_noise_strength, self.y_noise_strength), seeds, (self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize, normalize=normalize,
normalization_list=true_trainset.tensors[2:]) normalization_list=[full_noisy_x, full_noisy_y])
(noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\ (noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
add_noise((true_test_x, true_test_y), add_noise((true_test_x, true_test_y),
(self.x_noise_strength, self.y_noise_strength), seeds, (self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize, normalize=normalize,
# normalize both datasets with train set normalization_list=[full_noisy_x, full_noisy_y])
normalization_list=true_trainset.tensors[2:])
trainset = TensorDataset(noisy_train_x, noisy_train_y) trainset = TensorDataset(noisy_train_x, noisy_train_y)
testset = TensorDataset(noisy_test_x, noisy_test_y) testset = TensorDataset(noisy_test_x, noisy_test_y)
true_trainset = TensorDataset(true_train_x, true_train_y, true_trainset = TensorDataset(true_train_x, true_train_y,
...@@ -62,5 +66,3 @@ class repeated_sampling(): ...@@ -62,5 +66,3 @@ class repeated_sampling():
return trainset, testset return trainset, testset
else: else:
return trainset, testset, true_trainset, true_testset return trainset, testset, true_trainset, true_testset
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment