-
Jörg Martin authoredJörg Martin authored
train_noneiv.py 8.15 KiB
"""
Train non-EiV model using different seeds
"""
import random
import importlib
import os
import argparse
import json
import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from EIVArchitectures import Networks, initialize_weights
from EIVTrainingRoutines import train_and_store, loss_functions
# read in data via --data option
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="Loads data", default='california')
parser.add_argument("--no-autoindent", help="",
action="store_true") # to avoid conflics in IPython
args = parser.parse_args()
data = args.data
# load hyperparameters from JSON file
with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file:
conf_dict = json.load(conf_file)
long_dataname = conf_dict["long_dataname"]
short_dataname = conf_dict["short_dataname"]
lr = conf_dict["lr"]
batch_size = conf_dict["batch_size"]
test_batch_size = conf_dict["test_batch_size"]
number_of_epochs = conf_dict["number_of_epochs"]
unscaled_reg = conf_dict["unscaled_reg"]
report_point = conf_dict["report_point"]
p = conf_dict["p"]
lr_update = conf_dict["lr_update"]
# offset before updating sigma_y after each epoch
std_y_update_points = conf_dict["std_y_update_points"]
# will be used to predict the RMSE and update sigma_y accordingly
noneiv_prediction_number_of_draws = conf_dict["noneiv_prediction_number_of_draws"]
noneiv_prediction_number_of_batches = \
conf_dict["noneiv_prediction_number_of_batches"]
init_std_y_list = conf_dict["init_std_y_list"]
gamma = conf_dict["gamma"]
hidden_layers = conf_dict["hidden_layers"]
seed_range = conf_dict['seed_range']
print(f"Training on {long_dataname} data")
try:
gpu_number = conf_dict["gpu_number"]
device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available()
else 'cpu')
except KeyError:
device = torch.device('cpu')
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
# reproducability
seed_list = range(seed_range[0], seed_range[1])
def set_seeds(seed):
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
# to store the RMSE
rmse_chain = []
class UpdatedTrainEpoch(train_and_store.TrainEpoch):
def pre_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
"""
if epoch == 0:
self.lr = self.initial_lr
self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, gamma)
def update_std_y(self, net):
"""
Update the std_y of `net` via the RMSE of the prediction.
"""
net_train_state = net.training
net.train()
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= noneiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=noneiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if not net_train_state:
net.eval()
def check_if_update_std_y(self, epoch):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch >= std_y_update_points[0]\
and epoch % std_y_update_points[1] == 0
def post_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
"""
if self.check_if_update_std_y(epoch):
self.update_std_y(net)
self.lr_scheduler.step()
def post_train_update(self, net, epoch):
"""
Overwrites the corresponding method. If std_y of `net` was not updated
in the last training step, update it when finished with training.
`epoch` should be the number of the last training epoch.
"""
if not self.check_if_update_std_y(epoch):
self.update_std_y(net)
def extra_report(self, net, i):
"""
Overwrites the corresponding method
and fed after initialization of this class
"""
rmse = self.rmse(net).item()
rmse_chain.append(rmse)
writer.add_scalar('RMSE', rmse, self.total_count)
writer.add_scalar('std_y', self.last_std_y, self.total_count)
writer.add_scalar('train loss', self.last_train_loss, self.total_count)
writer.add_scalar('test loss', self.last_test_loss, self.total_count)
print(f'RMSE {rmse:.3f}')
def rmse(self, net):
"""
Compute the root mean squared error for `net`
"""
net_train_state = net.training
net.eval()
x, y = next(iter(self.test_dataloader))
if len(y.shape) <= 1:
y = y.view((-1,1))
out = net(x.to(device))[0].detach().cpu()
assert out.shape == y.shape
if net_train_state:
net.train()
return torch.sqrt(torch.mean((out-y)**2))
def train_on_data(init_std_y, seed):
"""
Sets `seed`, loads data and trains an Bernoulli Modell, starting with
`init_std_y`.
"""
# set seed
set_seeds(seed)
# load Datasets
train_data, test_data = load_data(seed=seed, splitting_part=0.8,
normalize=True)
# make dataloaders
train_dataloader = DataLoader(train_data, batch_size=batch_size,
shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=test_batch_size,
shuffle=True)
# create a net
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()
net = Networks.FNNBer(p=p,
init_std_y=init_std_y,
h=[input_dim, *hidden_layers, output_dim])
net.apply(initialize_weights.glorot_init)
net = net.to(device)
net.std_y_par.requires_grad = False
std_x_map = lambda: 0.0
std_y_map = lambda: net.get_std_y().detach().cpu().item()
# regularization
reg = unscaled_reg/len(train_data)
# create epoch_map
criterion = loss_functions.nll_reg_loss
epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
lr=lr, reg=reg, report_point=report_point, device=device)
# run and save
save_file = os.path.join('saved_networks',
f'noneiv_{short_dataname}'\
f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
f'_p_{p:.2f}_seed_{seed}.pkl')
train_and_store.train_and_store(net=net,
epoch_map=epoch_map,
number_of_epochs=number_of_epochs,
save_file=save_file)
if __name__ == '__main__':
for seed in seed_list:
# Tensorboard monitoring
writer = SummaryWriter(log_dir=f'/home/martin09/tmp/tensorboard/'\
f'run_noneiv_{short_dataname}_lr_{lr:.4f}_seed'\
f'_{seed}_uregu_{unscaled_reg:.1f}_p_{p:.2f}')
print(f'>>>>SEED: {seed}')
for init_std_y in init_std_y_list:
print(f'Using init_std_y={init_std_y:.3f}')
train_on_data(init_std_y, seed)