from math import pi import numpy as np import torch import torch.nn as nn from EIVGeneral import repetition class EIVDropout(nn.Module): """ A Dropout Layer with Dropout probability `p` (default 0.5) that repeats the same Bernoulli mask L times along the batch dimension - instead of taking a different one for each batch member - where L is the output of repetition_map(). When evaluation `forward(x)` the batch dimension of `x` (the first one) is asserted to be a multiple of `L=repetition_map()`. The `repetition_map` defaults to the constant map to 1, in which case `EIVDropout` is equivalent to `torch.nn.Dropout`. :param p: A float between 0 and 1. Defaults to 0.5. :param repetition_map: Map that takes no argument and returns an integer. Defaults to the constant map to 1. """ def __init__(self, p=0.5, repetition_map=lambda: 1): super().__init__() self.p = p self.repetition_map = repetition_map self._train = True def train(self, training=True): if training: self._train = True else: self._train = False def eval(self): self.train(training=False) def forward(self, x): if not self._train: return x else: device = self.study_device(x) L = self.repetition_map() input_shape = x.shape assert input_shape[0] % L == 0 mask_shape = list(input_shape) mask_shape[0] = int(input_shape[0] / L) mask = torch.bernoulli(torch.ones(mask_shape) * (1-self.p))\ / (1-self.p) repeated_mask = repetition.repeat_tensors(mask, number_of_draws=L)[0] assert x.shape == repeated_mask.shape return x * repeated_mask.to(device) @staticmethod def study_device(x): if x.is_cuda: return torch.device('cuda:' + str(x.get_device())) else: return torch.device('cpu') class EIVVariationalDropout(nn.Module): """ A Variational Dropout Layer (of Type A, as in Kingma et al.) with Dropout probability `p` (default 0.5) that repeats the same Bernoulli mask L times along the batch dimension - instead of taking a different one for each batch member - where L is the output of repetition_map(). When evaluation `forward(x)` the batch dimension of `x` (the first one) is asserted to be a multiple of `L=repetition_map()`. The `repetition_map` defaults to the constant map to 1, in which case `EIVDropout` is equivalent to classical variational Dropout. :param initial_alpha: A positive float, will be taken as initial value for alpha (the variance of the Gaussian dropout mask). :param repetition_map: Map that takes no argument and returns an integer. Defaults to the constant map to 1. """ c_1 = 1.16145124 c_2 = -1.50204118 c_3 = 0.58629921 def __init__(self, initial_alpha=0.5, repetition_map=lambda: 1): super().__init__() self.initial_alpha = initial_alpha self.repetition_map = repetition_map initial_alpha_par = self.invert_softplus(self.initial_alpha) self.alpha_par = nn.Parameter(torch.tensor(initial_alpha_par)) self._train = True def train(self, training=True): if training: self._train = True else: self._train = False def eval(self): self.train(training=False) @staticmethod def invert_softplus(softplus_value): """ Inverts the Softplus function :param softplus_value: A float """ return np.log(np.exp(softplus_value)-1) def alpha(self): return nn.Softplus()(self.alpha_par) def forward(self, x): if not self._train: return x else: device = self.study_device(x) L = self.repetition_map() input_shape = x.shape assert input_shape[0] % L == 0 mask_shape = list(input_shape) mask_shape[0] = int(input_shape[0] / L) mask = 1.0 + torch.sqrt(self.alpha()) * \ torch.randn(mask_shape).to(device) repeated_mask = repetition.repeat_tensors(mask, number_of_draws=L)[0] assert x.shape == repeated_mask.shape return x * repeated_mask.to(device) @staticmethod def study_device(x): if x.is_cuda: return torch.device('cuda:' + str(x.get_device())) else: return torch.device('cpu') def variational_dropout_regularizer(self): """ Taken from Kingma et al. Identical, up to a constant, to the neg. KL-div of the variational distribution and the improper prior. """ alpha = self.alpha() return 0.5 * torch.log(alpha)\ + self.c_1 * alpha + self.c_2 * alpha**2 + self.c_3 * alpha**3 class EIVInput(nn.Module): """ Input class for an Error-in-Variables model based on variational inference. For regularization there is a method EIVInput.x_evidence included. The underlying model is x = zeta + epsilon, where zeta denotes the true (but unknown) input variable and epsilon denotes normal noise with standard deviation std_x. :param init_std_x: The initial value of sigma_x (std of noise on input), will be made a learnable parameter. Defaults to 1.0 :param precision_prior_zeta: The precision of the prior on zeta (the true, but unknown input). Defaults to 0 (= flat prior) :param external_std_x: If not None, should be a map that returns (without arguments) std_x (i.e. as the output of self.get_std_x). If None (the Default) will be parametrized via a Softmax of a torch.nn.Parameter. :param std_x_scale: Will be used for scaling std_x internally. Defaults to 1.0. """ def __init__(self, init_std_x = 1.0, precision_prior_zeta=0.0, external_std_x=None, std_x_scale=1.0): super(EIVInput, self).__init__() self.init_std_x = init_std_x self.precision_prior_zeta = precision_prior_zeta self.external_std_x = external_std_x InverseSoftplus = lambda sigma: torch.log(torch.exp(sigma) - 1 ) self.std_x_scale = std_x_scale if external_std_x is None: self.std_x_par = nn.Parameter( InverseSoftplus( torch.tensor([init_std_x]))/self.std_x_scale) self.noise = True # self.fixed_z will be used as "noise" if self.noise is False self.fixed_z = 0.0 def get_std_x(self): """ Returns the standard deviation of the distribution of x given zeta. """ if self.external_std_x is None: return nn.Softplus()(self.std_x_par * self.std_x_scale) else: return self.external_std_x() def forward(self, x): std_x = self.get_std_x() if std_x.item() > 0: std_zeta = 1/(1/std_x**2 + self.precision_prior_zeta)**0.5 mean_zeta = std_zeta**2/std_x**2 * x else: assert std_x.item() == 0 std_zeta = 0.0 mean_zeta = x if self.noise: z = torch.randn_like(x) return mean_zeta + std_zeta * z else: return mean_zeta + std_zeta * self.fixed_z def neg_x_evidence(self, x, remove_divergent_term=True): """ Returns the value of x under the prior predictive of the EIV model :param x: A torch.tensor :param remove_divergent_term: If True (default) the constant term, that diverges for a float prior will be removed. """ if not remove_divergent_term: if self.precision_prior_zeta == 0: raise ValueError('Divergent term infinite in EIV model') else: divergent_term = 0.5 * torch.log(self.precision_prior_zeta) else: divergent_term = 0 batch_size = x.shape[0] regularization = 0.5/batch_size\ * self.precision_prior_zeta/( self.precision_prior_zeta * self.get_std_x()**2 + 1) \ * torch.sum(x**2) regularization += 0.5 * torch.log(2 * pi * (self.get_std_x()**2 * self.precision_prior_zeta + 1) ) regularization += divergent_term return regularization