From b07442689220c6ebaf093c7f2c81e3d2d986522b Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Fri, 4 Feb 2022 08:30:32 +0000 Subject: [PATCH] Allowed for different forward pass numbers --- Experiments/configurations/eiv_sine.json | 1 + Experiments/train_eiv.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/Experiments/configurations/eiv_sine.json b/Experiments/configurations/eiv_sine.json index 6d0a36d..b2cce34 100644 --- a/Experiments/configurations/eiv_sine.json +++ b/Experiments/configurations/eiv_sine.json @@ -11,6 +11,7 @@ "p": 0.1, "lr_update": 20, "std_y_update_points": [1,40], + "eiv_number_of_forward_draws": 10, "eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.1], diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py index 3ecccea..973fc59 100644 --- a/Experiments/train_eiv.py +++ b/Experiments/train_eiv.py @@ -41,10 +41,16 @@ p = conf_dict["p"] lr_update = conf_dict["lr_update"] # offset before updating sigma_y after each epoch std_y_update_points = conf_dict["std_y_update_points"] -# will be used to predict the RMSE and update sigma_y accordingly + +# will be used to predict, to compute the RMSE and update sigma_y accordingly +try: + eiv_number_of_forward_draws = conf_dict['eiv_number_of_forward_draws'] +except KeyError: + eiv_number_of_forward_draws = 5 eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"] eiv_prediction_number_of_batches = \ conf_dict["eiv_prediction_number_of_batches"] + init_std_y_list = conf_dict["init_std_y_list"] fixed_std_x = conf_dict['fixed_std_x'] gamma = conf_dict["gamma"] @@ -225,7 +231,9 @@ def train_on_data(init_std_y, seed): # regularization reg = unscaled_reg/len(train_data) # create epoch_map - criterion = loss_functions.nll_eiv + def criterion(net, x, y, reg): + return loss_functions.nll_eiv(net, x, y, reg, + number_of_draws=eiv_number_of_forward_draws) epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader, test_dataloader=test_dataloader, criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map, -- GitLab