from math import pi import numpy as np import torch from EIVGeneral.repetition import repeat_tensors, reshape_to_chunks def nll_reg_loss(net, x, y, reg): """ Returns the neg log likelihood with an additional regularization term. **Note**: that `reg` will not be divided by the data size (and by 2), this should be done beforehand. :param net: A torch.nn.Module. :param x: A torch.tensor, the input. :param y: A torch.tensor, the output. :param reg: A non-negative float, the regularization. """ out, sigmas = net(x) # Add label dimension to y if missing if len(y.shape) <= 1: y = y.view((-1,1)) # squeeze last dimensions into one y = y.view((*y.shape[:1], -1)) sigmas = sigmas.view((*sigmas.shape[:1], -1)) out = out.view((*out.shape[:1], -1)) assert out.shape == y.shape neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \ + ((out-y)**2)/(2*sigmas**2), dim=1)) regularization = net.regularizer(x, lamb=reg) return neg_log_likelihood + regularization def nll_eiv(net, x, y, reg, number_of_draws=5): """ negative log likelihood criterion for an Error in variables model (EIV) where `torch.logsumexp` is applied to partitions of size `number_of_draws` of `mu` and `sigma` in the batch dimension (that is the first one). *Note that `reg` will not be divided by the data size (and by 2), this should be done beforehand.* :param mu: predicted mu :param sigma: predicted sigma :param y: ground truth :number_of_draws: Integer, supposed to be larger than 2 """ # Add label dimension to y if missing if len(y.shape) <= 1: y = y.view((-1,1)) regularization = net.regularizer(x, lamb=reg) # repeat_tensors x, y = repeat_tensors(x, y, number_of_draws=number_of_draws) out, sigmas = net(x, repetition=number_of_draws) # split into chunks of size number_of_draws along batch dimension out, sigmas, y = reshape_to_chunks(out, sigmas, y, number_of_draws=number_of_draws) # squeeze last dimensions into one y = y.view((*y.shape[:2], -1)) sigmas = sigmas.view((*sigmas.shape[:2], -1)) out = out.view((*out.shape[:2], -1)) assert out.shape == y.shape # apply logsumexp to chunks and average the results nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi) -((y-out)**2)/(2*sigmas**2), dim=2), dim=1) - np.log(number_of_draws)).mean() return nll + regularization