From 52526a395d811a6468578c6093c77362d235b981 Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Wed, 15 Dec 2021 09:37:05 +0100 Subject: [PATCH] Summed over dimension in loss function --- EIVPackage/EIVArchitectures/Networks.py | 11 +++++--- .../EIVTrainingRoutines/loss_functions.py | 25 +++++++++++++------ Experiments/configurations/eiv_yacht.json | 2 +- Experiments/configurations/noneiv_yacht.json | 2 +- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py index 22f48bb..8928b3f 100644 --- a/EIVPackage/EIVArchitectures/Networks.py +++ b/EIVPackage/EIVArchitectures/Networks.py @@ -231,8 +231,9 @@ class FNNEIV(nn.Module): sigma = torch.mean(sigma, dim=1) return pred, sigma - def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True, - average_batch_dimension=True, scale_labels=None, + def predictive_logdensity(self, x_or_predictions, y, + number_of_draws=[100, 5], number_of_parameter_chunks = None, + remove_graph=True, average_batch_dimension=True, scale_labels=None, decouple_dimensions=False): """ Computes the logarithm of the predictive density evaluated at `y`. If @@ -264,7 +265,8 @@ class FNNEIV(nn.Module): False. """ if type(x_or_predictions) is torch.tensor: - out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws, + out, sigmas = self.predict(x_or_predictions, + number_of_draws=number_of_draws, number_of_parameter_chunks=number_of_parameter_chunks, remove_graph=remove_graph, take_average_of_prediction=False) @@ -307,7 +309,8 @@ class FNNEIV(nn.Module): else: return predictive_log_density_values - def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None, + def predict_mean_and_unc(self, x, number_of_draws=[100,5], + number_of_parameter_chunks = None, remove_graph=True): """ Take the mean and standard deviation over `number_of_draws` forward diff --git a/EIVPackage/EIVTrainingRoutines/loss_functions.py b/EIVPackage/EIVTrainingRoutines/loss_functions.py index 88e924d..e1c552d 100644 --- a/EIVPackage/EIVTrainingRoutines/loss_functions.py +++ b/EIVPackage/EIVTrainingRoutines/loss_functions.py @@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg): :param y: A torch.tensor, the output. :param reg: A non-negative float, the regularization. """ - out, std_y = net(x) + out, sigmas = net(x) # Add label dimension to y if missing if len(y.shape) <= 1: y = y.view((-1,1)) + # squeeze last dimensions into one + y = y.view((*y.shape[:1], -1)) + sigmas = sigmas.view((*sigmas.shape[:1], -1)) + out = out.view((*out.shape[:1], -1)) assert out.shape == y.shape - neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \ - + ((out-y)**2)/(2*std_y**2)) + neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \ + + ((out-y)**2)/(2*sigmas**2), dim=1)) regularization = net.regularizer(x, lamb=reg) return neg_log_likelihood + regularization @@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5): regularization = net.regularizer(x, lamb=reg) # repeat_tensors x, y = repeat_tensors(x, y, number_of_draws=number_of_draws) - pred, sigma = net(x, repetition=number_of_draws) + out, sigmas = net(x, repetition=number_of_draws) # split into chunks of size number_of_draws along batch dimension - pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws) - assert pred.shape == y.shape + out, sigmas, y = reshape_to_chunks(out, sigmas, y, + number_of_draws=number_of_draws) + # squeeze last dimensions into one + y = y.view((*y.shape[:2], -1)) + sigmas = sigmas.view((*sigmas.shape[:2], -1)) + out = out.view((*out.shape[:2], -1)) + assert out.shape == y.shape # apply logsumexp to chunks and average the results - nll = -1 * (torch.logsumexp(-1 * sigma.log() - -((y-pred)**2)/(2*sigma**2), dim=1) + nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi) + -((y-out)**2)/(2*sigmas**2), dim=2), dim=1) - np.log(number_of_draws)).mean() return nll + regularization diff --git a/Experiments/configurations/eiv_yacht.json b/Experiments/configurations/eiv_yacht.json index f95841b..7617289 100644 --- a/Experiments/configurations/eiv_yacht.json +++ b/Experiments/configurations/eiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "std_y_update_points": [20,20], + "std_y_update_points": [1,500], "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_yacht.json b/Experiments/configurations/noneiv_yacht.json index ae373b4..81a0d14 100644 --- a/Experiments/configurations/noneiv_yacht.json +++ b/Experiments/configurations/noneiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "std_y_update_points": [20,20], + "std_y_update_points": [1,500], "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], -- GitLab