Skip to content
Snippets Groups Projects
train_locally.py 3.3 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
import torch
from data import LodopabDataset
from trainer import Trainer
from model import UNet
Kerstin Kaspar's avatar
Kerstin Kaspar committed
import functools
from torchinfo import summary
from datetime import date
from pathlib import Path
import pickle
Kerstin Kaspar's avatar
Kerstin Kaspar committed

cache_files = {
    'train':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
    'validation':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
    'test':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}

ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'

Kerstin Kaspar's avatar
Kerstin Kaspar committed
model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
epochs = 2
save_after_epochs = 1
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# create Dataloaders
Kerstin Kaspar's avatar
Kerstin Kaspar committed
training_dataset = LodopabDataset(cache_files=cache_files,
                                  ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                  split='train',
                                  transform=functools.partial(torch.unsqueeze, dim=0),
                                  target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
                                                  batch_size=16,
                                                  shuffle=True)

validation_dataset = LodopabDataset(cache_files=cache_files,
                                    ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                    split='validation',
                                    transform=functools.partial(torch.unsqueeze, dim=0),
                                    target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                    batch_size=16,
                                                    shuffle=True)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")

# model defition
Kerstin Kaspar's avatar
Kerstin Kaspar committed
model = UNet()
Kerstin Kaspar's avatar
Kerstin Kaspar committed
print('model initialized:')
summary(model)
model.to(device)

# training parameters
Kerstin Kaspar's avatar
Kerstin Kaspar committed
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            weight_decay=1e-8)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# initiate Trainer
Kerstin Kaspar's avatar
Kerstin Kaspar committed
trainer = Trainer(model=model,
                  device=torch.device(device),
                  criterion=criterion,
                  optimizer=optimizer,
                  training_dataloader=training_dataloader,
                  validation_dataloader=validation_dataloader,
                  lr_scheduler=None,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epochs=epochs,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epoch=0,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  notebook=False,
                  model_output_path=model_output_path,
                  save_after_epochs=save_after_epochs)

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()

# save
torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
          'validation_losses': validation_losses,
          'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
    pickle.dump(training_losses, file)