From 6e255c64ba8fdeacff91faf028fbaf8eff45a10b Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Tue, 14 Dec 2021 13:06:29 +0100 Subject: [PATCH] Post train updating added --- EIVPackage/EIVArchitectures/Networks.py | 3 +- .../EIVTrainingRoutines/train_and_store.py | 16 ++++ Experiments/train_eiv.py | 93 ++++++++++++------- Experiments/train_noneiv.py | 89 +++++++++++------- 4 files changed, 130 insertions(+), 71 deletions(-) diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py index d9c0f1b..22f48bb 100644 --- a/EIVPackage/EIVArchitectures/Networks.py +++ b/EIVPackage/EIVArchitectures/Networks.py @@ -32,7 +32,7 @@ class FNNEIV(nn.Module): **Note**: - To change the deming factor afterwards, use the method `change_deming` - To change fixed_std_x afterwards, use the method `change_fixed_std_x` - - To change std_y use the method `change_std_x` + - To change std_y use the method `change_std_y` """ LeakyReLUSlope = 1e-2 def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0, @@ -349,6 +349,7 @@ class FNNBer(nn.Module): :param h: A list specifying the number of neurons in each layer. :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus be updated during optimization. Defaults to False. + To change std_y use the method `change_std_y` """ LeakyReLUSlope = 1e-2 def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1], diff --git a/EIVPackage/EIVTrainingRoutines/train_and_store.py b/EIVPackage/EIVTrainingRoutines/train_and_store.py index eb0378e..908c445 100644 --- a/EIVPackage/EIVTrainingRoutines/train_and_store.py +++ b/EIVPackage/EIVTrainingRoutines/train_and_store.py @@ -69,6 +69,15 @@ class TrainEpoch(): """ pass + + def post_train_update(self, net, epoch=None): + """ + Will be executed after the training is finished + :param net: The current net, a torch.nn.Module + :param epoch: Tue last epochn number, an integer. + """ + pass + def extra_report(self, net, step): """ Overwrite for reporting on state of net @@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): Calls `epoch_map` with `epoch` and the current epoch number `number_of_epochs` times and stores a list of its output as a pickled file under `save_file`. + After this (the training) is done, `epoch_map.post_train_update(net, + epoch_number)` is called, if existent. **Note**: The output of `epoch_map` is supposed to consist of 4 specific arguments, see below. :param net: A torch.nn.Module @@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): test_loss_collection.append(test_loss) std_x_collection.append(std_x) std_y_collection.append(std_y) + try: + epoch_map.post_train_update(net, number_of_epochs-1) + print('Training done. Performed post training updating.') + except AttributeError: + print('Training done. No post train updating performed.') # Saving state_dict = net.state_dict() to_save = { diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py index 2ab2dce..eb8d9d8 100644 --- a/Experiments/train_eiv.py +++ b/Experiments/train_eiv.py @@ -85,48 +85,71 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) + def update_std_y(self, net): + """ + Update the std_y of `net` via the RMSE of the prediction. + """ + net_train_state = net.training + net_noise_state = net.noise_is_on + net.train() + net.noise_on() + pred_collection = [] + y_collection = [] + for i, (x,y) in enumerate(self.train_dataloader): + if i >= eiv_prediction_number_of_batches: + break + if len(y.shape) <= 1: + y = y.view((-1,1)) + x,y = x.to(device), y.to(device) + pred, _ = net.predict(x, + number_of_draws=eiv_prediction_number_of_draws, + remove_graph = True, + take_average_of_prediction=True) + pred_collection.append(pred) + y_collection.append(y) + pred_collection = torch.cat(pred_collection, dim=0) + y_collection = torch.cat(y_collection, dim=0) + assert pred_collection.shape == y_collection.shape + rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) + net.change_std_y(rmse) + if not net_train_state: + net.eval() + if not net_noise_state: + net.noise_off() + + def check_if_update_std_y(self, epoch): + """ + Check whether to update std_y according to `epoch_number` and + `std_y_update_points`. If the later is an integer, after all epochs + greater than this number an update will be made (i.e. `True` will + be returned). If it is a list, only `epoch_number` greater than + `std_y_update_points[0]` that divide `std_y_update_points[1]` will + result in a True. + """ + if type(std_y_update_points) is int: + return epoch >= std_y_update_points + else: + assert type(std_y_update_points) is list + return epoch >= std_y_update_points[0]\ + and epoch % std_y_update_points[1] == 0 def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ - def update_std_y(epoch_number): - """ - Check whether to update std_y according to `epoch_number` and - `std_y_update_points`. If the later is an integer, after all epochs - greater than this number an update will be made (i.e. `True` will - be returned). If it is a list, only `epoch_number` greater than - `std_y_update_points[0]` that divide `std_y_update_points[1]` will - result in a True. - """ - if type(std_y_update_points) is int: - return epoch >= std_y_update_points - else: - assert type(std_y_update_points) is list - return epoch_number >= std_y_update_points[0]\ - and epoch_number % std_y_update_points[1] == 0 - if update_std_y(epoch): - pred_collection = [] - y_collection = [] - for i, (x,y) in enumerate(self.train_dataloader): - if i >= eiv_prediction_number_of_batches: - break - if len(y.shape) <= 1: - y = y.view((-1,1)) - x,y = x.to(device), y.to(device) - pred, _ = net.predict(x, - number_of_draws=eiv_prediction_number_of_draws, - remove_graph = True, - take_average_of_prediction=True) - pred_collection.append(pred) - y_collection.append(y) - pred_collection = torch.cat(pred_collection, dim=0) - y_collection = torch.cat(y_collection, dim=0) - assert pred_collection.shape == y_collection.shape - rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) - net.change_std_y(rmse) + if self.check_if_update_std_y(epoch): + self.update_std_y(net) self.lr_scheduler.step() + def post_train_update(self, net, epoch): + """ + Overwrites the corresponding method. If std_y of `net` was not updated + in the last training step, update it when finished with training. + `epoch` should be the number of the last training epoch. + """ + if not self.check_if_update_std_y(epoch): + self.update_std_y(net) + def extra_report(self, net, i): """ Overwrites the corresponding method diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py index 1f27ceb..316dda1 100644 --- a/Experiments/train_noneiv.py +++ b/Experiments/train_noneiv.py @@ -84,48 +84,67 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) + def update_std_y(self, net): + """ + Update the std_y of `net` via the RMSE of the prediction. + """ + net_train_state = net.training + net.train() + pred_collection = [] + y_collection = [] + for i, (x,y) in enumerate(self.train_dataloader): + if i >= noneiv_prediction_number_of_batches: + break + if len(y.shape) <= 1: + y = y.view((-1,1)) + x,y = x.to(device), y.to(device) + pred, _ = net.predict(x, + number_of_draws=noneiv_prediction_number_of_draws, + remove_graph = True, + take_average_of_prediction=True) + pred_collection.append(pred) + y_collection.append(y) + pred_collection = torch.cat(pred_collection, dim=0) + y_collection = torch.cat(y_collection, dim=0) + assert pred_collection.shape == y_collection.shape + rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) + net.change_std_y(rmse) + if not net_train_state: + net.eval() + + def check_if_update_std_y(self, epoch): + """ + Check whether to update std_y according to `epoch_number` and + `std_y_update_points`. If the later is an integer, after all epochs + greater than this number an update will be made (i.e. `True` will + be returned). If it is a list, only `epoch_number` greater than + `std_y_update_points[0]` that divide `std_y_update_points[1]` will + result in a True. + """ + if type(std_y_update_points) is int: + return epoch >= std_y_update_points + else: + assert type(std_y_update_points) is list + return epoch >= std_y_update_points[0]\ + and epoch % std_y_update_points[1] == 0 def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ - def update_std_y(epoch_number): - """ - Check whether to update std_y according to `epoch_number` and - `std_y_update_points`. If the later is an integer, after all epochs - greater than this number an update will be made (i.e. `True` will - be returned). If it is a list, only `epoch_number` greater than - `std_y_update_points[0]` that divide `std_y_update_points[1]` will - result in a True. - """ - if type(std_y_update_points) is int: - return epoch >= std_y_update_points - else: - assert type(std_y_update_points) is list - return epoch_number >= std_y_update_points[0]\ - and epoch_number % std_y_update_points[1] == 0 - if update_std_y(epoch): - pred_collection = [] - y_collection = [] - for i, (x,y) in enumerate(self.train_dataloader): - if i >= noneiv_prediction_number_of_batches: - break - if len(y.shape) <= 1: - y = y.view((-1,1)) - x,y = x.to(device), y.to(device) - pred, _ = net.predict(x, - number_of_draws=noneiv_prediction_number_of_draws, - remove_graph = True, - take_average_of_prediction=True) - pred_collection.append(pred) - y_collection.append(y) - pred_collection = torch.cat(pred_collection, dim=0) - y_collection = torch.cat(y_collection, dim=0) - assert pred_collection.shape == y_collection.shape - rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) - net.change_std_y(rmse) + if self.check_if_update_std_y(epoch): + self.update_std_y(net) self.lr_scheduler.step() + def post_train_update(self, net, epoch): + """ + Overwrites the corresponding method. If std_y of `net` was not updated + in the last training step, update it when finished with training. + `epoch` should be the number of the last training epoch. + """ + if not self.check_if_update_std_y(epoch): + self.update_std_y(net) + def extra_report(self, net, i): """ Overwrites the corresponding method -- GitLab