"""
Implements fully connected neural networks with, and without,
Errors-in-Variables (EiV) included. This module contains two classes
- FNNEIV: A fully connected neural networks with EiV input
- FNNBer: A fully connected neural networks without EiV input
Both classes have 4 hidden layers and Bernoulli Dropout in between.
"""
import torch
import torch.nn as nn
from EIVArchitectures.Layers import EIVInput, EIVDropout 
from EIVGeneral.repetition import repeat_tensors, reshape_to_chunks

class FNNEIV(nn.Module):
    """
    A fully connected net with Error-in-Variables input and Bernoulli dropout
    layers. 
    :param p: dropout rate, defaults to 0.2
    :param init_std_y: Initial estimated standard deviation for y.
    :param precision_prior_zeta: precision of the prior for zeta.
    Defaults to 0.0 (=improper prior)
    :param deming: Will be used as a coupling factor between std_y and std_x
    (the Deming regression), that is std_x = deming * std_y, unless
    `fixed_std_x` is different from `None`.
    :param h: A list specifying the number of neurons in each layer.
    :param fixed_std_x: If given, this value will be the output of the method
    `get_std_x()`, no matter what the deming factor.
    :param repetition: Positive integer, the default value for repeating input,
    defaults to 1.  For a single call this can also be specified in the forward
    method.
    :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
    be updated during optimization. Defaults to False.
    **Note**: 
    - To change the deming factor afterwards, use the method `change_deming`
    - To change fixed_std_x afterwards, use the method `change_fixed_std_x`
    - To change std_y use the method `change_std_y`
    """
    LeakyReLUSlope = 1e-2
    def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0, 
            deming=1.0, h=[10, 1024,1024,1024,1024, 1], 
            fixed_std_x = None, repetition = 1, std_y_requires_grad = False):
        super().__init__()
        # part before Bernoulli dropout
        self.init_std_y = init_std_y
        self.InverseSoftplus = lambda sigma: torch.log(torch.exp(sigma) - 1 )
        self.std_y_par = nn.parameter.Parameter(
                self.InverseSoftplus(torch.tensor([init_std_y])))
        self.std_y_par.requires_grad = std_y_requires_grad
        self._repetition = repetition
        self.main = nn.Sequential(
                EIVInput(precision_prior_zeta=precision_prior_zeta, 
                    external_std_x=self.get_std_x),
                nn.Linear(h[0], h[1]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                EIVDropout(p=p, repetition_map=self._repetition_map),
                nn.Linear(h[1],h[2]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                EIVDropout(p=p, repetition_map=self._repetition_map),
                nn.Linear(h[2],h[3]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                EIVDropout(p=p, repetition_map=self._repetition_map),
                nn.Linear(h[3],h[4]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                EIVDropout(p=p, repetition_map=self._repetition_map),
                nn.Linear(h[4], h[5]))
        self.p = p
        self._deming = deming
        if fixed_std_x is not None:
            if type(fixed_std_x) is not torch.tensor:
                fixed_std_x = torch.tensor(fixed_std_x)
        self._fixed_std_x = fixed_std_x
        # needed for switch_noise_off()
        self.noise_is_on = True


    def change_deming(self, deming):
        """
        Update deming factor to `deming`
        :param deming: A positive float
        """
        print('Updating deming from %.3f to %.3f' % (self._deming, deming))
        self._deming = deming

    def change_fixed_std_x(self, fixed_std_x):
        """
        Update internal fixed_std_x to `fixed_std_x`
        :param fixed_std_x: A positive float
        """
        print('Updating fixed_std_x from %.3f to %.3f' % (self._fixed_std_x, fixed_std_x))
        if fixed_std_x is not None:
            if type(fixed_std_x) is not torch.tensor:
                fixed_std_x = torch.tensor(fixed_std_x)
        self._fixed_std_x = fixed_std_x

    def change_std_y(self, std_y):
        """
        Update internal std_y to `std_y`
        :param std_y: A singular, positive torch.tensor
        """
        assert std_y.numel() == 1
        std_y = std_y.view((1,))
        print('Updating std_y from %.3f to %.3f' % (self.get_std_y().item(),
            std_y.item()))
        self.std_y_par.data = self.InverseSoftplus(std_y)


    def noise_off(self):
        self.noise_is_on = False

    def noise_on(self):
        self.noise_is_on = True

    def sigma(self, y):
        scalar_sigma = self.get_std_y()
        return scalar_sigma.repeat(y.shape)

    def get_std_x(self):
        if self.noise_is_on:
            if self._fixed_std_x is None:
                return self._deming * self.get_std_y() 
            else:
                return self._fixed_std_x
        else:
            return torch.tensor(0.0, dtype=torch.float32)

    def get_std_y(self):
        return nn.Softplus()(self.std_y_par)

    def _repetition_map(self):
        return self._repetition

    def forward(self, x, repetition=1):
        old_repetition = self._repetition
        self._repetition = repetition
        mu = self.main(x)
        sigma = self.sigma(mu)
        self._repetition = old_repetition
        return mu, sigma

    def regularizer(self, x, lamb):
        """
        Regularization for EIV net: prior KL term, 
        from "Bernoulli Dropout" by Gal et al., plus the regularizer term 
        from the EIV model (which is constant if precision_prior_zeta is 0).
        :param x: A torch.tensor, the input
        :param lamb: float, prefactor for regularization
        """
        regularization = 0
        p = torch.tensor(self.p)
        last_Dropout_layer = None
        for i, layer in enumerate(self.main):
            if type(layer) is not EIVDropout and last_Dropout_layer != i-1:
                for par in layer.parameters():
                    regularization += lamb*(par**2).sum().view((-1,))
            elif type(layer) is EIVDropout:
                next_layer = self.main[i+1]
                assert type(next_layer) is nn.Linear
                regularization += lamb*(next_layer.bias**2).sum().view((-1,))
                regularization += lamb/(1-p) \
                        * (next_layer.weight**2).sum().view((1,))
                last_Dropout_layer = i
            else:
                pass
        # entropy actually not needed here, added for completeness
        entropy = -1 * (p * torch.log(p) + (1-p) * torch.log(1-p))
        regularization += entropy
        if self._deming > 0 and self._fixed_std_x is None:
                # add EIV regularization term 
                # (constant if precision_prior_zeta is 0)
                regularization += self.main[0].neg_x_evidence(x)
        if self._fixed_std_x is not None:
            if self._fixed_std_x > 0:
                # add EIV regularization term 
                # (constant if precision_prior_zeta is 0)
                regularization += self.main[0].neg_x_evidence(x)
        return regularization 

    def predict(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None,
            remove_graph=True
            , take_average_of_prediction=True):
        """
        Average over `number_of_draws` forward passes. If 
        `take_average_of_prediction` is False, the averaging is skipped and
        all forward passes are returned.
        **Note**: This method does neither touch the input noise nor Dropout.
        The corresponding setting is left to the user!
        :param x: A torch.tensor, the input
        :param number_of_draws: An integer or a list. If an integer
        `number_of_draws`, will be converted internally to
        `[number_of_draws,1]`.Numbers of draws to obtain from x via parameter
        sampling (first element) and noise input sampling (second element).
        :param number_of_parameter_chunks: An integer or None (default). If
        None, the second element of `number_of_draws` will be taken (and will
        thus be identical to 1 if `number_of_draws` is an integer, see above).
        Samples in the parameter space will be divided into
        `number_of_parameter_chunks` chunks when collected. Can be used to
        reduced the memory usage. 
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :param take_average_of_prediction: If False, no averaging will be 
        applied to the prediction and the second dimension of the first output
        will count the number_of_draws.
        """
        # x, = repeat_tensors(x, number_of_draws=number_of_draws)
        if type(number_of_draws) is int:
            number_of_draws = [number_of_draws, 1]
        if number_of_parameter_chunks is None:
            number_of_parameter_chunks = number_of_draws[1]
        chunk_size = number_of_draws[0] // number_of_parameter_chunks
        pred_collection, sigma_collection = [], []
        remaining_draws = number_of_draws[0]
        while remaining_draws > 0:
            if remaining_draws < chunk_size:
                parameter_sample_size = remaining_draws
            else:
                parameter_sample_size = chunk_size
            repeated_x, = repeat_tensors(x, number_of_draws=parameter_sample_size * number_of_draws[1])
            pred, sigma = self.forward(repeated_x, repetition=number_of_draws[1])
            if remove_graph:
                pred, sigma = pred.detach(), sigma.detach()
            pred, sigma = reshape_to_chunks(pred, sigma, 
                    number_of_draws=parameter_sample_size * number_of_draws[1])
            pred_collection.append(pred)
            sigma_collection.append(sigma)
            remaining_draws -= parameter_sample_size
        pred = torch.cat(pred_collection, dim=1)
        sigma = torch.cat(sigma_collection, dim=1)
        # reduce along the draws (actually redundant for sigma)
        if take_average_of_prediction:
            pred, sigma = torch.mean(pred, dim=1), torch.mean(sigma, dim=1)
        else: 
            sigma = torch.mean(sigma, dim=1)
        return pred, sigma

    def predictive_logdensity(self, x_or_predictions, y,
            number_of_draws=[100, 5], number_of_parameter_chunks = None,
            remove_graph=True, average_batch_dimension=True, scale_labels=None,
            decouple_dimensions=False):
        """
        Computes the logarithm of the predictive density evaluated at `y`. If
        `average_batch_dimension` is `True` these values will be averaged over
        the batch dimension.
        :param x_or_predictions: Either a torch.tensor or a tuple. If
        `x_or_predictions' is a tensor, it will be used as input. If it is a
        tuple, it will be interpreted as the output of `predict` with `take_average_of_prediction` set to False.
        :param y: A torch.tensor, labels on which to evaluate the density
        :param number_of_draws: An integer or a list. If an integer
        `number_of_draws`, will be converted internally to
        `[number_of_draws,1]`.Numbers of draws to obtain from x via parameter
        sampling (first element) and noise input sampling (second element).
        :param number_of_parameter_chunks: An integer or None (default). If
        None, the second element of `number_of_draws` will be taken (and will
        thus be identical to 1 if `number_of_draws` is an integer, see above).
        Samples in the parameter space will be divided into
        `number_of_parameter_chunks` chunks when collected. Can be used to
        reduced the memory usage. 
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :param average_batch_dimension: Boolean. If True (default) the values
        will be averaged over the batch dimension. If False, the batch
        dimension will be left untouched and all values will be returned.
        :scale_labels: If not None (the default), scale labels in evaluation to
        make result comparable with the literature. 
        :decouple_dimensions: If True, treat dimensions seperate and finally
        average, to make results comparable with the literature. Defaults to
        False.
        """
        if type(x_or_predictions) is torch.tensor:
            out, sigmas = self.predict(x_or_predictions,
                    number_of_draws=number_of_draws,
                number_of_parameter_chunks=number_of_parameter_chunks,
                remove_graph=remove_graph,
                take_average_of_prediction=False)
        else:
            out, sigmas = x_or_predictions
        # Add "repetition" dimension to y and sigmas
        y = y[:,None,...]
        sigmas = sigmas[:,None,...]
        # add an output axis if necessary
        if len(y.shape) <= 2:
            y = y[...,None]
        if len(sigmas.shape) <= 2:
            sigmas = sigmas[...,None]
        # squeeze last dimensions into one
        y = y.view((*y.shape[:2], -1))
        sigmas = sigmas.view((*sigmas.shape[:2], -1))
        out = out.view((*out.shape[:2], -1))
        # check if dimensions consistent
        assert y.shape == sigmas.shape
        assert y.shape[0] == out.shape[0]
        assert y.shape[2] == out.shape[2]
        if scale_labels is not None:
            extended_scale_labels = scale_labels.flatten()[None,None,:]
            out = out * extended_scale_labels
            y = y * extended_scale_labels
            sigmas = sigmas * extended_scale_labels
        # exponential argument for density
        if not decouple_dimensions:
            exp_arg =  torch.sum(-1/(2*sigmas**2) * (y-out)**2-\
                        1/2 * torch.log(2 * torch.pi * sigmas**2), dim=2)
        else:
            exp_arg =  -1/(2*sigmas**2) * (y-out)**2-\
                            1/2 * torch.log(2 * torch.pi * sigmas**2)
        # average over parameter values
        predictive_log_density_values = \
                torch.logsumexp(input=exp_arg, dim=1)\
                    - torch.log(torch.prod(torch.tensor(number_of_draws))) 
        if average_batch_dimension:
            return torch.mean(predictive_log_density_values, dim=0)
        else:
            return predictive_log_density_values

    def predict_mean_and_unc(self, x, number_of_draws=[100,5],
            number_of_parameter_chunks = None,
            remove_graph=True):
        """
        Take the mean and standard deviation over `number_of_draws` forward
        passes and return them together with the predicted sigmas.
        **Note**: This method does neither touch the input noise nor Dropout.
        The corresponding setting is left to the user!
        :param x: A torch.tensor, the input
        :param number_of_draws: An integer or a list. If an integer
        `number_of_draws`, will be converted internally to
        `[number_of_draws,1]`.Numbers of draws to obtain from x via parameter
        sampling (first element) and noise input sampling (second element).
        :param number_of_parameter_chunks: An integer or None (default). If
        None, the second element of `number_of_draws` will be taken (and will
        thus be identical to 1 if `number_of_draws` is an integer, see above).
        Samples in the parameter space will be divided into
        `number_of_parameter_chunks` chunks when collected. Can be used to
        reduced the memory usage. 
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :return: mean, std, sigmas
        """
        out, sigmas = self.predict(x=x,
                number_of_draws=number_of_draws,
                number_of_parameter_chunks=number_of_parameter_chunks,
                remove_graph=remove_graph,
                take_average_of_prediction=False)
        mean = torch.mean(out, dim=1)
        std = torch.std(out, dim=1)
        return mean, std, sigmas

        


class FNNBer(nn.Module):
    """
    A fully connected net Bernoulli dropout layers.
    :param p: dropout rate, defaults to 0.5
    :param init_std_y: Initial standard deviation for input y. 
    :param h: A list specifying the number of neurons in each layer.
    :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
    be updated during optimization. Defaults to False.
    To change std_y use the method `change_std_y`
    """
    LeakyReLUSlope = 1e-2
    def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1],
            std_y_requires_grad=False):
        super().__init__()
        # part before Bernoulli dropout
        self.init_std_y = init_std_y
        self.InverseSoftplus = lambda sigma: torch.log(torch.exp(sigma) - 1 )
        self.std_y_par = nn.parameter.Parameter(
                self.InverseSoftplus(torch.tensor([init_std_y])))
        self.std_y_par.requires_grad = std_y_requires_grad
        self.main = nn.Sequential(
                nn.Linear(h[0], h[1]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[1],h[2]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[2],h[3]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[3],h[4]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[4], h[5]))
        self.p = p

    def sigma(self, y):
        scalar_sigma = self.get_std_y()
        return scalar_sigma.repeat(y.shape)

    def get_std_y(self):
        return nn.Softplus()(self.std_y_par)

    def change_std_y(self, std_y):
        """
        Update internal std_y to `std_y`
        :param std_y: A singular, positive torch.tensor
        """
        assert std_y.numel() == 1
        std_y = std_y.view((1,))
        print('Updating std_y from %.3f to %.3f' % (self.get_std_y().item(),
            std_y.item()))
        self.std_y_par.data = self.InverseSoftplus(std_y)

    def forward(self, x):
        mu = self.main(x)
        sigma = self.sigma(mu)
        return mu, sigma

    def regularizer(self, x, lamb):
        """
        Regularization (prior KL term), from "Bernoulli Dropout" by Gal et al.
        :param x: A torch.tensor, the input
        :param lamb: float, prefactor for regularization
        """
        regularization = 0
        p = torch.tensor(self.p)
        last_Dropout_layer = None
        for i, layer in enumerate(self.main):
            if type(layer) is not nn.Dropout and last_Dropout_layer != i-1:
                for par in layer.parameters():
                    regularization += lamb*(par**2).sum().view((-1,))
            elif type(layer) is nn.Dropout:
                next_layer = self.main[i+1]
                assert type(next_layer) is nn.Linear
                regularization += lamb*(next_layer.bias**2).sum().view((-1,))
                regularization += lamb/(1-p) \
                        * (next_layer.weight**2).sum().view((1,))
                last_Dropout_layer = i
            else:
                pass
        # entropy actually not needed here, added for completeness
        entropy = -1 * (p * torch.log(p) + (1-p) * torch.log(1-p))
        regularization += entropy
        return regularization 


    def predict(self, x, number_of_draws=100, remove_graph=True,
            take_average_of_prediction=True):
        """
        Average over `number_of_draws` forward passes. If 
        `take_average_of_prediction` is False, the averaging is skipped and
        all forward passes are returned.
        **Note**: This method does not touch the Dropout.
        The corresponding setting is left to the user! (analogous to
        the corresponding method for FNNEIV)
        :param x: A torch.tensor, the input
        :param number_of_draws: Number of draws to obtain from x
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :param take_average_of_prediction: If False, no averaging will be 
        applied to the prediction and the second dimension of the first output 
        will count the number_of_draws.
        :returns: predictions, sigmas
        """
        x, = repeat_tensors(x, number_of_draws=number_of_draws)
        pred, sigma = self.forward(x)
        if remove_graph:
            pred, sigma = pred.detach(), sigma.detach()
        pred, sigma = reshape_to_chunks(pred, sigma, 
                number_of_draws=number_of_draws)
        # reduce along the draws (actually redundant for sigma)
        if take_average_of_prediction:
            pred, sigma = torch.mean(pred, dim=1), torch.mean(sigma, dim=1)
        else: 
            sigma = torch.mean(sigma, dim=1)
        return pred, sigma

    def predictive_logdensity(self, x_or_predictions, y, number_of_draws=100, remove_graph=True,
            average_batch_dimension=True, scale_labels=None,
            decouple_dimensions=False):
        """
        Computes the logarithm of the predictive density evaluated at `y`. If
        `average_batch_dimension` is `True` these values will be averaged over
        the batch dimension.
        :param x_or_predictions: Either a torch.tensor or a tuple. If
        `x_or_predictions' is a tensor, it will be used as input. If it is a
        tuple, it will be interpreted as the output of `predict` with `take_average_of_prediction` set to False.
        :param y: A torch.tensor, labels on which to evaluate the density
        :param number_of_draws: Number of draws to obtain from x
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :param average_batch_dimension: Boolean. If True (default) the values
        will be averaged over the batch dimension. If False, the batch
        dimension will be left untouched and all values will be returned.
        :scale_labels: If not None (the default), scale labels in evaluation to
        make result comparable with the literature. 
        :decouple_dimensions: If True, treat dimensions seperate and finally
        average, to make results comparable with the literature. Defaults to
        False.
        """
        if type(x_or_predictions) is torch.tensor:
            out, sigmas = self.predict(x_or_predictions,
                    number_of_draws=number_of_draws,
                    take_average_of_prediction=False, remove_graph=remove_graph)
        else:
            out, sigmas = x_or_predictions
        # Add "repetition" dimension to y and sigmas
        y = y[:,None,...]
        sigmas = sigmas[:,None,...]
        if len(y.shape) <= 2:
            # add an output axis if necessary
            y = y[...,None]
        if len(sigmas.shape) <= 2:
            # add an output axis if necessary
            sigmas = sigmas[...,None]
        # squeeze last dimensions into one
        y = y.view((*y.shape[:2], -1))
        sigmas = sigmas.view((*sigmas.shape[:2], -1))
        out = out.view((*out.shape[:2], -1))
        # check if dimensions consistent
        assert y.shape == sigmas.shape
        assert y.shape[0] == out.shape[0]
        assert y.shape[2] == out.shape[2]
        if scale_labels is not None:
            extended_scale_labels = scale_labels.flatten()[None,None,:]
            out = out * extended_scale_labels
            y = y * extended_scale_labels
            sigmas = sigmas * extended_scale_labels
        # exponential argument for density
        if not decouple_dimensions:
            exp_arg =  torch.sum(-1/(2*sigmas**2) * (y-out)**2-\
                        1/2 * torch.log(2 * torch.pi * sigmas**2), dim=2)
        else:
            exp_arg =  -1/(2*sigmas**2) * (y-out)**2-\
                            1/2 * torch.log(2 * torch.pi * sigmas**2)
        # average over parameter values
        predictive_log_density_values = \
                torch.logsumexp(input=exp_arg, dim=1)\
                    - torch.log(torch.tensor(number_of_draws)) 
        if average_batch_dimension:
            return torch.mean(predictive_log_density_values, dim=0)
        else:
            return predictive_log_density_values

    def predict_mean_and_unc(self, x, number_of_draws=100, 
            remove_graph=True):
        """
        Take the mean and standard deviation over `number_of_draws` forward
        passes and return them together with the predicted sigmas.
        **Note**: This method does not touch the Dropout.
        The corresponding setting is left to the user!
        :param x: A torch.tensor, the input
        :param number_of_draws: An integer or a list. If an integer
        `number_of_draws`, will be converted internally to
        `[number_of_draws,1]`.Numbers of draws to obtain from x via parameter
        sampling (first element) and noise input sampling (second element).
        :param remove_graph: If True (default) the output will 
        be detached to save memory
        :return: mean, std, sigmas
        """
        out, sigmas = self.predict(x=x,
                number_of_draws=number_of_draws,
                remove_graph=remove_graph,
                take_average_of_prediction=False)
        mean = torch.mean(out, dim=1)
        std = torch.std(out, dim=1)
        return mean, std, sigmas



class SmallFNNBer(FNNBer):
    """
    A fully connected net Bernoulli dropout layers.
    :param p: dropout rate, defaults to 0.5
    :param init_std_y: Initial standard deviation for input y. 
    :param h: A list specifying the number of neurons in each layer.
    :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
    be updated during optimization. Defaults to False.
    """
    def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024, 1],
           std_y_requires_grad=False):
        super().__init__(p=p, init_std_y=init_std_y,
                std_y_requires_grad=std_y_requires_grad)
        self.main = nn.Sequential(
                nn.Linear(h[0], h[1]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[1],h[2]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[2],h[3]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[3],h[4]))

class ShallowFNNBer(FNNBer):
    """
    A fully connected net Bernoulli dropout layers.
    :param p: dropout rate, defaults to 0.5
    :param init_std_y: Initial standard deviation for input y. 
    :param h: A list specifying the number of neurons in each layer.
    """
    def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024, 1]):
        super().__init__(p=p, init_std_y=init_std_y)
        self.main = nn.Sequential(
                nn.Linear(h[0], h[1]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[1],h[2]),
                nn.LeakyReLU(self.LeakyReLUSlope),
                nn.Dropout(p=p),
                nn.Linear(h[2],h[3]))