import random
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from EIVArchitectures import Networks
from generate_wine_data import train_x, train_y,\
                                       test_x, test_y, train_data, test_data
from EIVTrainingRoutines import train_and_store, loss_functions

# hyperparameters
lr = 1e-3
batch_size = 16
number_of_epochs = 400
reg = 1e-9
report_point = 20
precision_prior_zeta=0.0
dim = train_x.shape[-1]
train_len = train_x.shape[0]
p = 0.5
lr_update = 250
pretraining = 50 
epoch_offset = pretraining
init_std_y_list = [0.15]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# reproducability
torch.backends.cudnn.benchmark = False
def set_seeds(seed):
    np.random.seed(seed)
    random.seed(seed) 
    torch.manual_seed(seed)
seed_list = range(20)

rmse_chain = []


class UpdatedTrainEpoch(train_and_store.TrainEpoch):
    def pre_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method.
        """
        if epoch == 0:
            self.lr = self.initial_lr
            self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, lr_update, 0.1
            )


    def post_epoch_update(self, net, epoch):
        """
        Overwrites the corresponding method
        """
        if epoch >= epoch_offset:
            net.std_y_par.requires_grad = True
        self.lr_scheduler.step()

    def extra_report(self, net, i):
        """
        Overwrites the corresponding method
        **Note**: self.test_couple has to be defined explicitely
        and fed after initialiaztion of this class
        """
        rmse = self.rmse(net).item()
        rmse_chain.append(rmse)
        print('RMSE %.2f', rmse)

    def rmse(self, net):
        """
        Compute the root mean squared error for `net`
        """
        mse = 0 
        net_train_state = net.training
        net.eval()
        x, y = self.test_couple
        out = net(x.to(device))[0].detach().cpu().view((-1,))
        y = y.view((-1,))
        if net_train_state:
            net.train()
        return torch.sqrt(torch.mean((out-y)**2))


def train_on_data(init_std_y, seed):
    """
    Trains a Bernoulli Modell
    """
    set_seeds(seed)
    # make to dataloader
    train_dataloader = DataLoader(train_data, batch_size=batch_size, 
            shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size,
            shuffle=True)
    # Create a net
    net = Networks.FNNBer(init_std_y=init_std_y, h=[dim,200,100,50,1])
    net = net.to(device)
    net.std_y_par.requires_grad = False
    std_x_map = lambda: 0.0
    std_y_map = lambda: net.get_std_y().detach().cpu().item()
    # Create epoch_map
    criterion = loss_functions.nll_reg_loss
    epoch_map = UpdatedTrainEpoch(
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            criterion=criterion, std_y_map=std_y_map, std_x_map=std_x_map,
            lr=lr, reg=reg,report_point=report_point, device=device)
    epoch_map.test_couple = (test_x, test_y)
    # run and save
    save_file = os.path.join('saved_networks',
            'noneiv_wine_init_std_y_%.3f_seed_%i.pkl'\
            % (init_std_y,seed))
    train_and_store.train_and_store(net=net, 
            epoch_map=epoch_map,
            number_of_epochs=number_of_epochs,
            save_file=save_file,
            rmse=rmse_chain)
    

if __name__ == '__main__':
    for seed in seed_list:
        for init_std_y in init_std_y_list:
                print('->->Using  init_std_y=%.2f<-<-<-<-'
                        %(init_std_y))
                train_on_data(init_std_y, seed)