From 40c427306bde690a39f22f75ad2db52927eff5c9 Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Fri, 3 Dec 2021 14:43:20 +0100 Subject: [PATCH] Fixed bug in prediction method of FNNEIV --- EIVPackage/EIVArchitectures/Networks.py | 4 ++-- Experiments/evaluate_tabular.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py index 3a26b02..09d75b8 100644 --- a/EIVPackage/EIVArchitectures/Networks.py +++ b/EIVPackage/EIVArchitectures/Networks.py @@ -206,7 +206,7 @@ class FNNEIV(nn.Module): pred, sigma = reshape_to_chunks(pred, sigma, number_of_draws=parameter_sample_size * number_of_draws[1]) pred_collection.append(pred) - sigma_collection.append(pred) + sigma_collection.append(sigma) remaining_draws -= parameter_sample_size pred = torch.cat(pred_collection, dim=1) sigma = torch.cat(sigma_collection, dim=1) @@ -279,7 +279,7 @@ class FNNEIV(nn.Module): # average over parameter values predictive_log_density_values = \ torch.logsumexp(input=exp_arg, dim=1)\ - - torch.log(torch.tensor(number_of_draws)) + - torch.log(torch.prod(torch.tensor(number_of_draws))) if average_batch_dimension: return torch.mean(predictive_log_density_values, dim=0) else: diff --git a/Experiments/evaluate_tabular.py b/Experiments/evaluate_tabular.py index 90b972f..e13f9d5 100644 --- a/Experiments/evaluate_tabular.py +++ b/Experiments/evaluate_tabular.py @@ -25,7 +25,7 @@ input_dim = train_data[0][0].numel() output_dim = train_data[0][1].numel() def collect_metrics(x,y, seed=0, - noneiv_number_of_draws=100, eiv_number_of_draws=100, + noneiv_number_of_draws=100, eiv_number_of_draws=[100,1], decouple_dimensions=False, device=torch.device('cuda:1')): """ :param x: A torch.tensor, taken as input @@ -128,7 +128,8 @@ def collect_metrics(x,y, seed=0, # NLL training_state = net.training net.train() - eiv_logdens = net.predictive_logdensity(x, y, number_of_draws=100, + eiv_logdens = net.predictive_logdensity(x, y, + number_of_draws=eiv_number_of_draws, decouple_dimensions=decouple_dimensions, scale_labels=\ train_data.dataset.std_labels.view((-1,)).to(device)\ @@ -162,6 +163,7 @@ for seed in tqdm(seed_list): # TODO: Despite statistics, the fluctuations seem to be large +# TODO: fix sqrt scaling, missing factor print('Non-EiV') print(f'RMSE {np.mean(noneiv_rmse_collection):.3f}'\ f'({np.std(noneiv_rmse_collection)/np.sqrt(num_test_epochs):.3f})') -- GitLab