From 52526a395d811a6468578c6093c77362d235b981 Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Wed, 15 Dec 2021 09:37:05 +0100
Subject: [PATCH] Summed over dimension in loss function

---
 EIVPackage/EIVArchitectures/Networks.py       | 11 +++++---
 .../EIVTrainingRoutines/loss_functions.py     | 25 +++++++++++++------
 Experiments/configurations/eiv_yacht.json     |  2 +-
 Experiments/configurations/noneiv_yacht.json  |  2 +-
 4 files changed, 26 insertions(+), 14 deletions(-)

diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py
index 22f48bb..8928b3f 100644
--- a/EIVPackage/EIVArchitectures/Networks.py
+++ b/EIVPackage/EIVArchitectures/Networks.py
@@ -231,8 +231,9 @@ class FNNEIV(nn.Module):
             sigma = torch.mean(sigma, dim=1)
         return pred, sigma
 
-    def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True,
-            average_batch_dimension=True, scale_labels=None,
+    def predictive_logdensity(self, x_or_predictions, y,
+            number_of_draws=[100, 5], number_of_parameter_chunks = None,
+            remove_graph=True, average_batch_dimension=True, scale_labels=None,
             decouple_dimensions=False):
         """
         Computes the logarithm of the predictive density evaluated at `y`. If
@@ -264,7 +265,8 @@ class FNNEIV(nn.Module):
         False.
         """
         if type(x_or_predictions) is torch.tensor:
-            out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws,
+            out, sigmas = self.predict(x_or_predictions,
+                    number_of_draws=number_of_draws,
                 number_of_parameter_chunks=number_of_parameter_chunks,
                 remove_graph=remove_graph,
                 take_average_of_prediction=False)
@@ -307,7 +309,8 @@ class FNNEIV(nn.Module):
         else:
             return predictive_log_density_values
 
-    def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None,
+    def predict_mean_and_unc(self, x, number_of_draws=[100,5],
+            number_of_parameter_chunks = None,
             remove_graph=True):
         """
         Take the mean and standard deviation over `number_of_draws` forward
diff --git a/EIVPackage/EIVTrainingRoutines/loss_functions.py b/EIVPackage/EIVTrainingRoutines/loss_functions.py
index 88e924d..e1c552d 100644
--- a/EIVPackage/EIVTrainingRoutines/loss_functions.py
+++ b/EIVPackage/EIVTrainingRoutines/loss_functions.py
@@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg):
     :param y: A torch.tensor, the output.
     :param reg: A non-negative float, the regularization.
     """
-    out, std_y = net(x)
+    out, sigmas = net(x)
     # Add label dimension to y if missing
     if len(y.shape) <= 1:
         y = y.view((-1,1))
+    # squeeze last dimensions into one
+    y = y.view((*y.shape[:1], -1))
+    sigmas = sigmas.view((*sigmas.shape[:1], -1))
+    out = out.view((*out.shape[:1], -1))
     assert out.shape == y.shape
-    neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \
-            + ((out-y)**2)/(2*std_y**2)) 
+    neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \
+            + ((out-y)**2)/(2*sigmas**2), dim=1)) 
     regularization = net.regularizer(x, lamb=reg)
     return neg_log_likelihood + regularization
 
@@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5):
     regularization = net.regularizer(x, lamb=reg)
     # repeat_tensors
     x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
-    pred, sigma = net(x, repetition=number_of_draws) 
+    out, sigmas = net(x, repetition=number_of_draws) 
     # split into chunks of size number_of_draws along batch dimension
-    pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws)
-    assert pred.shape == y.shape
+    out, sigmas, y = reshape_to_chunks(out, sigmas, y,
+            number_of_draws=number_of_draws)
+    # squeeze last dimensions into one
+    y = y.view((*y.shape[:2], -1))
+    sigmas = sigmas.view((*sigmas.shape[:2], -1))
+    out = out.view((*out.shape[:2], -1))
+    assert out.shape == y.shape
     # apply logsumexp to chunks and average the results
-    nll = -1 * (torch.logsumexp(-1 * sigma.log()
-        -((y-pred)**2)/(2*sigma**2), dim=1)
+    nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi)
+        -((y-out)**2)/(2*sigmas**2), dim=2), dim=1)
         - np.log(number_of_draws)).mean()
     return nll + regularization
diff --git a/Experiments/configurations/eiv_yacht.json b/Experiments/configurations/eiv_yacht.json
index f95841b..7617289 100644
--- a/Experiments/configurations/eiv_yacht.json
+++ b/Experiments/configurations/eiv_yacht.json
@@ -9,7 +9,7 @@
 	"report_point": 5,
 	"p": 0.2,
 	"lr_update": 200,
-	"std_y_update_points": [20,20],
+	"std_y_update_points": [1,500],
 	"eiv_prediction_number_of_draws": 100,
 	"eiv_prediction_number_of_batches": 10,
 	"init_std_y_list": [0.5],
diff --git a/Experiments/configurations/noneiv_yacht.json b/Experiments/configurations/noneiv_yacht.json
index ae373b4..81a0d14 100644
--- a/Experiments/configurations/noneiv_yacht.json
+++ b/Experiments/configurations/noneiv_yacht.json
@@ -9,7 +9,7 @@
 	"report_point": 5,
 	"p": 0.2,
 	"lr_update": 200,
-	"std_y_update_points": [20,20],
+	"std_y_update_points": [1,500],
 	"noneiv_prediction_number_of_draws": 100,
 	"noneiv_prediction_number_of_batches": 10,
 	"init_std_y_list": [0.5],
-- 
GitLab