Skip to content
Snippets Groups Projects
Commit b0744268 authored by Jörg Martin's avatar Jörg Martin
Browse files

Allowed for different forward pass numbers

parent 32ada0a6
Branches
Tags
No related merge requests found
......@@ -11,6 +11,7 @@
"p": 0.1,
"lr_update": 20,
"std_y_update_points": [1,40],
"eiv_number_of_forward_draws": 10,
"eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.1],
......
......@@ -41,10 +41,16 @@ p = conf_dict["p"]
lr_update = conf_dict["lr_update"]
# offset before updating sigma_y after each epoch
std_y_update_points = conf_dict["std_y_update_points"]
# will be used to predict the RMSE and update sigma_y accordingly
# will be used to predict, to compute the RMSE and update sigma_y accordingly
try:
eiv_number_of_forward_draws = conf_dict['eiv_number_of_forward_draws']
except KeyError:
eiv_number_of_forward_draws = 5
eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"]
eiv_prediction_number_of_batches = \
conf_dict["eiv_prediction_number_of_batches"]
init_std_y_list = conf_dict["init_std_y_list"]
fixed_std_x = conf_dict['fixed_std_x']
gamma = conf_dict["gamma"]
......@@ -225,7 +231,9 @@ def train_on_data(init_std_y, seed):
# regularization
reg = unscaled_reg/len(train_data)
# create epoch_map
criterion = loss_functions.nll_eiv
def criterion(net, x, y, reg):
return loss_functions.nll_eiv(net, x, y, reg,
number_of_draws=eiv_number_of_forward_draws)
epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment