"""
Computes the coverage of observations by predictions and uncertainty and
similar quantities. This module contains three functions
- epistemic_coverage: Computes the coverage of observations by predictions and
  their epistemic uncertainty and the averaged theoretical ground truth. If
  `normalized` is set to True, residuals are normalized and a fixed interval
  length is considered.
- normalized_std: Computes the standard deviation of the normalized residuals.
"""
import numpy as np
import scipy.stats
import torch

def logical_and_along_dimension(x, dim=-1, keepdim=False):
    """
    For a boolean tensor `x` perform a logical AND along the axis `dim`.
    If `keepdim` is true the axis `dim` will be kept (defaults to False).
    """
    return torch.prod(x.to(torch.bool), dim=dim, keepdim=keepdim).to(torch.bool)


def multivariate_interval_length(dim, q=0.95):
    """
    Returns the half side length of multivariate cube, symmetrically centered
    around 0 such that its measure under a standard normal distribution equals
    `q`.
    :param dim: A non-negative integer, the dimension.
    :param q: Float, should be between 0 and 1. Defaults to 0.95.
    """
    # use independence of components to reduce to a univariate quantile
    univariate_quantile = 0.5 * (1+0.95**(1/dim))
    return scipy.stats.norm.ppf(univariate_quantile)


def epistemic_coverage(prediction_triple,  y, q=0.95, normalize_errors=False):
    """
    Returns the average coverage of `y` by the interval 
    "prefactor * (predictions + q-Interval)",
    where "q-Interval" is the interval of measure `q` under the standard normal, 
    "predictions" is the first component of `prediction_triple` and prefactor is either
    the epistemic uncertainty, given by the second component of `prediction_triple`, if
    `normalize_errors` is False, or 1 if it is true. The coverage is returned
    as given by `y` and as a theoretical_coverage computed from the epistemic
    uncertainty (second component of `prediction_triple`) and the aleatoric uncertainty
    (third component of `prediction_triple`)
    :param prediction_triple: A triple of tensors containing (in this order): the
    predictions of the neural net (the average under the posterior), the
    epistemic uncertainty (the standard deviation under the posterior) and
    the aleatoric uncertainty. All tensors are expected to have two dimensions:
    a batch and a feature dimension.
    :param y: A `torch.tensor` of the same shape then the first two components
    of `prediction_triple`. If the feature dimension is missing, it is added.
    :param q: A float between 0 and 1. Defaults to 0.95.
    :param normalize_errors: If True, the deviations between predictions and
    `y` are normalized by the total uncertainty, computed from the aleatoric
    and epistemic uncertainty and the coverage w.r.t. q-interval is computed.
    :returns: numerical_coverage, theoretical_coverage
    """
    mean, epis_unc, aleat_unc = prediction_triple
    assert epis_unc.shape == aleat_unc.shape
    assert mean.shape == epis_unc.shape
    # Add feature dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == mean.shape
    # fix interval based on epis_unc
    interval_length = multivariate_interval_length(dim=y.shape[1], q=q) \
            * epis_unc
    total_unc = torch.sqrt(epis_unc**2 + aleat_unc **2)
    # numerical computation
    errors = mean - y
    if normalize_errors:
        assert errors.shape == total_unc.shape
        errors /= total_unc
    check_if_in_interval = logical_and_along_dimension(
            torch.abs(errors) <= interval_length, dim=1)
    numerical_coverage = torch.mean(
            check_if_in_interval.to(torch.float32)
            ).cpu().detach().item()
    # theoretical computation
    if not normalize_errors:
        cdf_args = (interval_length/total_unc).detach().cpu().numpy()
        cdf_values = scipy.stats.norm.cdf(cdf_args)
        prob_values = 2*cdf_values -1
        assert len(cdf_values.shape) == 2
        theoretical_coverage = np.mean(np.prod(prob_values, axis=1)).item()
    else:
        theoretical_coverage = q
    return numerical_coverage, theoretical_coverage

def normalized_std(prediction_triple, y):
    """
    Returns the standard deviation of normalized residuals. In theory this
    number should be equal to 1.0.
    :param prediction_triple: A triple of tensors containing (in this order): the
    predictions of the neural net (the average under the posterior), the
    epistemic uncertainty (the standard deviation under the posterior) and
    the aleatoric uncertainty.
    :param y: A `torch.tensor` of the same shape then the first two components
    of `prediction_triple`. If the feature dimension is missing, it is added.
    :returns: numerical_coverage, theoretical_coverage
    """
    mean, epis_unc, aleat_unc = prediction_triple
    assert epis_unc.shape == aleat_unc.shape
    assert mean.shape == epis_unc.shape
    # Add feature dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert y.shape == mean.shape
    total_unc = torch.sqrt(epis_unc**2 + aleat_unc **2)
    # numerical computation
    errors = mean - y
    assert errors.shape == total_unc.shape
    errors /= total_unc
    return torch.mean(torch.std(errors, dim=0)).item()