Skip to content
Snippets Groups Projects
Commit 558a7e41 authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

save files

parent 1715d91e
No related branches found
No related tags found
No related merge requests found
......@@ -2,13 +2,11 @@
<project version="4">
<component name="ChangeListManager">
<list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment="">
<change afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/data.py" beforeDir="false" afterPath="$PROJECT_DIR$/data.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/temp2.png" beforeDir="false" afterPath="$PROJECT_DIR$/temp2.png" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train_locally.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_locally.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train_test.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/trainer.py" beforeDir="false" afterPath="$PROJECT_DIR$/trainer.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......
tmp.p 0 → 100644
File added
......@@ -7,6 +7,7 @@ import functools
from datetime import date, datetime
from pathlib import Path
import pickle
from utils import check_filepath
# timer
start = datetime.now()
......@@ -87,7 +88,8 @@ torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
with open(loss_output_path.absolute(), 'wb') as file:
file = check_filepath(filepath=file, expand='num')
pickle.dump(training_losses, file)
# timer
......
......@@ -23,8 +23,9 @@ cache_files = {
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet_test.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet_test.p'
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet_test'
model_output_path = Path(f'./models/{model_name}/') / f'{model_name}.pt'
loss_output_path = Path('./models/losses') / f'{model_name}.p'
epochs = 2
save_after_epochs = 1
......@@ -87,7 +88,7 @@ torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
with open(Path.cwd() / loss_output_path, 'wb') as file:
pickle.dump(training_losses, file)
# timer
......
......@@ -3,6 +3,7 @@ import torch
from typing import Union
from pathlib import Path
from datetime import date
from utils import check_filepath
class Trainer:
......@@ -34,7 +35,7 @@ class Trainer:
self.notebook = notebook
self.model_output_path = model_output_path if model_output_path \
else Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
self.save_after_epochs = save_after_epochs if save_after_epochs else epochs-1
self.save_after_epochs = save_after_epochs if save_after_epochs else epochs - 1
self.training_loss = []
self.validation_loss = []
......@@ -58,8 +59,10 @@ class Trainer:
if self.lr_scheduler is not None:
self.lr_scheduler.batch()
if epoch % self.save_after_epochs:
torch.save(self.model.state_dict(), self.model_output_path.with_name(
self.model_output_path.name.replace('.pt', f'_e{epoch}.pt')))
save_as = check_filepath(filepath=self.model_output_path.with_name(
self.model_output_path.name.replace('.pt', f'_e{epoch}.pt')),
expand='num')
torch.save(self.model.state_dict(), save_as)
return self.training_loss, self.validation_loss, self.learning_rate
def _train(self):
......@@ -113,4 +116,4 @@ class Trainer:
self.validation_loss.append(np.mean(validation_losses))
batch_progress.close()
\ No newline at end of file
batch_progress.close()
import os
from pathlib import Path
from datetime import datetime
def check_filepath(filepath: (str, Path),
makedir: bool = True,
overwrite: bool = False,
expand: str = 'date',
**kwargs) -> Path:
"""
Checks a filepath existence and possibly creates directories
:param filepath: filepath to the target
:param makedir: create nonexistent directories in the filepath
:param overwrite: toggles allowance to overwrite existing data in the filepath
:param expand: 'date' (default): expands existing files with datetime (up to minute),
'num': expands with running digit
:return filepath: The correct path to the existing or created folder
"""
filepath = Path(filepath)
if not filepath.parent.exists():
if makedir:
os.makedirs(Path(filepath.parent))
else:
raise ValueError('Filepath does not exist. To create, rerun with makedir = True.')
if not overwrite:
if expand == 'date':
if filepath.exists():
filepath = Path(filepath.as_posix().replace(filepath.stem + filepath.suffix,
filepath.stem + '__{date}'.format(
date=datetime.now().strftime("%Y-%m-%d_%H-%M"))
+ filepath.suffix))
elif expand == 'num':
i = (len(filepath.suffix) + 1) * -1
while filepath.exists():
if filepath.name[i].isdigit():
if int(filepath.name[i]) + 1 < 10:
filepath = Path(filepath.as_posix().replace(filepath.name[i],
str(int(filepath.name[i]) + 1)))
else:
i -= 1
else:
filepath = Path(filepath.as_posix().replace(filepath.stem + filepath.suffix,
filepath.stem + '__1' + filepath.suffix))
else:
raise ValueError('Can not expand with ' + expand + ', only date and num are valid.')
return filepath
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment