Skip to content
Snippets Groups Projects
Commit 6749dc4b authored by Kerstin Kaspar's avatar Kerstin Kaspar
Browse files

cluster train py

parent 23b326f5
No related branches found
No related tags found
No related merge requests found
...@@ -2,55 +2,63 @@ import torch ...@@ -2,55 +2,63 @@ import torch
from data import LodopabDataset from data import LodopabDataset
from trainer import Trainer from trainer import Trainer
from model import UNet from model import UNet
from torchinfo import summary
import functools
from datetime import date
from pathlib import Path
# data locations
cache_files = { cache_files = {
'train': 'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation': 'validation':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy', '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test': 'test':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'} '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
# # cluster ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
# cache_files = {
# 'train':
# '/oneFS/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
# 'validation':
# '/oneFS/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
# 'test':
# '/oneFS/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
#
# ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data/'
# auto define the correct device model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
training_dataset = LodopabDataset(cache_files=cache_files, # create Dataloaders
training_dataset = LodopabDataset(cache_files=small_cache_files,
ground_truth_data_loc=ground_truth_data_loc, ground_truth_data_loc=ground_truth_data_loc,
split='train') split='train',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16, batch_size=16,
shuffle=True) shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files, validation_dataset = LodopabDataset(cache_files=small_cache_files,
ground_truth_data_loc=ground_truth_data_loc, ground_truth_data_loc=ground_truth_data_loc,
split='validation') split='validation',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16, batch_size=16,
shuffle=True) shuffle=True)
# model defition
model = UNet() model = UNet()
print('model initialized:')
summary(model)
# training parameters
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), optimizer = torch.optim.SGD(model.parameters(),
lr=0.01, lr=0.01,
weight_decay=1e-8) weight_decay=1e-8)
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
# initiate Trainer
trainer = Trainer(model=model, trainer = Trainer(model=model,
device=torch.device(device), device=torch.device(device),
criterion=criterion, criterion=criterion,
...@@ -58,6 +66,11 @@ trainer = Trainer(model=model, ...@@ -58,6 +66,11 @@ trainer = Trainer(model=model,
training_dataloader=training_dataloader, training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader, validation_dataloader=validation_dataloader,
lr_scheduler=None, lr_scheduler=None,
epochs=2, epochs=10,
epoch=0, epoch=0,
notebook=False) notebook=False)
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
model_name = f'models/{date.today().strftime("%Y-%m-%d")}_unet.pt'
torch.save(model.state_dict(), Path.cwd() / model_name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment