"""
Compute metrics for datasets for which there is not necessarily a ground truth.
Results will be stored in the results folder
"""
import importlib
import os
import argparse
import json

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store
from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std
from EIVData.repeated_sampling import repeated_sampling

# read in data via --data option
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="Loads data", default='replin')
parser.add_argument("--no-autoindent", help="",
        action="store_true") # to avoid conflics in IPython
args = parser.parse_args()
data = args.data

# load hyperparameters from JSON file
with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file:
    eiv_conf_dict = json.load(conf_file)
with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file:
    noneiv_conf_dict = json.load(conf_file)

long_dataname = eiv_conf_dict["long_dataname"]
short_dataname = eiv_conf_dict["short_dataname"]

print(f"Evaluating {long_dataname}")

scale_outputs = False 
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data

train_data, test_data = load_data()
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

try:
    gpu_number = eiv_conf_dict["gpu_number"]
    device = torch.device(f'cuda:{gpu_number}')
    try:
        torch.tensor([0.0]).to(device)
    except RuntimeError:
        if torch.cuda.is_available():
            print('Switched to GPU 0')
            device = torch.device('cuda:0')
        else:
            print('No cuda available, using CPU')
            device = torch.device('cpu')
except KeyError:
    device = torch.device('cpu')


def collect_metrics(x_y_pairs, seed=0,
    noneiv_number_of_draws=100, eiv_number_of_draws=[100,5],
    decouple_dimensions=False, device=device,
    scale_outputs=scale_outputs):
    """
    Compute various metrics for EiV and non-EiV for single seeds. Will be
    returned as dictionaries.
    :param x_y_pairs: A tuple of either the shape (None,None,x,y) or 
    (x_true,y_true,x,y) containing torch.tensor or None. x and y are
    considered as input and corresponding label. If the first two components
    are not None, they are considered to be the unnoisy counterparts.
    :param seed: Integer. The seed used for loading, defaults to 0.
    :param noneiv_number_of_draws: Number of draws for non-EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param noneiv_number_of_draws: Number of draws for EiV model
    for sampling from the posterior predictive. Defaults to 100.
    :param decouple_dimensions: Boolean. If True, the unsual convention
    of Gal et al. is followed where, in the evaluation of the
    log-posterior-predictive, each dimension is treated independently and then
    averaged. If False (default), a multivariate distribution is used.
    :param device: The device to use.
    :param scale_output: Boolean, scale the outputs for the RMSE, the bias and
    the log-dens to make them comparable with the literature.
    :returns: Dictionaries noneiv_metrics, eiv_metrics
    """
    true_x, true_y, x, y = x_y_pairs
    x,y = x.to(device), y.to(device)
    if true_x is not None:
        assert true_y is not None
        true_x,true_y = true_x.to(device), true_y.to(device)
    else:
        assert true_y is None

    # non-EiV
    noneiv_metrics = {}
    init_std_y = noneiv_conf_dict["init_std_y_list"][0]
    unscaled_reg = noneiv_conf_dict["unscaled_reg"]
    p = noneiv_conf_dict["p"]
    hidden_layers = noneiv_conf_dict["hidden_layers"]
    saved_file = os.path.join('saved_networks',
                f'noneiv_{short_dataname}'\
                        f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                        f'_p_{p:.2f}_seed_{seed}.pkl')
    net = Networks.FNNBer(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim]).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net, device=device)


    # RMSE
    training_state = net.training
    net.train()
    not_averaged_predictions = net.predict(x,\
            number_of_draws=noneiv_number_of_draws, 
            take_average_of_prediction=False)
    noneiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == noneiv_mean.shape
    res = y-noneiv_mean
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    noneiv_metrics['rmse'] = np.sqrt(np.mean(scaled_res**2))
    noneiv_metrics['bias'] = np.mean(scaled_res)
    noneiv_metrics['coverage_numerical'], noneiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y,\
            normalize_errors=False)
    noneiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y,\
            normalize_errors=True)
    noneiv_metrics['res_std'] = normalized_std(not_averaged_predictions, y)

    # metrics that need a ground truth
    if true_x is not None:
        noneiv_metrics['true_coverage_numerical'],\
                noneiv_metrics['true_coverage_theory'] =\
                epistemic_coverage(not_averaged_predictions, true_y,
                        normalize_errors=False, noisy_y=False)
        true_res = (true_y - noneiv_mean).detach().cpu().numpy().flatten()
        noneiv_metrics['true_rmse'] = np.sqrt(np.mean(true_res**2))


    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    noneiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=100,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()

    # EiV
    eiv_metrics = {}
    init_std_y = eiv_conf_dict["init_std_y_list"][0]
    unscaled_reg = eiv_conf_dict["unscaled_reg"]
    p = eiv_conf_dict["p"]
    hidden_layers = eiv_conf_dict["hidden_layers"]
    fixed_std_x = eiv_conf_dict["fixed_std_x"]
    saved_file = os.path.join('saved_networks',
            f'eiv_{short_dataname}'\
                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                    f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                    f'_seed_{seed}.pkl')
    net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
            h=[input_dim, *hidden_layers, output_dim],
            fixed_std_x=fixed_std_x).to(device)
    train_and_store.open_stored_training(saved_file=saved_file,
            net=net)

    # RMSE
    training_state = net.training
    noise_state = net.noise_is_on
    net.train()
    net.noise_on()
    not_averaged_predictions = net.predict(x, number_of_draws=eiv_number_of_draws, 
            take_average_of_prediction=False)
    eiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == eiv_mean.shape
    res = y-eiv_mean
    if scale_outputs:
        scale = train_data.dataset.std_labels.to(device)
        scaled_res = res * scale.view((1,-1))
    else:
        scaled_res = res
    scaled_res = scaled_res.detach().cpu().numpy().flatten()
    eiv_metrics['rmse' ]= np.sqrt(np.mean(scaled_res**2))
    eiv_metrics['bias' ]= np.mean(scaled_res)
    eiv_metrics['coverage_numerical'], eiv_metrics['coverage_theory'] =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=False)
    eiv_metrics['coverage_normalized'],_ =\
            epistemic_coverage(not_averaged_predictions, y, normalize_errors=True)
    eiv_metrics['res_std' ]= normalized_std(not_averaged_predictions, y)

    # metrics that need a ground truth
    if true_x is not None:
        eiv_metrics['true_coverage_numerical'],\
                eiv_metrics['true_coverage_theory'] =\
                epistemic_coverage(not_averaged_predictions, true_y,
                        normalize_errors=False, noisy_y=False)

        true_res = (true_y - eiv_mean).detach().cpu().numpy().flatten()
        eiv_metrics['true_rmse'] = np.sqrt(np.mean(true_res**2))

    # NLL
    if scale_outputs:
        scale_labels = train_data.dataset.std_labels.view((-1,)).to(device)
    else:
        scale_labels = None
    eiv_metrics['logdens' ]= net.predictive_logdensity(
            not_averaged_predictions, y,
            number_of_draws=eiv_number_of_draws,
            decouple_dimensions=decouple_dimensions,
            scale_labels=scale_labels).mean().detach().cpu().numpy()
    if training_state:
        net.train()
    else:
        net.eval()
    if noise_state:
        net.noise_on()
    else:
        net.noise_off()
    return noneiv_metrics, eiv_metrics



def collect_full_seed_range_metrics(load_data,
        seed_range,test_batch_size = 100, test_samples = 10,
        noneiv_number_of_draws=100, eiv_number_of_draws=[100,5], device=device,
        scale_outputs=scale_outputs):
    """
    Collect metrics that need all seeds for their computation.
    :param load_data: load_data map should take seed as an argument and,
    optionally, `return_ground_truth`.
    :param seed_range: iterator for seeds.
    :param test_batch_size: An integer, used for drawing samples from the test
    data.
    :param test_samples: Number of test samples with batch size
    `test_batch_size` to take.
    :param noneiv_number_of_draws: Number of samples to take for the prediction
    of the non-EiV model. Defaults to 100.
    :param eiv_number_of_draws:Number of samples to take for the prediction
    of the model. Defaults to [100,5].
    :param device: The torch.device to use
    :param scale_output: Boolean, scale the outputs for some metrics. Defaults
    to False.
    :returns: Dictionaries noneiv_metrics, eiv_metrics
    """
    noneiv_metrics = {}
    eiv_metrics = {}
    noneiv_residual_collection = []
    eiv_residual_collection = []
    for i, seed in enumerate(seed_range):
        # load data according toseed
        try:
            train_data, test_data, true_train_data, true_test_data \
                    = load_data(seed=seed, return_ground_truth=True)
        except TypeError:
            train_data, test_data = load_data(seed=seed)
            true_train_data, true_test_data = None, None

        ## Compute x-dependant bias

        # only for repeated_sampling datasets
        if type(load_data) == repeated_sampling:
            # only if there is a ground truth
            if true_test_data is not None:
                # non-EiV
                init_std_y = noneiv_conf_dict["init_std_y_list"][0]
                unscaled_reg = noneiv_conf_dict["unscaled_reg"]
                p = noneiv_conf_dict["p"]
                hidden_layers = noneiv_conf_dict["hidden_layers"]
                saved_file = os.path.join('saved_networks',
                            f'noneiv_{short_dataname}'\
                                    f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                                    f'_p_{p:.2f}_seed_{seed}.pkl')
                net = Networks.FNNBer(p=p, init_std_y=init_std_y,
                        h=[input_dim, *hidden_layers, output_dim]).to(device)
                # load network
                train_and_store.open_stored_training(saved_file=saved_file,
                        net=net, device=device)

                true_test_dataloader = DataLoader(true_test_data,
                    batch_size=int(np.min((len(test_data), test_batch_size))),
                    shuffle=False)
                # to collect x-dependant residuals
                true_scaled_res_collection = []
                # variable to be used for checking
                # that we loop over the same true_x for each seed
                noneiv_true_x_sum = 0
                for j, (true_x, true_y, noisy_x, _) in\
                        enumerate(true_test_dataloader):
                    if j >= test_samples:
                        break
                    # store the sum of the true_x
                    noneiv_true_x_sum += true_x.abs().sum().item()
                    
                    true_x, true_y, noisy_x =\
                            true_x.to(device), true_y.to(device),\
                            noisy_x.to(device)
                
                    # Residuals
                    training_state = net.training
                    net.train()
                    not_averaged_predictions = net.predict(noisy_x,\
                            number_of_draws=noneiv_number_of_draws, 
                            take_average_of_prediction=False)
                    noneiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
                    if len(true_y.shape) <= 1:
                        true_y = true_y.view((-1,1))
                    assert true_y.shape == noneiv_mean.shape
                    true_res = true_y - noneiv_mean
                    if scale_outputs:
                        scale = train_data.dataset.std_labels.to(device)
                        true_scaled_res = true_res * scale.view((1,-1))
                    else:
                        true_scaled_res = true_res

                    # append residual
                    true_scaled_res_collection.append(true_scaled_res)

                    # restore net
                    if training_state:
                        net.train()
                    else:
                        net.eval()
                if i>0:
                    # check that the used true x are the same for each
                    # seed, by comparing their sum
                    assert noneiv_true_x_sum == old_noneiv_true_x_sum
                old_noneiv_true_x_sum = noneiv_true_x_sum
                
                # concatenate batches along batch dimension
                true_scaled_res_collection =\
                        torch.concat(true_scaled_res_collection, dim=0)
                noneiv_residual_collection.append(true_scaled_res_collection)


                # EiV
                init_std_y = eiv_conf_dict["init_std_y_list"][0]
                unscaled_reg = eiv_conf_dict["unscaled_reg"]
                p = eiv_conf_dict["p"]
                hidden_layers = eiv_conf_dict["hidden_layers"]
                fixed_std_x = eiv_conf_dict["fixed_std_x"]
                saved_file = os.path.join('saved_networks',
                        f'eiv_{short_dataname}'\
                                f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                                f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                                f'_seed_{seed}.pkl')
                net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
                        h=[input_dim, *hidden_layers, output_dim],
                        fixed_std_x=fixed_std_x).to(device)
                # load network
                train_and_store.open_stored_training(saved_file=saved_file,
                        net=net, device=device)

                # reinitialize dataloader to get the same true_x
                true_test_dataloader = DataLoader(true_test_data,
                    batch_size=int(np.min((len(test_data), test_batch_size))),
                    shuffle=False)
                true_scaled_res_collection = []
                # variable to be used for checking
                # that we loop over the same true_x for each seed
                eiv_true_x_sum = 0
                for j, (true_x, true_y, noisy_x, _) in\
                        enumerate(true_test_dataloader):
                    if j >= test_samples:
                        break
                    # store the sum of the true_x
                    eiv_true_x_sum += true_x.abs().sum().item()
                    true_x, true_y, noisy_x =\
                            true_x.to(device), true_y.to(device),\
                            noisy_x.to(device)
                    # Residuals
                    training_state = net.training
                    noise_state = net.noise_is_on
                    net.train()
                    net.noise_on()
                    not_averaged_predictions = net.predict(noisy_x,\
                            number_of_draws=eiv_number_of_draws, 
                            take_average_of_prediction=False)
                    eiv_mean = torch.mean(not_averaged_predictions[0], dim=1)
                    if len(true_y.shape) <= 1:
                        true_y = true_y.view((-1,1))
                    assert true_y.shape == eiv_mean.shape
                    true_res = true_y - eiv_mean
                    if scale_outputs:
                        scale = train_data.dataset.std_labels.to(device)
                        true_scaled_res = true_res * scale.view((1,-1))
                    else:
                        true_scaled_res = true_res
                    # append residuals
                    true_scaled_res_collection.append(true_scaled_res)
                    # restore net
                    if training_state:
                        net.train()
                    else:
                        net.eval()
                    if noise_state:
                        net.noise_on()
                    else:
                        net.noise_off()
                # check whether EiV and non-EiV used the same true_x for each
                # seed by comparing their sum
                assert eiv_true_x_sum == noneiv_true_x_sum
                if i>0:
                    assert eiv_true_x_sum == old_eiv_true_x_sum
                old_eiv_true_x_sum = eiv_true_x_sum
                # concate batches along batch dimension
                true_scaled_res_collection =\
                        torch.concat(true_scaled_res_collection, dim=0)
                eiv_residual_collection.append(true_scaled_res_collection)


    ## Store quantities

    # Compute and store (averaged) x-dependant bias
    if type(load_data) == repeated_sampling and\
            len(noneiv_residual_collection) > 0 and\
            len(eiv_residual_collection) > 0:
        noneiv_residual_collection = torch.stack(\
                tuple(noneiv_residual_collection), dim=-1)
        bias_per_x = torch.mean(noneiv_residual_collection, dim=-1)
        avg_bias = torch.mean(torch.abs(bias_per_x))
        noneiv_metrics['avg_bias'] = avg_bias

        eiv_residual_collection = torch.stack(tuple(eiv_residual_collection),\
                dim=-1)
        bias_per_x = torch.mean(eiv_residual_collection, dim=-1)
        avg_bias = torch.mean(torch.abs(bias_per_x))
        eiv_metrics['avg_bias'] = avg_bias
    return noneiv_metrics, eiv_metrics

# single seed metrics
noneiv_metrics_collection = {}
eiv_metrics_collection = {}
collection_keys = []
num_test_epochs = 10
assert noneiv_conf_dict["seed_range"] == eiv_conf_dict["seed_range"]
seed_list = range(noneiv_conf_dict["seed_range"][0],
        noneiv_conf_dict["seed_range"][1])









max_batch_number = 2
for seed in tqdm(seed_list):
    try:
        train_data, test_data, true_train_data, true_test_data \
                = load_data(seed=seed, return_ground_truth=True)
    except TypeError:
        train_data, test_data = load_data(seed=seed)
        true_train_data, true_test_data = None, None
    if true_test_data is None:
        test_dataloader = DataLoader(test_data,
            batch_size=int(np.min((len(test_data),
        800))), shuffle=True)
    else:
        test_dataloader = DataLoader(true_test_data,
                batch_size=int(np.min((len(true_test_data), 800))), shuffle=True)
    for i in tqdm(range(num_test_epochs)):
        for j, x_y_pairs in enumerate(test_dataloader):
            if j > max_batch_number:
                break
            # fill in ground truth with None, if not existent
            if true_test_data is None:
                x_y_pairs = (None, None, *x_y_pairs)
            # should contain (true_x,true_y,x,y) or (None,None,x,y)
            assert len(x_y_pairs) == 4
            noneiv_metrics, eiv_metrics = collect_metrics(x_y_pairs,
                    seed=seed)
            if i==0 and j==0:
                # fill collection keys
                assert eiv_metrics.keys() == noneiv_metrics.keys()
                collection_keys = list(eiv_metrics.keys())
                for key in collection_keys:
                    noneiv_metrics_collection[key] = []
                    eiv_metrics_collection[key] = []
            # collect results
            for key in collection_keys:
                noneiv_metrics_collection[key].append(noneiv_metrics[key])
                eiv_metrics_collection[key].append(eiv_metrics[key])

