Newer
Older
from data import LodopabDataset
from trainer import Trainer
from model import UNet
# timer
start = datetime.now()
print(f'Starting training at {start.strftime("%Y-%m-%d %T")}')
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
model_output_path = Path(f'./models/{model_name}/') / f'{model_name}.pt'
loss_output_path = Path(f'./models/{model_name}/') / f'{model_name}_losses.p'
split='train',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
shuffle=True)
split='validation',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
shuffle=True)
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
trainer = Trainer(model=model,
device=torch.device(device),
criterion=criterion,
optimizer=optimizer,
training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader,
lr_scheduler=None,
notebook=False,
model_output_path=model_output_path,
save_after_epochs=save_after_epochs)
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
# save
torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
'validation_losses': validation_losses,
'lr_rates': lr_rates}
with open(check_filepath(filepath=loss_output_path, expand='num'), 'wb') as file:
pickle.dump(losses, file)
# timer
end = datetime.now()
duration = end - start
print(f'Ending training at {end.strftime("%Y-%m-%d %T")}')
print(f'Duration of training: {str(duration)}')