From 32ada0a68a5630b8881937a12663d06a172ffc09 Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Wed, 2 Feb 2022 15:46:29 +0000
Subject: [PATCH] train_noneiv fixed

---
 Experiments/plot_prediction.py | 13 ++++++++++---
 Experiments/train_noneiv.py    |  7 ++++++-
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py
index 6702506..3d91d48 100644
--- a/Experiments/plot_prediction.py
+++ b/Experiments/plot_prediction.py
@@ -69,6 +69,11 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
     # get datanames
     long_dataname = conf_dict["long_dataname"]
     short_dataname = conf_dict["short_dataname"]
+    try:
+        normalize = conf_dict['normalize']
+    except KeyError:
+        # normalize by default
+        normalize = True
 
 
     # load hyperparameters
@@ -98,8 +103,10 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
 
 
     # determine dimensions
-    _, test_data, normalized_func = load_data(seed=plotting_seed, return_ground_truth=False,
-            return_normalized_func=True)
+    _, test_data, normalized_func = load_data(seed=plotting_seed,
+            return_ground_truth=False,
+            return_normalized_func=True,
+            normalize=normalize)
     input_dim = test_data[0][0].numel()
     output_dim = test_data[0][1].numel()
     assert output_dim == 1
@@ -209,7 +216,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
     return plotting_dictionary
 
 
-data_list = ['sine'] # short datanames
+data_list = ['cubic'] # short datanames
 #TODO: Check which ranges are "correct"
 list_x_range = [torch.linspace(-2.5,2.5, 50)]
 list_color = [('red','blue')]
diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py
index a585ce6..f18617a 100644
--- a/Experiments/train_noneiv.py
+++ b/Experiments/train_noneiv.py
@@ -49,6 +49,11 @@ init_std_y_list = conf_dict["init_std_y_list"]
 gamma = conf_dict["gamma"]
 hidden_layers = conf_dict["hidden_layers"]
 seed_range = conf_dict['seed_range']
+try:
+    normalize = conf_dict['normalize']
+except KeyError:
+    # normalize by default
+    normalize = True
 
 print(f"Training on {long_dataname} data")
 
@@ -190,7 +195,7 @@ def train_on_data(init_std_y, seed):
     set_seeds(seed)
     # load Datasets
     train_data, test_data = load_data(seed=seed, splitting_part=0.8,
-            normalize=True)
+            normalize=normalize)
     # make dataloaders
     train_dataloader = DataLoader(train_data, batch_size=batch_size, 
             shuffle=True)
-- 
GitLab