Skip to content
Snippets Groups Projects
Commit f9f0119d authored by Jörg Martin's avatar Jörg Martin
Browse files

repeated sampling updated

parent b4f1f576
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,8 @@ import sys ...@@ -7,7 +7,8 @@ import sys
import torch import torch
from torch.utils.data import TensorDataset from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
unnormalize_tensor
class repeated_sampling(): class repeated_sampling():
""" """
...@@ -25,17 +26,35 @@ class repeated_sampling(): ...@@ -25,17 +26,35 @@ class repeated_sampling():
""" """
def __init__(self, dataclass, fixed_seed=0): def __init__(self, dataclass, fixed_seed=0):
self.dataclass = dataclass self.dataclass = dataclass
self.func = dataclass.func
self.fixed_seed = fixed_seed self.fixed_seed = fixed_seed
self.x_noise_strength = dataclass.x_noise_strength self.x_noise_strength = dataclass.x_noise_strength
self.y_noise_strength = dataclass.y_noise_strength self.y_noise_strength = dataclass.y_noise_strength
def __call__(self,seed=0, splitting_part=0.8, normalize=True, def __call__(self,seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False): return_ground_truth=False,
return_normalized_func=False):
_, _, true_trainset, true_testset\ _, _, true_trainset, true_testset\
= self.dataclass.load_data( = self.dataclass.load_data(
seed=self.fixed_seed, splitting_part=splitting_part, seed=self.fixed_seed, splitting_part=splitting_part,
normalize=False, normalize=False,
return_ground_truth=True) return_ground_truth=True)
"""
Loads repeated sampling data
:param seed: Seed for the used noise
:param splitting_part: Which fraction of the data to use as training
data. Defaults to 0.8.
:param normalize: Whether to normalize the data, defaults to True.
:param return_ground_truth: Boolean. If True, the unnoisy ground truth will
also be returned. Defaults to False.
:param return_normalized_func: Boolean (default False). If True, the
normalized version of the used function is returned as a last element.
:returns: trainset, testset, (, normalized_func) if
return_ground_truth is False,
else trainset, testset, true_trainset,
true_testset, (, normalized_func). The "true" datasets each return
**four tensors**: The true x,y and their noisy counterparts.
"""
full_noisy_x = torch.concat((true_trainset.tensors[2], full_noisy_x = torch.concat((true_trainset.tensors[2],
true_testset.tensors[2]), dim=0) true_testset.tensors[2]), dim=0)
full_noisy_y = torch.concat((true_trainset.tensors[3], full_noisy_y = torch.concat((true_trainset.tensors[3],
...@@ -46,23 +65,41 @@ class repeated_sampling(): ...@@ -46,23 +65,41 @@ class repeated_sampling():
# draw different seeds for noise and splitting # draw different seeds for noise and splitting
seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\ seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\
generator=random_generator)] generator=random_generator)]
(noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\ # use same normalization for train and test
add_noise((true_train_x, true_train_y), (noisy_train_x, noisy_train_y), (true_train_x, true_train_y),\
normalization_list= add_noise((true_train_x, true_train_y),
(self.x_noise_strength, self.y_noise_strength), seeds, (self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize, normalize=normalize,
normalization_list=[full_noisy_x, full_noisy_y]) normalization_list=[full_noisy_x, full_noisy_y],
return_normalization=True)
(noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\ (noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
add_noise((true_test_x, true_test_y), add_noise((true_test_x, true_test_y),
(self.x_noise_strength, self.y_noise_strength), seeds, (self.x_noise_strength, self.y_noise_strength), seeds,
normalize=normalize, normalize=normalize,
normalization_list=[full_noisy_x, full_noisy_y]) normalization_list=[full_noisy_x, full_noisy_y],
return_normalization=False) # same normalization
def normalized_func(x):
unnormalized_x = unnormalize_tensor(x, normalization_list[0])
y = self.func(unnormalized_x)
normalized_y = normalize_tensor(y, normalization_list[1])
return normalized_y
trainset = TensorDataset(noisy_train_x, noisy_train_y) trainset = TensorDataset(noisy_train_x, noisy_train_y)
testset = TensorDataset(noisy_test_x, noisy_test_y) testset = TensorDataset(noisy_test_x, noisy_test_y)
true_trainset = TensorDataset(true_train_x, true_train_y, true_trainset = TensorDataset(true_train_x, true_train_y,
noisy_train_x, noisy_train_y) noisy_train_x, noisy_train_y)
true_testset = TensorDataset(true_test_x, true_test_y, true_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y) noisy_test_x, noisy_test_y)
# return different objects, depending on Booleans
if not return_ground_truth: if not return_ground_truth:
return trainset, testset if not return_normalized_func:
return trainset, testset
else:
return trainset, testset, normalized_func
else: else:
return trainset, testset, true_trainset, true_testset if not return_normalized_func:
return trainset, testset, true_trainset,\
true_testset
else:
return trainset, testset, true_trainset,\
true_testset, normalized_func
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment