Skip to content
Snippets Groups Projects
train.ipynb 17.2 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
{
 "cells": [
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 1,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "84110e61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from data import LodopabDataset\n",
    "from trainer import Trainer\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "from model import UNet\n",
    "from torchinfo import summary\n",
    "import functools\n",
    "from datetime import date\n",
    "from pathlib import Path"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 2,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "15aaa025",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data locations\n",
    "cache_files = {\n",
    "    'train':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "    'validation':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "    'test':\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'\n",
    "\n",
    "model_path = Path('./model') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 3,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "b12a5345",
   "metadata": {},
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "outputs": [
    {
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kaspar01/projects/unet/data.py:80: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370131125/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
      "  self.inputs = torch.from_numpy(self.inputs)\n"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
     ]
    }
   ],
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "source": [
    "# create Dataloaders\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "training_dataset = LodopabDataset(cache_files=small_cache_files,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                  ground_truth_data_loc=ground_truth_data_loc,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                  split='train',\n",
    "                                  transform=functools.partial(torch.unsqueeze, dim=0),\n",
    "                                  target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
    "training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,\n",
    "                                                  batch_size=16,\n",
    "                                                  shuffle=True)\n",
    "\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "validation_dataset = LodopabDataset(cache_files=small_cache_files,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                    ground_truth_data_loc=ground_truth_data_loc,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                                    split='validation',\n",
    "                                    transform=functools.partial(torch.unsqueeze, dim=0),\n",
    "                                    target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "\n",
    "validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,\n",
    "                                                    batch_size=16,\n",
    "                                                    shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 4,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "ada1e3df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model defition\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "model = UNet()\n",
    "print('model initialized:')\n",
    "summary(model)"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 5,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "235cced9",
   "metadata": {},
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device.\n"
     ]
    }
   ],
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "source": [
    "# training parameters\n",
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(),\n",
    "                            lr=0.01,\n",
    "                            weight_decay=1e-8)\n",
    "\n",
    "# auto define the correct device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device.\")"
   ]
  },
  {
   "cell_type": "code",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "execution_count": 9,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "id": "cadd871b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initiate Trainer\n",
    "trainer = Trainer(model=model,\n",
    "                  device=torch.device(device),\n",
    "                  criterion=criterion,\n",
    "                  optimizer=optimizer,\n",
    "                  training_dataloader=training_dataloader,\n",
    "                  validation_dataloader=validation_dataloader,\n",
    "                  lr_scheduler=None,\n",
    "                  epochs=2,\n",
    "                  epoch=0,\n",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    "                  notebook=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4c1bd4e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65d276a3bd7b46ec90de8b74d13176c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e5a3cceed1a1401390c92a470c27a1ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/2239 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "IndexError",
     "evalue": "index 29756 is out of bounds for dimension 0 with size 128",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# start training\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m training_losses, validation_losses, lr_rates \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/projects/unet/trainer.py:45\u001b[0m, in \u001b[0;36mTrainer.run_trainer\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m epoch_progress:\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 45\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     47\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidation_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     48\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate()\n",
      "File \u001b[0;32m~/projects/unet/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer._train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     61\u001b[0m train_losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m     62\u001b[0m batch_progress \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m'\u001b[39m, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader))\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, (x, y) \u001b[38;5;129;01min\u001b[39;00m batch_progress:\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;28minput\u001b[39m, target \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice), y\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/notebook.py:257\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    256\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 257\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m():\n\u001b[1;32m    258\u001b[0m             \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m    259\u001b[0m             \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m    260\u001b[0m     \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1192\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m   1194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m   1196\u001b[0m         \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m   1197\u001b[0m         \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m   1198\u001b[0m         \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:435\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    433\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    434\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()\n\u001b[0;32m--> 435\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    436\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    438\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    439\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:475\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    473\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    474\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 475\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    476\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m    477\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data)\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     46\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     46\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[0;32m~/projects/unet/data.py:88\u001b[0m, in \u001b[0;36mLodopabDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m     87\u001b[0m     \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]\n\u001b[0;32m---> 88\u001b[0m     target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtargets[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m, target\n",
      "\u001b[0;31mIndexError\u001b[0m: index 29756 is out of bounds for dimension 0 with size 128"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "training_losses, validation_losses, lr_rates = trainer.run_trainer()"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   ]
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  },
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf6c39f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), Path.cwd() / model_path)"
   ]
  },
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b69efe6",
   "metadata": {},
   "outputs": [],
   "source": []
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Kerstin Kaspar's avatar
Kerstin Kaspar committed
   "version": "3.9.7"
Kerstin Kaspar's avatar
Kerstin Kaspar committed
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}