From 40c427306bde690a39f22f75ad2db52927eff5c9 Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Fri, 3 Dec 2021 14:43:20 +0100
Subject: [PATCH] Fixed bug in prediction method of FNNEIV

---
 EIVPackage/EIVArchitectures/Networks.py | 4 ++--
 Experiments/evaluate_tabular.py         | 6 ++++--
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py
index 3a26b02..09d75b8 100644
--- a/EIVPackage/EIVArchitectures/Networks.py
+++ b/EIVPackage/EIVArchitectures/Networks.py
@@ -206,7 +206,7 @@ class FNNEIV(nn.Module):
             pred, sigma = reshape_to_chunks(pred, sigma, 
                     number_of_draws=parameter_sample_size * number_of_draws[1])
             pred_collection.append(pred)
-            sigma_collection.append(pred)
+            sigma_collection.append(sigma)
             remaining_draws -= parameter_sample_size
         pred = torch.cat(pred_collection, dim=1)
         sigma = torch.cat(sigma_collection, dim=1)
@@ -279,7 +279,7 @@ class FNNEIV(nn.Module):
         # average over parameter values
         predictive_log_density_values = \
                 torch.logsumexp(input=exp_arg, dim=1)\
-                    - torch.log(torch.tensor(number_of_draws)) 
+                    - torch.log(torch.prod(torch.tensor(number_of_draws))) 
         if average_batch_dimension:
             return torch.mean(predictive_log_density_values, dim=0)
         else:
diff --git a/Experiments/evaluate_tabular.py b/Experiments/evaluate_tabular.py
index 90b972f..e13f9d5 100644
--- a/Experiments/evaluate_tabular.py
+++ b/Experiments/evaluate_tabular.py
@@ -25,7 +25,7 @@ input_dim = train_data[0][0].numel()
 output_dim = train_data[0][1].numel()
 
 def collect_metrics(x,y, seed=0,
-        noneiv_number_of_draws=100, eiv_number_of_draws=100,
+        noneiv_number_of_draws=100, eiv_number_of_draws=[100,1],
         decouple_dimensions=False, device=torch.device('cuda:1')):
     """
     :param x: A torch.tensor, taken as input
@@ -128,7 +128,8 @@ def collect_metrics(x,y, seed=0,
     # NLL
     training_state = net.training
     net.train()
-    eiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100,
+    eiv_logdens = net.predictive_logdensity(x, y,
+            number_of_draws=eiv_number_of_draws,
             decouple_dimensions=decouple_dimensions,
             scale_labels=\
             train_data.dataset.std_labels.view((-1,)).to(device)\
@@ -162,6 +163,7 @@ for seed in tqdm(seed_list):
 
 
 # TODO: Despite statistics, the fluctuations seem to be large
+# TODO: fix sqrt scaling, missing factor
 print('Non-EiV')
 print(f'RMSE {np.mean(noneiv_rmse_collection):.3f}'\
         f'({np.std(noneiv_rmse_collection)/np.sqrt(num_test_epochs):.3f})')
-- 
GitLab