"""
Computes the coverage of observations by predictions and uncertainty and
similar quantities. This module contains three functions
- epistemic_coverage: Computes the coverage of observations by predictions and
  their epistemic uncertainty and the averaged theoretical ground truth. If
  `normalized` is set to True, residuals are normalized and a fixed interval
  length is considered.
- normalized_std: Computes the standard deviation of the normalized residuals.
"""
import numpy as np
import scipy.stats
import torch

def logical_and_along_dimension(x, dim=-1, keepdim=False):
    """
    For a boolean tensor `x` perform a logical AND along the axis `dim`.
    If `keepdim` is true the axis `dim` will be kept (defaults to False).
    """
    return torch.prod(x.to(torch.bool), dim=dim, keepdim=keepdim).to(torch.bool)


def multivariate_interval_length(dim, q=0.95):
    """
    Returns the half side length of multivariate cube, symmetrically centered
    around 0 such that its measure under a standard normal distribution equals
    `q`.
    :param dim: A positive integer, the dimension.
    :param q: Float, should be between 0 and 1. Defaults to 0.95.
    """
    # use independence of components to reduce to a univariate quantile
    univariate_quantile = 0.5 * (1+q**(1/dim))
    return scipy.stats.norm.ppf(univariate_quantile)


def epistemic_coverage(not_averaged_predictions,  y, q=0.95,
        normalize_errors=False,
        noisy_y=True):
    """
    Returns the average coverage of `y` by the interval 
    "predictions + prefactor * q-Interval", where 
    - "q-Interval" is the interval of measure `q` under the standard normal, 
    - "predictions" are the entries of the first component of the tuple
      `not_averaged_predictions` averaged over their second dimension.
    - "prefactor either equals the epistemic uncertainty, computed from the
      first component of `not_averaged_predictions`,if
    `normalize_errors` is set to False, or 1 if it is true. 
    The coverage is returned as given by the `y` and as a theoretical_coverage
    computed from the epistemic uncertainty and the aleatoric uncertainty
    (second component of `not_averaged_predictions`).
    ** Note **: If `noisy_y` is True, the `y` will be treated as the unnoisy
    ground truth. If it is False (default) it will be treated as noisy. For
    real data only use the later. 
    :param not_averaged_predictions: A tuple of tensors as in the output of
    `FNNEIV.predict` with `take_average_of_prediction` set to `False`, i.e.:
    the predictions of the neural net not averaged over the first dimension 
    (the repetition dimension in `FNNEIV.predict`) and
    the aleatoric uncertainty with a batch dimension and a feature dimension. 
    :param y: A `torch.tensor` of the same shape then the second components
    of `not_averaged_predictions`. If the feature dimension is missing, it is added.
    :param q: A float between 0 and 1. Defaults to 0.95.
    :param normalize_errors: If True, the deviations between predictions and
    `y` are normalized by the total uncertainty, computed from the aleatoric
    and epistemic uncertainty and the coverage w.r.t. q-interval is computed.
    :param noisy_y: Boolean. If True (the default), `y` is treated as noisy and
    the total uncertainty is considered. If False, `y` is treated as the
    unnoisy ground truth.
    :returns: numerical_coverage, theoretical_coverage
    """
    out, sigmas = not_averaged_predictions
    # add an output axis if necessary
    if len(y.shape) <= 1:
        y = y[...,None]
    if len(sigmas.shape) <= 1:
        sigmas = sigmas[...,None]
    # squeeze last dimensions into one
    y = y.view((y.shape[0], -1))
    sigmas = sigmas.view((sigmas.shape[0], -1))
    out = out.view((*out.shape[:2], -1))
    # check if dimensions are consistent
    # compute epistemic uncertainty
    epis_unc = torch.std(out, dim=1)
    out = torch.mean(out, dim=1)
    assert y.shape == sigmas.shape
    assert y.shape == out.shape
    assert epis_unc.shape == sigmas.shape
    # compute total uncertainty
    if noisy_y:
        total_unc = torch.sqrt(epis_unc**2 + sigmas **2)
    else:
        # for unnoisy y, the aleatoric uncertainty is treated as 0
        total_unc = epis_unc
    # fix interval based on epis_unc
    out_dim = y.shape[1]
    if not normalize_errors:
        interval_length = multivariate_interval_length(dim=out_dim, q=q) \
                * epis_unc
    else:
        interval_length = multivariate_interval_length(dim=out_dim, q=q) 
    # numerical computation
    errors = out - y
    if normalize_errors:
        assert errors.shape == total_unc.shape
        errors /= total_unc
    check_if_in_interval = logical_and_along_dimension(
            torch.abs(errors) <= interval_length, dim=1)
    numerical_coverage = torch.mean(
            check_if_in_interval.to(torch.float32)
            ).cpu().detach().item()
    # theoretical computation
    if not normalize_errors:
        cdf_args = (interval_length/total_unc).detach().cpu().numpy()
        cdf_values = scipy.stats.norm.cdf(cdf_args)
        prob_values = 2*cdf_values -1
        assert len(cdf_values.shape) == 2
        # take product over feature dimension 
        # and average over batch dimension
        theoretical_coverage = np.mean(np.prod(prob_values, axis=1)).item()
    else:
        theoretical_coverage = q
    return numerical_coverage, theoretical_coverage

def normalized_std(not_averaged_predictions, y):
    """
    Returns the standard deviation of normalized residuals, averaged over the
    feature dimension. In theory this
    number should be equal to 1.0.
    :param not_averaged_predictions: A triple of tensors containing (in this order): the
    predictions of the neural net (the average under the posterior), the
    epistemic uncertainty (the standard deviation under the posterior) and
    the aleatoric uncertainty.
    :param y: A `torch.tensor` of the same shape then the first two components
    of `not_averaged_predictions`. If the feature dimension is missing, it is added.
    :returns: numerical_coverage, theoretical_coverage
    """
    out = not_averaged_predictions[0]
    sigmas = not_averaged_predictions[1]
    # add repetition dimension
    y = y[:,None,...]
    sigmas = sigmas[:,None,...]
    # add an output axis if necessary
    if len(y.shape) <= 2:
        y = y[...,None]
    if len(sigmas.shape) <= 2:
        sigmas = sigmas[...,None]
    # squeeze last dimensions into one
    y = y.view((*y.shape[:2], -1))
    sigmas = sigmas.view((*sigmas.shape[:2], -1))
    out = out.view((*out.shape[:2], -1))
    # check if dimensions consistent
    assert y.shape == sigmas.shape
    assert y.shape[0] == out.shape[0]
    assert y.shape[2] == out.shape[2]
    # compute epistemic uncertainty
    epis_unc = torch.std(out, dim=1, keepdim=True)
    assert epis_unc.shape == sigmas.shape
    total_unc = torch.sqrt(epis_unc**2 + sigmas **2)
    # numerical computation
    errors = out - y
    errors /= total_unc
    # average over feature dimension
    return torch.mean(torch.std(errors, dim=(0,1))).item()