From 00982faed6cd87f0ba7b301914037d8440ab11ab Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Thu, 6 Jan 2022 15:30:45 +0100 Subject: [PATCH] Added repeated linear and quadratic --- EIVPackage/EIVData/repeated_linear.py | 11 +++++++++++ EIVPackage/EIVData/repeated_quadratic.py | 11 +++++++++++ EIVPackage/EIVData/repeated_sampling.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 EIVPackage/EIVData/repeated_linear.py create mode 100644 EIVPackage/EIVData/repeated_quadratic.py diff --git a/EIVPackage/EIVData/repeated_linear.py b/EIVPackage/EIVData/repeated_linear.py new file mode 100644 index 0000000..6e0bd53 --- /dev/null +++ b/EIVPackage/EIVData/repeated_linear.py @@ -0,0 +1,11 @@ +""" +Repeated sampling from the linear dataset. +""" +from EIVData import linear + +from EIVData.repeated_sampling import repeated_sampling + +fixed_seed = 0 + +load_data = repeated_sampling(dataclass=linear, + fixed_seed=fixed_seed) diff --git a/EIVPackage/EIVData/repeated_quadratic.py b/EIVPackage/EIVData/repeated_quadratic.py new file mode 100644 index 0000000..4c40982 --- /dev/null +++ b/EIVPackage/EIVData/repeated_quadratic.py @@ -0,0 +1,11 @@ +""" +Repeated sampling from the quadratic dataset. +""" +from EIVData import quadratic + +from EIVData.repeated_sampling import repeated_sampling + +fixed_seed = 0 + +load_data = repeated_sampling(dataclass=quadratic, + fixed_seed=fixed_seed) diff --git a/EIVPackage/EIVData/repeated_sampling.py b/EIVPackage/EIVData/repeated_sampling.py index 771577d..465a20e 100644 --- a/EIVPackage/EIVData/repeated_sampling.py +++ b/EIVPackage/EIVData/repeated_sampling.py @@ -1,3 +1,7 @@ +""" +Contains the class `repeated_sampling` that can be used to generate +datasets for repeated sampling from datasets with a ground truth. +""" import sys import torch @@ -6,6 +10,19 @@ from torch.utils.data import TensorDataset from EIVGeneral.manipulate_tensors import add_noise class repeated_sampling(): + """ + A class for repeated sampling from datasets with a known ground truth and + known input and output noise. The class `dataclass` should contain a + `load_data` routine that returns a ground truth and two positive floats + `x_noise_strength` and `y_noise_strength` that will be used as the standard + deviation of input and output noise. + :param dataclass: A module that contains a routine `load_data`, which + accepts the keyword `return_ground_truth` and returns the noisy and true + train and test datasets, and two positive floats `x_noise_strength` and + `y_noise_strength`. + :param fixed_seed: Integer. The seed to load the unnoisy ground truth, + defaults to 0. + """ def __init__(self, dataclass, fixed_seed=0): self.dataclass = dataclass self.fixed_seed = fixed_seed -- GitLab