Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Kerstin Kaspar
Unet
Commits
3b1e170c
Commit
3b1e170c
authored
Jun 15, 2022
by
Kerstin Kaspar
Browse files
documentation
parent
c85e5bd3
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
.idea/workspace.xml
View file @
3b1e170c
...
...
@@ -3,10 +3,22 @@
<component
name=
"ChangeListManager"
>
<list
default=
"true"
id=
"714ca32e-a67a-44cf-90a6-e7ab36748bed"
name=
"Changes"
comment=
""
>
<change
beforePath=
"$PROJECT_DIR$/.idea/workspace.xml"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.idea/workspace.xml"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/README.md"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/README.md"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/data.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/data.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/data_inspect.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/data_inspect.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/eval.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/eval.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/losses.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/losses.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/temp.png"
beforeDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/temp2.png"
beforeDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/model.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/model.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/orig_model.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/orig_model.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/reco.py"
beforeDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train.ipynb"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train.ipynb"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train_locally.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train_locally.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/train_test.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/train_test.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/trainer.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/trainer.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/utils.py"
beforeDir=
"false"
afterPath=
"$PROJECT_DIR$/utils.py"
afterDir=
"false"
/>
<change
beforePath=
"$PROJECT_DIR$/validate.py"
beforeDir=
"false"
/>
</list>
<option
name=
"SHOW_DIALOG"
value=
"false"
/>
<option
name=
"HIGHLIGHT_CONFLICTS"
value=
"true"
/>
...
...
@@ -26,7 +38,6 @@
<component
name=
"HighlightingSettingsPerFile"
>
<setting
file=
"file://$PROJECT_DIR$/trainer.py"
root0=
"FORCE_HIGHLIGHTING"
/>
<setting
file=
"file://$PROJECT_DIR$/summary.txt"
root0=
"FORCE_HIGHLIGHTING"
/>
<setting
file=
"file://$PROJECT_DIR$/reco.py"
root0=
"FORCE_HIGHLIGHTING"
/>
<setting
file=
"file://$USER_HOME$/.conda/envs/nn/Lib/site-packages/torch/nn/modules/conv.py"
root0=
"SKIP_INSPECTION"
/>
<setting
file=
"file://$PROJECT_DIR$/model.py"
root0=
"FORCE_HIGHLIGHTING"
/>
<setting
file=
"file://$PROJECT_DIR$/utils.py"
root0=
"FORCE_HIGHLIGHTING"
/>
...
...
@@ -173,15 +184,4 @@
</map>
</option>
</component>
<component
name=
"XDebuggerManager"
>
<breakpoint-manager>
<breakpoints>
<line-breakpoint
enabled=
"true"
suspend=
"THREAD"
type=
"python-line"
>
<url>
file://$PROJECT_DIR$/utils.py
</url>
<line>
19
</line>
<option
name=
"timeStamp"
value=
"2"
/>
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
</project>
\ No newline at end of file
.ipynb_checkpoints/train-checkpoint.ipynb
View file @
3b1e170c
This diff is collapsed.
Click to expand it.
This diff is collapsed.
Click to expand it.
README.md
View file @
3b1e170c
My implementation of a U-Net similar to the FBPConvNet after:
### UNet for Loow Dose CT
This repository contains my implementation of a U-Net similar to the FBPConvNet after:
K. H. Jin, M. T. McCann, E. Froustey and M. Unser, "Deep Convolutional Neural Network for Inverse Problems in Imaging,"
in IEEE Transactions on Image Processing, vol. 26, no. 9, pp. 4509-4522, Sept. 2017,
https://doi.org/10.1109/TIP.2017.2713099
https://github.com/panakino/FBPConvNet
### Project structure
```
text
Unet
├───mlruns
│ '''trained models and losses'''
├───models
│ '''trained models'''
│ └───losses
│ '''loss files of the trained models'''
├───msub
│ '''scripts for high performance cluster job management'''
│
│ conda_env_nn.yml
│ '''conda environment file (prerequisites)'''
│ data.py
│ '''functions for data loading and preparation'''
│ data_inspect.py
│ '''script to inspect single data samples (FBP and ground truth)'''
│ eval.py
│ '''script to evaluate data with a trained model'''
│ losses.py
│ '''script to inspect and plot losses of a trained model'''
│ model.py
│ '''UNet class and functions'''
│ orig_model.py
│ '''FBPConvNet class and functions'''
│ README.md
│ '''information on the repository'''
│ train.ipynb
│ '''jupyter notebook to train a model on the high performance cluster'''
│ train.py
│ '''script to train a model on the high performance cluster'''
│ trainer.py
│ '''Trainer class and functions'''
│ train_locally.py
│ '''script to train a model on my local machine'''
│ train_test.py
│ '''script to train a model with small datasets'''
│ utils.py
│ '''additional functions'''
```
\ No newline at end of file
data.py
View file @
3b1e170c
"""
functions for data loading and preparation
"""
import
math
from
pathlib
import
Path
import
numpy
as
np
import
h5py
import
torch
from
torch.utils.data
import
Dataset
from
typing
import
Union
,
List
,
Callable
,
Tuple
from
typing
import
Union
,
List
,
Callable
,
Tuple
,
Dict
from
tqdm
import
tqdm
import
torch.nn.functional
as
F
def
get_im
(
input
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]):
def
get_im
(
input
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
)
->
np
.
ndarray
:
"""
returns the 2D image of a sample
...
...
@@ -20,12 +25,15 @@ def get_im(input: Union[torch.Tensor, np.ndarray]):
def
get_lodopab_ground_truth
(
data_loc
:
Union
[
Path
,
str
],
indicator
:
Union
[
str
,
List
[
str
]]
=
[
'train'
,
'validation'
,
'test'
]):
indicator
:
Union
[
str
,
List
[
str
]]
=
[
'train'
,
'validation'
,
'test'
]
)
->
Union
[
Dict
[
str
,
np
.
ndarray
],
np
.
ndarray
]:
"""
Function to load the Ground Truth Lodopab data similarly to the cached data
:param data_loc: Path to directory with all necessary data files in h5py
:param indicator: 'train', 'test', or 'validation', or a list of several indicators to define which data is loaded
:return data: np. ndarray of data or if more than one indicator dictionary of the form {indicator: np.ndarray}
"""
data_loc
=
Path
(
data_loc
)
indicator
=
[
indicator
]
if
type
(
indicator
)
is
not
list
else
indicator
...
...
@@ -49,7 +57,8 @@ def get_lodopab_ground_truth(data_loc: Union[Path, str],
def
load_cache_files
(
cache_files
:
dict
,
indicator
:
str
):
indicator
:
str
)
->
np
.
ndarray
:
"""
:param cache_files: input location - dict of form {'train': <path to .npy file>,
'validation': <path to .npy file>,
...
...
@@ -64,7 +73,8 @@ def load_cache_files(cache_files: dict,
def
get_ground_truth
(
idx
:
Union
[
int
,
List
[
int
]],
indicator
:
str
,
path
:
Union
[
Path
,
str
])
->
np
.
ndarray
:
path
:
Union
[
Path
,
str
]
)
->
np
.
ndarray
:
if
type
(
idx
)
is
int
:
gt_idx
=
idx
%
128
gt_file_num
=
math
.
floor
(
idx
/
128
)
...
...
@@ -171,10 +181,10 @@ class LodopabDataset(Dataset):
self
.
inputs
=
F
.
pad
(
input
=
torch
.
from_numpy
(
self
.
inputs
),
pad
=
(
0
,
6
,
6
,
0
),
mode
=
'constant'
,
value
=
0
)
self
.
targets
=
F
.
pad
(
input
=
torch
.
from_numpy
(
self
.
targets
),
pad
=
(
0
,
6
,
6
,
0
),
mode
=
'constant'
,
value
=
0
)
def
__len__
(
self
):
def
__len__
(
self
)
->
tuple
:
return
self
.
inputs
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
)
->
[
torch
.
Tensor
,
torch
.
Tensor
]
:
input
=
self
.
transform
(
self
.
inputs
[
idx
])
if
self
.
transform
else
self
.
inputs
[
idx
]
target
=
self
.
target_transform
(
self
.
targets
[
idx
])
if
self
.
target_transform
else
self
.
targets
[
idx
]
return
input
,
target
data_inspect.py
View file @
3b1e170c
"""
script to inspect single data samples (FBP and ground truth)
"""
import
torch
from
data
import
get_samples_from_cache
,
get_im
from
torchvision
import
transforms
...
...
eval.py
View file @
3b1e170c
"""
script to evaluate data with a trained model
"""
from
pathlib
import
Path
import
torch
import
numpy
as
np
...
...
losses.py
View file @
3b1e170c
"""
script to inspect and plot losses of a trained model
"""
import
numpy
as
np
import
pickle
import
matplotlib.pyplot
as
plt
...
...
model.py
View file @
3b1e170c
"""
UNet class and functions
"""
from
turtle
import
forward
import
torch
from
torch
import
nn
...
...
orig_model.py
View file @
3b1e170c
"""
FBPConvNet class and functions
"""
import
torch
import
torch.nn
as
nn
...
...
reco.py
deleted
100644 → 0
View file @
c85e5bd3
from
pathlib
import
Path
import
torch
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
model
import
UNet
from
torchinfo
import
summary
from
data
import
get_samples_from_cache
import
torch.nn.functional
as
fun
import
functools
# data locations
cache_files
=
{
'train'
:
'../data/lodopab-ct/fbp/cache_lodopab_train_fbp.npy'
,
'validation'
:
'../data/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy'
,
'test'
:
'../data/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'
}
model_path
=
Path
(
"models"
)
/
"2022-05-12_unet.pt"
sample
=
get_samples_from_cache
(
cache_files
=
cache_files
,
transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
)
input
=
torch
.
unsqueeze
(
sample
,
dim
=
0
)
input2
=
transform
(
sample
)
# idx = np.random.randint(cache_data.shape[0])
# sample = torch.from_numpy(cache_data[idx])[None, None, :, :]
# input = fun.pad(input=sample, pad=(0, 6, 6, 0), mode='constant', value=0)
model
=
UNet
()
model
.
load_state_dict
(
torch
.
load
(
model_path
))
# model.eval()
img
=
torch
.
rand
((
1
,
1
,
362
,
362
))
img
=
torch
.
rand
((
1
,
1
,
358
,
358
))
img
=
torch
.
rand
((
1
,
1
,
368
,
368
))
model
(
img
)
out
=
model
(
input
)
summary
(
model
,
input_data
=
input
)
plt
.
figure
()
plt
.
imshow
(
out
.
detach
().
numpy
()[
0
][
0
])
plt
.
savefig
(
'temp.png'
)
\ No newline at end of file
train.ipynb
View file @
3b1e170c
{
"cells": [
{
"cell_type": "markdown",
"id": "fe78fece",
"metadata": {},
"source": [
"Notebook to train a model on the high performance cluster"
]
},
{
"cell_type": "code",
"execution_count":
1
,
"id": "
84110e61
",
"execution_count":
null
,
"id": "
4b7a7e5e
",
"metadata": {},
"outputs": [],
"source": [
...
...
@@ -469,7 +477,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.
8.13
"
"version": "3.
9.7
"
}
},
"nbformat": 4,
...
...
%% Cell type:code id:84110e61 tags:
%% Cell type:markdown id:fe78fece tags:
Notebook to train a model on the high performance cluster
%% Cell type:code id:4b7a7e5e tags:
```
python
import
torch
from
data
import
LodopabDataset
from
trainer
import
Trainer
from
model
import
UNet
from
torchinfo
import
summary
import
functools
from
datetime
import
date
from
pathlib
import
Path
```
%% Cell type:code id:15aaa025 tags:
```
python
# data locations
cache_files
=
{
'train'
:
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy'
,
'validation'
:
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy'
,
'test'
:
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'
}
ground_truth_data_loc
=
'/gpfs001/scratch/kaspar01/lodopab-ct/data'
```
%% Cell type:code id:13b8a348 tags:
```
python
# small set data locations
cache_files
=
{
'train'
:
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy'
,
'validation'
:
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy'
,
'test'
:
'/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'
}
ground_truth_data_loc
=
'/oneFS/daten841/kaspar01/lodopab-ct/data/'
```
%% Cell type:code id:f02e654d tags:
```
python
# output
model_output_path
=
Path
(
'./models'
)
/
f
'
{
date
.
today
().
strftime
(
"%Y-%m-%d"
)
}
_unet.pt'
epochs
=
2
save_after_epochs
=
1
```
%% Cell type:code id:b12a5345 tags:
```
python
# create Dataloaders
training_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
ground_truth_data_loc
=
ground_truth_data_loc
,
split
=
'train'
,
transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
),
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
training_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
training_dataset
,
batch_size
=
16
,
shuffle
=
True
)
validation_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
ground_truth_data_loc
=
ground_truth_data_loc
,
split
=
'validation'
,
transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
),
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
validation_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
validation_dataset
,
batch_size
=
16
,
shuffle
=
True
)
```
%%%% Output: stream
... Loading cached files for dataset: "train" ...
Done loading cached files for dataset: "train"
%%%% Output: stream
Loading Ground Truth for dataset: "train": 100%|██████████████████████████████████████████| 3/3 [00:00<00:00, 9.89it/s]
%%%% Output: stream
... Loading cached files for dataset: "validation" ...
Done loading cached files for dataset: "validation"
%%%% Output: stream
Loading Ground Truth for dataset: "validation": 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 10.15it/s]
%% Cell type:code id:c31c73c8 tags:
```
python
# auto define the correct device
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
print
(
f
"Using
{
device
}
device."
)
```
%%%% Output: stream
Using cuda device.
%% Cell type:code id:ada1e3df tags:
```
python
# model defition
model
=
UNet
()
print
(
'model initialized:'
)
summary
(
model
)
model
.
to
(
device
)
# training parameters
criterion
=
torch
.
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
weight_decay
=
1e-8
)
```
%%%% Output: stream
model initialized:
%% Cell type:code id:cadd871b tags:
```
python
# initiate Trainer
trainer
=
Trainer
(
model
=
model
,
device
=
torch
.
device
(
device
),
criterion
=
criterion
,
optimizer
=
optimizer
,
training_dataloader
=
training_dataloader
,
validation_dataloader
=
validation_dataloader
,
lr_scheduler
=
None
,
epochs
=
10
,
epoch
=
0
,
notebook
=
True
,
model_output_path
=
model_output_path
,
save_after_epochs
=
save_after_epochs
)
```
%% Cell type:code id:4c1bd4e1 tags:
```
python
# start training
training_losses
,
validation_losses
,
lr_rates
=
trainer
.
run_trainer
()
```
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%%%% Output: display_data
%% Cell type:code id:bf6c39f5 tags:
```
python
torch
.
save
(
model
.
state_dict
(),
Path
.
cwd
()
/
model_path
)
```
%% Cell type:code id:5b69efe6 tags:
```
python
```
...
...
train.py
View file @
3b1e170c
"""
script to train a model on the high performance cluster
"""
import
torch
from
data
import
LodopabDataset
from
trainer
import
Trainer
...
...
@@ -13,7 +16,7 @@ from utils import check_filepath
start
=
datetime
.
now
()
print
(
f
'Starting training at
{
start
.
strftime
(
"%Y-%m-%d %T"
)
}
'
)
# data locations
#
input
data locations
cache_files
=
{
'train'
:
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy'
,
...
...
@@ -24,12 +27,19 @@ cache_files = {
ground_truth_data_loc
=
'/gpfs001/scratch/kaspar01/lodopab-ct/data'
# output paths
model_name
=
f
'
{
date
.
today
().
strftime
(
"%Y-%m-%d"
)
}
_unet'
model_dir
=
Path
(
f
'./models/
{
model_name
}
/'
)
model_output_path
=
model_dir
/
f
'
{
model_name
}
.pt'
loss_output_path
=
model_dir
/
f
'
{
model_name
}
_losses.p'
# adjustable parameters
epochs
=
500
save_after_epochs
=
25
learning_rate
=
0.1
weight_decay
=
1e-8
batch_size
=
16
# create Dataloaders
training_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -39,7 +49,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
training_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
training_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
validation_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -49,7 +59,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
validation_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
validation_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
# auto define the correct device
...
...
@@ -86,7 +96,7 @@ trainer = Trainer(model=model,
# start training
training_losses
,
validation_losses
,
lr_rates
=
trainer
.
run_trainer
()
# save
# save
final model and losses
torch
.
save
(
model
.
state_dict
(),
check_filepath
(
filepath
=
model_output_path
,
expand
=
'num'
))
losses
=
{
'training_losses'
:
training_losses
,
'validation_losses'
:
validation_losses
,
...
...
train_locally.py
View file @
3b1e170c
...
...
@@ -4,10 +4,16 @@ from trainer import Trainer
from
model
import
UNet
import
functools
from
torchinfo
import
summary
from
datetime
import
date
from
datetime
import
date
,
datetime
from
pathlib
import
Path
import
pickle
# timer
start
=
datetime
.
now
()
print
(
f
'Starting training at
{
start
.
strftime
(
"%Y-%m-%d %T"
)
}
'
)
# input paths
cache_files
=
{
'train'
:
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy'
,
...
...
@@ -18,10 +24,16 @@ cache_files = {
ground_truth_data_loc
=
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
# output paths
model_output_path
=
Path
(
'./models'
)
/
f
'
{
date
.
today
().
strftime
(
"%Y-%m-%d"
)
}
_unet.pt'
loss_output_path
=
Path
(
'./models/losses'
)
/
f
'
{
date
.
today
().
strftime
(
"%Y-%m-%d"
)
}
_unet.p'
# adjustable parameters
epochs
=
2
save_after_epochs
=
1
learning_rate
=
0.1
weight_decay
=
1e-8
batch_size
=
16
# create Dataloaders
training_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -31,7 +43,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
training_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
training_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
validation_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -41,7 +53,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
validation_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
validation_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
# auto define the correct device
...
...
@@ -50,15 +62,15 @@ print(f"Using {device} device.")
# model defition
model
=
UNet
()
print
(
'model initialized:'
)
summary
(
model
)
model
.
to
(
device
)
# print('model initialized:')
# summary(model)
# training parameters
criterion
=
torch
.
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
weight_decay
=
1e-8
)
lr
=
learning_rate
,
weight_decay
=
weight_decay
)
# initiate Trainer
trainer
=
Trainer
(
model
=
model
,
...
...
@@ -77,10 +89,16 @@ trainer = Trainer(model=model,
# start training
training_losses
,
validation_losses
,
lr_rates
=
trainer
.
run_trainer
()
# save
# save
final model and losses
torch
.
save
(
model
.
state_dict
(),
Path
.
cwd
()
/
model_output_path
)
losses
=
{
'training_losses'
:
training_losses
,
'validation_losses'
:
validation_losses
,
'lr_rates'
:
lr_rates
}
with
open
(
Path
.
cwd
()
/
loss_output_path
,
'w'
)
as
file
:
pickle
.
dump
(
training_losses
,
file
)
# timer
end
=
datetime
.
now
()
duration
=
end
-
start
print
(
f
'Ending training at
{
end
.
strftime
(
"%Y-%m-%d %T"
)
}
'
)
print
(
f
'Duration of training:
{
str
(
duration
)
}
'
)
train_test.py
View file @
3b1e170c
"""
script to train a model with small datasets
"""
import
torch
from
data
import
LodopabDataset
from
trainer
import
Trainer
...
...
@@ -13,7 +17,7 @@ from utils import check_filepath
start
=
datetime
.
now
()
print
(
f
'Starting training at
{
start
.
strftime
(
"%Y-%m-%d %T"
)
}
'
)
# data locations
#
# cluster input
data locations
# cache_files = {
# 'train':
# '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
...
...
@@ -24,7 +28,7 @@ print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
#
# ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
# local data locations
# local
input
data locations
cache_files
=
{
'train'
:
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy'
,
...
...
@@ -35,12 +39,18 @@ cache_files = {
ground_truth_data_loc
=
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/data/'
# output paths
model_name
=
f
'
{
date
.
today
().
strftime
(
"%Y-%m-%d"
)
}
_unet_test'
model_dir
=
Path
(
f
'./models/
{
model_name
}
/'
)
model_output_path
=
model_dir
/
f
'
{
model_name
}
.pt'
losses_output_path
=
model_dir
/
f
'
{
model_name
}
_losses.p'
# adjustable parameters
epochs
=
2
save_after_epochs
=
1
learning_rate
=
0.1
weight_decay
=
1e-8
batch_size
=
16
# create Dataloaders
training_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -50,7 +60,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
training_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
training_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
validation_dataset
=
LodopabDataset
(
cache_files
=
cache_files
,
...
...
@@ -60,7 +70,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
target_transform
=
functools
.
partial
(
torch
.
unsqueeze
,
dim
=
0
))
validation_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
validation_dataset
,
batch_size
=
16
,
batch_size
=
batch_size
,
shuffle
=
True
)
# auto define the correct device
...
...
@@ -76,8 +86,8 @@ model.to(device)
# training parameters
criterion
=
torch
.
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
weight_decay
=
1e-8
)
lr
=
learning_rate
,
weight_decay
=
weight_decay
)
# initiate Trainer
trainer
=
Trainer
(
model
=
model
,
...
...
@@ -96,7 +106,7 @@ trainer = Trainer(model=model,
# start training
training_losses
,
validation_losses
,
lr_rates
=
trainer
.
run_trainer
()
# save
# save
final model and losses
torch
.
save
(
model
.
state_dict
(),
check_filepath
(
filepath
=
model_output_path
,
expand
=
'num'
))
losses
=
{
'training_losses'
:
training_losses
,
'validation_losses'
:
validation_losses
,
...
...
trainer.py
View file @
3b1e170c
"""
Trainer class and functions
"""
import
pickle
import
numpy
as
np
import
torch
from
typing
import
Union
...
...
@@ -9,6 +11,9 @@ from utils import check_filepath
class
Trainer
:
"""
for training and validation
"""
def
__init__
(
self
,
model
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
...
...
@@ -53,7 +58,7 @@ class Trainer:
from
tqdm
import
trange
epoch_progress
=
trange
(
self
.
epochs
,
desc
=
'Epoch'
)
for
epoch
in
epoch_progress
:
for
_
in
epoch_progress
:
self
.
epoch
+=
1
self
.
_train
()
...
...
@@ -76,9 +81,7 @@ class Trainer:
losses
=
{
'training_losses'
:
self
.
training_loss
,
'validation_losses'
:
self
.
validation_loss
,
'lr_rates'
:
self
.
learning_rate
}
save_loss_as
=
check_filepath
(
filepath
=
self
.
loss_output_path
.
with_name
(
self
.
loss_output_path
.
name
.
replace
(
'.p'
,
f
'_e
{
self
.
epoch
:
03
}
.pt'
)),
expand
=
'num'
)
save_loss_as
=
check_filepath
(
filepath
=
self
.
loss_output_path
,
replace
=
True
)
with
open
(
save_loss_as
,
'wb'
)
as
file
:
pickle
.
dump
(
losses
,
file
)
...
...
@@ -117,7 +120,7 @@ class Trainer:
self
.
model
.
eval
()
validation_losses
=
[]
batch_progress
=
tqdm
(
enumerate
(
self
.
validation_dataloader
),
'Validation'
,
total
=
len
(
self
.
training
_dataloader
),
batch_progress
=
tqdm
(
enumerate
(
self
.
validation_dataloader
),
'Validation'
,
total
=
len
(
self
.
validation
_dataloader
),
leave
=
False
)
for
i
,
(
x
,
y
)
in
batch_progress
:
...
...
utils.py
View file @
3b1e170c
"""
additional functions
"""
import
os
from
pathlib
import
Path
from
datetime
import
datetime
...
...
@@ -5,14 +8,14 @@ from datetime import datetime