Skip to content
Snippets Groups Projects
train_noneiv_energy.py 4.63 KiB
"""
Train non-EiV model on energy efficiency dataset using different seeds
"""
import random
import os

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

from EIVArchitectures import Networks, initialize_weights
from EIVData.energy_efficiency import load_data
from EIVTrainingRoutines import train_and_store, loss_functions

# hyperparameters
lr = 1e-3
batch_size = 32
test_batch_size = 600
number_of_epochs = 600
unscaled_reg = 10
report_point = 5
p = 0.2
lr_update = 100
# pretraining = 300
epoch_offset = 100
init_std_y_list = [0.5]
gamma = 0.5
hidden_layers = [1024, 1024, 1024, 1024]
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# reproducability
def set_seeds(seed):
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed) 
    torch.manual_seed(seed)
seed_list = range(10)

# to store the RMSE
rmse_chain = []

class UpdatedTrainEpoch(train_and_store.TrainEpoch):
    def pre_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method
        """
        if epoch == 0:
            self.lr = self.initial_lr
            self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, lr_update, gamma)


    def post_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method
        """
        if epoch >= epoch_offset:
            net.std_y_par.requires_grad = True
        self.lr_scheduler.step() 

    def extra_report(self, net, i):
        """
        Overwrites the corresponding method
        and fed after initialization of this class
        """
        rmse = self.rmse(net).item()
        rmse_chain.append(rmse)
        writer.add_scalar('RMSE', rmse, self.total_count)
        writer.add_scalar('train loss', self.last_train_loss, self.total_count)
        writer.add_scalar('test loss', self.last_test_loss, self.total_count)
        print(f'RMSE {rmse:.3f}')

    def rmse(self, net):
        """
        Compute the root mean squared error for `net`
        """
        net_train_state = net.training
        net.eval()
        x, y = next(iter(self.test_dataloader))
        if len(y.shape) <= 1:
            y = y.view((-1,1))
        out = net(x.to(device))[0].detach().cpu()
        assert out.shape == y.shape
        if net_train_state:
            net.train()
        return torch.sqrt(torch.mean((out-y)**2))

def train_on_data(init_std_y, seed):
    """
    Sets `seed`, loads data and trains an Bernoulli Modell, starting with
    `init_std_y`.
    """
    # set seed
    set_seeds(seed)
    # load Datasets
    train_data, test_data = load_data(seed=seed, splitting_part=0.8,
            normalize=True)
    # make dataloaders
    train_dataloader = DataLoader(train_data, batch_size=batch_size, 
            shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=test_batch_size,
            shuffle=True)
    # create a net
    input_dim = train_data[0][0].numel()
    output_dim = train_data[0][1].numel()
    net = Networks.FNNBer(p=p,
            init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim])
    net.apply(initialize_weights.glorot_init)
    net = net.to(device)
    net.std_y_par.requires_grad = False
    std_x_map = lambda: 0.0
    std_y_map = lambda: net.get_std_y().detach().cpu().item()
    # regularization
    reg = unscaled_reg/len(train_data)
    # create epoch_map
    criterion = loss_functions.nll_reg_loss
    epoch_map = UpdatedTrainEpoch(train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
            lr=lr, reg=reg, report_point=report_point, device=device)
    # run and save
    save_file = os.path.join('saved_networks',
            f'noneiv_energy'\
                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                    f'_p_{p:.2f}_seed_{seed}.pkl')
    train_and_store.train_and_store(net=net, 
            epoch_map=epoch_map,
            number_of_epochs=number_of_epochs,
            save_file=save_file)
    

if __name__ == '__main__':
    for seed in seed_list:
        # Tensorboard monitoring
        writer = SummaryWriter(log_dir=f'/home/martin09/tmp/tensorboard/'\
                f'run_noneiv_energy_lr_{lr:.4f}_seed'\
                f'_{seed}_uregu_{unscaled_reg:.1f}_p_{p:.2f}')
        print(f'>>>>SEED: {seed}')
        for init_std_y in init_std_y_list:
            print(f'Using init_std_y={init_std_y:.3f}')
            train_on_data(init_std_y, seed)