From 5fd3aa28e0f15592f426c509c2fa98468eb0d9fd Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Fri, 4 Feb 2022 13:46:02 +0000 Subject: [PATCH] Updated configs for cubic --- EIVPackage/EIVData/cubic.py | 8 ++++---- Experiments/configurations/eiv_cubic.json | 2 +- Experiments/plot_prediction.py | 13 ++++++++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/EIVPackage/EIVData/cubic.py b/EIVPackage/EIVData/cubic.py index 710ea87..dce6ea4 100644 --- a/EIVPackage/EIVData/cubic.py +++ b/EIVPackage/EIVData/cubic.py @@ -5,12 +5,12 @@ from torch.utils.data import TensorDataset from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ unnormalize_tensor -total_number_of_datapoints = 500 -input_range = [-4,4] +total_number_of_datapoints = 1000 +input_range = [-1,1] slope = 1.0 intercept = 0.0 -x_noise_strength = 0.2 -y_noise_strength = 3 +x_noise_strength = 0.1 +y_noise_strength = 0.1 func = lambda true_x: slope * true_x**3 + intercept def load_data(seed=0, splitting_part=0.8, normalize=True, diff --git a/Experiments/configurations/eiv_cubic.json b/Experiments/configurations/eiv_cubic.json index 2a3ac2c..0e9f9e8 100644 --- a/Experiments/configurations/eiv_cubic.json +++ b/Experiments/configurations/eiv_cubic.json @@ -16,7 +16,7 @@ "init_std_y_list": [0.5], "gamma": 0.5, "hidden_layers": [128, 128, 128, 128], - "fixed_std_x": 0.20, + "fixed_std_x": 0.10, "seed_range": [0,10], "gpu_number": 1 } diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py index 9ba7fd8..c0e1817 100644 --- a/Experiments/plot_prediction.py +++ b/Experiments/plot_prediction.py @@ -216,10 +216,13 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, return plotting_dictionary -data_list = ['sine'] # short datanames -list_x_range = [torch.linspace(-0.3,0.9, 50)] -list_color = [('red','blue')] -list_number_of_draws = [((100,5), 100)] +data_list = ['linear','quadratic','cubic','sine'] # short datanames +list_x_range = [torch.linspace(-1.0,1.0, 50), + torch.linspace(-1.0,1.0, 50), + torch.linspace(-1.0,1.0, 50), + torch.linspace(-0.2,0.8, 50)] +list_color = [('red','blue')] * len(data_list) +list_number_of_draws = [((100,5), 100)] * len(data_list) for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, list_x_range, list_color, list_number_of_draws)): eiv_plotting_dictionary = compute_predictions_and_uncertainties( @@ -234,7 +237,7 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, number_of_draws=number_of_draws[1]) input_dim = eiv_plotting_dictionary['input_dim'] if input_dim == 1: - plt.figure(i) + plt.figure(i+1) plt.clf() x_values, y_values = eiv_plotting_dictionary['range_points'] plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k') -- GitLab