Skip to content
Snippets Groups Projects
loss_functions.py 2.1 KiB
Newer Older
Jörg Martin's avatar
Jörg Martin committed
from math import pi

import numpy as np
import torch

from EIVGeneral.repetition import repeat_tensors, reshape_to_chunks


def nll_reg_loss(net, x, y, reg):
    """
    Returns the neg log likelihood with an additional regularization term.
    **Note**: that `reg` will not be divided by the data size (and by 2), 
     this should be done beforehand.
Jörg Martin's avatar
Jörg Martin committed
    :param net: A torch.nn.Module.
    :param x: A torch.tensor, the input.
    :param y: A torch.tensor, the output.
    :param reg: A non-negative float, the regularization.
    """
    out, std_y = net(x)
    # Add label dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
    assert out.shape == y.shape
Jörg Martin's avatar
Jörg Martin committed
    neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \
            + ((out-y)**2)/(2*std_y**2)) 
    regularization = net.regularizer(x, lamb=reg)
    return neg_log_likelihood + regularization


def nll_eiv(net, x, y, reg, number_of_draws=5):
Jörg Martin's avatar
Jörg Martin committed
    """
    negative log likelihood criterion for an Error in variables model (EIV) 
    where `torch.logsumexp` is applied to partitions of size `number_of_draws`
    of `mu` and `sigma` in the batch dimension (that is the first one). 
    *Note that `reg` will not be divided by the data size (and by 2), 
     this should be done beforehand.*
    :param mu: predicted mu
    :param sigma: predicted sigma
    :param y: ground truth
    :number_of_draws: Integer, supposed to be larger than 2
    """
    # Add label dimension to y if missing
    if len(y.shape) <= 1:
        y = y.view((-1,1))
Jörg Martin's avatar
Jörg Martin committed
    regularization = net.regularizer(x, lamb=reg)
    # repeat_tensors
    x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
    pred, sigma = net(x, repetition=number_of_draws) 
    # split into chunks of size number_of_draws along batch dimension
    pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws)
    assert pred.shape == y.shape
Jörg Martin's avatar
Jörg Martin committed
    # apply logsumexp to chunks and average the results
    nll = -1 * (torch.logsumexp(-1 * sigma.log()
        -((y-pred)**2)/(2*sigma**2), dim=1)
        - np.log(number_of_draws)).mean()
    return nll + regularization