"""
Compute metrics for datasets for which there is not necessarily a ground truth.
Results will be stored in the results folder
"""
import importlib
import os
import argparse
import json

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store
from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std

# read in data via --data option
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="Loads data", default='linear')
parser.add_argument("--no-autoindent", help="",
        action="store_true") # to avoid conflics in IPython
args = parser.parse_args()
data = args.data

# load hyperparameters from JSON file
with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file:
    eiv_conf_dict = json.load(conf_file)
with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file:
    noneiv_conf_dict = json.load(conf_file)

long_dataname = eiv_conf_dict["long_dataname"]
short_dataname = eiv_conf_dict["short_dataname"]

print(f"Evaluating {long_dataname}")

scale_outputs = False 
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data

train_data, test_data = load_data()
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

try:
    gpu_number = eiv_conf_dict["gpu_number"]
    device = torch.device(f'cuda:{gpu_number}')
    try:
        torch.tensor([0.0]).to(device)
    except RuntimeError:
        if torch.cuda.is_available():
            print('Switched to GPU 0')
            device = torch.device('cuda:0')
        else:
            print('No cuda available, using CPU')
            device = torch.device('cpu')
except KeyError:
    device = torch.device('cpu')


def collect_metrics(x,y, seed=0,
    noneiv_number_of_draws=100, eiv_number_of_draws=[100,5],
    decouple_dimensions=False, device=device,
    scale_outputs=scale_outputs):
    """
    Compute various metrics for EiV and non-EiV. Will be returned as
    dictionaries.
    :param x: A torch.tensor, taken as input
    :param y: A torch.tensor, taken as output
    :param seed: Integer. The seed used for loading, defaults to 0.
    :param noneiv_number_of_draws: Number of draws for non-EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param noneiv_number_of_draws: Number of draws for EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param decouple_dimensions: Boolean. If True, the unsual convention
    of Gal et al. is followed where, in the evaluation of the
    log-posterior-predictive, each dimension is treated independently and then
    averaged. If False (default), a multivariate distribution is used.
    :param device: The device to use.
    :param scale_output: Boolean, scale the outputs for the RMSE, the bias and
    the log-dens to make them comparable with the literature.
    :returns: Dictionaries noneiv_metrics, eiv_metrics
    """
    x,y = x.to(device), y.to(device)


    # non-EiV
    noneiv_metrics = {}
    init_std_y = noneiv_conf_dict["init_std_y_list"][0]
    unscaled_reg = noneiv_conf_dict["unscaled_reg"]
    p = noneiv_conf_dict["p"]
    hidden_layers = noneiv_conf_dict["hidden_layers"]
    saved_file = os.path.join('saved_networks',
                f'noneiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim]).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net, device=device)


    # RMSE
    training_state = net.training
    net.train()
    not_averaged_predictions = net.predict(x, number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=False)
    noneiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == noneiv_mean.shape
    res = y-noneiv_mean
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    noneiv_metrics['rmse'] = np.sqrt(np.mean(scaled_res**2))
    noneiv_metrics['bias'] = np.mean(scaled_res)
    noneiv_metrics['coverage_numerical'], noneiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=False)
    noneiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=True)
    noneiv_metrics['res_std'] = normalized_std(not_averaged_predictions, y)
    


    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    noneiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()

    # EiV
    eiv_metrics = {}
    init_std_y = eiv_conf_dict["init_std_y_list"][0]
    unscaled_reg = eiv_conf_dict["unscaled_reg"]
    p = eiv_conf_dict["p"]
    hidden_layers = eiv_conf_dict["hidden_layers"]
    fixed_std_x = eiv_conf_dict["fixed_std_x"]
    saved_file = os.path.join('saved_networks',
            f'eiv_{short_dataname}'\
                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                    f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                    f'_seed_{seed}.pkl')
    net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim],
            fixed_std_x=fixed_std_x).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)

    # RMSE
    training_state = net.training
    noise_state = net.noise_is_on
    net.train()
    net.noise_on()
    not_averaged_predictions = net.predict(x, number_of_draws=eiv_number_of_draws, 
            take_average_of_prediction=False)
    eiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == eiv_mean.shape
    res = y-eiv_mean
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    eiv_metrics['rmse' ]= np.sqrt(np.mean(scaled_res**2))
    eiv_metrics['bias' ]= np.mean(scaled_res)
    eiv_metrics['coverage_numerical'], eiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=False)
    eiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=True)
    eiv_metrics['res_std' ]= normalized_std(not_averaged_predictions, y)


    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    eiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=eiv_number_of_draws,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()
    if noise_state:
        net.noise_on()
    else:
        net.noise_off()
    return noneiv_metrics, eiv_metrics


collection_keys = ['rmse','logdens','bias','coverage_numerical',
        'coverage_theory','coverage_normalized','res_std']
noneiv_metrics_collection = {}
eiv_metrics_collection = {}
for key in collection_keys:
    noneiv_metrics_collection[key] = []
    eiv_metrics_collection[key] = []
num_test_epochs = 10
assert noneiv_conf_dict["seed_range"] == eiv_conf_dict["seed_range"]
seed_list = range(noneiv_conf_dict["seed_range"][0],
        noneiv_conf_dict["seed_range"][1])
max_batch_number = 2
for seed in tqdm(seed_list):
    train_data, test_data = load_data(seed=seed)
    test_dataloader = DataLoader(test_data,
            batch_size=int(np.min((len(test_data),
        800))), shuffle=True)
    for i in tqdm(range(num_test_epochs)):
        for j, (x,y) in enumerate(test_dataloader):
            if j > max_batch_number:
                break

            noneiv_metrics, eiv_metrics = collect_metrics(x,y, seed=seed)
            for key in collection_keys:
                noneiv_metrics_collection[key].append(noneiv_metrics[key])
                eiv_metrics_collection[key].append(eiv_metrics[key])

results_dict = {}
print('Non-EiV:\n-----')
results_dict['noneiv'] = {}
for key in collection_keys:
    metric_mean = float(np.mean(noneiv_metrics_collection[key]))
    metric_std  = float(np.std(noneiv_metrics_collection[key])/np.sqrt(num_test_epochs*len(seed_list)))
    results_dict['noneiv'][key] = (metric_mean, metric_std)
    print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
print('\n')
print('EiV:\n-----')
results_dict['eiv'] = {}
for key in collection_keys:
    metric_mean = float(np.mean(eiv_metrics_collection[key]))
    metric_std  = float(np.std(eiv_metrics_collection[key])/np.sqrt(num_test_epochs*len(seed_list)))
    print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
    results_dict['eiv'][key] = (metric_mean, metric_std)

# write results to a JSON file in the results folder
with open(os.path.join('results',f'metrics_{short_dataname}.json'), 'w') as f:
    json.dump(results_dict, f)