import importlib
import os

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store


long_dataname = 'energy_efficiency'
short_dataname = 'energy'

load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
train_noneiv = importlib.import_module(f'train_noneiv_{short_dataname}')
train_eiv = importlib.import_module(f'train_eiv_{short_dataname}')

train_data, test_data = load_data()
test_dataloader = DataLoader(test_data, batch_size=int(np.max((len(test_data),
    64))), shuffle=True)
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

def collect_metrics(x,y, seed=0,
        noneiv_number_of_draws=100, eiv_number_of_draws=100,
        decouple_dimensions=False, device=torch.device('cuda:1')):
    """
    :param x: A torch.tensor, taken as input
    :param y: A torch.tensor, taken as output
    :param seed: Integer. The seed used for loading, defaults to 0.
    :param noneiv_number_of_draws: Number of draws for non-EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param noneiv_number_of_draws: Number of draws for EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param decouple_dimensions: Boolean. If True, the unsual convention
    of Gal et al. is followed where, in the evaluation of the
    log-posterior-predictive, each dimension is treated independently and then
    averaged. If False (default), a multivariate distribution is used.
    :returns: noneiv_rmse, noneiv_logdens,eiv_rmse, eiv_logdens
    """
    x,y = x.to(device), y.to(device)
    init_std_y = train_noneiv.init_std_y_list[0]
    unscaled_reg = train_noneiv.unscaled_reg
    p = train_noneiv.p
    hidden_layers = train_noneiv.hidden_layers
    saved_file = os.path.join('saved_networks',
                f'noneiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim]).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net, device=device)


    # RMSE
    training_state = net.training
    net.train()
    out = net.predict(x, number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=True)[0]
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == out.shape
    res = y-out
    scale = train_data.dataset.std_labels.to(device)
    scaled_res = res * scale.view((1,-1))
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    noneiv_rmse = np.sqrt(np.mean(scaled_res**2))


    # NLL
    training_state = net.training
    net.train()
    noneiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=\
                   train_data.dataset.std_labels.view((-1,)).to(device)\
                   ).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()

    # EiV
    init_std_y = train_eiv.init_std_y_list[0]
    unscaled_reg = train_eiv.unscaled_reg
    p = train_eiv.p
    hidden_layers = train_eiv.hidden_layers
    fixed_std_x = train_eiv.fixed_std_x
    saved_file = os.path.join('saved_networks',
            f'eiv_energy'\
                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                    f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                    f'_seed_{seed}.pkl')
    net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim],
            fixed_std_x=fixed_std_x).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)
    # RMSE
    training_state = net.training
    noise_state = net.noise_is_on
    net.train()
    net.noise_on()
    out = net.predict(x, number_of_draws=eiv_number_of_draws,
            take_average_of_prediction=True)[0]
    if len(y.shape) <=1:
        y = y.view((-1,1))
    assert y.shape == out.shape
    res = y-out
    scale = train_data.dataset.std_labels.to(device)
    scaled_res = res * scale.view((1,-1))
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    eiv_rmse = np.sqrt(np.mean(scaled_res**2))
    if training_state:
        net.train()
    else:
        net.eval()
    if noise_state:
        net.noise_on()
    else:
        net.noise_off()


    # NLL
    training_state = net.training
    net.train()
    eiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=\
            train_data.dataset.std_labels.view((-1,)).to(device)\
            ).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()
    return noneiv_rmse, noneiv_logdens, eiv_rmse, eiv_logdens

noneiv_rmse_collection = []
noneiv_logdens_collection = []
eiv_rmse_collection = []
eiv_logdens_collection = []
num_test_epochs = 20
assert train_noneiv.seed_list == train_eiv.seed_list
seed_list = train_noneiv.seed_list
for seed in tqdm(seed_list):
    train_data, test_data = load_data(seed=seed)
    test_dataloader = DataLoader(test_data,
            batch_size=int(np.max((len(test_data),
        800))), shuffle=True)
    for i in tqdm(range(num_test_epochs)):
        for x,y in test_dataloader:
            noneiv_rmse, noneiv_logdens, eiv_rmse, eiv_logdens =\
                    collect_metrics(x,y, seed=seed)
            noneiv_rmse_collection.append(noneiv_rmse)
            noneiv_logdens_collection.append(noneiv_logdens)
            eiv_rmse_collection.append(eiv_rmse)
            eiv_logdens_collection.append(eiv_logdens)


# TODO: Despite statistics, the fluctuations seem to be large
print('Non-EiV')
print(f'RMSE {np.mean(noneiv_rmse_collection):.3f}'\
        f'({np.std(noneiv_rmse_collection)/np.sqrt(num_test_epochs):.3f})')
print(f'LogDens {np.mean(noneiv_logdens_collection):.3f}'\
        f'({np.std(noneiv_logdens_collection)/np.sqrt(num_test_epochs):.3f})')
print('EiV')
print(f'RMSE {np.mean(eiv_rmse_collection):.3f}'\
        f'({np.std(eiv_rmse_collection)/np.sqrt(num_test_epochs):.3f})')
print(f'LogDens {np.mean(eiv_logdens_collection):.3f}'\
        f'({np.std(eiv_logdens_collection)/np.sqrt(num_test_epochs):.3f})')