diff --git a/EIVPackage/EIVArchitectures/Networks.py b/EIVPackage/EIVArchitectures/Networks.py index 22f48bb5b5b6f2a800c1b2066c2a951c58816963..8928b3fde92517f7397eb0542f8189ea9bf132bd 100644 --- a/EIVPackage/EIVArchitectures/Networks.py +++ b/EIVPackage/EIVArchitectures/Networks.py @@ -231,8 +231,9 @@ class FNNEIV(nn.Module): sigma = torch.mean(sigma, dim=1) return pred, sigma - def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True, - average_batch_dimension=True, scale_labels=None, + def predictive_logdensity(self, x_or_predictions, y, + number_of_draws=[100, 5], number_of_parameter_chunks = None, + remove_graph=True, average_batch_dimension=True, scale_labels=None, decouple_dimensions=False): """ Computes the logarithm of the predictive density evaluated at `y`. If @@ -264,7 +265,8 @@ class FNNEIV(nn.Module): False. """ if type(x_or_predictions) is torch.tensor: - out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws, + out, sigmas = self.predict(x_or_predictions, + number_of_draws=number_of_draws, number_of_parameter_chunks=number_of_parameter_chunks, remove_graph=remove_graph, take_average_of_prediction=False) @@ -307,7 +309,8 @@ class FNNEIV(nn.Module): else: return predictive_log_density_values - def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None, + def predict_mean_and_unc(self, x, number_of_draws=[100,5], + number_of_parameter_chunks = None, remove_graph=True): """ Take the mean and standard deviation over `number_of_draws` forward diff --git a/EIVPackage/EIVTrainingRoutines/loss_functions.py b/EIVPackage/EIVTrainingRoutines/loss_functions.py index 88e924df76c1b2bd00e4c4abcdf18089f699539e..e1c552d3d621dd29c91b555c9c50a1089c836a12 100644 --- a/EIVPackage/EIVTrainingRoutines/loss_functions.py +++ b/EIVPackage/EIVTrainingRoutines/loss_functions.py @@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg): :param y: A torch.tensor, the output. :param reg: A non-negative float, the regularization. """ - out, std_y = net(x) + out, sigmas = net(x) # Add label dimension to y if missing if len(y.shape) <= 1: y = y.view((-1,1)) + # squeeze last dimensions into one + y = y.view((*y.shape[:1], -1)) + sigmas = sigmas.view((*sigmas.shape[:1], -1)) + out = out.view((*out.shape[:1], -1)) assert out.shape == y.shape - neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \ - + ((out-y)**2)/(2*std_y**2)) + neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \ + + ((out-y)**2)/(2*sigmas**2), dim=1)) regularization = net.regularizer(x, lamb=reg) return neg_log_likelihood + regularization @@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5): regularization = net.regularizer(x, lamb=reg) # repeat_tensors x, y = repeat_tensors(x, y, number_of_draws=number_of_draws) - pred, sigma = net(x, repetition=number_of_draws) + out, sigmas = net(x, repetition=number_of_draws) # split into chunks of size number_of_draws along batch dimension - pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws) - assert pred.shape == y.shape + out, sigmas, y = reshape_to_chunks(out, sigmas, y, + number_of_draws=number_of_draws) + # squeeze last dimensions into one + y = y.view((*y.shape[:2], -1)) + sigmas = sigmas.view((*sigmas.shape[:2], -1)) + out = out.view((*out.shape[:2], -1)) + assert out.shape == y.shape # apply logsumexp to chunks and average the results - nll = -1 * (torch.logsumexp(-1 * sigma.log() - -((y-pred)**2)/(2*sigma**2), dim=1) + nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi) + -((y-out)**2)/(2*sigmas**2), dim=2), dim=1) - np.log(number_of_draws)).mean() return nll + regularization diff --git a/Experiments/configurations/eiv_yacht.json b/Experiments/configurations/eiv_yacht.json index f95841b1a8074294a5d1ba0b921c8deee0dd5d4f..7617289d9e49bb9ceddd9962d6d865fee3413ae7 100644 --- a/Experiments/configurations/eiv_yacht.json +++ b/Experiments/configurations/eiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "std_y_update_points": [20,20], + "std_y_update_points": [1,500], "eiv_prediction_number_of_draws": 100, "eiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5], diff --git a/Experiments/configurations/noneiv_yacht.json b/Experiments/configurations/noneiv_yacht.json index ae373b4a120ee252f0737adf940c9b352e48dcc6..81a0d1459d5faf474a922d7916336c57c255c07f 100644 --- a/Experiments/configurations/noneiv_yacht.json +++ b/Experiments/configurations/noneiv_yacht.json @@ -9,7 +9,7 @@ "report_point": 5, "p": 0.2, "lr_update": 200, - "std_y_update_points": [20,20], + "std_y_update_points": [1,500], "noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_batches": 10, "init_std_y_list": [0.5],