-
Jörg Martin authoredJörg Martin authored
train_eiv_vd_multinomial.py 5.52 KiB
import random
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from EIVArchitectures import Networks
from generate_multinomial_data import get_data
from EIVTrainingRoutines import train_and_store, loss_functions
# hyperparameters
lr = 1e-3
batch_size = 200
number_of_epochs = 350
n_train = 100000
reg = [1e-6, 1/n_train]
report_point = 40
precision_prior_zeta=0.0
dim = 5
initial_alpha = 0.5
lr_update = 300
pretraining = 200
epoch_offset = pretraining
std_x_list = [0.05, 0.07, 0.10]
deming_scale_list = [0.15, 0.20, 0.30]
init_std_y_list = [0.15]
std_y = 0.3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# reproducability
torch.backends.cudnn.benchmark = False
def set_seeds(seed):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
seed_list = range(20)
# to store the RMSE
rmse_chain = []
def mse(net, x, y, reg):
"""
Computes the mean squared error + regularization
"""
out = net(x)[0]
regularization = 0
for p in net.parameters():
regularization += reg[0] * torch.sum(p**2)
return nn.MSELoss()(out, y) + regularization
def deming_gen(deming, initial_deming = 0.0, stepsize = 1/20):
"""
Yields a generator to update the deming factor
"""
updated_deming = initial_deming
while True:
updated_deming += stepsize * deming
yield min(updated_deming, deming)
class UpdatedTrainEpoch(train_and_store.TrainEpoch):
def pre_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
"""
if epoch == 0:
self.lr = self.initial_lr
self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, 0.1
)
if epoch >= pretraining:
self.criterion = loss_functions.nll_eiv_no_jensen
def post_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
**Note**: self.deming_updater has to be defined explicitely
and fed after initialiaztion of this class
"""
if epoch >= pretraining:
net.change_deming(next(self.deming_updater))
if epoch >= epoch_offset:
net.std_y_par.requires_grad = True
self.lr_scheduler.step()
def extra_report(self, net, i):
"""
Overwrites the corresponding method
**Note**: self.val_data_pure has to be defined explicitely
and fed after initialiaztion of this class
"""
rmse = self.rmse(net).item()
rmse_chain.append(rmse)
print('RMSE %.2f', rmse)
def rmse(self, net):
"""
Compute the root mean squared error for `net`
"""
mse = 0
net_train_state = net.training
net_noise_state = net.noise_is_on
net.eval()
net.noise_off()
x, y = self.val_data_pure
out = net(x.to(device))[0].detach().cpu()
if net_train_state:
net.train()
if net_noise_state:
net.noise_on()
return torch.sqrt(torch.mean((out-y)**2))
def train_on_data(std_x, init_std_y, deming_scale, seed):
"""
Loads data associated with `std_x` and trains an EIV Modell
"""
deming = deming_scale
# load Datasets
_, train_data, _, test_data, val_data_pure, _, _ =\
get_data(std_x=std_x, std_y=std_y, dim=dim, n_train=n_train)
train_data = TensorDataset(*train_data )
test_data = TensorDataset(*test_data )
set_seeds(seed)
# make to dataloader
train_dataloader = DataLoader(train_data, batch_size=batch_size,
shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size,
shuffle=True)
# Create a net
net = Networks.FNN_VD_EIV(initial_alpha=initial_alpha,
init_std_y=init_std_y,
precision_prior_zeta=precision_prior_zeta, deming=0.0,
h=[dim,500,300,100,1])
net = net.to(device)
net.std_y_par.requires_grad = False
std_x_map = lambda: net.get_std_x().detach().cpu().item()
std_y_map = lambda: net.get_std_y().detach().cpu().item()
# Create epoch_map
criterion = mse
# criterion = loss_functions.nll_eiv_no_jensen
epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
lr=lr, reg=reg,report_point=report_point, device=device)
epoch_map.deming_updater = iter(deming_gen(deming=deming))
epoch_map.val_data_pure = val_data_pure
# run and save
save_file = os.path.join('saved_networks','eiv_vd_multinomial_std_x_%.3f'\
'_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'\
% (std_x, std_y, init_std_y, deming_scale, seed))
train_and_store.train_and_store(net=net,
epoch_map=epoch_map,
number_of_epochs=number_of_epochs,
save_file=save_file,
rmse=rmse_chain)
if __name__ == '__main__':
for seed in seed_list:
print('SEED: %i' % (seed,))
for init_std_y in init_std_y_list:
for std_x, deming_scale in zip(std_x_list, deming_scale_list):
rmse_chain.clear()
print('->->Using std_x=%.2f and init_std_y=%.2f'\
'and deming_scale %.2f<-<-<-<-'
%(std_x, init_std_y, deming_scale))
train_on_data(std_x, init_std_y, deming_scale, seed)