"""
Compute metrics for datasets for which there is not necessarily a ground truth.
Results will be stored in the results folder
"""
import importlib
import os
import argparse
import json

import numpy as np
import torch
import torch.backends.cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store
from EIVGeneral.coverage_plotting import get_coverage_distribution
from EIVGeneral.manipulate_datasets import VerticalCut


# read in data via --data option
parser = argparse.ArgumentParser()
parser.add_argument("--data", help="Loads data", default='naval')
parser.add_argument("--no-autoindent", help="",
        action="store_true") # to avoid conflics in IPython
args = parser.parse_args()
data = args.data

# load hyperparameters from JSON file
with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file:
    eiv_conf_dict = json.load(conf_file)
with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file:
    noneiv_conf_dict = json.load(conf_file)

long_dataname = eiv_conf_dict["long_dataname"]
short_dataname = eiv_conf_dict["short_dataname"]

print(f"Plotting coverage for {long_dataname}")

scale_outputs = False 
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data

try:
    gpu_number = eiv_conf_dict["gpu_number"]
    device = torch.device(f'cuda:{gpu_number}')
    try:
        torch.tensor([0.0]).to(device)
    except RuntimeError:
        if torch.cuda.is_available():
            print('Switched to GPU 0')
            device = torch.device('cuda:0')
        else:
            print('No cuda available, using CPU')
            device = torch.device('cpu')
except KeyError:
    device = torch.device('cpu')


# test whether there is a ground truth
try:
    train_data, test_data, true_train_data, true_test_data \
            = load_data(seed=0, return_ground_truth=True)
    ground_truth_exists = True
except TypeError:
    train_data, test_data = load_data(seed=0)
    true_train_data, true_test_data = None, None
    ground_truth_exists = False

train_data, test_data = load_data()
input_dim = train_data[0][0].numel()
output_dim = train_data[0][1].numel()

## Create iterators
seed_list = range(noneiv_conf_dict["seed_range"][0],
        noneiv_conf_dict["seed_range"][1])

# networks
def net_iterator(eiv=True, seed_list=seed_list):
    if eiv:
        init_std_y = eiv_conf_dict["init_std_y_list"][0]
        unscaled_reg = eiv_conf_dict["unscaled_reg"]
        p = eiv_conf_dict["p"]
        hidden_layers = eiv_conf_dict["hidden_layers"]
        fixed_std_x = eiv_conf_dict["fixed_std_x"]
        net = Networks.FNNEIV(p=p, init_std_y=init_std_y,
                h=[input_dim, *hidden_layers, output_dim],
                fixed_std_x=fixed_std_x).to(device)
        for seed in seed_list:
            saved_file = os.path.join('saved_networks',
                    f'eiv_{short_dataname}'\
                            f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                            f'_p_{p:.2f}_fixed_std_x_{fixed_std_x:.3f}'\
                            f'_seed_{seed}.pkl')
            train_and_store.open_stored_training(saved_file=saved_file,
                    net=net, device=device)
            yield net
    else:
        init_std_y = noneiv_conf_dict["init_std_y_list"][0]
        unscaled_reg = noneiv_conf_dict["unscaled_reg"]
        p = noneiv_conf_dict["p"]
        hidden_layers = noneiv_conf_dict["hidden_layers"]
        net = Networks.FNNBer(p=p, init_std_y=init_std_y,
                h=[input_dim, *hidden_layers, output_dim]).to(device)
        for seed in seed_list:
            saved_file = os.path.join('saved_networks',
                        f'noneiv_{short_dataname}'\
                                f'_init_std_y_{init_std_y:.3f}_ureg_{unscaled_reg:.1f}'\
                                f'_p_{p:.2f}_seed_{seed}.pkl')
            train_and_store.open_stored_training(saved_file=saved_file,
                    net=net, device=device)
            yield net

# dataloaders
def dataloader_iterator(seed_list=seed_list, use_ground_truth=False,
        batch_size = 1000):
    for seed in seed_list:
        if not use_ground_truth:
            train_data, test_data = load_data(seed=seed)
            test_dataloader = DataLoader(test_data, 
                    batch_size=batch_size,
                    shuffle=True)
            yield test_dataloader
        else:
            assert ground_truth_exists
            _, _, _, true_test =\
                    load_data(seed=seed, return_ground_truth=True)
            # take noisy x but unnoisy y
            cut_true_test = VerticalCut(true_test, components_to_pick=[2,1])
            test_dataloader = DataLoader(cut_true_test, 
                    batch_size=batch_size,
                    shuffle=True)
            yield test_dataloader




eiv_numerical_coverage, eiv_theoretical_coverage = get_coverage_distribution(
        net_iterator=net_iterator(eiv=True),
        dataloader_iterator=dataloader_iterator(),
        device=device,
        number_of_draws=[100,5])
mean_eiv_theoretical_coverage = np.mean(eiv_theoretical_coverage, axis=1)
std_eiv_theoretical_coverage = np.std(eiv_theoretical_coverage, axis=1)
mean_eiv_numerical_coverage = np.mean(eiv_numerical_coverage, axis=1)
std_eiv_numerical_coverage = np.std(eiv_numerical_coverage, axis=1)
noneiv_numerical_coverage, noneiv_theoretical_coverage = get_coverage_distribution(
        net_iterator=net_iterator(eiv=False),
        dataloader_iterator=dataloader_iterator(),
        device=device,
        number_of_draws=100)
mean_noneiv_theoretical_coverage = np.mean(noneiv_theoretical_coverage, axis=1)
std_noneiv_theoretical_coverage = np.std(noneiv_theoretical_coverage, axis=1)
mean_noneiv_numerical_coverage = np.mean(noneiv_numerical_coverage, axis=1)
std_noneiv_numerical_coverage = np.std(noneiv_numerical_coverage, axis=1)
plt.plot(mean_eiv_theoretical_coverage, mean_eiv_numerical_coverage, color='r', label='EiV')
plt.fill_between(mean_eiv_theoretical_coverage, mean_eiv_numerical_coverage
        - std_eiv_numerical_coverage,
        mean_eiv_numerical_coverage + std_eiv_numerical_coverage, color='r', alpha=0.5)
plt.plot(mean_noneiv_theoretical_coverage, mean_noneiv_numerical_coverage, color='b', label='nonEiV')
plt.fill_between(mean_noneiv_theoretical_coverage, mean_noneiv_numerical_coverage
        - std_noneiv_numerical_coverage,
        mean_noneiv_numerical_coverage + std_noneiv_numerical_coverage, color='b', alpha=0.5)
diag_x = np.linspace(0, np.max(mean_eiv_numerical_coverage))
plt.plot(diag_x, diag_x, 'k--')
plt.show()