Skip to content
Snippets Groups Projects
train.py 3.59 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
import torch
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from data import LodopabDataset
from trainer import Trainer
from model import UNet
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from torchinfo import summary
import functools
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from datetime import date, datetime
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from pathlib import Path
Kerstin Kaspar's avatar
Kerstin Kaspar committed
import pickle
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from utils import check_filepath
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# timer
start = datetime.now()
print(f'Starting training at {start.strftime("%Y-%m-%d  %T")}')
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# data locations
Kerstin Kaspar's avatar
Kerstin Kaspar committed
cache_files = {
    'train':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    'validation':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    'test':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
model_output_path = Path(f'./models/{model_name}/') / f'{model_name}.pt'
loss_output_path = Path(f'./models/{model_name}/') / f'{model_name}_losses.p'
Kerstin Kaspar's avatar
Kerstin Kaspar committed
epochs = 100
Kerstin Kaspar's avatar
Kerstin Kaspar committed
save_after_epochs = 10
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# create Dataloaders
Kerstin Kaspar's avatar
Kerstin Kaspar committed
training_dataset = LodopabDataset(cache_files=cache_files,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                  ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                  split='train',
                                  transform=functools.partial(torch.unsqueeze, dim=0),
                                  target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
                                                  batch_size=16,
                                                  shuffle=True)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
validation_dataset = LodopabDataset(cache_files=cache_files,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                    ground_truth_data_loc=ground_truth_data_loc,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                    split='validation',
                                    transform=functools.partial(torch.unsqueeze, dim=0),
                                    target_transform=functools.partial(torch.unsqueeze, dim=0))
Kerstin Kaspar's avatar
Kerstin Kaspar committed

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                    batch_size=16,
                                                    shuffle=True)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# model defition
Kerstin Kaspar's avatar
Kerstin Kaspar committed
model = UNet()
Kerstin Kaspar's avatar
Kerstin Kaspar committed
print('model initialized:')
summary(model)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
model.to(device)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# training parameters
Kerstin Kaspar's avatar
Kerstin Kaspar committed
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            weight_decay=1e-8)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
# initiate Trainer
Kerstin Kaspar's avatar
Kerstin Kaspar committed
trainer = Trainer(model=model,
                  device=torch.device(device),
                  criterion=criterion,
                  optimizer=optimizer,
                  training_dataloader=training_dataloader,
                  validation_dataloader=validation_dataloader,
                  lr_scheduler=None,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epochs=epochs,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  epoch=0,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  notebook=False,
                  model_output_path=model_output_path,
                  save_after_epochs=save_after_epochs)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# save
torch.save(model.state_dict(), Path.cwd() / model_output_path)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
losses = {'training_losses': training_losses,
          'validation_losses': validation_losses,
          'lr_rates': lr_rates}
Kerstin Kaspar's avatar
Kerstin Kaspar committed
with open(check_filepath(filepath=loss_output_path, expand='num'), 'wb') as file:
    pickle.dump(losses, file)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

# timer
end = datetime.now()
duration = end - start
print(f'Ending training at {end.strftime("%Y-%m-%d  %T")}')
print(f'Duration of training: {str(duration)}')