import torch
from data import LodopabDataset
from trainer import Trainer
from model import UNet
import functools
from torchinfo import summary
from datetime import date
from pathlib import Path
import pickle

cache_files = {
    'train':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
    'validation':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
    'test':
        '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}

ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'

model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
epochs = 2
save_after_epochs = 1

# create Dataloaders
training_dataset = LodopabDataset(cache_files=cache_files,
                                  ground_truth_data_loc=ground_truth_data_loc,
                                  split='train',
                                  transform=functools.partial(torch.unsqueeze, dim=0),
                                  target_transform=functools.partial(torch.unsqueeze, dim=0))

training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
                                                  batch_size=16,
                                                  shuffle=True)

validation_dataset = LodopabDataset(cache_files=cache_files,
                                    ground_truth_data_loc=ground_truth_data_loc,
                                    split='validation',
                                    transform=functools.partial(torch.unsqueeze, dim=0),
                                    target_transform=functools.partial(torch.unsqueeze, dim=0))

validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                                    batch_size=16,
                                                    shuffle=True)

# auto define the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device.")

# model defition
model = UNet()
print('model initialized:')
summary(model)
model.to(device)

# training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.01,
                            weight_decay=1e-8)

# initiate Trainer
trainer = Trainer(model=model,
                  device=torch.device(device),
                  criterion=criterion,
                  optimizer=optimizer,
                  training_dataloader=training_dataloader,
                  validation_dataloader=validation_dataloader,
                  lr_scheduler=None,
                  epochs=epochs,
                  epoch=0,
                  notebook=False,
                  model_output_path=model_output_path,
                  save_after_epochs=save_after_epochs)

# start training
training_losses, validation_losses, lr_rates = trainer.run_trainer()

# save
torch.save(model.state_dict(), Path.cwd() / model_output_path)
losses = {'training_losses': training_losses,
          'validation_losses': validation_losses,
          'lr_rates': lr_rates}
with open(Path.cwd() / loss_output_path, 'w') as file:
    pickle.dump(training_losses, file)