import random import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from EIVArchitectures import Networks from generate_housing_data import train_x, train_y,\ test_x, test_y, train_data, test_data from EIVTrainingRoutines import train_and_store, loss_functions # hyperparameters lr = 1e-3 batch_size = 16 number_of_epochs = 1000 reg = 1e-2 report_point = 20 precision_prior_zeta=0.0 dim = train_x.shape[-1] p = 0.5 lr_update = 950 pretraining = 300 epoch_offset = pretraining deming_factor_list = [0.01, 0.05, 0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7] init_std_y_list = [0.15] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # for reproducability torch.backends.cudnn.benchmark = False def set_seeds(seed): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) seed_list = range(20) # to store the RMSE rmse_chain = [] def mse(net, x, y, reg): """ Computes the mean squared error + regularization """ out = net(x)[0] regularization = 0 for p in net.parameters(): regularization += reg * torch.sum(p**2) return nn.MSELoss()(out, y) + regularization def deming_gen(deming, initial_deming = 0.0, stepsize = 1/20): """ Yields a generator to update the deming factor """ updated_deming = initial_deming while True: updated_deming += stepsize * deming yield min(updated_deming, deming) class UpdatedTrainEpoch(train_and_store.TrainEpoch): def pre_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ if epoch == 0: self.lr = self.initial_lr self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, 0.1 ) if epoch >= pretraining: self.criterion = loss_functions.nll_eiv_no_jensen def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method **Note**: self.deming_updater has to be defined explicitely and fed after initialiaztion of this class """ if epoch >= pretraining: net.change_deming(next(self.deming_updater)) if epoch >= epoch_offset: net.std_y_par.requires_grad = True self.lr_scheduler.step() def extra_report(self, net, i): """ Overwrites the corresponding method **Note**: self.test_couple has to be defined explicitely and fed after initialiaztion of this class """ rmse = self.rmse(net).item() rmse_chain.append(rmse) print('RMSE %.2f', rmse) def rmse(self, net): """ Compute the root mean squared error for `net` """ mse = 0 net_train_state = net.training net_noise_state = net.noise_is_on net.eval() net.noise_off() x, y = self.test_couple out = net(x.to(device))[0].detach().cpu().view((-1,)) y = y.view((-1,)) if net_train_state: net.train() if net_noise_state: net.noise_on() return torch.sqrt(torch.mean((out-y)**2)) def train_on_data(init_std_y, deming_factor, seed): """ Trains an EIV model """ set_seeds(seed) deming = deming_factor # load Datasets # make to dataloader train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True) # Create a net net = Networks.FNNEIV(p=p, init_std_y=init_std_y, precision_prior_zeta=precision_prior_zeta, deming=0.0, h=[dim,200,100,50,1]) net = net.to(device) net.std_y_par.requires_grad = False std_x_map = lambda: net.get_std_x().detach().cpu().item() std_y_map = lambda: net.get_std_y().detach().cpu().item() # Create epoch_map criterion = mse # criterion = loss_functions.nll_eiv_no_jensen epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader, test_dataloader=test_dataloader, criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map, lr=lr, reg=reg,report_point=report_point, device=device) epoch_map.deming_updater = iter(deming_gen(deming=deming)) epoch_map.test_couple = (test_x, test_y) # run and save save_file = os.path.join('saved_networks', 'eiv_housing_init_std_y_%.3f_deming'\ '_factor_%.3f_seed_%i.pkl' % (init_std_y, deming_factor, seed)) train_and_store.train_and_store(net=net, epoch_map=epoch_map, number_of_epochs=number_of_epochs, save_file=save_file, rmse=rmse_chain) if __name__ == '__main__': for seed in seed_list: for init_std_y in init_std_y_list: for deming_factor in deming_factor_list: rmse_chain.clear() print('->->Using init_std_y=%.2f'\ 'and deming_factor %.2f<-<-<-<-' %(init_std_y, deming_factor)) train_on_data(init_std_y, deming_factor, seed)