Skip to content
Snippets Groups Projects
loss_functions.py 2.49 KiB
Newer Older
Jörg Martin's avatar
Jörg Martin committed
from math import pi

import numpy as np
import torch

from EIVGeneral.repetition import repeat_tensors, reshape_to_chunks


def nll_reg_loss(net, x, y, reg):
    """
    Returns the neg log likelihood with an additional regularization term.
    **Note**: that `reg` will not be divided by the data size (and by 2), 
     this should be done beforehand.
Jörg Martin's avatar
Jörg Martin committed
    :param net: A torch.nn.Module.
    :param x: A torch.tensor, the input.
    :param y: A torch.tensor, the output.
    :param reg: A non-negative float, the regularization.
    """
    out, sigmas = net(x)
    # Add label dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    # squeeze last dimensions into one
    y = y.view((*y.shape[:1], -1))
    sigmas = sigmas.view((*sigmas.shape[:1], -1))
    out = out.view((*out.shape[:1], -1))
    assert out.shape == y.shape
    neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \
            + ((out-y)**2)/(2*sigmas**2), dim=1)) 
Jörg Martin's avatar
Jörg Martin committed
    regularization = net.regularizer(x, lamb=reg)
    return neg_log_likelihood + regularization


def nll_eiv(net, x, y, reg, number_of_draws=5):
Jörg Martin's avatar
Jörg Martin committed
    """
    negative log likelihood criterion for an Error in variables model (EIV) 
    where `torch.logsumexp` is applied to partitions of size `number_of_draws`
    of `mu` and `sigma` in the batch dimension (that is the first one). 
    *Note that `reg` will not be divided by the data size (and by 2), 
     this should be done beforehand.*
    :param mu: predicted mu
    :param sigma: predicted sigma
    :param y: ground truth
    :number_of_draws: Integer, supposed to be larger than 2
    """
    # Add label dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
Jörg Martin's avatar
Jörg Martin committed
    regularization = net.regularizer(x, lamb=reg)
    # repeat_tensors
    x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
    out, sigmas = net(x, repetition=number_of_draws) 
Jörg Martin's avatar
Jörg Martin committed
    # split into chunks of size number_of_draws along batch dimension
    out, sigmas, y = reshape_to_chunks(out, sigmas, y,
            number_of_draws=number_of_draws)
    # squeeze last dimensions into one
    y = y.view((*y.shape[:2], -1))
    sigmas = sigmas.view((*sigmas.shape[:2], -1))
    out = out.view((*out.shape[:2], -1))
    assert out.shape == y.shape
Jörg Martin's avatar
Jörg Martin committed
    # apply logsumexp to chunks and average the results
    nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi)
        -((y-out)**2)/(2*sigmas**2), dim=2), dim=1)
Jörg Martin's avatar
Jörg Martin committed
        - np.log(number_of_draws)).mean()
    return nll + regularization