# full seed range metrics
print('Computing metrics that use all seeds at once...')
noneiv_full_seed_range_metrics, eiv_full_seed_range_metrics =\
        collect_full_seed_range_metrics(load_data=load_data,\
                seed_range=seed_list)
# add keys to collection_keys
assert noneiv_full_seed_range_metrics.keys() ==\
        eiv_full_seed_range_metrics.keys()
full_seed_range_collection_keys = list(noneiv_full_seed_range_metrics.keys())
collection_keys += full_seed_range_collection_keys


results_dict = {}
print('Non-EiV:\n-----')
results_dict['noneiv'] = {}
for key in collection_keys:
    if key not in full_seed_range_collection_keys:
        # per seed metrics
        metric_mean = float(np.mean(noneiv_metrics_collection[key]))
        metric_std  = float(np.std(noneiv_metrics_collection[key])/\
                np.sqrt(num_test_epochs*len(seed_list)))
        results_dict['noneiv'][key] = (metric_mean, metric_std)
        print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
    else:
        # full seed range metrics (without a std)
        metric = float(noneiv_full_seed_range_metrics[key])
        results_dict['noneiv'][key] = metric
        print(f'{key}: {metric:.5f} (NaN)')

print('\n')
print('EiV:\n-----')
results_dict['eiv'] = {}
for key in collection_keys:
    if key not in full_seed_range_collection_keys:
        # per seed metrics
        metric_mean = float(np.mean(eiv_metrics_collection[key]))
        metric_std  = float(np.std(eiv_metrics_collection[key])/\
                np.sqrt(num_test_epochs*len(seed_list)))
        print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
        results_dict['eiv'][key] = (metric_mean, metric_std)
    else:
        # full seed range metrics (without a std)
        metric = float(eiv_full_seed_range_metrics[key])
        results_dict['eiv'][key] = metric
        print(f'{key}: {metric:.5f} (NaN)')

# write results to a JSON file in the results folder
with open(os.path.join('results',f'metrics_{short_dataname}.json'), 'w') as f:
    json.dump(results_dict, f)