import torch from data import LodopabDataset from trainer import Trainer from model import UNet import functools from torchinfo import summary from datetime import date from pathlib import Path import pickle cache_files = { 'train': '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', 'validation': '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy', 'test': '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'} ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/' model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt' loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p' epochs = 2 save_after_epochs = 1 # create Dataloaders training_dataset = LodopabDataset(cache_files=cache_files, ground_truth_data_loc=ground_truth_data_loc, split='train', transform=functools.partial(torch.unsqueeze, dim=0), target_transform=functools.partial(torch.unsqueeze, dim=0)) training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=16, shuffle=True) validation_dataset = LodopabDataset(cache_files=cache_files, ground_truth_data_loc=ground_truth_data_loc, split='validation', transform=functools.partial(torch.unsqueeze, dim=0), target_transform=functools.partial(torch.unsqueeze, dim=0)) validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=16, shuffle=True) # auto define the correct device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device.") # model defition model = UNet() print('model initialized:') summary(model) model.to(device) # training parameters criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-8) # initiate Trainer trainer = Trainer(model=model, device=torch.device(device), criterion=criterion, optimizer=optimizer, training_dataloader=training_dataloader, validation_dataloader=validation_dataloader, lr_scheduler=None, epochs=epochs, epoch=0, notebook=False, model_output_path=model_output_path, save_after_epochs=save_after_epochs) # start training training_losses, validation_losses, lr_rates = trainer.run_trainer() # save torch.save(model.state_dict(), Path.cwd() / model_output_path) losses = {'training_losses': training_losses, 'validation_losses': validation_losses, 'lr_rates': lr_rates} with open(Path.cwd() / loss_output_path, 'w') as file: pickle.dump(training_losses, file)