Skip to content
Snippets Groups Projects
train_and_store.py 9.76 KiB
Newer Older
Jörg Martin's avatar
Jörg Martin committed
import numpy as np
import torch

class TrainEpoch():
    """
    A simple implementation of the training during one epoch.
    The training is implemented in the __call__ method that returns collected 
    samples of the train loss and test loss, averaged between 
    two "report points", and the std_x and std_y returned by `std_x_map` and 
    `std_y_map` (which should be **floats**).
    If `verbose` is `True` (the default) these values will be printed at 
    "report points".
    :param train_dataloader: A torch.nn.utils.DataLoader (for train data)
    :param test_dataloader: A torch.nn.utils.DataLoader (for test data)
    :param criterion: A map that takes net, x, y and reg as input and returns 
    a single-element torch.tensor
    :param std_y_map: Map without arguments that returns a float.
    :param std_x_map: Map without arguments that returns a float.
    :param lr: The learn rate (a positive float), defaults to 1e-3
    :param reg: The regularization, defaults to 1.0
    :param report_point: Positive integer, distance between two report points.
    Defaults to 100.
    :param verbose: If True (the default) prints information while training.
    :param device: Do training on this device (defaults to 'cpu')
    std_x and std_y are printed at each report point.
    """
    def __init__(self, train_dataloader, test_dataloader, 
            criterion, std_y_map, std_x_map=lambda: 0.0, lr=1e-3, 
            reg=1.0, report_point=100, verbose=True, device='cpu'):
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.initial_lr = lr
        self.criterion = criterion
        self.std_x_map = std_x_map
        self.std_y_map = std_y_map
        self.reg = reg
        self.report_point = report_point
        self.verbose = verbose
        self.device = device
        # 
        self.lr_generator = iter(self.next_lr())
        self.lr = None
Jörg Martin's avatar
Jörg Martin committed

    def next_lr(self):
        while True:
            yield self.initial_lr

    def pre_epoch_update(self, net, epoch):
        """
        Overwrite to update optimizer, the learn rate or similar in a
        different manner.
        *Note* This method is expected to define at least `self.optimizer`.
        :param net: The net to be used for the optimizer
        :param epoch: The current epoch, an integer.
        """
        old_lr = self.lr
        self.lr = next(self.lr_generator)
        if old_lr != self.lr:
            self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)


    def post_epoch_update(self, net, epoch):
        """
        Overwrite for inheritance to update the net (e.g. 
        its deming factor for EIV) after each epoch
        :param net: The current net, a torch.nn.Module.
        :param epoch: The current epoch, an integer.
        """
        pass

Jörg Martin's avatar
Jörg Martin committed

    def post_train_update(self, net, epoch=None):
        """
        Will be executed after the training is finished
        :param net: The current net, a torch.nn.Module
        :param epoch: Tue last epochn number, an integer.
        """
        pass

Jörg Martin's avatar
Jörg Martin committed
    def extra_report(self, net, step):
        """
        Overwrite for reporting on state of net
        at each report point
        """
        pass


    def __call__(self, net, epoch):
        """
        :param net: A torch.nn.Module
        :param epoch: The current epoch, a non-negative integer.
         Will only be used for formatting of the printing.
        """
        self.pre_epoch_update(net, epoch)
        train_loss, test_loss = [], []
        stored_train_loss, stored_test_loss = [], []
        stored_std_x, stored_std_y = [], []
        if self.verbose:
            print('>>>> Epoch %i' % (epoch,))
        stored_train_loss_to_average = []
        stored_test_loss_to_average = []
        for i, (x,y) in enumerate(self.train_dataloader):
Jörg Martin's avatar
Jörg Martin committed
            # optimize on train data
            x, y = x.to(self.device), y.to(self.device)
            loss = self.criterion(net, x, y, self.reg)
            loss.backward()
            self.optimizer.step()
            net.zero_grad()
            stored_train_loss_to_average.append(loss.detach().cpu().item())
            if i % self.report_point == 0:
                # evaluate on test_data
                for j, (x,y) in enumerate(self.test_dataloader):
                    x,y = x.to(self.device), y.to(self.device)
                    loss = self.criterion(net, x, y,
                            self.reg).detach().cpu().item()
                    stored_test_loss_to_average.append(loss) 
                std_x = self.std_x_map()
                std_y = self.std_y_map()
                # store values
                stored_train_loss.append(
                        np.mean(stored_train_loss_to_average))
                stored_test_loss.append(
                        np.mean(stored_test_loss_to_average))
                stored_std_x.append(std_x)
                stored_std_y.append(std_y)
                if i>0 and self.verbose:
                    print('Step %i,'\
                            ' train_loss: %.2f,'\
                            ' test_loss: %.2f,'\
                            ' std_x: %.2f,'
                            ' std_y: %.2f' % (i,
                               stored_train_loss[-1],
                               stored_test_loss[-1],
                               std_x,
                               std_y
                               ))
                # to be used for extra reporting
                self.last_train_loss = stored_train_loss[-1]
                self.last_test_loss = stored_test_loss[-1]
                self.last_std_x = std_x
                self.last_std_y = std_y
                # extra reporting
