Skip to content
Snippets Groups Projects
Commit 9af89ca0 authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

loss files at save

parent 3cfe33e0
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment=""> <list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/2022-05-30_15-30_eval_1000.png" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/losses.py" beforeDir="false" afterPath="$PROJECT_DIR$/losses.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/losses.py" beforeDir="false" afterPath="$PROJECT_DIR$/losses.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
...@@ -80,6 +79,27 @@ ...@@ -80,6 +79,27 @@
<option name="INPUT_FILE" value="" /> <option name="INPUT_FILE" value="" />
<method v="2" /> <method v="2" />
</configuration> </configuration>
<configuration name="losses" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="unet" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/losses.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="train_locally" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> <configuration name="train_locally" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="unet" /> <module name="unet" />
<option name="INTERPRETER_OPTIONS" value="" /> <option name="INTERPRETER_OPTIONS" value="" />
...@@ -104,6 +124,7 @@ ...@@ -104,6 +124,7 @@
<recent_temporary> <recent_temporary>
<list> <list>
<item itemvalue="Python.eval" /> <item itemvalue="Python.eval" />
<item itemvalue="Python.losses" />
<item itemvalue="Python.train_locally" /> <item itemvalue="Python.train_locally" />
</list> </list>
</recent_temporary> </recent_temporary>
......
...@@ -20,8 +20,8 @@ cache_files = { ...@@ -20,8 +20,8 @@ cache_files = {
gt_path = 'C:/Users/heinec03/CodeProjects/data/lodopab-ct' gt_path = 'C:/Users/heinec03/CodeProjects/data/lodopab-ct'
model_name = '2022-06-01_unet' model_name = '2022-06-02_unet'
epoch = None epoch = 250
model_dir = Path(f'./models/{model_name}/') model_dir = Path(f'./models/{model_name}/')
model_name = f'{model_name}_e{epoch}' if epoch else model_name model_name = f'{model_name}_e{epoch}' if epoch else model_name
model_path = model_dir / f'{model_name}.pt' model_path = model_dir / f'{model_name}.pt'
......
...@@ -3,7 +3,7 @@ import pickle ...@@ -3,7 +3,7 @@ import pickle
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
model_name = '2022-06-01_unet' model_name = '2022-06-02_unet'
model_dir = Path(f'./models/{model_name}/') model_dir = Path(f'./models/{model_name}/')
losses_path = model_dir / f'{model_name}_losses.p' losses_path = model_dir / f'{model_name}_losses.p'
fig_path = model_dir / f'{model_name}_loss_fig.png' fig_path = model_dir / f'{model_name}_loss_fig.png'
......
...@@ -27,7 +27,7 @@ ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data' ...@@ -27,7 +27,7 @@ ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet' model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
model_dir = Path(f'./models/{model_name}/') model_dir = Path(f'./models/{model_name}/')
model_output_path = model_dir / f'{model_name}.pt' model_output_path = model_dir / f'{model_name}.pt'
losses_output_path = model_dir / f'{model_name}_losses.p' loss_output_path = model_dir / f'{model_name}_losses.p'
epochs = 500 epochs = 500
save_after_epochs = 25 save_after_epochs = 25
...@@ -80,6 +80,7 @@ trainer = Trainer(model=model, ...@@ -80,6 +80,7 @@ trainer = Trainer(model=model,
epoch=0, epoch=0,
notebook=False, notebook=False,
model_output_path=model_output_path, model_output_path=model_output_path,
loss_output_path=loss_output_path,
save_after_epochs=save_after_epochs) save_after_epochs=save_after_epochs)
# start training # start training
...@@ -90,7 +91,7 @@ torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand ...@@ -90,7 +91,7 @@ torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand
losses = {'training_losses': training_losses, losses = {'training_losses': training_losses,
'validation_losses': validation_losses, 'validation_losses': validation_losses,
'lr_rates': lr_rates} 'lr_rates': lr_rates}
with open(check_filepath(filepath=losses_output_path, expand='num'), 'wb') as file: with open(check_filepath(filepath=loss_output_path, expand='num'), 'wb') as file:
pickle.dump(losses, file) pickle.dump(losses, file)
# timer # timer
......
import pickle
import numpy as np import numpy as np
import torch import torch
from typing import Union from typing import Union
...@@ -20,6 +22,7 @@ class Trainer: ...@@ -20,6 +22,7 @@ class Trainer:
epoch: int = 0, epoch: int = 0,
notebook: bool = False, notebook: bool = False,
model_output_path: Union[str, Path] = None, model_output_path: Union[str, Path] = None,
loss_output_path: Union[str, Path] = None,
save_after_epochs: int = None save_after_epochs: int = None
): ):
self.model = model self.model = model
...@@ -35,6 +38,8 @@ class Trainer: ...@@ -35,6 +38,8 @@ class Trainer:
self.notebook = notebook self.notebook = notebook
self.model_output_path = model_output_path if model_output_path \ self.model_output_path = model_output_path if model_output_path \
else Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt' else Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
self.loss_output_path = loss_output_path if loss_output_path \
else Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet_loss.p'
self.save_after_epochs = save_after_epochs if save_after_epochs else epochs self.save_after_epochs = save_after_epochs if save_after_epochs else epochs
self.training_loss = [] self.training_loss = []
...@@ -64,10 +69,18 @@ class Trainer: ...@@ -64,10 +69,18 @@ class Trainer:
return self.training_loss, self.validation_loss, self.learning_rate return self.training_loss, self.validation_loss, self.learning_rate
def _save(self): def _save(self):
save_as = check_filepath(filepath=self.model_output_path.with_name( save_model_as = check_filepath(filepath=self.model_output_path.with_name(
self.model_output_path.name.replace('.pt', f'_e{self.epoch:03}.pt')), self.model_output_path.name.replace('.pt', f'_e{self.epoch:03}.pt')),
expand='num') expand='num')
torch.save(self.model.state_dict(), save_as) torch.save(self.model.state_dict(), save_model_as)
losses = {'training_losses': self.training_loss,
'validation_losses': self.validation_loss,
'lr_rates': self.learning_rate}
save_loss_as = check_filepath(filepath=self.loss_output_path.with_name(
self.loss_output_path.name.replace('.p', f'_e{self.epoch:03}.pt')),
expand='num')
with open(save_loss_as, 'wb') as file:
pickle.dump(losses, file)
def _train(self): def _train(self):
if self.notebook: if self.notebook:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment