Skip to content
Snippets Groups Projects
train.py 2.82 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
import torch
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from data import LodopabDataset
from trainer import Trainer
from model import UNet
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from torchinfo import summary
import functools
from datetime import date
from pathlib import Path
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed

# data locations
Kerstin Kaspar's avatar
Kerstin Kaspar committed
cache_files = {
    'train':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    'validation':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    'test':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
Kerstin Kaspar's avatar
Kerstin Kaspar committed


Kerstin Kaspar's avatar
Kerstin Kaspar committed
model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# create Dataloaders
training_dataset = LodopabDataset(cache_files=small_cache_files,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                  ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                  split='train',
                                  transform=functools.partial(torch.unsqueeze, dim=0),
                                  target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
                                                  batch_size=16,
                                                  shuffle=True)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
validation_dataset = LodopabDataset(cache_files=small_cache_files,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                    ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                    split='validation',
                                    transform=functools.partial(torch.unsqueeze, dim=0),
                                    target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                    batch_size=16,
                                                    shuffle=True)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# model defition
Kerstin Kaspar's avatar
Kerstin Kaspar committed
model = UNet()
Kerstin Kaspar's avatar
Kerstin Kaspar committed
print('model initialized:')
summary(model)

# training parameters
Kerstin Kaspar's avatar
Kerstin Kaspar committed
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            weight_decay=1e-8)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")

# initiate Trainer
Kerstin Kaspar's avatar
Kerstin Kaspar committed
trainer = Trainer(model=model,
                  device=torch.device(device),
                  criterion=criterion,
                  optimizer=optimizer,
                  training_dataloader=training_dataloader,
                  validation_dataloader=validation_dataloader,
                  lr_scheduler=None,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epochs=10,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epoch=0,
                  notebook=False)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
model_name = f'models/{date.today().strftime("%Y-%m-%d")}_unet.pt'
torch.save(model.state_dict(), Path.cwd() / model_name)