From 5fd3aa28e0f15592f426c509c2fa98468eb0d9fd Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Fri, 4 Feb 2022 13:46:02 +0000
Subject: [PATCH] Updated configs for cubic

---
 EIVPackage/EIVData/cubic.py               |  8 ++++----
 Experiments/configurations/eiv_cubic.json |  2 +-
 Experiments/plot_prediction.py            | 13 ++++++++-----
 3 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/EIVPackage/EIVData/cubic.py b/EIVPackage/EIVData/cubic.py
index 710ea87..dce6ea4 100644
--- a/EIVPackage/EIVData/cubic.py
+++ b/EIVPackage/EIVData/cubic.py
@@ -5,12 +5,12 @@ from torch.utils.data import TensorDataset
 from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
         unnormalize_tensor
 
-total_number_of_datapoints = 500
-input_range = [-4,4]
+total_number_of_datapoints = 1000
+input_range = [-1,1]
 slope = 1.0
 intercept = 0.0
-x_noise_strength = 0.2 
-y_noise_strength = 3
+x_noise_strength = 0.1 
+y_noise_strength = 0.1
 func = lambda true_x: slope * true_x**3 + intercept 
 
 def load_data(seed=0, splitting_part=0.8, normalize=True,
diff --git a/Experiments/configurations/eiv_cubic.json b/Experiments/configurations/eiv_cubic.json
index 2a3ac2c..0e9f9e8 100644
--- a/Experiments/configurations/eiv_cubic.json
+++ b/Experiments/configurations/eiv_cubic.json
@@ -16,7 +16,7 @@
 	"init_std_y_list": [0.5],
 	"gamma": 0.5,
 	"hidden_layers": [128, 128, 128, 128],
-	"fixed_std_x": 0.20,
+	"fixed_std_x": 0.10,
 	"seed_range": [0,10],
 	"gpu_number": 1
 }
diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py
index 9ba7fd8..c0e1817 100644
--- a/Experiments/plot_prediction.py
+++ b/Experiments/plot_prediction.py
@@ -216,10 +216,13 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
     return plotting_dictionary
 
 
-data_list = ['sine'] # short datanames
-list_x_range = [torch.linspace(-0.3,0.9, 50)]
-list_color = [('red','blue')]
-list_number_of_draws = [((100,5), 100)]
+data_list = ['linear','quadratic','cubic','sine'] # short datanames
+list_x_range = [torch.linspace(-1.0,1.0, 50),
+                torch.linspace(-1.0,1.0, 50),
+                torch.linspace(-1.0,1.0, 50),
+        torch.linspace(-0.2,0.8, 50)]
+list_color = [('red','blue')] * len(data_list)
+list_number_of_draws = [((100,5), 100)] * len(data_list)
 for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
     list_x_range, list_color, list_number_of_draws)):
     eiv_plotting_dictionary = compute_predictions_and_uncertainties(
@@ -234,7 +237,7 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
             number_of_draws=number_of_draws[1])
     input_dim = eiv_plotting_dictionary['input_dim']
     if input_dim == 1:
-        plt.figure(i)
+        plt.figure(i+1)
         plt.clf()
         x_values, y_values = eiv_plotting_dictionary['range_points']
         plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k')
-- 
GitLab