Newer
Older
from math import pi
import numpy as np
import torch
from EIVGeneral.repetition import repeat_tensors, reshape_to_chunks
def nll_reg_loss(net, x, y, reg):
"""
Returns the neg log likelihood with an additional regularization term.
**Note**: that `reg` will not be divided by the data size (and by 2),
this should be done beforehand.
:param net: A torch.nn.Module.
:param x: A torch.tensor, the input.
:param y: A torch.tensor, the output.
:param reg: A non-negative float, the regularization.
"""
if len(y.shape) <= 1:
y = y.view((-1,1))
# squeeze last dimensions into one
y = y.view((*y.shape[:1], -1))
sigmas = sigmas.view((*sigmas.shape[:1], -1))
out = out.view((*out.shape[:1], -1))
assert out.shape == y.shape
neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \
+ ((out-y)**2)/(2*sigmas**2), dim=1))
regularization = net.regularizer(x, lamb=reg)
return neg_log_likelihood + regularization
def nll_eiv(net, x, y, reg, number_of_draws=5):
"""
negative log likelihood criterion for an Error in variables model (EIV)
where `torch.logsumexp` is applied to partitions of size `number_of_draws`
of `mu` and `sigma` in the batch dimension (that is the first one).
*Note that `reg` will not be divided by the data size (and by 2),
this should be done beforehand.*
:param mu: predicted mu
:param sigma: predicted sigma
:param y: ground truth
:number_of_draws: Integer, supposed to be larger than 2
"""
# Add label dimension to y if missing
if len(y.shape) <= 1:
y = y.view((-1,1))
regularization = net.regularizer(x, lamb=reg)
# repeat_tensors
x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
out, sigmas = net(x, repetition=number_of_draws)
# split into chunks of size number_of_draws along batch dimension
out, sigmas, y = reshape_to_chunks(out, sigmas, y,
number_of_draws=number_of_draws)
# squeeze last dimensions into one
y = y.view((*y.shape[:2], -1))
sigmas = sigmas.view((*sigmas.shape[:2], -1))
out = out.view((*out.shape[:2], -1))
assert out.shape == y.shape
nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi)
-((y-out)**2)/(2*sigmas**2), dim=2), dim=1)
- np.log(number_of_draws)).mean()
return nll + regularization