Newer
Older
import torch
from data import LodopabDataset
from trainer import Trainer
from model import UNet
import functools
from torchinfo import summary
from datetime import date
from pathlib import Path
import pickle
cache_files = {
'train':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'validation':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'test':
'//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
epochs = 2
save_after_epochs = 1
training_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='train',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=cache_files,
ground_truth_data_loc=ground_truth_data_loc,
split='validation',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
shuffle=True)
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
# model defition
print('model initialized:')
summary(model)
model.to(device)
# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
trainer = Trainer(model=model,
device=torch.device(device),
criterion=criterion,
optimizer=optimizer,
training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader,
lr_scheduler=None,
notebook=False,
model_output_path=model_output_path,
save_after_epochs=save_after_epochs)
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
# save
torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
pickle.dump(training_losses, file)