Skip to content
Snippets Groups Projects
train_noneiv_multinomial_ensemble_seed.py 4.20 KiB
import random
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from EIVArchitectures import Networks
from generate_multinomial_data import get_data
from EIVTrainingRoutines import train_and_store, loss_functions

# hyperparameters
lr = 1e-3
batch_size = 200
number_of_epochs = 350
reg = 1e-6
report_point = 40
precision_prior_zeta=0.0
n_train = 100000
dim = 5
p = 0.5
lr_update = 300
pretraining = 200
epoch_offset = pretraining
std_x_list = [0.05,0.07,0.10]
deming_scale_list = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5]
init_std_y_list = [0.15]
std_y = 0.30
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# reproducability
torch.backends.cudnn.benchmark = False
def set_seeds(seed):
    np.random.seed(seed)
    random.seed(seed) 
    torch.manual_seed(seed)
seed_list = range(20, 20+ 20*5)

# To store the RMSE
rmse_chain = []

class UpdatedTrainEpoch(train_and_store.TrainEpoch):
    def pre_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method
        """
        if epoch == 0:
            self.lr = self.initial_lr
            self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, lr_update, 0.1
            )


    def post_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method
        """
        if epoch >= epoch_offset:
            net.std_y_par.requires_grad = True
        self.lr_scheduler.step()

    def extra_report(self, net, i):
        """
        Overwrites the corresponding method
        **Note**: self.val_data_pure has to be defined explicitely
        and fed after initialiaztion of this class
        """
        rmse = self.rmse(net).item()
        rmse_chain.append(rmse)
        print('RMSE %.2f', rmse)

    def rmse(self, net):
        """
        Compute the root mean squared error for `net`
        """
        mse = 0 
        net_train_state = net.training
        net.eval()
        x, y = self.val_data_pure
        out = net(x.to(device))[0].detach().cpu()
        if net_train_state:
            net.train()
        return torch.sqrt(torch.mean((out-y)**2))


def train_on_data(std_x, init_std_y, seed):
    """
    Loads data associated with `std_x` and trains an Bernoulli Modell.
    """
    # load Datasets
    _, train_data, _, test_data, val_data_pure, _, _ =\
            get_data(std_x=std_x, std_y=std_y, dim=dim, n_train=n_train)
    train_data = TensorDataset(*train_data )
    test_data = TensorDataset(*test_data )
    set_seeds(seed)
    # make to dataloader
    train_dataloader = DataLoader(train_data, batch_size=batch_size, 
            shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size,
            shuffle=True)
    # Create a net
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[dim,500,300,100,1])
    net = net.to(device)
    net.std_y_par.requires_grad = False
    std_x_map = lambda: 0.0
    std_y_map = lambda: net.get_std_y().detach().cpu().item()
    # Create epoch_map
    criterion = loss_functions.nll_reg_loss
    epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
            lr=lr, reg=reg,report_point=report_point, device=device)
    epoch_map.val_data_pure = val_data_pure
    # run and save
    save_file = os.path.join('saved_networks','noneiv_multinomial_std_x_%.3f'\
            '_std_y_%.3f_init_std_y_%.3f_ensemble_seed_%i.pkl'\
            % (std_x, std_y, init_std_y, seed))
    train_and_store.train_and_store(net=net, 
            epoch_map=epoch_map,
            number_of_epochs=number_of_epochs,
            save_file=save_file,
            rmse=rmse_chain)
    

if __name__ == '__main__':
    for seed in seed_list:
        print('SEED: %i' % (seed,))
        for init_std_y in init_std_y_list:
            for std_x in std_x_list:
                print('->->Using std_x=%.2f and init_std_y=%.2f<-<-<-<-'
                        %(std_x, init_std_y))
                rmse_chain.clear()
                train_on_data(std_x, init_std_y, seed)