Skip to content
Snippets Groups Projects
train.ipynb 12.6 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
{
 "cells": [
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 1,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "84110e61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from data import LodopabDataset\n",
    "from trainer import Trainer\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "from model import UNet\n",
    "from torchinfo import summary\n",
    "import functools\n",
    "from datetime import date\n",
    "from pathlib import Path"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 2,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "15aaa025",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data locations\n",
    "cache_files = {\n",
    "    'train':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "    'validation':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "    'test':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 3,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "13b8a348",
   "metadata": {},
   "outputs": [],
   "source": [
    "# small set data locations\n",
    "cache_files = {\n",
    "    'train':\n",
    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',\n",
    "    'validation':\n",
    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy',\n",
    "    'test':\n",
    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'}\n",
    "\n",
    "ground_truth_data_loc = '/oneFS/daten841/kaspar01/lodopab-ct/data/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f02e654d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# output\n",
    "model_output_path = Path('./models') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'\n",
    "epochs = 2\n",
    "save_after_epochs = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "b12a5345",
   "metadata": {},
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "outputs": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    ... Loading cached files for dataset: \"train\" ...\r",
      "Done loading cached files for dataset: \"train\"\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading Ground Truth for dataset: \"train\": 100%|██████████████████████████████████████████| 3/3 [00:00<00:00,  9.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    ... Loading cached files for dataset: \"validation\" ...\r",
      "Done loading cached files for dataset: \"validation\"\n"
     ]
    },
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    {
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     "name": "stderr",
     "output_type": "stream",
     "text": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
      "Loading Ground Truth for dataset: \"validation\": 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 10.15it/s]\n"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     ]
    }
   ],
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "source": [
    "# create Dataloaders\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "training_dataset = LodopabDataset(cache_files=cache_files,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                  ground_truth_data_loc=ground_truth_data_loc,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                  split='train',\n",
    "                                  transform=functools.partial(torch.unsqueeze, dim=0),\n",
    "                                  target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
    "training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,\n",
    "                                                  batch_size=16,\n",
    "                                                  shuffle=True)\n",
    "\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "validation_dataset = LodopabDataset(cache_files=cache_files,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                    ground_truth_data_loc=ground_truth_data_loc,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                    split='validation',\n",
    "                                    transform=functools.partial(torch.unsqueeze, dim=0),\n",
    "                                    target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
    "validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,\n",
    "                                                    batch_size=16,\n",
    "                                                    shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 6,
   "id": "c31c73c8",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "metadata": {},
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device.\n"
     ]
    }
   ],
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "source": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "# auto define the correct device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device.\")"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 7,
   "id": "ada1e3df",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "metadata": {},
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
      "model initialized:\n"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     ]
    }
   ],
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "source": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "# model defition\n",
    "model = UNet()\n",
    "print('model initialized:')\n",
    "summary(model)\n",
    "model.to(device)\n",
    "\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "# training parameters\n",
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(),\n",
    "                            lr=0.01,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                            weight_decay=1e-8)"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 8,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "cadd871b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initiate Trainer\n",
    "trainer = Trainer(model=model,\n",
    "                  device=torch.device(device),\n",
    "                  criterion=criterion,\n",
    "                  optimizer=optimizer,\n",
    "                  training_dataloader=training_dataloader,\n",
    "                  validation_dataloader=validation_dataloader,\n",
    "                  lr_scheduler=None,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                  epochs=10,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                  epoch=0,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                  notebook=True,\n",
    "                  model_output_path=model_output_path,\n",
    "                  save_after_epochs=save_after_epochs)"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": null,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "4c1bd4e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "model_id": "73279fc9b95e4e09aab2553b86ab0874",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "Epoch:   0%|          | 0/10 [00:00<?, ?it/s]"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "model_id": "d03f9f365e49484db8f2a3baeaf73038",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
Kerstin Kaspar's avatar
Kerstin Kaspar committed
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fff18bcff6554c1ca04c8966e0f96f62",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b33ba398ee734e51b729b1aede972bb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3df25a49264f4a7e8532fa150bf7089a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "82ed12bfe65947368fcdc770bba9430e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3a1261be29e44d4813b727277f52219",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7757599671e44bb8753e2eaa6648ff9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bf6bbee755544529a2c3944e2082efad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8d22e916dd942df933fa8dcdcd4fd99",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e61297656af4c26b28620dd5271a7a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78571ab864004405beedee980425b3f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "705eca833be14c92bb0c0eaa756a8fda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2159f388e3134a7e8cf2bdd3f76401b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e17f9b41998e40c4ab0846ed21f5218c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3789b6062f2645f2b817d9ffaf9fb0d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    }
   ],
   "source": [
    "# start training\n",
    "training_losses, validation_losses, lr_rates = trainer.run_trainer()"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  },
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf6c39f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), Path.cwd() / model_path)"
   ]
  },
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b69efe6",
   "metadata": {},
   "outputs": [],
   "source": []
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "version": "3.8.13"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}