import numpy as np import torch class TrainEpoch(): """ A simple implementation of the training during one epoch. The training is implemented in the __call__ method that returns collected samples of the train loss and test loss, averaged between two "report points", and the std_x and std_y returned by `std_x_map` and `std_y_map` (which should be **floats**). If `verbose` is `True` (the default) these values will be printed at "report points". :param train_dataloader: A torch.nn.utils.DataLoader (for train data) :param test_dataloader: A torch.nn.utils.DataLoader (for test data) :param criterion: A map that takes net, x, y and reg as input and returns a single-element torch.tensor :param std_y_map: Map without arguments that returns a float. :param std_x_map: Map without arguments that returns a float. :param lr: The learn rate (a positive float), defaults to 1e-3 :param reg: The regularization, defaults to 1.0 :param report_point: Positive integer, distance between two report points. Defaults to 100. :param verbose: If True (the default) prints information while training. :param device: Do training on this device (defaults to 'cpu') std_x and std_y are printed at each report point. """ def __init__(self, train_dataloader, test_dataloader, criterion, std_y_map, std_x_map=lambda: 0.0, lr=1e-3, reg=1.0, report_point=100, verbose=True, device='cpu'): self.train_dataloader = train_dataloader self.test_dataloader = test_dataloader self.initial_lr = lr self.criterion = criterion self.std_x_map = std_x_map self.std_y_map = std_y_map self.reg = reg self.report_point = report_point self.verbose = verbose self.device = device # self.lr_generator = iter(self.next_lr()) self.lr = None self.total_count = 0 def next_lr(self): while True: yield self.initial_lr def pre_epoch_update(self, net, epoch): """ Overwrite to update optimizer, the learn rate or similar in a different manner. *Note* This method is expected to define at least `self.optimizer`. :param net: The net to be used for the optimizer :param epoch: The current epoch, an integer. """ old_lr = self.lr self.lr = next(self.lr_generator) if old_lr != self.lr: self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) def post_epoch_update(self, net, epoch): """ Overwrite for inheritance to update the net (e.g. its deming factor for EIV) after each epoch :param net: The current net, a torch.nn.Module. :param epoch: The current epoch, an integer. """ pass def post_train_update(self, net, epoch=None): """ Will be executed after the training is finished :param net: The current net, a torch.nn.Module :param epoch: Tue last epochn number, an integer. """ pass def extra_report(self, net, step): """ Overwrite for reporting on state of net at each report point """ pass def __call__(self, net, epoch): """ :param net: A torch.nn.Module :param epoch: The current epoch, a non-negative integer. Will only be used for formatting of the printing. """ self.pre_epoch_update(net, epoch) train_loss, test_loss = [], [] stored_train_loss, stored_test_loss = [], [] stored_std_x, stored_std_y = [], [] if self.verbose: print('>>>> Epoch %i' % (epoch,)) stored_train_loss_to_average = [] stored_test_loss_to_average = [] for i, (x,y) in enumerate(self.train_dataloader): self.total_count += 1 # optimize on train data x, y = x.to(self.device), y.to(self.device) loss = self.criterion(net, x, y, self.reg) loss.backward() self.optimizer.step() net.zero_grad() stored_train_loss_to_average.append(loss.detach().cpu().item()) if i % self.report_point == 0: # evaluate on test_data for j, (x,y) in enumerate(self.test_dataloader): x,y = x.to(self.device), y.to(self.device) loss = self.criterion(net, x, y, self.reg).detach().cpu().item() stored_test_loss_to_average.append(loss) std_x = self.std_x_map() std_y = self.std_y_map() # store values stored_train_loss.append( np.mean(stored_train_loss_to_average)) stored_test_loss.append( np.mean(stored_test_loss_to_average)) stored_std_x.append(std_x) stored_std_y.append(std_y) if i>0 and self.verbose: print('Step %i,'\ ' train_loss: %.2f,'\ ' test_loss: %.2f,'\ ' std_x: %.2f,' ' std_y: %.2f' % (i, stored_train_loss[-1], stored_test_loss[-1], std_x, std_y )) # to be used for extra reporting self.last_train_loss = stored_train_loss[-1] self.last_test_loss = stored_test_loss[-1] self.last_std_x = std_x self.last_std_y = std_y # extra reporting self.extra_report(net, i) stored_train_loss_to_average = [] stored_test_loss_to_average = [] self.post_epoch_update(net, epoch) # convert to tensors stored_train_loss = torch.tensor(stored_train_loss, dtype=torch.float32) stored_test_loss = torch.tensor(stored_test_loss, dtype=torch.float32) stored_std_x = torch.tensor(stored_std_x, dtype=torch.float32) stored_std_y = torch.tensor(stored_std_y, dtype=torch.float32) return stored_train_loss,\ stored_test_loss,\ stored_std_x,\ stored_std_y def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): """ Calls `epoch_map` with `epoch` and the current epoch number `number_of_epochs` times and stores a list of its output as a pickled file under `save_file`. After this (the training) is done, `epoch_map.post_train_update(net, epoch_number)` is called, if existent. **Note**: The output of `epoch_map` is supposed to consist of 4 specific arguments, see below. :param net: A torch.nn.Module :param epoch_map: A map taking net and an integer (the epoch number) as an input and returning 4 torch.tensors containing the evolution of train_loss, test_loss, std_x, std_y during one epoch. :param save_file: A string containing the path to the file to be pickled. :param kwargs: keywords and their values will also be store after the training. Use lists to ensure they are updated. """ std_x_collection = [] std_y_collection = [] train_loss_collection = [] test_loss_collection = [] std_x_collection = [] std_x_collection = [] # Saving for epoch in range(number_of_epochs): train_loss, test_loss, std_x, std_y = epoch_map(net, epoch) train_loss_collection.append(train_loss) test_loss_collection.append(test_loss) std_x_collection.append(std_x) std_y_collection.append(std_y) try: epoch_map.post_train_update(net, number_of_epochs-1) print('Training done. Performed post training updating.') except AttributeError: print('Training done. No post train updating performed.') # Saving state_dict = net.state_dict() to_save = { 'train_loss' : train_loss_collection, 'test_loss' : test_loss_collection, 'std_x' : std_x_collection, 'std_y' : std_y_collection, 'state_dict' : state_dict } for key, value in kwargs.items(): to_save[key] = value with open(save_file, 'wb') as file_handle: torch.save(to_save, file_handle) def open_stored_training(saved_file, net=None, join_epochs = True, extra_keys = None, device=torch.device('cpu')): """ Counterpart to `train_and_store`, opens `saved_file` and returns :param saved_file: A pickle file generated with train_and_store :param net: The corresponding torch.nn.Module :param join_epochs: If True (default), all epochs will be concatenated. :param extra_keys: None (default) or a list of strings. If the latter, is used to extract extra entries of the stored dictionary to be returned as a list. :param device: The device to be used for loading the state_dict. """ with open(saved_file, 'rb') as file_handle: loaded_dict = torch.load(file_handle, map_location=device) train_loss = loaded_dict['train_loss'] test_loss = loaded_dict['test_loss'] std_x = loaded_dict['std_x'] std_y = loaded_dict['std_y'] state_dict = loaded_dict['state_dict'] if net is not None: net.load_state_dict(state_dict) if join_epochs: train_loss = torch.cat(train_loss, dim=0) test_loss = torch.cat(test_loss, dim=0) std_x = torch.cat(std_x, dim=0) std_y = torch.cat(std_y, dim=0) if extra_keys is None: return train_loss, test_loss, std_x, std_y, state_dict else: extra_list = [loaded_dict[key] for key in extra_keys] return train_loss, test_loss, std_x, std_y, state_dict, extra_list