Skip to content
Snippets Groups Projects
Commit 52526a39 authored by Jörg Martin's avatar Jörg Martin
Browse files

Summed over dimension in loss function

parent 6e255c64
Branches
Tags
No related merge requests found
...@@ -231,8 +231,9 @@ class FNNEIV(nn.Module): ...@@ -231,8 +231,9 @@ class FNNEIV(nn.Module):
sigma = torch.mean(sigma, dim=1) sigma = torch.mean(sigma, dim=1)
return pred, sigma return pred, sigma
def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True, def predictive_logdensity(self, x_or_predictions, y,
average_batch_dimension=True, scale_labels=None, number_of_draws=[100, 5], number_of_parameter_chunks = None,
remove_graph=True, average_batch_dimension=True, scale_labels=None,
decouple_dimensions=False): decouple_dimensions=False):
""" """
Computes the logarithm of the predictive density evaluated at `y`. If Computes the logarithm of the predictive density evaluated at `y`. If
...@@ -264,7 +265,8 @@ class FNNEIV(nn.Module): ...@@ -264,7 +265,8 @@ class FNNEIV(nn.Module):
False. False.
""" """
if type(x_or_predictions) is torch.tensor: if type(x_or_predictions) is torch.tensor:
out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws, out, sigmas = self.predict(x_or_predictions,
number_of_draws=number_of_draws,
number_of_parameter_chunks=number_of_parameter_chunks, number_of_parameter_chunks=number_of_parameter_chunks,
remove_graph=remove_graph, remove_graph=remove_graph,
take_average_of_prediction=False) take_average_of_prediction=False)
...@@ -307,7 +309,8 @@ class FNNEIV(nn.Module): ...@@ -307,7 +309,8 @@ class FNNEIV(nn.Module):
else: else:
return predictive_log_density_values return predictive_log_density_values
def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None, def predict_mean_and_unc(self, x, number_of_draws=[100,5],
number_of_parameter_chunks = None,
remove_graph=True): remove_graph=True):
""" """
Take the mean and standard deviation over `number_of_draws` forward Take the mean and standard deviation over `number_of_draws` forward
......
...@@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg): ...@@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg):
:param y: A torch.tensor, the output. :param y: A torch.tensor, the output.
:param reg: A non-negative float, the regularization. :param reg: A non-negative float, the regularization.
""" """
out, std_y = net(x) out, sigmas = net(x)
# Add label dimension to y if missing # Add label dimension to y if missing
if len(y.shape) <= 1: if len(y.shape) <= 1:
y = y.view((-1,1)) y = y.view((-1,1))
# squeeze last dimensions into one
y = y.view((*y.shape[:1], -1))
sigmas = sigmas.view((*sigmas.shape[:1], -1))
out = out.view((*out.shape[:1], -1))
assert out.shape == y.shape assert out.shape == y.shape
neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \ neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \
+ ((out-y)**2)/(2*std_y**2)) + ((out-y)**2)/(2*sigmas**2), dim=1))
regularization = net.regularizer(x, lamb=reg) regularization = net.regularizer(x, lamb=reg)
return neg_log_likelihood + regularization return neg_log_likelihood + regularization
...@@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5): ...@@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5):
regularization = net.regularizer(x, lamb=reg) regularization = net.regularizer(x, lamb=reg)
# repeat_tensors # repeat_tensors
x, y = repeat_tensors(x, y, number_of_draws=number_of_draws) x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
pred, sigma = net(x, repetition=number_of_draws) out, sigmas = net(x, repetition=number_of_draws)
# split into chunks of size number_of_draws along batch dimension # split into chunks of size number_of_draws along batch dimension
pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws) out, sigmas, y = reshape_to_chunks(out, sigmas, y,
assert pred.shape == y.shape number_of_draws=number_of_draws)
# squeeze last dimensions into one
y = y.view((*y.shape[:2], -1))
sigmas = sigmas.view((*sigmas.shape[:2], -1))
out = out.view((*out.shape[:2], -1))
assert out.shape == y.shape
# apply logsumexp to chunks and average the results # apply logsumexp to chunks and average the results
nll = -1 * (torch.logsumexp(-1 * sigma.log() nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi)
-((y-pred)**2)/(2*sigma**2), dim=1) -((y-out)**2)/(2*sigmas**2), dim=2), dim=1)
- np.log(number_of_draws)).mean() - np.log(number_of_draws)).mean()
return nll + regularization return nll + regularization
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
"report_point": 5, "report_point": 5,
"p": 0.2, "p": 0.2,
"lr_update": 200, "lr_update": 200,
"std_y_update_points": [20,20], "std_y_update_points": [1,500],
"eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.5],
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
"report_point": 5, "report_point": 5,
"p": 0.2, "p": 0.2,
"lr_update": 200, "lr_update": 200,
"std_y_update_points": [20,20], "std_y_update_points": [1,500],
"noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10, "noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.5],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment