Skip to content
Snippets Groups Projects
Commit 6e255c64 authored by Jörg Martin's avatar Jörg Martin
Browse files

Post train updating added

parent fddd1aa3
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ class FNNEIV(nn.Module): ...@@ -32,7 +32,7 @@ class FNNEIV(nn.Module):
**Note**: **Note**:
- To change the deming factor afterwards, use the method `change_deming` - To change the deming factor afterwards, use the method `change_deming`
- To change fixed_std_x afterwards, use the method `change_fixed_std_x` - To change fixed_std_x afterwards, use the method `change_fixed_std_x`
- To change std_y use the method `change_std_x` - To change std_y use the method `change_std_y`
""" """
LeakyReLUSlope = 1e-2 LeakyReLUSlope = 1e-2
def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0, def __init__(self, p = 0.2, init_std_y=1.0, precision_prior_zeta=0.0,
...@@ -349,6 +349,7 @@ class FNNBer(nn.Module): ...@@ -349,6 +349,7 @@ class FNNBer(nn.Module):
:param h: A list specifying the number of neurons in each layer. :param h: A list specifying the number of neurons in each layer.
:param std_y_requires_grad: Whether `sigma_y` will require_grad and thus :param std_y_requires_grad: Whether `sigma_y` will require_grad and thus
be updated during optimization. Defaults to False. be updated during optimization. Defaults to False.
To change std_y use the method `change_std_y`
""" """
LeakyReLUSlope = 1e-2 LeakyReLUSlope = 1e-2
def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1], def __init__(self, p=0.2, init_std_y=1.0, h=[10, 1024,1024,1024,1024, 1],
......
...@@ -69,6 +69,15 @@ class TrainEpoch(): ...@@ -69,6 +69,15 @@ class TrainEpoch():
""" """
pass pass
def post_train_update(self, net, epoch=None):
"""
Will be executed after the training is finished
:param net: The current net, a torch.nn.Module
:param epoch: Tue last epochn number, an integer.
"""
pass
def extra_report(self, net, step): def extra_report(self, net, step):
""" """
Overwrite for reporting on state of net Overwrite for reporting on state of net
...@@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): ...@@ -156,6 +165,8 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
Calls `epoch_map` with `epoch` and the current epoch number Calls `epoch_map` with `epoch` and the current epoch number
`number_of_epochs` times and stores a list of `number_of_epochs` times and stores a list of
its output as a pickled file under `save_file`. its output as a pickled file under `save_file`.
After this (the training) is done, `epoch_map.post_train_update(net,
epoch_number)` is called, if existent.
**Note**: The output of `epoch_map` is supposed to **Note**: The output of `epoch_map` is supposed to
consist of 4 specific arguments, see below. consist of 4 specific arguments, see below.
:param net: A torch.nn.Module :param net: A torch.nn.Module
...@@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs): ...@@ -179,6 +190,11 @@ def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
test_loss_collection.append(test_loss) test_loss_collection.append(test_loss)
std_x_collection.append(std_x) std_x_collection.append(std_x)
std_y_collection.append(std_y) std_y_collection.append(std_y)
try:
epoch_map.post_train_update(net, number_of_epochs-1)
print('Training done. Performed post training updating.')
except AttributeError:
print('Training done. No post train updating performed.')
# Saving # Saving
state_dict = net.state_dict() state_dict = net.state_dict()
to_save = { to_save = {
......
...@@ -85,48 +85,71 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): ...@@ -85,48 +85,71 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch):
self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, gamma) self.optimizer, lr_update, gamma)
def update_std_y(self, net):
"""
Update the std_y of `net` via the RMSE of the prediction.
"""
net_train_state = net.training
net_noise_state = net.noise_is_on
net.train()
net.noise_on()
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= eiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=eiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if not net_train_state:
net.eval()
if not net_noise_state:
net.noise_off()
def check_if_update_std_y(self, epoch):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch >= std_y_update_points[0]\
and epoch % std_y_update_points[1] == 0
def post_epoch_update(self, net, epoch): def post_epoch_update(self, net, epoch):
""" """
Overwrites the corresponding method Overwrites the corresponding method
""" """
def update_std_y(epoch_number): if self.check_if_update_std_y(epoch):
""" self.update_std_y(net)
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch_number >= std_y_update_points[0]\
and epoch_number % std_y_update_points[1] == 0
if update_std_y(epoch):
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= eiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=eiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
self.lr_scheduler.step() self.lr_scheduler.step()
def post_train_update(self, net, epoch):
"""
Overwrites the corresponding method. If std_y of `net` was not updated
in the last training step, update it when finished with training.
`epoch` should be the number of the last training epoch.
"""
if not self.check_if_update_std_y(epoch):
self.update_std_y(net)
def extra_report(self, net, i): def extra_report(self, net, i):
""" """
Overwrites the corresponding method Overwrites the corresponding method
......
...@@ -84,48 +84,67 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch): ...@@ -84,48 +84,67 @@ class UpdatedTrainEpoch(train_and_store.TrainEpoch):
self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, lr_update, gamma) self.optimizer, lr_update, gamma)
def update_std_y(self, net):
"""
Update the std_y of `net` via the RMSE of the prediction.
"""
net_train_state = net.training
net.train()
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= noneiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=noneiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
if not net_train_state:
net.eval()
def check_if_update_std_y(self, epoch):
"""
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch >= std_y_update_points[0]\
and epoch % std_y_update_points[1] == 0
def post_epoch_update(self, net, epoch): def post_epoch_update(self, net, epoch):
""" """
Overwrites the corresponding method Overwrites the corresponding method
""" """
def update_std_y(epoch_number): if self.check_if_update_std_y(epoch):
""" self.update_std_y(net)
Check whether to update std_y according to `epoch_number` and
`std_y_update_points`. If the later is an integer, after all epochs
greater than this number an update will be made (i.e. `True` will
be returned). If it is a list, only `epoch_number` greater than
`std_y_update_points[0]` that divide `std_y_update_points[1]` will
result in a True.
"""
if type(std_y_update_points) is int:
return epoch >= std_y_update_points
else:
assert type(std_y_update_points) is list
return epoch_number >= std_y_update_points[0]\
and epoch_number % std_y_update_points[1] == 0
if update_std_y(epoch):
pred_collection = []
y_collection = []
for i, (x,y) in enumerate(self.train_dataloader):
if i >= noneiv_prediction_number_of_batches:
break
if len(y.shape) <= 1:
y = y.view((-1,1))
x,y = x.to(device), y.to(device)
pred, _ = net.predict(x,
number_of_draws=noneiv_prediction_number_of_draws,
remove_graph = True,
take_average_of_prediction=True)
pred_collection.append(pred)
y_collection.append(y)
pred_collection = torch.cat(pred_collection, dim=0)
y_collection = torch.cat(y_collection, dim=0)
assert pred_collection.shape == y_collection.shape
rmse = torch.sqrt(torch.mean((pred_collection - y_collection)**2))
net.change_std_y(rmse)
self.lr_scheduler.step() self.lr_scheduler.step()
def post_train_update(self, net, epoch):
"""
Overwrites the corresponding method. If std_y of `net` was not updated
in the last training step, update it when finished with training.
`epoch` should be the number of the last training epoch.
"""
if not self.check_if_update_std_y(epoch):
self.update_std_y(net)
def extra_report(self, net, i): def extra_report(self, net, i):
""" """
Overwrites the corresponding method Overwrites the corresponding method
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment