""" Train non-EiV model on the naval propulsion dataset using different seeds """ import random import os import numpy as np import torch import torch.backends.cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVData.naval_propulsion import load_data from EIVTrainingRoutines import train_and_store, loss_functions # hyperparameters lr = 1e-3 batch_size = 32 test_batch_size = 600 number_of_epochs = 30 unscaled_reg = 10 report_point = 5 p = 0.2 lr_update = 20 # pretraining = 300 epoch_offset = 20 init_std_y_list = [0.5] gamma = 0.5 hidden_layers = [1024, 1024, 1024, 1024] device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') # reproducability def set_seeds(seed): torch.backends.cudnn.benchmark = False np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) seed_list = range(10) # to store the RMSE rmse_chain = [] class UpdatedTrainEpoch(train_and_store.TrainEpoch): def pre_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ if epoch == 0: self.lr = self.initial_lr self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, lr_update, gamma) def post_epoch_update(self, net, epoch): """ Overwrites the corresponding method """ if epoch >= epoch_offset: net.std_y_par.requires_grad = True self.lr_scheduler.step() def extra_report(self, net, i): """ Overwrites the corresponding method and fed after initialization of this class """ rmse = self.rmse(net).item() rmse_chain.append(rmse) writer.add_scalar('RMSE', rmse, self.total_count) writer.add_scalar('train loss', self.last_train_loss, self.total_count) writer.add_scalar('test loss', self.last_test_loss, self.total_count) print(f'RMSE {rmse:.3f}') def rmse(self, net): """ Compute the root mean squared error for `net` """ net_train_state = net.training net.eval() x, y = next(iter(self.test_dataloader)) if len(y.shape) <= 1: y = y.view((-1,1)) out = net(x.to(device))[0].detach().cpu() assert out.shape == y.shape if net_train_state: net.train() return torch.sqrt(torch.mean((out-y)**2)) def train_on_data(init_std_y, seed): """ Sets `seed`, loads data and trains an Bernoulli Modell, starting with `init_std_y`. """ # set seed set_seeds(seed) # load Datasets train_data, test_data = load_data(seed=seed, splitting_part=0.8, normalize=True) # make dataloaders train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=test_batch_size, shuffle=True) # create a net input_dim = train_data[0][0].numel() output_dim = train_data[0][1].numel() net = Networks.FNNBer(p=p, init_std_y=init_std_y, h=[input_dim, *hidden_layers, output_dim]) net.apply(initialize_weights.glorot_init) net = net.to(device) net.std_y_par.requires_grad = False std_x_map = lambda: 0.0 std_y_map = lambda: net.get_std_y().detach().cpu().item() # regularization reg = unscaled_reg/len(train_data) # create epoch_map criterion = loss_functions.nll_reg_loss epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader, test_dataloader=test_dataloader, criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map, lr=lr, reg=reg, report_point=report_point, device=device) # run and save save_file = os.path.join('saved_networks', f'noneiv_naval'\ f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\ f'_p_{p:.2f}_seed_{seed}.pkl') train_and_store.train_and_store(net=net, epoch_map=epoch_map, number_of_epochs=number_of_epochs, save_file=save_file) if __name__ == '__main__': for seed in seed_list: # Tensorboard monitoring writer = SummaryWriter(log_dir=f'/home/martin09/tmp/tensorboard/'\ f'run_noneiv_naval_lr_{lr:.4f}_seed'\ f'_{seed}_uregu_{unscaled_reg:.1f}_p_{p:.2f}') print(f'>>>>SEED: {seed}') for init_std_y in init_std_y_list: print(f'Using init_std_y={init_std_y:.3f}') train_on_data(init_std_y, seed)