Jörg Martin's avatar
Jörg Martin committed
                self.extra_report(net, i)
                stored_train_loss_to_average = []
                stored_test_loss_to_average = []
        self.post_epoch_update(net, epoch)
        # convert to tensors
        stored_train_loss = torch.tensor(stored_train_loss, 
                dtype=torch.float32)
        stored_test_loss = torch.tensor(stored_test_loss,
                dtype=torch.float32)
        stored_std_x = torch.tensor(stored_std_x, dtype=torch.float32)
        stored_std_y = torch.tensor(stored_std_y, dtype=torch.float32)
        return stored_train_loss,\
                 stored_test_loss,\
                 stored_std_x,\
                 stored_std_y



def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
    """
    Calls `epoch_map` with `epoch` and the current epoch number
    `number_of_epochs` times and stores a list of 
    its output as a pickled file under `save_file`.
Jörg Martin's avatar
Jörg Martin committed
    After this (the training) is done, `epoch_map.post_train_update(net,
    epoch_number)` is called, if existent.
Jörg Martin's avatar
Jörg Martin committed
    **Note**: The output of `epoch_map` is supposed to 
    consist of 4 specific arguments, see below.
    :param net: A torch.nn.Module
    :param epoch_map: A map taking net and an integer (the epoch number) as an
    input and returning 4 torch.tensors containing the evolution of
    train_loss, test_loss, std_x, std_y during one epoch.
    :param save_file: A string containing the path to the  file to be pickled.
    :param kwargs: keywords and their values will 
    also be store after the training. Use lists to ensure they are updated.
    """
    std_x_collection = []
    std_y_collection = []
    train_loss_collection = []
    test_loss_collection = []
    std_x_collection = []
    std_x_collection = []
    # Saving
    for epoch in range(number_of_epochs):
        train_loss, test_loss, std_x, std_y = epoch_map(net, epoch)
        train_loss_collection.append(train_loss)
        test_loss_collection.append(test_loss)
        std_x_collection.append(std_x)
        std_y_collection.append(std_y)
Jörg Martin's avatar
Jörg Martin committed
    try:
        epoch_map.post_train_update(net, number_of_epochs-1)
        print('Training done. Performed post training updating.')
    except AttributeError:
        print('Training done. No post train updating performed.')
Jörg Martin's avatar
Jörg Martin committed
    # Saving
    state_dict = net.state_dict()
    to_save = {
            'train_loss' : train_loss_collection,
            'test_loss' : test_loss_collection,
            'std_x' : std_x_collection,
            'std_y' : std_y_collection,
            'state_dict' : state_dict
            }
    for key, value in kwargs.items():
        to_save[key] = value
    with open(save_file, 'wb') as file_handle:
        torch.save(to_save, file_handle)


def open_stored_training(saved_file, net=None, join_epochs = True, 
        extra_keys = None, device=torch.device('cpu')):
    """
    Counterpart to `train_and_store`, opens `saved_file` and returns
    :param saved_file: A pickle file generated with train_and_store
    :param net: The corresponding torch.nn.Module
    :param join_epochs: If True (default), all epochs will be concatenated.
    :param extra_keys: None (default) or a list of strings. If the latter, is
    used to extract extra entries of the stored dictionary to be returned
    as a list.
    :param device: The device to be used for loading the state_dict.
    """
    with open(saved_file, 'rb') as file_handle:
        loaded_dict = torch.load(file_handle, map_location=device)
    train_loss = loaded_dict['train_loss'] 
    test_loss = loaded_dict['test_loss']
    std_x = loaded_dict['std_x']
    std_y = loaded_dict['std_y']
    state_dict = loaded_dict['state_dict']
    if net is not None:
        net.load_state_dict(state_dict)
    if join_epochs:
        train_loss = torch.cat(train_loss, dim=0)
        test_loss = torch.cat(test_loss, dim=0)
        std_x = torch.cat(std_x, dim=0)
        std_y = torch.cat(std_y, dim=0)
    if extra_keys is None:
        return train_loss, test_loss, std_x, std_y, state_dict
    else:
        extra_list = [loaded_dict[key] for key in extra_keys]
        return train_loss, test_loss, std_x, std_y, state_dict, extra_list