Skip to content
Snippets Groups Projects
Commit abf78c52 authored by Jörg Martin's avatar Jörg Martin
Browse files

Added repeated_sampling

parent 4d96c6dd
No related branches found
No related tags found
No related merge requests found
import sys
import torch
from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise
class repeated_sampling():
def __init__(self, dataclass, fixed_seed=0):
self.dataclass = dataclass
......@@ -7,12 +14,36 @@ class repeated_sampling():
def __call__(self,seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False):
_, _, _, true_testset = self.dataclass.load_data(
_, _, true_trainset, true_testset\
= self.dataclass.load_data(
seed=self.fixed_seed, splitting_part=splitting_part,
normalize=normalize,
return_ground_truth=return_ground_truth)
true_x, true_y = true_testset.tensors[:2]
return_ground_truth=True)
true_train_x, true_train_y = true_trainset.tensors[:2]
true_test_x, true_test_y = true_testset.tensors[:2]
random_generator = torch.Generator().manual_seed(seed)
# draw different seeds for noise and splitting
seeds = [int(t) for t in torch.randint(0,sys.maxsize,(2,),\
generator=random_generator)]
(noisy_train_x, noisy_train_y), (true_train_x, true_train_y) =\
add_noise((true_train_x, true_train_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=True,
normalization_list=true_trainset.tensors[2:])
(noisy_test_x, noisy_test_y), (true_test_x, true_test_y) =\
add_noise((true_test_x, true_test_y),
(self.x_noise_strength, self.y_noise_strength), seeds,
normalize=True,
normalization_list=true_testset.tensors[2:])
trainset = TensorDataset(noisy_train_x, noisy_train_y)
testset = TensorDataset(noisy_test_x, noisy_test_y)
true_trainset = TensorDataset(true_train_x, true_train_y,
noisy_train_x, noisy_train_y)
true_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y)
if not return_ground_truth:
return trainset, testset
else:
return trainset, testset, true_trainset, true_testset
......@@ -19,7 +19,8 @@ def normalize_tensor(t, mean_std):
return (t-mean_std[0])/mean_std[1]
def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True):
def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True,
normalization_list = None):
"""
Takes the tensors in `tensor_list`, adds random noise using the standard
deviations in `noise_strength_list` and the seeds in `seed_list`, then, if
......@@ -31,15 +32,26 @@ def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True):
:param noise_strength_list: A list of positive floats
:param seed_list: A list of integers.
:param normalize: A Boolean, defaults to True.
:param normalization_list: Either None (default) or a list of tensors.
If the latter, these tensors will be used for normalization and `normalize`
is assumed to be True.
:returns: noisy_tensor_list, unnoisy_tensor_list, both normalized
"""
noisy_t_list = []
unnoisy_t_list = []
for t,noise,seed in zip(tensor_list, noise_strength_list, seed_list):
if normalization_list is not None:
assert normalize
assert len(normalization_list) == len(tensor_list)
for i, (t,noise,seed) in enumerate(zip(tensor_list, noise_strength_list,\
seed_list)):
noisy_t = t + noise * torch.randn(t.shape,
generator=torch.Generator().manual_seed(seed))
if normalize:
noisy_t_normalization = get_normalization(noisy_t)
if normalization_list is not None:
noisy_t_normalization =\
get_normalization(normalization_list[i])
else:
noisy_t_normalization = get_normalization(noisy_t)
noisy_t = normalize_tensor(noisy_t, noisy_t_normalization)
t = normalize_tensor(t, noisy_t_normalization)
noisy_t_list.append(noisy_t)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment