diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py index d9c0f1ba854549882ba22499232e8d33e1619334..8928b3fde92517f7397eb0542f8189ea9bf132bd 100644 --- a/EIVPackage/EIVArchitectures/Networks.py +++ b/EIVPackage/EIVArchitectures/Networks.py @@ -32,7 +32,7 @@ class FNNEIV(nn.Module): **Note**: - To change the deming factor afterwards, use the method `change_deming` - To change fixed_std_x afterwards, use the method `change_fixed_std_x` - - To change std_y use the method `change_std_x` + - To change std_y use the method `change_std_y` """ LeakyReLUSlope = 1e-2 def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0, @@ -231,8 +231,9 @@ class FNNEIV(nn.Module): sigma = torch.mean(sigma, dim=1) return pred, sigma - def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True, - average_batch_dimension=True, scale_labels=None, + def predictive_logdensity(self, x_or_predictions, y, + number_of_draws=[100, 5], number_of_parameter_chunks = None, + remove_graph=True, average_batch_dimension=True, scale_labels=None, decouple_dimensions=False): """ Computes the logarithm of the predictive density evaluated at `y`. If @@ -264,7 +265,8 @@ class FNNEIV(nn.Module): False. """ if type(x_or_predictions) is torch.tensor: - out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws, + out, sigmas = self.predict(x_or_predictions, + number_of_draws=number_of_draws, number_of_parameter_chunks=number_of_parameter_chunks, remove_graph=remove_graph, take_average_of_prediction=False) @@ -307,7 +309,8 @@ class FNNEIV(nn.Module): else: return predictive_log_density_values - def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None, + def predict_mean_and_unc(self, x, number_of_draws=[100,5], + number_of_parameter_chunks = None, remove_graph=True): """ Take the mean and standard deviation over `number_of_draws` forward @@ -349,6 +352,7 @@ class FNNBer(nn.Module): :param h: A list specifying the number of neurons in each layer. :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus be updated during optimization. Defaults to False. + To change std_y use the method `change_std_y` """ LeakyReLUSlope = 1e-2 def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1], diff --git a/EIVPackage/EIVTrainingRoutines/loss_functions.py b/EIVPackage/EIVTrainingRoutines/loss_functions.py index 88e924df76c1b2bd00e4c4abcdf18089f699539e..e1c552d3d621dd29c91b555c9c50a1089c836a12 100644 --- a/EIVPackage/EIVTrainingRoutines/loss_functions.py +++ b/EIVPackage/EIVTrainingRoutines/loss_functions.py @@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg): :param y: A torch.tensor, the output. :param reg: A non-negative float, the regularization. """ - out, std_y = net(x) + out, sigmas = net(x) # Add label dimension to y if missing if len(y.shape) <= 1: y = y.view((-1,1)) + # squeeze last dimensions into one + y = y.view((*y.shape[:1], -1)) + sigmas = sigmas.view((*sigmas.shape[:1], -1)) + out = out.view((*out.shape[:1], -1)) assert out.shape == y.shape - neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \ - + ((out-y)**2)/(2*std_y**2)) + neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \ + + ((out-y)**2)/(2*sigmas**2), dim=1)) regularization = net.regularizer(x, lamb=reg) return neg_log_likelihood + regularization @@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5): regularization = net.regularizer(x, lamb=reg) # repeat_tensors x, y = repeat_tensors(x, y, number_of_draws=number_of_draws) - pred, sigma = net(x, repetition=number_of_draws) + out, sigmas = net(x, repetition=number_of_draws) # split into chunks of size number_of_draws along batch dimension - pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws) - assert pred.shape == y.shape + out, sigmas, y = reshape_to_chunks(out, sigmas, y, + number_of_draws=number_of_draws) + # squeeze last dimensions into one + y = y.view((*y.shape[:2], -1)) + sigmas = sigmas.view((*sigmas.shape[:2], -1)) + out = out.view((*out.shape[:2], -1)) + assert out.shape == y.shape # apply logsumexp to chunks and average the results - nll = -1 * (torch.logsumexp(-1 * sigma.log() - -((y-pred)**2)/(2*sigma**2), dim=1) + nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi) + -((y-out)**2)/(2*sigmas**2), dim=2), dim=1) - np.log(number_of_draws)).mean() return nll + regularization diff --git a/EIVPackage/EIVTrainingRoutines/train_and_store.py b/EIVPackage/EIVTrainingRoutines/train_and_store.py index eb0378e953c00eee2bb0465b1aea8580b92baa01..908c4457baad75de7fbe40c66716ae3ebe7f5c7c 100644 --- a/EIVPackage/EIVTrainingRoutines/train_and_store.py +++ b/EIVPackage/EIVTrainingRoutines/train_and_store.py @@ -69,6 +69,15 @@ class TrainEpoch(): """ pass + + def post_train_update(self, net, epoch=None): + """ + Will be executed after the training is finished + :param net: The current net, a torch.nn.Module + :param epoch: Tue last epochn number, an integer. + """ + pass + def extra_report(self, net, step): """ Overwrite for reporting on state of net @@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): Calls `epoch_map` with `epoch` and the current epoch number `number_of_epochs` times and stores a list of its output as a pickled file under `save_file`. + After this (the training) is done, `epoch_map.post_train_update(net, + epoch_number)` is called, if existent. **Note**: The output of `epoch_map` is supposed to consist of 4 specific arguments, see below. :param net: A torch.nn.Module @@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): test_loss_collection.append(test_loss) std_x_collection.append(std_x) std_y_collection.append(std_y) + try: + epoch_map.post_train_update(net, number_of_epochs-1) + print('Training done. Performed post training updating.') + except AttributeError: + print('Training done. No post train updating performed.') # Saving state_dict = net.state_dict() to_save = { diff --git a/Experiments/configurations/eiv_california.json b/Experiments/configurations/eiv_california.json index 29a71622d99299d2c7ccde4f148b560de52418ee..6cfa57f9f95c63dd347629d0acf4df37b96b180f 100644 --- a/Experiments/configurations/eiv_california.json +++ b/Experiments/configurations/eiv_california.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.1, "lr_update": 20, - "epoch_offset": 10, + "std_y_update_points": [10,5], "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_concrete.json b/Experiments/configurations/eiv_concrete.json index 11121c144b429ad2f970aa815f2a4e7fed8a893a..0e234346f6dda090daacb9aebf33f8c3f666081c 100644 --- a/Experiments/configurations/eiv_concrete.json +++ b/Experiments/configurations/eiv_concrete.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 10, + "std_y_update_points": 10, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_energy.json b/Experiments/configurations/eiv_energy.json index bca6774dd7b4ce71ca78b487246d6a5e7f4b76f3..f13ab1b52417aa7ed20c4db9d75a0c653ff0b321 100644 --- a/Experiments/configurations/eiv_energy.json +++ b/Experiments/configurations/eiv_energy.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 100, - "epoch_offset": 100, + "std_y_update_points": 100, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_kin8nm.json b/Experiments/configurations/eiv_kin8nm.json index fa3718a096e417c2ab456a8589009d0052ae6670..df2e4bafbee99aca95dcd36f8f67097d3a6ce3b5 100644 --- a/Experiments/configurations/eiv_kin8nm.json +++ b/Experiments/configurations/eiv_kin8nm.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 19, + "std_y_update_points": 19, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_msd.json b/Experiments/configurations/eiv_msd.json index 1c2527630dac6172c9271ee230b5c34a2abfd71f..a738f2e8bf034e173f60b17e5cd6d2f7d6fdb856 100644 --- a/Experiments/configurations/eiv_msd.json +++ b/Experiments/configurations/eiv_msd.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 4, - "epoch_offset": 4, + "std_y_update_points": 4, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_naval.json b/Experiments/configurations/eiv_naval.json index 3358831d627188d5c2b93ceaa748b77de893c23d..092f4f822abf0e9c677affa8ae275e46309435b8 100644 --- a/Experiments/configurations/eiv_naval.json +++ b/Experiments/configurations/eiv_naval.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 20, + "std_y_update_points": 20, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_power.json b/Experiments/configurations/eiv_power.json index 842f6e2662bef07696de1eb5f041e677b586dcd6..ab5649f61de1ea2214249b8a07e25cd5b2adea9e 100644 --- a/Experiments/configurations/eiv_power.json +++ b/Experiments/configurations/eiv_power.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 10, - "epoch_offset": 15, + "std_y_update_points": 15, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_protein.json b/Experiments/configurations/eiv_protein.json index 97f5946d431d04bc83c71904846152160fa0c7d7..9623799e3d8b85b2f568987afb4d406fbb7fe45c 100644 --- a/Experiments/configurations/eiv_protein.json +++ b/Experiments/configurations/eiv_protein.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 10, - "epoch_offset": 10, + "std_y_update_points": 10, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_wine.json b/Experiments/configurations/eiv_wine.json index 0b40a61272719cee6afc839f171103c54a830af8..f0beaf490846c2281a477f627f05c98b18cfa7f3 100644 --- a/Experiments/configurations/eiv_wine.json +++ b/Experiments/configurations/eiv_wine.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 30, - "epoch_offset": 50, + "std_y_update_points": 50, "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/eiv_yacht.json b/Experiments/configurations/eiv_yacht.json index c668113462816474eca51186853bbc241ca3160d..7617289d9e49bb9ceddd9962d6d865fee3413ae7 100644 --- a/Experiments/configurations/eiv_yacht.json +++ b/Experiments/configurations/eiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "epoch_offset": 20, + "std_y_update_points": [1,500], "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_california.json b/Experiments/configurations/noneiv_california.json index 5005b43d8272c22893903a8a6c7494edeed4e554..e0729764699ba987d4a8e8b4e344601303c16f3b 100644 --- a/Experiments/configurations/noneiv_california.json +++ b/Experiments/configurations/noneiv_california.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.1, "lr_update": 20, - "epoch_offset": 0 , + "std_y_update_points": [10,5] , "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_concrete.json b/Experiments/configurations/noneiv_concrete.json index 66552b7193675eda193479ff9909494fc2bc2133..15e28c53a67064987c267f932733a09d00698c43 100644 --- a/Experiments/configurations/noneiv_concrete.json +++ b/Experiments/configurations/noneiv_concrete.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 10, + "std_y_update_points": 10, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_energy.json b/Experiments/configurations/noneiv_energy.json index 74eb45c721793e94f0efbb4d67a3f4acaa05e8da..18c068ea7a9cb96f1075a522ae3f3ad6552bacac 100644 --- a/Experiments/configurations/noneiv_energy.json +++ b/Experiments/configurations/noneiv_energy.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 100, - "epoch_offset": 100, + "std_y_update_points": 100, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_kin8nm.json b/Experiments/configurations/noneiv_kin8nm.json index 22615f61b00cf2b06c6e5fbc0051f3c86e8f92cb..8ecf817149b1a37dd9a4ab5c44fd2700795fbcde 100644 --- a/Experiments/configurations/noneiv_kin8nm.json +++ b/Experiments/configurations/noneiv_kin8nm.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 19, + "std_y_update_points": 19, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_msd.json b/Experiments/configurations/noneiv_msd.json index 1c4f338fe3152c6a21d66cff05376efcbf88f148..a0e2f73beb63e165c7da5266b774f227e90310c1 100644 --- a/Experiments/configurations/noneiv_msd.json +++ b/Experiments/configurations/noneiv_msd.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 4, - "epoch_offset": 4, + "std_y_update_points": 4, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_naval.json b/Experiments/configurations/noneiv_naval.json index 5dea6c49692e193dc25818ae83c6f467860a1520..2da83b73101d1a5a957f9773053b9610b0cb8aff 100644 --- a/Experiments/configurations/noneiv_naval.json +++ b/Experiments/configurations/noneiv_naval.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 20, - "epoch_offset": 20, + "std_y_update_points": 20, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_power.json b/Experiments/configurations/noneiv_power.json index 6e524e6c4f684f9c0affc983f3d2d5a7e4597248..48aac9dcf4cc208095142760d7bf286c2589b22b 100644 --- a/Experiments/configurations/noneiv_power.json +++ b/Experiments/configurations/noneiv_power.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 10, - "epoch_offset": 15, + "std_y_update_points": 15, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_protein.json b/Experiments/configurations/noneiv_protein.json index 1edd04fa1d28b8634080beec1865d30b6f854153..454deaef823c1a25c03588c4f430b5d0f67e5e93 100644 --- a/Experiments/configurations/noneiv_protein.json +++ b/Experiments/configurations/noneiv_protein.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 10, - "epoch_offset": 10, + "std_y_update_points": 10, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_wine.json b/Experiments/configurations/noneiv_wine.json index 717fb33c371b6d3743da58c4489ed96bd3e62ea6..ddd2199fae58e2dd62409d554f51983b5bd1ef02 100644 --- a/Experiments/configurations/noneiv_wine.json +++ b/Experiments/configurations/noneiv_wine.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 30, - "epoch_offset": 50, + "std_y_update_points": 50, "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_yacht.json b/Experiments/configurations/noneiv_yacht.json index edb775cc550ec8b2ca8ed32a4962cb64eea20e65..81a0d1459d5faf474a922d7916336c57c255c07f 100644 --- a/Experiments/configurations/noneiv_yacht.json +++ b/Experiments/configurations/noneiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "epoch_offset": 20, + "std_y_update_points": [1,500], "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/evaluate_tabular.py b/Experiments/evaluate_tabular.py index 7493e13b37fd417bb00fec04873d740a79781950..ba0654f071cdb96cfa72145fe048da4b088bc16b 100644 --- a/Experiments/evaluate_tabular.py +++ b/Experiments/evaluate_tabular.py @@ -144,7 +144,7 @@ def collect_metrics(x,y, seed=0, noise_state = net.noise_is_on net.train() net.noise_on() - not_averaged_predictions = net.predict(x, number_of_draws=noneiv_number_of_draws, + not_averaged_predictions = net.predict(x, number_of_draws=eiv_number_of_draws, take_average_of_prediction=False) eiv_mean = torch.mean(not_averaged_predictions[0], dim=1) if len(y.shape) <= 1: diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py index 78f6d61d5f0af1d0e29de0f10f99275db40f42c4..eb8d9d88292dd63130b38bc78db4eb1ed1db9014 100644 --- a/Experiments/train_eiv.py +++ b/Experiments/train_eiv.py @@ -40,10 +40,11 @@ report_point = conf_dict["report_point"] p = conf_dict["p"] lr_update = conf_dict["lr_update"] # offset before updating sigma_y after each epoch -epoch_offset = conf_dict["epoch_offset"] +std_y_update_points = conf_dict["std_y_update_points"] # will be used to predict the RMSE and update sigma_y accordingly eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"] -eiv_prediction_number_of_batches = conf_dict["eiv_prediction_number_of_batches"] +eiv_prediction_number_of_batches = \ + conf_dict["eiv_prediction_number_of_batches"] init_std_y_list = conf_dict["init_std_y_list"] fixed_std_x = conf_dict['fixed_std_x'] gamma = conf_dict["gamma"] @@ -54,7 +55,8 @@ print(f"Training on {long_dataname} data") try: gpu_number = conf_dict["gpu_number"] - device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') + device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() + else 'cpu') except KeyError: device = torch.device('cpu') @@ -65,7 +67,7 @@ seed_list = range(seed_range[0], seed_range[1]) def set_seeds(seed): torch.backends.cudnn.benchmark = False - np.random.seed(seed) + np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -83,33 +85,71 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) + def update_std_y(self, net): + """ + Update the std_y of `net` via the RMSE of the prediction. + """ + net_train_state = net.training + net_noise_state = net.noise_is_on + net.train() + net.noise_on() + pred_collection = [] + y_collection = [] + for i, (x,y) in enumerate(self.train_dataloader): + if i >= eiv_prediction_number_of_batches: + break + if len(y.shape) <= 1: + y = y.view((-1,1)) + x,y = x.to(device), y.to(device) + pred, _ = net.predict(x, + number_of_draws=eiv_prediction_number_of_draws, + remove_graph = True, + take_average_of_prediction=True) + pred_collection.append(pred) + y_collection.append(y) + pred_collection = torch.cat(pred_collection, dim=0) + y_collection = torch.cat(y_collection, dim=0) + assert pred_collection.shape == y_collection.shape + rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) + net.change_std_y(rmse) + if not net_train_state: + net.eval() + if not net_noise_state: + net.noise_off() + + def check_if_update_std_y(self, epoch): + """ + Check whether to update std_y according to `epoch_number` and + `std_y_update_points`. If the later is an integer, after all epochs + greater than this number an update will be made (i.e. `True` will + be returned). If it is a list, only `epoch_number` greater than + `std_y_update_points[0]` that divide `std_y_update_points[1]` will + result in a True. + """ + if type(std_y_update_points) is int: + return epoch >= std_y_update_points + else: + assert type(std_y_update_points) is list + return epoch >= std_y_update_points[0]\ + and epoch % std_y_update_points[1] == 0 def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ - if epoch >= epoch_offset: - pred_collection = [] - y_collection = [] - for i, (x,y) in enumerate(self.train_dataloader): - if i >= eiv_prediction_number_of_batches: - break - if len(y.shape) <= 1: - y = y.view((-1,1)) - x,y = x.to(device), y.to(device) - pred, _ = net.predict(x, - number_of_draws=eiv_prediction_number_of_draws, - remove_graph = True, - take_average_of_prediction=True) - pred_collection.append(pred) - y_collection.append(y) - pred_collection = torch.cat(pred_collection, dim=0) - y_collection = torch.cat(y_collection, dim=0) - assert pred_collection.shape == y_collection.shape - rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) - net.change_std_y(rmse) + if self.check_if_update_std_y(epoch): + self.update_std_y(net) self.lr_scheduler.step() + def post_train_update(self, net, epoch): + """ + Overwrites the corresponding method. If std_y of `net` was not updated + in the last training step, update it when finished with training. + `epoch` should be the number of the last training epoch. + """ + if not self.check_if_update_std_y(epoch): + self.update_std_y(net) + def extra_report(self, net, i): """ Overwrites the corresponding method @@ -119,7 +159,6 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): rmse_chain.append(rmse) writer.add_scalar('RMSE', rmse, self.total_count) writer.add_scalar('std_y', self.last_std_y, self.total_count) - writer.add_scalar('RMSE:std_y', rmse/self.last_std_y, self.total_count) writer.add_scalar('train loss', self.last_train_loss, self.total_count) writer.add_scalar('test loss', self.last_test_loss, self.total_count) print(f'RMSE {rmse:.3f}') diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py index e63dbf233062d7ee47465cd5cd50f5708271cd68..316dda1e8f5fbf42f46ae9135dadd5c803623df8 100644 --- a/Experiments/train_noneiv.py +++ b/Experiments/train_noneiv.py @@ -40,10 +40,11 @@ report_point = conf_dict["report_point"] p = conf_dict["p"] lr_update = conf_dict["lr_update"] # offset before updating sigma_y after each epoch -epoch_offset = conf_dict["epoch_offset"] +std_y_update_points = conf_dict["std_y_update_points"] # will be used to predict the RMSE and update sigma_y accordingly noneiv_prediction_number_of_draws = conf_dict["noneiv_prediction_number_of_draws"] -noneiv_prediction_number_of_batches = conf_dict["noneiv_prediction_number_of_batches"] +noneiv_prediction_number_of_batches = \ + conf_dict["noneiv_prediction_number_of_batches"] init_std_y_list = conf_dict["init_std_y_list"] gamma = conf_dict["gamma"] hidden_layers = conf_dict["hidden_layers"] @@ -53,7 +54,8 @@ print(f"Training on {long_dataname} data") try: gpu_number = conf_dict["gpu_number"] - device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') + device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() + else 'cpu') except KeyError: device = torch.device('cpu') @@ -82,33 +84,67 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) + def update_std_y(self, net): + """ + Update the std_y of `net` via the RMSE of the prediction. + """ + net_train_state = net.training + net.train() + pred_collection = [] + y_collection = [] + for i, (x,y) in enumerate(self.train_dataloader): + if i >= noneiv_prediction_number_of_batches: + break + if len(y.shape) <= 1: + y = y.view((-1,1)) + x,y = x.to(device), y.to(device) + pred, _ = net.predict(x, + number_of_draws=noneiv_prediction_number_of_draws, + remove_graph = True, + take_average_of_prediction=True) + pred_collection.append(pred) + y_collection.append(y) + pred_collection = torch.cat(pred_collection, dim=0) + y_collection = torch.cat(y_collection, dim=0) + assert pred_collection.shape == y_collection.shape + rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) + net.change_std_y(rmse) + if not net_train_state: + net.eval() + + def check_if_update_std_y(self, epoch): + """ + Check whether to update std_y according to `epoch_number` and + `std_y_update_points`. If the later is an integer, after all epochs + greater than this number an update will be made (i.e. `True` will + be returned). If it is a list, only `epoch_number` greater than + `std_y_update_points[0]` that divide `std_y_update_points[1]` will + result in a True. + """ + if type(std_y_update_points) is int: + return epoch >= std_y_update_points + else: + assert type(std_y_update_points) is list + return epoch >= std_y_update_points[0]\ + and epoch % std_y_update_points[1] == 0 def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ - if epoch >= epoch_offset: - pred_collection = [] - y_collection = [] - for i, (x,y) in enumerate(self.train_dataloader): - if i>= noneiv_prediction_number_of_batches: - break - if len(y.shape) <= 1: - y = y.view((-1,1)) - x,y = x.to(device), y.to(device) - pred, _ = net.predict(x, - number_of_draws=noneiv_prediction_number_of_draws, - remove_graph = True, - take_average_of_prediction=True) - pred_collection.append(pred) - y_collection.append(y) - pred_collection = torch.cat(pred_collection, dim=0) - y_collection = torch.cat(y_collection, dim=0) - assert pred_collection.shape == y_collection.shape - rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) - net.change_std_y(rmse) + if self.check_if_update_std_y(epoch): + self.update_std_y(net) self.lr_scheduler.step() + def post_train_update(self, net, epoch): + """ + Overwrites the corresponding method. If std_y of `net` was not updated + in the last training step, update it when finished with training. + `epoch` should be the number of the last training epoch. + """ + if not self.check_if_update_std_y(epoch): + self.update_std_y(net) + def extra_report(self, net, i): """ Overwrites the corresponding method @@ -118,7 +154,6 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): rmse_chain.append(rmse) writer.add_scalar('RMSE', rmse, self.total_count) writer.add_scalar('std_y', self.last_std_y, self.total_count) - writer.add_scalar('RMSE:std_y', rmse/self.last_std_y, self.total_count) writer.add_scalar('train loss', self.last_train_loss, self.total_count) writer.add_scalar('test loss', self.last_test_loss, self.total_count) print(f'RMSE {rmse:.3f}')