Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import torch
class TrainEpoch():
"""
A simple implementation of the training during one epoch.
The training is implemented in the __call__ method that returns collected
samples of the train loss and test loss, averaged between
two "report points", and the std_x and std_y returned by `std_x_map` and
`std_y_map` (which should be **floats**).
If `verbose` is `True` (the default) these values will be printed at
"report points".
:param train_dataloader: A torch.nn.utils.DataLoader (for train data)
:param test_dataloader: A torch.nn.utils.DataLoader (for test data)
:param criterion: A map that takes net, x, y and reg as input and returns
a single-element torch.tensor
:param std_y_map: Map without arguments that returns a float.
:param std_x_map: Map without arguments that returns a float.
:param lr: The learn rate (a positive float), defaults to 1e-3
:param reg: The regularization, defaults to 1.0
:param report_point: Positive integer, distance between two report points.
Defaults to 100.
:param verbose: If True (the default) prints information while training.
:param device: Do training on this device (defaults to 'cpu')
std_x and std_y are printed at each report point.
"""
def __init__(self, train_dataloader, test_dataloader,
criterion, std_y_map, std_x_map=lambda: 0.0, lr=1e-3,
reg=1.0, report_point=100, verbose=True, device='cpu'):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.initial_lr = lr
self.criterion = criterion
self.std_x_map = std_x_map
self.std_y_map = std_y_map
self.reg = reg
self.report_point = report_point
self.verbose = verbose
self.device = device
#
self.lr_generator = iter(self.next_lr())
self.lr = None
self.total_count = 0
def next_lr(self):
while True:
yield self.initial_lr
def pre_epoch_update(self, net, epoch):
"""
Overwrite to update optimizer, the learn rate or similar in a
different manner.
*Note* This method is expected to define at least `self.optimizer`.
:param net: The net to be used for the optimizer
:param epoch: The current epoch, an integer.
"""
old_lr = self.lr
self.lr = next(self.lr_generator)
if old_lr != self.lr:
self.optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
def post_epoch_update(self, net, epoch):
"""
Overwrite for inheritance to update the net (e.g.
its deming factor for EIV) after each epoch
:param net: The current net, a torch.nn.Module.
:param epoch: The current epoch, an integer.
"""
pass
def post_train_update(self, net, epoch=None):
"""
Will be executed after the training is finished
:param net: The current net, a torch.nn.Module
:param epoch: Tue last epochn number, an integer.
"""
pass
def extra_report(self, net, step):
"""
Overwrite for reporting on state of net
at each report point
"""
pass
def __call__(self, net, epoch):
"""
:param net: A torch.nn.Module
:param epoch: The current epoch, a non-negative integer.
Will only be used for formatting of the printing.
"""
self.pre_epoch_update(net, epoch)
train_loss, test_loss = [], []
stored_train_loss, stored_test_loss = [], []
stored_std_x, stored_std_y = [], []
if self.verbose:
print('>>>> Epoch %i' % (epoch,))
stored_train_loss_to_average = []
stored_test_loss_to_average = []
for i, (x,y) in enumerate(self.train_dataloader):
self.total_count += 1
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# optimize on train data
x, y = x.to(self.device), y.to(self.device)
loss = self.criterion(net, x, y, self.reg)
loss.backward()
self.optimizer.step()
net.zero_grad()
stored_train_loss_to_average.append(loss.detach().cpu().item())
if i % self.report_point == 0:
# evaluate on test_data
for j, (x,y) in enumerate(self.test_dataloader):
x,y = x.to(self.device), y.to(self.device)
loss = self.criterion(net, x, y,
self.reg).detach().cpu().item()
stored_test_loss_to_average.append(loss)
std_x = self.std_x_map()
std_y = self.std_y_map()
# store values
stored_train_loss.append(
np.mean(stored_train_loss_to_average))
stored_test_loss.append(
np.mean(stored_test_loss_to_average))
stored_std_x.append(std_x)
stored_std_y.append(std_y)
if i>0 and self.verbose:
print('Step %i,'\
' train_loss: %.2f,'\
' test_loss: %.2f,'\
' std_x: %.2f,'
' std_y: %.2f' % (i,
stored_train_loss[-1],
stored_test_loss[-1],
std_x,
std_y
))
# to be used for extra reporting
self.last_train_loss = stored_train_loss[-1]
self.last_test_loss = stored_test_loss[-1]
self.last_std_x = std_x
self.last_std_y = std_y
# extra reporting
self.extra_report(net, i)
stored_train_loss_to_average = []
stored_test_loss_to_average = []
self.post_epoch_update(net, epoch)
# convert to tensors
stored_train_loss = torch.tensor(stored_train_loss,
dtype=torch.float32)
stored_test_loss = torch.tensor(stored_test_loss,
dtype=torch.float32)
stored_std_x = torch.tensor(stored_std_x, dtype=torch.float32)
stored_std_y = torch.tensor(stored_std_y, dtype=torch.float32)
return stored_train_loss,\
stored_test_loss,\
stored_std_x,\
stored_std_y
def train_and_store(net, epoch_map, number_of_epochs, save_file, **kwargs):
"""
Calls `epoch_map` with `epoch` and the current epoch number
`number_of_epochs` times and stores a list of
its output as a pickled file under `save_file`.
After this (the training) is done, `epoch_map.post_train_update(net,
epoch_number)` is called, if existent.
**Note**: The output of `epoch_map` is supposed to
consist of 4 specific arguments, see below.
:param net: A torch.nn.Module
:param epoch_map: A map taking net and an integer (the epoch number) as an
input and returning 4 torch.tensors containing the evolution of
train_loss, test_loss, std_x, std_y during one epoch.
:param save_file: A string containing the path to the file to be pickled.
:param kwargs: keywords and their values will
also be store after the training. Use lists to ensure they are updated.
"""
std_x_collection = []
std_y_collection = []
train_loss_collection = []
test_loss_collection = []
std_x_collection = []
std_x_collection = []
# Saving
for epoch in range(number_of_epochs):
train_loss, test_loss, std_x, std_y = epoch_map(net, epoch)
train_loss_collection.append(train_loss)
test_loss_collection.append(test_loss)
std_x_collection.append(std_x)
std_y_collection.append(std_y)
try:
epoch_map.post_train_update(net, number_of_epochs-1)
print('Training done. Performed post training updating.')
except AttributeError:
print('Training done. No post train updating performed.')
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Saving
state_dict = net.state_dict()
to_save = {
'train_loss' : train_loss_collection,
'test_loss' : test_loss_collection,
'std_x' : std_x_collection,
'std_y' : std_y_collection,
'state_dict' : state_dict
}
for key, value in kwargs.items():
to_save[key] = value
with open(save_file, 'wb') as file_handle:
torch.save(to_save, file_handle)
def open_stored_training(saved_file, net=None, join_epochs = True,
extra_keys = None, device=torch.device('cpu')):
"""
Counterpart to `train_and_store`, opens `saved_file` and returns
:param saved_file: A pickle file generated with train_and_store
:param net: The corresponding torch.nn.Module
:param join_epochs: If True (default), all epochs will be concatenated.
:param extra_keys: None (default) or a list of strings. If the latter, is
used to extract extra entries of the stored dictionary to be returned
as a list.
:param device: The device to be used for loading the state_dict.
"""
with open(saved_file, 'rb') as file_handle:
loaded_dict = torch.load(file_handle, map_location=device)
train_loss = loaded_dict['train_loss']
test_loss = loaded_dict['test_loss']
std_x = loaded_dict['std_x']
std_y = loaded_dict['std_y']
state_dict = loaded_dict['state_dict']
if net is not None:
net.load_state_dict(state_dict)
if join_epochs:
train_loss = torch.cat(train_loss, dim=0)
test_loss = torch.cat(test_loss, dim=0)
std_x = torch.cat(std_x, dim=0)
std_y = torch.cat(std_y, dim=0)
if extra_keys is None:
return train_loss, test_loss, std_x, std_y, state_dict
else:
extra_list = [loaded_dict[key] for key in extra_keys]
return train_loss, test_loss, std_x, std_y, state_dict, extra_list