Commit 3b1e170c authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

documentation

parent c85e5bd3
......@@ -3,10 +3,22 @@
<component name="ChangeListManager">
<list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" afterDir="false" />
<change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
<change beforePath="$PROJECT_DIR$/data.py" beforeDir="false" afterPath="$PROJECT_DIR$/data.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/data_inspect.py" beforeDir="false" afterPath="$PROJECT_DIR$/data_inspect.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/losses.py" beforeDir="false" afterPath="$PROJECT_DIR$/losses.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/temp.png" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/temp2.png" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/model.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/orig_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/orig_model.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/reco.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/train.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/train.ipynb" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train_locally.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_locally.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/train_test.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/trainer.py" beforeDir="false" afterPath="$PROJECT_DIR$/trainer.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/validate.py" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -26,7 +38,6 @@
<component name="HighlightingSettingsPerFile">
<setting file="file://$PROJECT_DIR$/trainer.py" root0="FORCE_HIGHLIGHTING" />
<setting file="file://$PROJECT_DIR$/summary.txt" root0="FORCE_HIGHLIGHTING" />
<setting file="file://$PROJECT_DIR$/reco.py" root0="FORCE_HIGHLIGHTING" />
<setting file="file://$USER_HOME$/.conda/envs/nn/Lib/site-packages/torch/nn/modules/conv.py" root0="SKIP_INSPECTION" />
<setting file="file://$PROJECT_DIR$/model.py" root0="FORCE_HIGHLIGHTING" />
<setting file="file://$PROJECT_DIR$/utils.py" root0="FORCE_HIGHLIGHTING" />
......@@ -173,15 +184,4 @@
</map>
</option>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/utils.py</url>
<line>19</line>
<option name="timeStamp" value="2" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
</project>
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
My implementation of a U-Net similar to the FBPConvNet after:
### UNet for Loow Dose CT
This repository contains my implementation of a U-Net similar to the FBPConvNet after:
K. H. Jin, M. T. McCann, E. Froustey and M. Unser, "Deep Convolutional Neural Network for Inverse Problems in Imaging,"
in IEEE Transactions on Image Processing, vol. 26, no. 9, pp. 4509-4522, Sept. 2017,
https://doi.org/10.1109/TIP.2017.2713099
https://github.com/panakino/FBPConvNet
### Project structure
```text
Unet
├───mlruns
│ '''trained models and losses'''
├───models
│ '''trained models'''
│ └───losses
│ '''loss files of the trained models'''
├───msub
│ '''scripts for high performance cluster job management'''
│ conda_env_nn.yml
│ '''conda environment file (prerequisites)'''
│ data.py
│ '''functions for data loading and preparation'''
│ data_inspect.py
│ '''script to inspect single data samples (FBP and ground truth)'''
│ eval.py
│ '''script to evaluate data with a trained model'''
│ losses.py
│ '''script to inspect and plot losses of a trained model'''
│ model.py
│ '''UNet class and functions'''
│ orig_model.py
│ '''FBPConvNet class and functions'''
│ README.md
│ '''information on the repository'''
│ train.ipynb
│ '''jupyter notebook to train a model on the high performance cluster'''
│ train.py
│ '''script to train a model on the high performance cluster'''
│ trainer.py
│ '''Trainer class and functions'''
│ train_locally.py
│ '''script to train a model on my local machine'''
│ train_test.py
│ '''script to train a model with small datasets'''
│ utils.py
│ '''additional functions'''
```
\ No newline at end of file
"""
functions for data loading and preparation
"""
import math
from pathlib import Path
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset
from typing import Union, List, Callable, Tuple
from typing import Union, List, Callable, Tuple, Dict
from tqdm import tqdm
import torch.nn.functional as F
def get_im(input: Union[torch.Tensor, np.ndarray]):
def get_im(input: Union[torch.Tensor, np.ndarray]
) -> np.ndarray:
"""
returns the 2D image of a sample
......@@ -20,12 +25,15 @@ def get_im(input: Union[torch.Tensor, np.ndarray]):
def get_lodopab_ground_truth(data_loc: Union[Path, str],
indicator: Union[str, List[str]] = ['train', 'validation', 'test']):
indicator: Union[str, List[str]] = ['train', 'validation', 'test']
) -> Union[Dict[str, np.ndarray], np.ndarray]:
"""
Function to load the Ground Truth Lodopab data similarly to the cached data
:param data_loc: Path to directory with all necessary data files in h5py
:param indicator: 'train', 'test', or 'validation', or a list of several indicators to define which data is loaded
:return data: np. ndarray of data or if more than one indicator dictionary of the form {indicator: np.ndarray}
"""
data_loc = Path(data_loc)
indicator = [indicator] if type(indicator) is not list else indicator
......@@ -49,7 +57,8 @@ def get_lodopab_ground_truth(data_loc: Union[Path, str],
def load_cache_files(cache_files: dict,
indicator: str):
indicator: str
) -> np.ndarray:
"""
:param cache_files: input location - dict of form {'train': <path to .npy file>,
'validation': <path to .npy file>,
......@@ -64,7 +73,8 @@ def load_cache_files(cache_files: dict,
def get_ground_truth(idx: Union[int, List[int]],
indicator: str,
path: Union[Path, str]) -> np.ndarray:
path: Union[Path, str]
) -> np.ndarray:
if type(idx) is int:
gt_idx = idx % 128
gt_file_num = math.floor(idx / 128)
......@@ -171,10 +181,10 @@ class LodopabDataset(Dataset):
self.inputs = F.pad(input=torch.from_numpy(self.inputs), pad=(0, 6, 6, 0), mode='constant', value=0)
self.targets = F.pad(input=torch.from_numpy(self.targets), pad=(0, 6, 6, 0), mode='constant', value=0)
def __len__(self):
def __len__(self) -> tuple:
return self.inputs.shape[0]
def __getitem__(self, idx):
def __getitem__(self, idx) -> [torch.Tensor, torch.Tensor]:
input = self.transform(self.inputs[idx]) if self.transform else self.inputs[idx]
target = self.target_transform(self.targets[idx]) if self.target_transform else self.targets[idx]
return input, target
"""
script to inspect single data samples (FBP and ground truth)
"""
import torch
from data import get_samples_from_cache, get_im
from torchvision import transforms
......
"""
script to evaluate data with a trained model
"""
from pathlib import Path
import torch
import numpy as np
......
"""
script to inspect and plot losses of a trained model
"""
import numpy as np
import pickle
import matplotlib.pyplot as plt
......
"""
UNet class and functions
"""
from turtle import forward
import torch
from torch import nn
......
"""
FBPConvNet class and functions
"""
import torch
import torch.nn as nn
......
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from model import UNet
from torchinfo import summary
from data import get_samples_from_cache
import torch.nn.functional as fun
import functools
# data locations
cache_files = {
'train':
'../data/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation':
'../data/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test':
'../data/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
model_path = Path("models") / "2022-05-12_unet.pt"
sample = get_samples_from_cache(cache_files=cache_files, transform=functools.partial(torch.unsqueeze, dim=0))
transform = functools.partial(torch.unsqueeze, dim=0)
input = torch.unsqueeze(sample, dim=0)
input2 = transform(sample)
# idx = np.random.randint(cache_data.shape[0])
# sample = torch.from_numpy(cache_data[idx])[None, None, :, :]
# input = fun.pad(input=sample, pad=(0, 6, 6, 0), mode='constant', value=0)
model = UNet()
model.load_state_dict(torch.load(model_path))
# model.eval()
img = torch.rand((1, 1, 362, 362))
img = torch.rand((1, 1, 358, 358))
img = torch.rand((1, 1, 368, 368))
model(img)
out = model(input)
summary(model, input_data=input)
plt.figure()
plt.imshow(out.detach().numpy()[0][0])
plt.savefig('temp.png')
\ No newline at end of file
{
"cells": [
{
"cell_type": "markdown",
"id": "fe78fece",
"metadata": {},
"source": [
"Notebook to train a model on the high performance cluster"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "84110e61",
"execution_count": null,
"id": "4b7a7e5e",
"metadata": {},
"outputs": [],
"source": [
......@@ -469,7 +477,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.7"
}
},
"nbformat": 4,
......
%% Cell type:code id:84110e61 tags:
%% Cell type:markdown id:fe78fece tags:
Notebook to train a model on the high performance cluster
%% Cell type:code id:4b7a7e5e tags:
``` python
import torch
from data import LodopabDataset
from trainer import Trainer
from model import UNet
from torchinfo import summary
import functools
from datetime import date
from pathlib import Path
```
%% Cell type:code id:15aaa025 tags:
``` python
# data locations
cache_files = {
'train':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
```
%% Cell type:code id:13b8a348 tags:
``` python
# small set data locations
cache_files = {
'train':
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',
'validation':
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy',
'test':
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '/oneFS/daten841/kaspar01/lodopab-ct/data/'
```
%% Cell type:code id:f02e654d tags:
``` python
# output
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
epochs = 2
save_after_epochs = 1
```
%% Cell type:code id:b12a5345 tags:
``` python
# create Dataloaders
training_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='train',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='validation',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
shuffle=True)
```
%%%% Output: stream
... Loading cached files for dataset: "train" ...
Done loading cached files for dataset: "train"
%%%% Output: stream
Loading Ground Truth for dataset: "train": 100%|██████████████████████████████████████████| 3/3 [00:00<00:00, 9.89it/s]
%%%% Output: stream
... Loading cached files for dataset: "validation" ...
Done loading cached files for dataset: "validation"
%%%% Output: stream
Loading Ground Truth for dataset: "validation": 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 10.15it/s]
%% Cell type:code id:c31c73c8 tags:
``` python
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
```
%%%% Output: stream
Using cuda device.
%% Cell type:code id:ada1e3df tags:
``` python
# model defition
model = UNet()
print('model initialized:')
summary(model)
model.to(device)
# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
```
%%%% Output: stream
model initialized:
%% Cell type:code id:cadd871b tags:
``` python
# initiate Trainer
trainer = Trainer(model=model,
device=torch.device(device),
criterion=criterion,
optimizer=optimizer,
training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader,
lr_scheduler=None,
epochs=10,
epoch=0,
notebook=True,
model_output_path=model_output_path,
save_after_epochs=save_after_epochs)
```
%% Cell type:code id:4c1bd4e1 tags:
``` python
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
```
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%% Cell type:code id:bf6c39f5 tags:
``` python
torch.save(model.state_dict(), Path.cwd() / model_path)
```
%% Cell type:code id:5b69efe6 tags:
``` python
```
......
"""
script to train a model on the high performance cluster
"""
import torch
from data import LodopabDataset
from trainer import Trainer
......@@ -13,7 +16,7 @@ from utils import check_filepath
start = datetime.now()
print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
# data locations
# input data locations
cache_files = {
'train':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
......@@ -24,12 +27,19 @@ cache_files = {
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
# output paths
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
model_dir = Path(f'./models/{model_name}/')
model_output_path = model_dir / f'{model_name}.pt'
loss_output_path = model_dir / f'{model_name}_losses.p'
# adjustable parameters
epochs = 500
save_after_epochs = 25
learning_rate = 0.1
weight_decay = 1e-8
batch_size = 16
# create Dataloaders
training_dataset = LodopabDataset(cache_files=cache_files,
......@@ -39,7 +49,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
......@@ -49,7 +59,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
# auto define the correct device
......@@ -86,7 +96,7 @@ trainer = Trainer(model=model,
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
# save
# save final model and losses
torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num'))
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
......
......@@ -4,10 +4,16 @@ from trainer import Trainer
from model import UNet
import functools
from torchinfo import summary
from datetime import date
from datetime import date, datetime
from pathlib import Path
import pickle
# timer
start = datetime.now()
print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
# input paths
cache_files = {
'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
......@@ -18,10 +24,16 @@ cache_files = {
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
# output paths
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
# adjustable parameters
epochs = 2
save_after_epochs = 1
learning_rate = 0.1
weight_decay = 1e-8
batch_size = 16
# create Dataloaders
training_dataset = LodopabDataset(cache_files=cache_files,
......@@ -31,7 +43,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
......@@ -41,7 +53,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
# auto define the correct device
......@@ -50,15 +62,15 @@ print(f"Using {device} device.")
# model defition
model = UNet()
print('model initialized:')
summary(model)
model.to(device)
# print('model initialized:')
# summary(model)
# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
lr=learning_rate,
weight_decay=weight_decay)
# initiate Trainer
trainer = Trainer(model=model,
......@@ -77,10 +89,16 @@ trainer = Trainer(model=model,
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
# save
# save final model and losses
torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
pickle.dump(training_losses, file)
# timer
end = datetime.now()
duration = end - start
print(f'Ending training at {end.strftime("%Y-%m-%d %T")}')
print(f'Duration of training: {str(duration)}')
"""
script to train a model with small datasets
"""
import torch
from data import LodopabDataset
from trainer import Trainer
......@@ -13,7 +17,7 @@ from utils import check_filepath
start = datetime.now()
print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
# data locations
# # cluster input data locations
# cache_files = {
# 'train':
# '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
......@@ -24,7 +28,7 @@ print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
#
# ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
# local data locations
# local input data locations
cache_files = {
'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',
......@@ -35,12 +39,18 @@ cache_files = {
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/data/'
# output paths
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet_test'
model_dir = Path(f'./models/{model_name}/')
model_output_path = model_dir / f'{model_name}.pt'
losses_output_path = model_dir / f'{model_name}_losses.p'
# adjustable parameters
epochs = 2
save_after_epochs = 1
learning_rate = 0.1
weight_decay = 1e-8
batch_size = 16
# create Dataloaders
training_dataset = LodopabDataset(cache_files=cache_files,
......@@ -50,7 +60,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
......@@ -60,7 +70,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
batch_size=batch_size,
shuffle=True)
# auto define the correct device
......@@ -76,8 +86,8 @@ model.to(device)
# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
lr=learning_rate,
weight_decay=weight_decay)
# initiate Trainer
trainer = Trainer(model=model,
......@@ -96,7 +106,7 @@ trainer = Trainer(model=model,
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
# save
# save final model and losses
torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num'))
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
......
"""
Trainer class and functions
"""
import pickle
import numpy as np
import torch
from typing import Union
......@@ -9,6 +11,9 @@ from utils import check_filepath
class Trainer:
"""
for training and validation
"""
def __init__(self,
model: torch.nn.Module,
device: torch.device,
......@@ -53,7 +58,7 @@ class Trainer:
from tqdm import trange
epoch_progress = trange(self.epochs, desc='Epoch')
for epoch in epoch_progress:
for _ in epoch_progress:
self.epoch += 1
self._train()
......@@ -76,9 +81,7 @@ class Trainer:
losses = {'training_losses': self.training_loss,
'validation_losses': self.validation_loss,
'lr_rates': self.learning_rate}
save_loss_as = check_filepath(filepath=self.loss_output_path.with_name(
self.loss_output_path.name.replace('.p', f'_e{self.epoch:03}.pt')),
expand='num')
save_loss_as = check_filepath(filepath=self.loss_output_path, replace=True)
with open(save_loss_as, 'wb') as file:
pickle.dump(losses, file)
......@@ -117,7 +120,7 @@ class Trainer:
self.model.eval()
validation_losses = []
batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.training_dataloader),
batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.validation_dataloader),
leave=False)
for i, (x, y) in batch_progress:
......
"""
additional functions
"""
import os
from pathlib import Path
from datetime import datetime
......@@ -5,14 +8,14 @@ from datetime import datetime