""" Train EiV model using different seeds """ import random import importlib import os import argparse import json import numpy as np import torch import torch.backends.cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVTrainingRoutines import train_and_store, loss_functions # read in data via --data option parser = argparse.ArgumentParser() parser.add_argument("--data", help="Loads data", default='california') parser.add_argument("--no-autoindent", help="", action="store_true") # to avoid conflics in IPython args = parser.parse_args() data = args.data # load hyperparameters from JSON file with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file: conf_dict = json.load(conf_file) long_dataname = conf_dict["long_dataname"] short_dataname = conf_dict["short_dataname"] lr = conf_dict["lr"] batch_size = conf_dict["batch_size"] test_batch_size = conf_dict["test_batch_size"] number_of_epochs = conf_dict["number_of_epochs"] unscaled_reg = conf_dict["unscaled_reg"] report_point = conf_dict["report_point"] p = conf_dict["p"] lr_update = conf_dict["lr_update"] # offset before updating sigma_y after each epoch epoch_offset = conf_dict["epoch_offset"] # will be used to predict the RMSE and update sigma_y accordingly eiv_prediction_number_of_draws = conf_dict["eiv_prediction_number_of_draws"] eiv_prediction_number_of_batches = conf_dict["eiv_prediction_number_of_batches"] init_std_y_list = conf_dict["init_std_y_list"] fixed_std_x = conf_dict['fixed_std_x'] gamma = conf_dict["gamma"] hidden_layers = conf_dict["hidden_layers"] seed_range = conf_dict['seed_range'] print(f"Training on {long_dataname} data") try: gpu_number = conf_dict["gpu_number"] device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') except KeyError: device = torch.device('cpu') load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data # reproducability seed_list = range(seed_range[0], seed_range[1]) def set_seeds(seed): torch.backends.cudnn.benchmark = False np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) # to store the RMSE rmse_chain = [] class UpdatedTrainEpoch(train_and_store.TrainEpoch): def pre_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ if epoch == 0: self.lr = self.initial_lr self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ if epoch >= epoch_offset: pred_collection = [] y_collection = [] for i, (x,y) in enumerate(self.train_dataloader): if i >= eiv_prediction_number_of_batches: break if len(y.shape) <= 1: y = y.view((-1,1)) x,y = x.to(device), y.to(device) pred, _ = net.predict(x, number_of_draws=eiv_prediction_number_of_draws, remove_graph = True, take_average_of_prediction=True) pred_collection.append(pred) y_collection.append(y) pred_collection = torch.cat(pred_collection, dim=0) y_collection = torch.cat(y_collection, dim=0) assert pred_collection.shape == y_collection.shape rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2)) net.change_std_y(rmse) self.lr_scheduler.step() def extra_report(self, net, i): """ Overwrites the corresponding method and fed after initialization of this class """ rmse = self.rmse(net).item() rmse_chain.append(rmse) writer.add_scalar('RMSE', rmse, self.total_count) writer.add_scalar('std_y', self.last_std_y, self.total_count) writer.add_scalar('RMSE:std_y', rmse/self.last_std_y, self.total_count) writer.add_scalar('train loss', self.last_train_loss, self.total_count) writer.add_scalar('test loss', self.last_test_loss, self.total_count) print(f'RMSE {rmse:.3f}') def rmse(self, net): """ Compute the root mean squared error for `net` """ net_train_state = net.training net_noise_state = net.noise_is_on net.eval() net.noise_off() x, y = next(iter(self.test_dataloader)) if len(y.shape) <= 1: y = y.view((-1,1)) out = net(x.to(device))[0].detach().cpu() assert out.shape == y.shape if net_train_state: net.train() if net_noise_state: net.noise_on() return torch.sqrt(torch.mean((out-y)**2)) def train_on_data(init_std_y, seed): """ Sets `seed`, loads data and trains an Bernoulli Modell, starting with `init_std_y`. """ # set seed set_seeds(seed) # load Datasets train_data, test_data = load_data(seed=seed, splitting_part=0.8, normalize=True) # make dataloaders train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=test_batch_size, shuffle=True) # create a net input_dim = train_data[0][0].numel() output_dim = train_data[0][1].numel() net = Networks.FNNEIV(p=p, init_std_y=init_std_y, h=[input_dim, *hidden_layers, output_dim], fixed_std_x=fixed_std_x) net.apply(initialize_weights.glorot_init) net = net.to(device) net.std_y_par.requires_grad = False std_x_map = lambda: net.get_std_x().detach().cpu().item() std_y_map = lambda: net.get_std_y().detach().cpu().item() # regularization reg = unscaled_reg/len(train_data) # create epoch_map criterion = loss_functions.nll_eiv epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader, test_dataloader=test_dataloader, criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map, lr=lr, reg=reg, report_point=report_point, device=device) # run and save save_file = os.path.join('saved_networks', f'eiv_{short_dataname}'\ f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\ f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\ f'_seed_{seed}.pkl') train_and_store.train_and_store(net=net, epoch_map=epoch_map, number_of_epochs=number_of_epochs, save_file=save_file) if __name__ == '__main__': for seed in seed_list: # Tensorboard monitoring writer = SummaryWriter(log_dir=f'/home/martin09/tmp/tensorboard/'\ f'run_eiv_{short_dataname}_lr_{lr:.4f}_seed'\ f'_{seed}_uregu_{unscaled_reg:.1f}_p_{p:.2f}'\ f'_fixed_std_x_{fixed_std_x:.3f}') print(f'>>>>SEED: {seed}') for init_std_y in init_std_y_list: print(f'Using init_std_y={init_std_y:.3f}') train_on_data(init_std_y, seed)