import importlib
import os

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store


long_dataname = 'energy_efficiency'
short_dataname = 'energy'

load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
train_noneiv = importlib.import_module(f'train_noneiv_{short_dataname}')
train_eiv = importlib.import_module(f'train_eiv_{short_dataname}')

train_data, test_data = load_data()
test_dataloader = DataLoader(test_data, batch_size=int(np.max((len(test_data),
    800))))
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

def collect_metrics(x,y, seed=0,
        noneiv_number_of_draws=500, eiv_number_of_draws=500,
        decouple_dimensions=False):
    """
    :param x: A torch.tensor, taken as input
    :param y: A torch.tensor, taken as output
    :param seed: Integer. The seed used for loading, defaults to 0.
    :param noneiv_number_of_draws: Number of draws for non-EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param noneiv_number_of_draws: Number of draws for EiV model
    for sampling from the posterior predictive. Defaults to 500.
    :param decouple_dimensions: Boolean. If True, the unsual convention
    of Gal et al. is followed where, in the evaluation of the
    log-posterior-predictive, each dimension is treated independently and then
    averaged. If False (default), a multivariate distribution is used.
    :returns: noneiv_rmse, noneiv_logdens,eiv_rmse, eiv_logdens
    """
    init_std_y = train_noneiv.init_std_y_list[0]
    unscaled_reg = train_noneiv.unscaled_reg
    p = train_noneiv.p
    hidden_layers = train_noneiv.hidden_layers
    saved_file = os.path.join('saved_networks',
                f'noneiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim])
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)


    # RMSE
    training_state = net.training
    net.train()
    out = net.predict(x, number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=True)[0]
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == out.shape
    res = y-out
    scale = train_data.dataset.std_labels
    scaled_res = res * scale.view((1,-1))
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    noneiv_rmse = np.sqrt(np.mean(scaled_res**2)) 


    # NLL
    training_state = net.training
    net.train()
    noneiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=train_data.dataset.std_labels.view((-1,))).mean()
    if training_state:
        net.train()
    else:
        net.eval()

    # EiV
    init_std_y = train_eiv.init_std_y_list[0]
    unscaled_reg = train_eiv.unscaled_reg
    p = train_eiv.p
    hidden_layers = train_eiv.hidden_layers
    fixed_std_x = train_eiv.fixed_std_x
    saved_file = os.path.join('saved_networks',
                f'eiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim], fixed_std_x=fixed_std_x)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)
    # RMSE
    training_state = net.training
    noise_state = net.noise_is_on
    net.train()
    net.noise_on()
    out = net.predict(x, number_of_draws=eiv_number_of_draws,
            take_average_of_prediction=True)[0]
    if len(y.shape) <=1:
        y = y.view((-1,1))
    assert y.shape == out.shape
    res = y-out
    scale = train_data.dataset.std_labels
    scaled_res = res * scale.view((1,-1))
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    eiv_rmse = np.sqrt(np.mean(scaled_res**2)) 
    if training_state:
        net.train()
    else:
        net.eval()
    if noise_state:
        net.noise_on()
    else:
        net.noise_off()


    # NLL
    training_state = net.training
    net.train()
    eiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=train_data.dataset.std_labels.view((-1,))).mean()
    if training_state:
        net.train()
    else:
        net.eval()
    return noneiv_rmse, noneiv_logdens, eiv_rmse, eiv_logdens

noneiv_rmse_collection = []
noneiv_logdens_collection = []
eiv_rmse_collection = []
eiv_logdens_collection = []
number_of_samples = 20
for _ in tqdm(range(number_of_samples)):
    x,y = next(iter(test_dataloader))
    noneiv_rmse, noneiv_logdens, eiv_rmse, eiv_logdens = collect_metrics(x,y)
    noneiv_rmse_collection.append(noneiv_rmse)
    noneiv_logdens_collection.append(noneiv_logdens)
    eiv_rmse_collection.append(eiv_rmse)
    eiv_logdens_collection.append(eiv_logdens)


print('Non-EiV')
print(f'RMSE {np.mean(noneiv_rmse_collection):.3f} ({np.std(noneiv_rmse_collection)/np.sqrt(number_of_samples):.3f})')
print(f'LogDens {np.mean(noneiv_logdens_collection):.3f} ({np.std(noneiv_logdens_collection)/np.sqrt(number_of_samples):.3f})')
print('EiV')
print(f'RMSE {np.mean(eiv_rmse_collection):.3f} ({np.std(eiv_rmse_collection)/np.sqrt(number_of_samples):.3f})')
print(f'LogDens {np.mean(eiv_logdens_collection):.3f} ({np.std(eiv_logdens_collection)/np.sqrt(number_of_samples):.3f})')