Skip to content
Snippets Groups Projects
Commit 1ee21c3d authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

small changes

parent 5374110e
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,8 @@
<component name="ChangeListManager">
<list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/trainer.py" beforeDir="false" afterPath="$PROJECT_DIR$/trainer.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train_test.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......
......@@ -24,8 +24,9 @@ cache_files = {
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
model_output_path = Path(f'./models/{model_name}/') / f'{model_name}.pt'
loss_output_path = Path(f'./models/{model_name}/') / f'{model_name}_losses.p'
epochs = 100
save_after_epochs = 10
......
......@@ -7,6 +7,7 @@ import functools
from datetime import date, datetime
from pathlib import Path
import pickle
from utils import check_filepath
# timer
start = datetime.now()
......@@ -25,7 +26,7 @@ ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet_test'
model_output_path = Path(f'./models/{model_name}/') / f'{model_name}.pt'
loss_output_path = Path('./models/losses') / f'{model_name}.p'
loss_output_path = Path(f'./models/{model_name}/') / f'{model_name}_losses.p'
epochs = 2
save_after_epochs = 1
......@@ -88,7 +89,8 @@ torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'wb') as file:
with open(loss_output_path.absolute(), 'wb') as file:
file = check_filepath(filepath=file, expand='num')
pickle.dump(training_losses, file)
# timer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment