Skip to content
Snippets Groups Projects
Commit ba899f54 authored by Jörg Martin's avatar Jörg Martin
Browse files

Included the possibility for intervals between std_y updates

parents 243bdcb2 b144db33
No related branches found
No related tags found
No related merge requests found
Showing
with 59 additions and 30 deletions
......@@ -32,7 +32,7 @@ class FNNEIV(nn.Module):
**Note**:
- To change the deming factor afterwards, use the method `change_deming`
- To change fixed_std_x afterwards, use the method `change_fixed_std_x`
- To change std_y use the method `change_std_x`
- To change std_y use the method `change_std_y`
"""
LeakyReLUSlope = 1e-2
def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0,
......@@ -231,8 +231,9 @@ class FNNEIV(nn.Module):
sigma = torch.mean(sigma, dim=1)
return pred, sigma
def predictive_logdensity(self, x_or_predictions, y, number_of_draws=[100, 5], number_of_parameter_chunks = None, remove_graph=True,
average_batch_dimension=True, scale_labels=None,
def predictive_logdensity(self, x_or_predictions, y,
number_of_draws=[100, 5], number_of_parameter_chunks = None,
remove_graph=True, average_batch_dimension=True, scale_labels=None,
decouple_dimensions=False):
"""
Computes the logarithm of the predictive density evaluated at `y`. If
......@@ -264,7 +265,8 @@ class FNNEIV(nn.Module):
False.
"""
if type(x_or_predictions) is torch.tensor:
out, sigmas = self.predict(x_or_predictions, number_of_draws=number_of_draws,
out, sigmas = self.predict(x_or_predictions,
number_of_draws=number_of_draws,
number_of_parameter_chunks=number_of_parameter_chunks,
remove_graph=remove_graph,
take_average_of_prediction=False)
......@@ -307,7 +309,8 @@ class FNNEIV(nn.Module):
else:
return predictive_log_density_values
def predict_mean_and_unc(self, x, number_of_draws=[100,5], number_of_parameter_chunks = None,
def predict_mean_and_unc(self, x, number_of_draws=[100,5],
number_of_parameter_chunks = None,
remove_graph=True):
"""
Take the mean and standard deviation over `number_of_draws` forward
......@@ -349,6 +352,7 @@ class FNNBer(nn.Module):
:param h: A list specifying the number of neurons in each layer.
:param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
be updated during optimization. Defaults to False.
To change std_y use the method `change_std_y`
"""
LeakyReLUSlope = 1e-2
def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1],
......
......@@ -16,13 +16,17 @@ def nll_reg_loss(net, x, y, reg):
:param y: A torch.tensor, the output.
:param reg: A non-negative float, the regularization.
"""
out, std_y = net(x)
out, sigmas = net(x)
# Add label dimension to y if missing
if len(y.shape) <= 1:
y = y.view((-1,1))
# squeeze last dimensions into one
y = y.view((*y.shape[:1], -1))
sigmas = sigmas.view((*sigmas.shape[:1], -1))
out = out.view((*out.shape[:1], -1))
assert out.shape == y.shape
neg_log_likelihood = torch.mean(0.5* torch.log(2*pi*std_y**2) \
+ ((out-y)**2)/(2*std_y**2))
neg_log_likelihood = torch.mean(torch.sum(0.5* torch.log(2*pi*sigmas**2) \
+ ((out-y)**2)/(2*sigmas**2), dim=1))
regularization = net.regularizer(x, lamb=reg)
return neg_log_likelihood + regularization
......@@ -45,12 +49,17 @@ def nll_eiv(net, x, y, reg, number_of_draws=5):
regularization = net.regularizer(x, lamb=reg)
# repeat_tensors
x, y = repeat_tensors(x, y, number_of_draws=number_of_draws)
pred, sigma = net(x, repetition=number_of_draws)
out, sigmas = net(x, repetition=number_of_draws)
# split into chunks of size number_of_draws along batch dimension
pred, sigma, y = reshape_to_chunks(pred, sigma, y, number_of_draws=number_of_draws)
assert pred.shape == y.shape
out, sigmas, y = reshape_to_chunks(out, sigmas, y,
number_of_draws=number_of_draws)
# squeeze last dimensions into one
y = y.view((*y.shape[:2], -1))
sigmas = sigmas.view((*sigmas.shape[:2], -1))
out = out.view((*out.shape[:2], -1))
assert out.shape == y.shape
# apply logsumexp to chunks and average the results
nll = -1 * (torch.logsumexp(-1 * sigma.log()
-((y-pred)**2)/(2*sigma**2), dim=1)
nll = -1 * (torch.logsumexp(torch.sum(-1/2 * torch.log(sigmas**2 * 2 * pi)
-((y-out)**2)/(2*sigmas**2), dim=2), dim=1)
- np.log(number_of_draws)).mean()
return nll + regularization
......@@ -69,6 +69,15 @@ class TrainEpoch():
"""
pass
def post_train_update(self, net, epoch=None):
"""
Will be executed after the training is finished
:param net: The current net, a torch.nn.Module
:param epoch: Tue last epochn number, an integer.
"""
pass
def extra_report(self, net, step):
"""
Overwrite for reporting on state of net
......@@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
Calls `epoch_map` with `epoch` and the current epoch number
`number_of_epochs` times and stores a list of
its output as a pickled file under `save_file`.
After this (the training) is done, `epoch_map.post_train_update(net,
epoch_number)` is called, if existent.
**Note**: The output of `epoch_map` is supposed to
consist of 4 specific arguments, see below.
:param net: A torch.nn.Module
......@@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
test_loss_collection.append(test_loss)
std_x_collection.append(std_x)
std_y_collection.append(std_y)
try:
epoch_map.post_train_update(net, number_of_epochs-1)
print('Training done. Performed post training updating.')
except AttributeError:
print('Training done. No post train updating performed.')
# Saving
state_dict = net.state_dict()
to_save = {
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.1,
"lr_update": 20,
"epoch_offset": 10,
"std_y_update_points": [10,5],
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 10,
"std_y_update_points": 10,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 100,
"epoch_offset": 100,
"std_y_update_points": 100,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 19,
"std_y_update_points": 19,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 4,
"epoch_offset": 4,
"std_y_update_points": 4,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 20,
"std_y_update_points": 20,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 10,
"epoch_offset": 15,
"std_y_update_points": 15,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 10,
"epoch_offset": 10,
"std_y_update_points": 10,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 30,
"epoch_offset": 50,
"std_y_update_points": 50,
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 200,
"epoch_offset": 20,
"std_y_update_points": [1,500],
"eiv_prediction_number_of_draws": 100,
"eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.1,
"lr_update": 20,
"epoch_offset": 0 ,
"std_y_update_points": [10,5] ,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 10,
"std_y_update_points": 10,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 100,
"epoch_offset": 100,
"std_y_update_points": 100,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 19,
"std_y_update_points": 19,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 4,
"epoch_offset": 4,
"std_y_update_points": 4,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 20,
"epoch_offset": 20,
"std_y_update_points": 20,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
......@@ -9,7 +9,7 @@
"report_point": 5,
"p": 0.2,
"lr_update": 10,
"epoch_offset": 15,
"std_y_update_points": 15,
"noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment