From 32ada0a68a5630b8881937a12663d06a172ffc09 Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Wed, 2 Feb 2022 15:46:29 +0000 Subject: [PATCH] train_noneiv fixed --- Experiments/plot_prediction.py | 13 ++++++++++--- Experiments/train_noneiv.py | 7 ++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py index 6702506..3d91d48 100644 --- a/Experiments/plot_prediction.py +++ b/Experiments/plot_prediction.py @@ -69,6 +69,11 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, # get datanames long_dataname = conf_dict["long_dataname"] short_dataname = conf_dict["short_dataname"] + try: + normalize = conf_dict['normalize'] + except KeyError: + # normalize by default + normalize = True # load hyperparameters @@ -98,8 +103,10 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, # determine dimensions - _, test_data, normalized_func = load_data(seed=plotting_seed, return_ground_truth=False, - return_normalized_func=True) + _, test_data, normalized_func = load_data(seed=plotting_seed, + return_ground_truth=False, + return_normalized_func=True, + normalize=normalize) input_dim = test_data[0][0].numel() output_dim = test_data[0][1].numel() assert output_dim == 1 @@ -209,7 +216,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, return plotting_dictionary -data_list = ['sine'] # short datanames +data_list = ['cubic'] # short datanames #TODO: Check which ranges are "correct" list_x_range = [torch.linspace(-2.5,2.5, 50)] list_color = [('red','blue')] diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py index a585ce6..f18617a 100644 --- a/Experiments/train_noneiv.py +++ b/Experiments/train_noneiv.py @@ -49,6 +49,11 @@ init_std_y_list = conf_dict["init_std_y_list"] gamma = conf_dict["gamma"] hidden_layers = conf_dict["hidden_layers"] seed_range = conf_dict['seed_range'] +try: + normalize = conf_dict['normalize'] +except KeyError: + # normalize by default + normalize = True print(f"Training on {long_dataname} data") @@ -190,7 +195,7 @@ def train_on_data(init_std_y, seed): set_seeds(seed) # load Datasets train_data, test_data = load_data(seed=seed, splitting_part=0.8, - normalize=True) + normalize=normalize) # make dataloaders train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) -- GitLab