Skip to content
Snippets Groups Projects
Commit 32ada0a6 authored by Jörg Martin's avatar Jörg Martin
Browse files

train_noneiv fixed

parent e2c57f2d
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,11 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, ...@@ -69,6 +69,11 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
# get datanames # get datanames
long_dataname = conf_dict["long_dataname"] long_dataname = conf_dict["long_dataname"]
short_dataname = conf_dict["short_dataname"] short_dataname = conf_dict["short_dataname"]
try:
normalize = conf_dict['normalize']
except KeyError:
# normalize by default
normalize = True
# load hyperparameters # load hyperparameters
...@@ -98,8 +103,10 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, ...@@ -98,8 +103,10 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
# determine dimensions # determine dimensions
_, test_data, normalized_func = load_data(seed=plotting_seed, return_ground_truth=False, _, test_data, normalized_func = load_data(seed=plotting_seed,
return_normalized_func=True) return_ground_truth=False,
return_normalized_func=True,
normalize=normalize)
input_dim = test_data[0][0].numel() input_dim = test_data[0][0].numel()
output_dim = test_data[0][1].numel() output_dim = test_data[0][1].numel()
assert output_dim == 1 assert output_dim == 1
...@@ -209,7 +216,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, ...@@ -209,7 +216,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
return plotting_dictionary return plotting_dictionary
data_list = ['sine'] # short datanames data_list = ['cubic'] # short datanames
#TODO: Check which ranges are "correct" #TODO: Check which ranges are "correct"
list_x_range = [torch.linspace(-2.5,2.5, 50)] list_x_range = [torch.linspace(-2.5,2.5, 50)]
list_color = [('red','blue')] list_color = [('red','blue')]
......
...@@ -49,6 +49,11 @@ init_std_y_list = conf_dict["init_std_y_list"] ...@@ -49,6 +49,11 @@ init_std_y_list = conf_dict["init_std_y_list"]
gamma = conf_dict["gamma"] gamma = conf_dict["gamma"]
hidden_layers = conf_dict["hidden_layers"] hidden_layers = conf_dict["hidden_layers"]
seed_range = conf_dict['seed_range'] seed_range = conf_dict['seed_range']
try:
normalize = conf_dict['normalize']
except KeyError:
# normalize by default
normalize = True
print(f"Training on {long_dataname} data") print(f"Training on {long_dataname} data")
...@@ -190,7 +195,7 @@ def train_on_data(init_std_y, seed): ...@@ -190,7 +195,7 @@ def train_on_data(init_std_y, seed):
set_seeds(seed) set_seeds(seed)
# load Datasets # load Datasets
train_data, test_data = load_data(seed=seed, splitting_part=0.8, train_data, test_data = load_data(seed=seed, splitting_part=0.8,
normalize=True) normalize=normalize)
# make dataloaders # make dataloaders
train_dataloader = DataLoader(train_data, batch_size=batch_size, train_dataloader = DataLoader(train_data, batch_size=batch_size,
shuffle=True) shuffle=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment