from math import pi
import numpy as np
import torch
import torch.nn as nn
from EIVGeneral import repetition


class EIVDropout(nn.Module):
    """
    A Dropout Layer with Dropout probability `p` (default 0.5) that repeats the
    same Bernoulli mask L times along the batch dimension - instead of taking a
    different one for each batch member - where L is the output of
    repetition_map(). When evaluation `forward(x)` the batch dimension of `x`
    (the first one) is asserted to be a multiple of `L=repetition_map()`. The
    `repetition_map` defaults to the constant map to 1, in which case
    `EIVDropout` is equivalent to `torch.nn.Dropout`.  
    :param p: A float between 0 and 1. Defaults to 0.5.  
    :param repetition_map: Map that takes no
    argument and returns an integer. Defaults to the constant map to 1. 
    """
    def __init__(self, p=0.5, repetition_map=lambda: 1):
        super().__init__()
        self.p = p
        self.repetition_map = repetition_map
        self._train = True

    def train(self, training=True):
        if training:
            self._train = True
        else:
            self._train = False

    def eval(self):
        self.train(training=False)

    def forward(self, x):
        if not self._train:
            return x
        else: 
            device = self.study_device(x)
            L = self.repetition_map()
            input_shape = x.shape 
            assert input_shape[0] % L == 0
            mask_shape = list(input_shape)
            mask_shape[0] = int(input_shape[0] / L)
            mask = torch.bernoulli(torch.ones(mask_shape) * (1-self.p))\
                    / (1-self.p)
            repeated_mask = repetition.repeat_tensors(mask,
                    number_of_draws=L)[0]
            assert x.shape == repeated_mask.shape
            return x * repeated_mask.to(device)

    @staticmethod
    def study_device(x):
        if x.is_cuda:
            return torch.device('cuda:' + str(x.get_device()))
        else:
            return torch.device('cpu')
        
class EIVVariationalDropout(nn.Module):
    """
    A Variational Dropout Layer (of Type A, as in Kingma et al.) with Dropout
    probability `p`
    (default 0.5) that repeats the same Bernoulli mask L times along the batch
    dimension - instead of taking a
    different one for each batch member - where L is the output of
    repetition_map(). When evaluation `forward(x)` the batch dimension of `x`
    (the first one) is asserted to be a multiple of `L=repetition_map()`. The
    `repetition_map` defaults to the constant map to 1, in which case
    `EIVDropout` is equivalent to classical variational Dropout.
    :param initial_alpha: A positive float, will be taken as initial value for
    alpha (the variance of the Gaussian dropout mask).
    :param repetition_map: Map that takes no
    argument and returns an integer. Defaults to the constant map to 1. 
    """
    c_1 = 1.16145124
    c_2 = -1.50204118
    c_3 = 0.58629921
    def __init__(self, initial_alpha=0.5, repetition_map=lambda: 1):
        super().__init__()
        self.initial_alpha = initial_alpha
        self.repetition_map = repetition_map
        initial_alpha_par = self.invert_softplus(self.initial_alpha)
        self.alpha_par = nn.Parameter(torch.tensor(initial_alpha_par))
        self._train = True

    def train(self, training=True):
        if training:
            self._train = True
        else:
            self._train = False

    def eval(self):
        self.train(training=False)

    @staticmethod
    def invert_softplus(softplus_value):
        """
        Inverts the Softplus function
        :param softplus_value: A float
        """
        return np.log(np.exp(softplus_value)-1)

    def alpha(self):
        return nn.Softplus()(self.alpha_par)

    def forward(self, x):
        if not self._train:
            return x
        else: 
            device = self.study_device(x)
            L = self.repetition_map()
            input_shape = x.shape 
            assert input_shape[0] % L == 0
            mask_shape = list(input_shape)
            mask_shape[0] = int(input_shape[0] / L)
            mask = 1.0  + torch.sqrt(self.alpha()) * \
                torch.randn(mask_shape).to(device)
            repeated_mask = repetition.repeat_tensors(mask,
                    number_of_draws=L)[0]
            assert x.shape == repeated_mask.shape
            return x * repeated_mask.to(device)

    @staticmethod
    def study_device(x):
        if x.is_cuda:
            return torch.device('cuda:' + str(x.get_device()))
        else:
            return torch.device('cpu')
    

    def variational_dropout_regularizer(self):
        """
        Taken from Kingma et al. Identical, up to a constant, to the neg.
        KL-div of the variational distribution and the improper prior.
        """
        alpha = self.alpha()
        return 0.5 * torch.log(alpha)\
                + self.c_1 * alpha + self.c_2 * alpha**2 + self.c_3 * alpha**3


class EIVInput(nn.Module):
    """
    Input class for an Error-in-Variables model based on variational inference.
    For regularization there is a method EIVInput.x_evidence included. The
    underlying model is x = zeta + epsilon, where zeta denotes the true (but
    unknown) input variable and epsilon denotes normal noise with standard
    deviation std_x.  
    :param init_std_x: The initial value of sigma_x (std of
    noise on input), will be made a learnable parameter. Defaults to 1.0
    :param precision_prior_zeta: The precision of the prior on zeta (the true,
    but unknown input). Defaults to 0 (= flat prior)
    :param external_std_x: If not None, should be a map that returns (without
    arguments) std_x (i.e. as the output of self.get_std_x). If None (the
    Default) will be parametrized via a Softmax of a torch.nn.Parameter.
    :param std_x_scale: Will be used for scaling std_x internally. Defaults
    to 1.0.
    """
    def __init__(self, init_std_x = 1.0, precision_prior_zeta=0.0,
            external_std_x=None, std_x_scale=1.0):
        super(EIVInput, self).__init__()
        self.init_std_x = init_std_x
        self.precision_prior_zeta = precision_prior_zeta
        self.external_std_x = external_std_x
        InverseSoftplus = lambda sigma: torch.log(torch.exp(sigma) - 1 )
        self.std_x_scale = std_x_scale
        if external_std_x is None:
            self.std_x_par = nn.Parameter(
                    InverseSoftplus(
                        torch.tensor([init_std_x]))/self.std_x_scale)
        self.noise = True
        # self.fixed_z will be used as "noise" if self.noise is False
        self.fixed_z = 0.0

    def get_std_x(self):
        """
        Returns the standard deviation of the distribution of x given zeta.
        """
        if self.external_std_x is None:
            return nn.Softplus()(self.std_x_par * self.std_x_scale)
        else:
            return self.external_std_x()


    def forward(self, x):
        std_x = self.get_std_x()
        if std_x.item() > 0:
            std_zeta = 1/(1/std_x**2 + self.precision_prior_zeta)**0.5
            mean_zeta = std_zeta**2/std_x**2 * x
        else:
            assert std_x.item() == 0
            std_zeta = 0.0
            mean_zeta = x
        if self.noise:
            z = torch.randn_like(x)
            return mean_zeta + std_zeta * z
        else: 
            return mean_zeta + std_zeta * self.fixed_z

    def neg_x_evidence(self, x, remove_divergent_term=True):
        """
        Returns the value of x under the prior predictive of the EIV model
        :param x: A torch.tensor
        :param remove_divergent_term: If True (default) the constant term, that
        diverges for a float prior will be removed.
        """
        if not remove_divergent_term:
            if self.precision_prior_zeta == 0:
                raise ValueError('Divergent term infinite in EIV model')
            else:
                divergent_term = 0.5 * torch.log(self.precision_prior_zeta)
        else:
            divergent_term = 0
        batch_size = x.shape[0]
        regularization = 0.5/batch_size\
                * self.precision_prior_zeta/(
                        self.precision_prior_zeta * self.get_std_x()**2 + 1) \
                            * torch.sum(x**2)
        regularization += 0.5 * torch.log(2 * pi * 
                (self.get_std_x()**2 * self.precision_prior_zeta + 1) ) 
        regularization += divergent_term
        return  regularization