Newer
Older
from data import LodopabDataset
from trainer import Trainer
from model import UNet
from torchinfo import summary
import functools
from datetime import date
from pathlib import Path
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
'/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
model_path = Path('./model') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
# create Dataloaders
training_dataset = LodopabDataset(cache_files=small_cache_files,
split='train',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
batch_size=16,
shuffle=True)
validation_dataset = LodopabDataset(cache_files=small_cache_files,
split='validation',
transform=functools.partial(torch.unsqueeze, dim=0),
target_transform=functools.partial(torch.unsqueeze, dim=0))
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=16,
shuffle=True)
print('model initialized:')
summary(model)
# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01,
weight_decay=1e-8)
# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")
# initiate Trainer
trainer = Trainer(model=model,
device=torch.device(device),
criterion=criterion,
optimizer=optimizer,
training_dataloader=training_dataloader,
validation_dataloader=validation_dataloader,
lr_scheduler=None,
# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()
model_name = f'models/{date.today().strftime("%Y-%m-%d")}_unet.pt'
torch.save(model.state_dict(), Path.cwd() / model_name)