Skip to content
Snippets Groups Projects
Commit b0744268 authored by Jörg Martin's avatar Jörg Martin
Browse files

Allowed for different forward pass numbers

parent 32ada0a6
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
"p": 0.1, "p": 0.1,
"lr_update": 20, "lr_update": 20,
"std_y_update_points": [1,40], "std_y_update_points": [1,40],
"eiv_number_of_forward_draws": 10,
"eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.1], "init_std_y_list": [0.1],
......
...@@ -41,10 +41,16 @@ p = conf_dict["p"] ...@@ -41,10 +41,16 @@ p = conf_dict["p"]
lr_update = conf_dict["lr_update"] lr_update = conf_dict["lr_update"]
# offset before updating sigma_y after each epoch # offset before updating sigma_y after each epoch
std_y_update_points = conf_dict["std_y_update_points"] std_y_update_points = conf_dict["std_y_update_points"]
# will be used to predict the RMSE and update sigma_y accordingly
# will be used to predict, to compute the RMSE and update sigma_y accordingly
try:
eiv_number_of_forward_draws = conf_dict['eiv_number_of_forward_draws']
except KeyError:
eiv_number_of_forward_draws = 5
eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"] eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"]
eiv_prediction_number_of_batches = \ eiv_prediction_number_of_batches = \
conf_dict["eiv_prediction_number_of_batches"] conf_dict["eiv_prediction_number_of_batches"]
init_std_y_list = conf_dict["init_std_y_list"] init_std_y_list = conf_dict["init_std_y_list"]
fixed_std_x = conf_dict['fixed_std_x'] fixed_std_x = conf_dict['fixed_std_x']
gamma = conf_dict["gamma"] gamma = conf_dict["gamma"]
...@@ -225,7 +231,9 @@ def train_on_data(init_std_y, seed): ...@@ -225,7 +231,9 @@ def train_on_data(init_std_y, seed):
# regularization # regularization
reg = unscaled_reg/len(train_data) reg = unscaled_reg/len(train_data)
# create epoch_map # create epoch_map
criterion = loss_functions.nll_eiv def criterion(net, x, y, reg):
return loss_functions.nll_eiv(net, x, y, reg,
number_of_draws=eiv_number_of_forward_draws)
epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader, epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map, criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment