Skip to content
Snippets Groups Projects
Commit 6e255c64 authored by Jörg Martin's avatar Jörg Martin
Browse files

Post train updating added

parent fddd1aa3
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ class FNNEIV(nn.Module):
**Note**:
- To change the deming factor afterwards, use the method `change_deming`
- To change fixed_std_x afterwards, use the method `change_fixed_std_x`
- To change std_y use the method `change_std_x`
- To change std_y use the method `change_std_y`
"""
LeakyReLUSlope = 1e-2
def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0,
......@@ -349,6 +349,7 @@ class FNNBer(nn.Module):
:param h: A list specifying the number of neurons in each layer.
:param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
be updated during optimization. Defaults to False.
To change std_y use the method `change_std_y`
"""
LeakyReLUSlope = 1e-2
def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1],
......
......@@ -69,6 +69,15 @@ class TrainEpoch():
"""
pass
def post_train_update(self, net, epoch=None):
"""
Will be executed after the training is finished
:param net: The current net, a torch.nn.Module
:param epoch: Tue last epochn number, an integer.
"""
pass
def extra_report(self, net, step):
"""
Overwrite for reporting on state of net
......@@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
Calls `epoch_map` with `epoch` and the current epoch number
`number_of_epochs` times and stores a list of
its output as a pickled file under `save_file`.
After this (the training) is done, `epoch_map.post_train_update(net,
epoch_number)` is called, if existent.
**Note**: The output of `epoch_map` is supposed to
consist of 4 specific arguments, see below.
:param net: A torch.nn.Module
......@@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
test_loss_collection.append(test_loss)
std_x_collection.append(std_x)
std_y_collection.append(std_y)
try:
epoch_map.post_train_update(net, number_of_epochs-1)
print('Training done. Performed post training updating.')
except AttributeError:
print('Training done. No post train updating performed.')
# Saving
state_dict = net.state_dict()
to_save = {
......
......@@ -85,48 +85,71 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch):
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, gamma)
def update_std_y(self, net):
"""
Update the std_y of `net` via the RMSE of the prediction.
"""
net_train_state = net.training
net_noise_state = net.noise_is_on
net.train()
net.noise_on()
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= eiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=eiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if not net_train_state:
net.eval()
if not net_noise_state:
net.noise_off()
def check_if_update_std_y(self, epoch):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch >= std_y_update_points[0]\
and epoch % std_y_update_points[1] == 0
def post_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
"""
def update_std_y(epoch_number):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch_number >= std_y_update_points[0]\
and epoch_number % std_y_update_points[1] == 0
if update_std_y(epoch):
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= eiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=eiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if self.check_if_update_std_y(epoch):
self.update_std_y(net)
self.lr_scheduler.step()
def post_train_update(self, net, epoch):
"""
Overwrites the corresponding method. If std_y of `net` was not updated
in the last training step, update it when finished with training.
`epoch` should be the number of the last training epoch.
"""
if not self.check_if_update_std_y(epoch):
self.update_std_y(net)
def extra_report(self, net, i):
"""
Overwrites the corresponding method
......
......@@ -84,48 +84,67 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch):
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, gamma)
def update_std_y(self, net):
"""
Update the std_y of `net` via the RMSE of the prediction.
"""
net_train_state = net.training
net.train()
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= noneiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=noneiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if not net_train_state:
net.eval()
def check_if_update_std_y(self, epoch):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch >= std_y_update_points[0]\
and epoch % std_y_update_points[1] == 0
def post_epoch_update(self, net, epoch):
"""
Overwrites the corresponding method
"""
def update_std_y(epoch_number):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch_number >= std_y_update_points[0]\
and epoch_number % std_y_update_points[1] == 0
if update_std_y(epoch):
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= noneiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=noneiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if self.check_if_update_std_y(epoch):
self.update_std_y(net)
self.lr_scheduler.step()
def post_train_update(self, net, epoch):
"""
Overwrites the corresponding method. If std_y of `net` was not updated
in the last training step, update it when finished with training.
`epoch` should be the number of the last training epoch.
"""
if not self.check_if_update_std_y(epoch):
self.update_std_y(net)
def extra_report(self, net, i):
"""
Overwrites the corresponding method
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment