Skip to content
Snippets Groups Projects
Commit be480346 authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

train

parent 6749dc4b
No related branches found
No related tags found
No related merge requests found
...@@ -8,15 +8,19 @@ from datetime import date ...@@ -8,15 +8,19 @@ from datetime import date
from pathlib import Path from pathlib import Path
small_cache_files = { # data locations
cache_files = {
'train': 'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy', '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation': 'validation':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy', '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test': 'test':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'} '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/data/' ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
# create Dataloaders # create Dataloaders
training_dataset = LodopabDataset(cache_files=small_cache_files, training_dataset = LodopabDataset(cache_files=small_cache_files,
...@@ -62,7 +66,7 @@ trainer = Trainer(model=model, ...@@ -62,7 +66,7 @@ trainer = Trainer(model=model,
training_dataloader=training_dataloader, training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader, validation_dataloader=validation_dataloader,
lr_scheduler=None, lr_scheduler=None,
epochs=2, epochs=10,
epoch=0, epoch=0,
notebook=False) notebook=False)
......
...@@ -9,21 +9,33 @@ from pathlib import Path ...@@ -9,21 +9,33 @@ from pathlib import Path
# data locations # data locations
# cache_files = {
# 'train':
# '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
# 'validation':
# '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
# 'test':
# '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
#
# ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
# test
# cluster
cache_files = { cache_files = {
'train': 'train':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',
'validation': 'validation':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy', '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy',
'test': 'test':
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'} '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data' ground_truth_data_loc = '/oneFS/daten841/kaspar01/lodopab-ct/data/'
model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt' model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
# create Dataloaders # create Dataloaders
training_dataset = LodopabDataset(cache_files=small_cache_files, training_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc, ground_truth_data_loc=ground_truth_data_loc,
split='train', split='train',
transform=functools.partial(torch.unsqueeze, dim=0), transform=functools.partial(torch.unsqueeze, dim=0),
...@@ -33,7 +45,7 @@ training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, ...@@ -33,7 +45,7 @@ training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16, batch_size=16,
shuffle=True) shuffle=True)
validation_dataset = LodopabDataset(cache_files=small_cache_files, validation_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc, ground_truth_data_loc=ground_truth_data_loc,
split='validation', split='validation',
transform=functools.partial(torch.unsqueeze, dim=0), transform=functools.partial(torch.unsqueeze, dim=0),
......
import torch
from data import LodopabDataset
from trainer import Trainer
from model import UNet
cache_files = {
'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
training_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='train')
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='validation')
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
shuffle=True)
model = UNet()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
trainer = Trainer(model=model,
device=torch.device(device),
criterion=criterion,
optimizer=optimizer,
training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader,
lr_scheduler=None,
epochs=2,
epoch=0,
notebook=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment