import importlib
import os

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store
from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std

long_dataname = 'energy_efficiency'
short_dataname = 'energy'

scale_outputs = False 
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
train_noneiv = importlib.import_module(f'train_noneiv_{short_dataname}')
train_eiv = importlib.import_module(f'train_eiv_{short_dataname}')

train_data, test_data = load_data()
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

def collect_metrics(x,y, seed=0,
    noneiv_number_of_draws=100, eiv_number_of_draws=[100,5],
    decouple_dimensions=False, device=torch.device('cuda:1'),
    scale_outputs=scale_outputs):
    """
    Compute various metrics for EiV and non-EiV. Will be returned as
    dictionaries.
    :param x: A torch.tensor, taken as input
    :param y: A torch.tensor, taken as output
    :param seed: Integer. The seed used for loading, defaults to 0.
    :param noneiv_number_of_draws: Number of draws for non-EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param noneiv_number_of_draws: Number of draws for EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param decouple_dimensions: Boolean. If True, the unsual convention
    of Gal et al. is followed where, in the evaluation of the
    log-posterior-predictive, each dimension is treated independently and then
    averaged. If False (default), a multivariate distribution is used.
    :param scale_output: Boolean, scale the outputs for the RMSE, the bias and
    the log-dens to make them comparable with the literature.
    :returns: Dictionaries noneiv_metrics, eiv_metrics
    """
    x,y = x.to(device), y.to(device)


    # non-EiV
    noneiv_metrics = {}
    init_std_y = train_noneiv.init_std_y_list[0]
    unscaled_reg = train_noneiv.unscaled_reg
    p = train_noneiv.p
    hidden_layers = train_noneiv.hidden_layers
    saved_file = os.path.join('saved_networks',
                f'noneiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim]).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net, device=device)


    # RMSE
    training_state = net.training
    net.train()
    not_averaged_predictions = net.predict(x, number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=False)
    noneiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == noneiv_mean.shape
    res = y-noneiv_mean
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    noneiv_metrics['rmse'] = np.sqrt(np.mean(scaled_res**2))
    noneiv_metrics['bias'] = np.mean(scaled_res)
    noneiv_metrics['coverage_numerical'], noneiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=False)
    noneiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=True)
    noneiv_metrics['res_std'] = normalized_std(not_averaged_predictions, y)
    


    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    noneiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()

    # EiV
    eiv_metrics = {}
    init_std_y = train_eiv.init_std_y_list[0]
    unscaled_reg = train_eiv.unscaled_reg
    p = train_eiv.p
    hidden_layers = train_eiv.hidden_layers
    fixed_std_x = train_eiv.fixed_std_x
    saved_file = os.path.join('saved_networks',
            f'eiv_{short_dataname}'\
                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                    f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                    f'_seed_{seed}.pkl')
    net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim],
            fixed_std_x=fixed_std_x).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)

    # RMSE
    training_state = net.training
    noise_state = net.noise_is_on
    net.train()
    net.noise_on()
    not_averaged_predictions = net.predict(x, number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=False)
    eiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == eiv_mean.shape
    res = y-eiv_mean
    scale = train_data.dataset.std_labels.to(device)
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    eiv_metrics['rmse' ]= np.sqrt(np.mean(scaled_res**2))
    eiv_metrics['bias' ]= np.mean(scaled_res)
    eiv_metrics['coverage_numerical'], eiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=False)
    eiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=True)
    eiv_metrics['res_std' ]= normalized_std(not_averaged_predictions, y)


    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    eiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=eiv_number_of_draws,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()
    if noise_state:
        net.noise_on()
    else:
        net.noise_off()
    return noneiv_metrics, eiv_metrics


collection_keys = ['rmse','logdens','bias','coverage_numerical',
        'coverage_theory','coverage_normalized','res_std']
noneiv_metrics_collection = {}
eiv_metrics_collection = {}
for key in collection_keys:
    noneiv_metrics_collection[key] = []
    eiv_metrics_collection[key] = []
num_test_epochs = 10
assert train_noneiv.seed_list == train_eiv.seed_list
seed_list = train_noneiv.seed_list
max_batch_number = 2
for seed in tqdm(seed_list):
    train_data, test_data = load_data(seed=seed)
    test_dataloader = DataLoader(test_data,
            batch_size=int(np.min((len(test_data),
        800))), shuffle=True)
    for i in tqdm(range(num_test_epochs)):
        for j, (x,y) in enumerate(test_dataloader):
            if j > max_batch_number:
                break

            noneiv_metrics, eiv_metrics = collect_metrics(x,y, seed=seed)
            for key in collection_keys:
                noneiv_metrics_collection[key].append(noneiv_metrics[key])
                eiv_metrics_collection[key].append(eiv_metrics[key])

print('Non-EiV\n-----')
for key in collection_keys:
    print(f'{key} {np.mean(noneiv_metrics_collection[key]):.5f}'\
            f'({np.std(noneiv_metrics_collection[key])/np.sqrt(num_test_epochs*len(seed_list)):.5f})')
print('EiV\n-----')
for key in collection_keys:
    print(f'{key} {np.mean(eiv_metrics_collection[key]):.5f}'\
            f'({np.std(eiv_metrics_collection[key])/np.sqrt(num_test_epochs*len(seed_list)):.5f})')