diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index a105d966be4b3e8bcb0d0ca6d9a40dc868c871bc..f9050b1059edda371d437671e8c37ff178eff4c3 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -3,10 +3,22 @@
   <component name="ChangeListManager">
     <list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment="">
       <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/data.py" beforeDir="false" afterPath="$PROJECT_DIR$/data.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/data_inspect.py" beforeDir="false" afterPath="$PROJECT_DIR$/data_inspect.py" afterDir="false" />
       <change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" />
       <change beforePath="$PROJECT_DIR$/losses.py" beforeDir="false" afterPath="$PROJECT_DIR$/losses.py" afterDir="false" />
-      <change beforePath="$PROJECT_DIR$/temp.png" beforeDir="false" />
-      <change beforePath="$PROJECT_DIR$/temp2.png" beforeDir="false" />
+      <change beforePath="$PROJECT_DIR$/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/model.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/orig_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/orig_model.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/reco.py" beforeDir="false" />
+      <change beforePath="$PROJECT_DIR$/train.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/train.ipynb" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/train_locally.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_locally.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/train_test.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/trainer.py" beforeDir="false" afterPath="$PROJECT_DIR$/trainer.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
+      <change beforePath="$PROJECT_DIR$/validate.py" beforeDir="false" />
     </list>
     <option name="SHOW_DIALOG" value="false" />
     <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -26,7 +38,6 @@
   <component name="HighlightingSettingsPerFile">
     <setting file="file://$PROJECT_DIR$/trainer.py" root0="FORCE_HIGHLIGHTING" />
     <setting file="file://$PROJECT_DIR$/summary.txt" root0="FORCE_HIGHLIGHTING" />
-    <setting file="file://$PROJECT_DIR$/reco.py" root0="FORCE_HIGHLIGHTING" />
     <setting file="file://$USER_HOME$/.conda/envs/nn/Lib/site-packages/torch/nn/modules/conv.py" root0="SKIP_INSPECTION" />
     <setting file="file://$PROJECT_DIR$/model.py" root0="FORCE_HIGHLIGHTING" />
     <setting file="file://$PROJECT_DIR$/utils.py" root0="FORCE_HIGHLIGHTING" />
@@ -173,15 +184,4 @@
       </map>
     </option>
   </component>
-  <component name="XDebuggerManager">
-    <breakpoint-manager>
-      <breakpoints>
-        <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
-          <url>file://$PROJECT_DIR$/utils.py</url>
-          <line>19</line>
-          <option name="timeStamp" value="2" />
-        </line-breakpoint>
-      </breakpoints>
-    </breakpoint-manager>
-  </component>
 </project>
\ No newline at end of file
diff --git a/.ipynb_checkpoints/train-checkpoint.ipynb b/.ipynb_checkpoints/train-checkpoint.ipynb
index 2837960a7e74f5bed1f91751911faffec490981f..6a924e828f7cca47cb05c3d63539856403962864 100644
--- a/.ipynb_checkpoints/train-checkpoint.ipynb
+++ b/.ipynb_checkpoints/train-checkpoint.ipynb
@@ -1,9 +1,17 @@
 {
  "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fe78fece",
+   "metadata": {},
+   "source": [
+    "Notebook to train a model on the high performance cluster"
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "id": "84110e61",
+   "execution_count": null,
+   "id": "4b7a7e5e",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -33,29 +41,81 @@
     "    'test':\n",
     "        '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}\n",
     "\n",
-    "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'\n",
-    "\n",
-    "model_path = Path('./model') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'"
+    "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 3,
+   "id": "13b8a348",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# small set data locations\n",
+    "cache_files = {\n",
+    "    'train':\n",
+    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',\n",
+    "    'validation':\n",
+    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy',\n",
+    "    'test':\n",
+    "        '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'}\n",
+    "\n",
+    "ground_truth_data_loc = '/oneFS/daten841/kaspar01/lodopab-ct/data/'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "f02e654d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# output\n",
+    "model_output_path = Path('./models') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'\n",
+    "epochs = 2\n",
+    "save_after_epochs = 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
    "id": "b12a5345",
    "metadata": {},
    "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "    ... Loading cached files for dataset: \"train\" ...\r",
+      "Done loading cached files for dataset: \"train\"\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Loading Ground Truth for dataset: \"train\": 100%|██████████████████████████████████████████| 3/3 [00:00<00:00,  9.89it/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "    ... Loading cached files for dataset: \"validation\" ...\r",
+      "Done loading cached files for dataset: \"validation\"\n"
+     ]
+    },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/home/kaspar01/projects/unet/data.py:80: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370131125/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
-      "  self.inputs = torch.from_numpy(self.inputs)\n"
+      "Loading Ground Truth for dataset: \"validation\": 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 10.15it/s]\n"
      ]
     }
    ],
    "source": [
     "# create Dataloaders\n",
-    "training_dataset = LodopabDataset(cache_files=small_cache_files,\n",
+    "training_dataset = LodopabDataset(cache_files=cache_files,\n",
     "                                  ground_truth_data_loc=ground_truth_data_loc,\n",
     "                                  split='train',\n",
     "                                  transform=functools.partial(torch.unsqueeze, dim=0),\n",
@@ -65,7 +125,7 @@
     "                                                  batch_size=16,\n",
     "                                                  shuffle=True)\n",
     "\n",
-    "validation_dataset = LodopabDataset(cache_files=small_cache_files,\n",
+    "validation_dataset = LodopabDataset(cache_files=cache_files,\n",
     "                                    ground_truth_data_loc=ground_truth_data_loc,\n",
     "                                    split='validation',\n",
     "                                    transform=functools.partial(torch.unsqueeze, dim=0),\n",
@@ -78,46 +138,55 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
-   "id": "ada1e3df",
+   "execution_count": 6,
+   "id": "c31c73c8",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using cuda device.\n"
+     ]
+    }
+   ],
    "source": [
-    "# model defition\n",
-    "model = UNet()\n",
-    "print('model initialized:')\n",
-    "summary(model)"
+    "# auto define the correct device\n",
+    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+    "print(f\"Using {device} device.\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
-   "id": "235cced9",
+   "execution_count": 7,
+   "id": "ada1e3df",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Using cuda device.\n"
+      "model initialized:\n"
      ]
     }
    ],
    "source": [
+    "# model defition\n",
+    "model = UNet()\n",
+    "print('model initialized:')\n",
+    "summary(model)\n",
+    "model.to(device)\n",
+    "\n",
     "# training parameters\n",
     "criterion = torch.nn.MSELoss()\n",
     "optimizer = torch.optim.SGD(model.parameters(),\n",
     "                            lr=0.01,\n",
-    "                            weight_decay=1e-8)\n",
-    "\n",
-    "# auto define the correct device\n",
-    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
-    "print(f\"Using {device} device.\")"
+    "                            weight_decay=1e-8)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "id": "cadd871b",
    "metadata": {},
    "outputs": [],
@@ -130,26 +199,28 @@
     "                  training_dataloader=training_dataloader,\n",
     "                  validation_dataloader=validation_dataloader,\n",
     "                  lr_scheduler=None,\n",
-    "                  epochs=2,\n",
+    "                  epochs=10,\n",
     "                  epoch=0,\n",
-    "                  notebook=True)"
+    "                  notebook=True,\n",
+    "                  model_output_path=model_output_path,\n",
+    "                  save_after_epochs=save_after_epochs)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": null,
    "id": "4c1bd4e1",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "65d276a3bd7b46ec90de8b74d13176c4",
+       "model_id": "73279fc9b95e4e09aab2553b86ab0874",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
+       "Epoch:   0%|          | 0/10 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
@@ -158,36 +229,212 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "e5a3cceed1a1401390c92a470c27a1ed",
+       "model_id": "d03f9f365e49484db8f2a3baeaf73038",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Training:   0%|          | 0/2239 [00:00<?, ?it/s]"
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
-     "ename": "IndexError",
-     "evalue": "index 29756 is out of bounds for dimension 0 with size 128",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
-      "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# start training\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m training_losses, validation_losses, lr_rates \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
-      "File \u001b[0;32m~/projects/unet/trainer.py:45\u001b[0m, in \u001b[0;36mTrainer.run_trainer\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m epoch_progress:\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 45\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     47\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidation_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     48\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate()\n",
-      "File \u001b[0;32m~/projects/unet/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer._train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     61\u001b[0m train_losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m     62\u001b[0m batch_progress \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m'\u001b[39m, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader))\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, (x, y) \u001b[38;5;129;01min\u001b[39;00m batch_progress:\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;28minput\u001b[39m, target \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice), y\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/notebook.py:257\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    256\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 257\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m():\n\u001b[1;32m    258\u001b[0m             \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m    259\u001b[0m             \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m    260\u001b[0m     \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1192\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m   1194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m   1196\u001b[0m         \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m   1197\u001b[0m         \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m   1198\u001b[0m         \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:435\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    433\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    434\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()\n\u001b[0;32m--> 435\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    436\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    438\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    439\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:475\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    473\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    474\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 475\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    476\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m    477\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data)\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     46\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
-      "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     43\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     45\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     46\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
-      "File \u001b[0;32m~/projects/unet/data.py:88\u001b[0m, in \u001b[0;36mLodopabDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m     87\u001b[0m     \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]\n\u001b[0;32m---> 88\u001b[0m     target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtargets[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m     89\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m, target\n",
-      "\u001b[0;31mIndexError\u001b[0m: index 29756 is out of bounds for dimension 0 with size 128"
-     ]
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fff18bcff6554c1ca04c8966e0f96f62",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b33ba398ee734e51b729b1aede972bb5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3df25a49264f4a7e8532fa150bf7089a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "82ed12bfe65947368fcdc770bba9430e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a3a1261be29e44d4813b727277f52219",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c7757599671e44bb8753e2eaa6648ff9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bf6bbee755544529a2c3944e2082efad",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e8d22e916dd942df933fa8dcdcd4fd99",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e61297656af4c26b28620dd5271a7a6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "78571ab864004405beedee980425b3f2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "705eca833be14c92bb0c0eaa756a8fda",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2159f388e3134a7e8cf2bdd3f76401b7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e17f9b41998e40c4ab0846ed21f5218c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Validation:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3789b6062f2645f2b817d9ffaf9fb0d3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training:   0%|          | 0/24 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
     }
    ],
    "source": [
diff --git a/README.md b/README.md
index 3adc2763b5b37e1ebb80c6a7623255b8dc7523c6..01d97a43d6c766e496d71538b2cc068dacb4acf1 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,51 @@
-My implementation of a U-Net similar to the FBPConvNet after:
+### UNet for Loow Dose CT 
+This repository contains my implementation of a U-Net similar to the FBPConvNet after:
 
 K. H. Jin, M. T. McCann, E. Froustey and M. Unser, "Deep Convolutional Neural Network for Inverse Problems in Imaging," 
 in IEEE Transactions on Image Processing, vol. 26, no. 9, pp. 4509-4522, Sept. 2017,
 https://doi.org/10.1109/TIP.2017.2713099
 https://github.com/panakino/FBPConvNet
+
+### Project structure
+```text
+Unet
+├───mlruns
+│    '''trained models and losses'''
+├───models
+│    '''trained models'''
+│   └───losses
+│        '''loss files of the trained models'''
+├───msub
+│    '''scripts for high performance cluster job management'''
+│    
+│   conda_env_nn.yml
+│    '''conda environment file (prerequisites)'''
+│   data.py
+│    '''functions for data loading and preparation'''
+│   data_inspect.py
+│    '''script to inspect single data samples (FBP and ground truth)'''
+│   eval.py
+│    '''script to evaluate data with a trained model'''
+│   losses.py
+│    '''script to inspect and plot losses of a trained model'''
+│   model.py
+│    '''UNet class and functions'''
+│   orig_model.py
+│    '''FBPConvNet class and functions'''
+│   README.md
+│    '''information on the repository'''
+│   train.ipynb
+│    '''jupyter notebook to train a model on the high performance cluster'''
+│   train.py
+│    '''script to train a model on the high performance cluster'''
+│   trainer.py
+│    '''Trainer class and functions'''
+│   train_locally.py
+│    '''script to train a model on my local machine'''
+│   train_test.py
+│    '''script to train a model with small datasets'''
+│   utils.py
+│    '''additional functions'''
+
+
+```
\ No newline at end of file
diff --git a/data.py b/data.py
index 9afeec6ddaa6ba60d965a7751f9de49b8ab953f6..9888ffa3caff066c97fbe72074065717388da29e 100644
--- a/data.py
+++ b/data.py
@@ -1,15 +1,20 @@
+"""
+functions for data loading and preparation
+"""
+
 import math
 from pathlib import Path
 import numpy as np
 import h5py
 import torch
 from torch.utils.data import Dataset
-from typing import Union, List, Callable, Tuple
+from typing import Union, List, Callable, Tuple, Dict
 from tqdm import tqdm
 import torch.nn.functional as F
 
 
-def get_im(input: Union[torch.Tensor, np.ndarray]):
+def get_im(input: Union[torch.Tensor, np.ndarray]
+           ) -> np.ndarray:
     """
     returns the 2D image of a sample
 
@@ -20,12 +25,15 @@ def get_im(input: Union[torch.Tensor, np.ndarray]):
 
 
 def get_lodopab_ground_truth(data_loc: Union[Path, str],
-                             indicator: Union[str, List[str]] = ['train', 'validation', 'test']):
+                             indicator: Union[str, List[str]] = ['train', 'validation', 'test']
+                             ) -> Union[Dict[str, np.ndarray], np.ndarray]:
     """
     Function to load the Ground Truth Lodopab data similarly to the cached data
 
     :param data_loc: Path to directory with all necessary data files in h5py
     :param indicator: 'train', 'test', or 'validation', or a list of several indicators to define which data is loaded
+
+    :return  data: np. ndarray of data or if more than one indicator dictionary of the form {indicator: np.ndarray}
     """
     data_loc = Path(data_loc)
     indicator = [indicator] if type(indicator) is not list else indicator
@@ -49,7 +57,8 @@ def get_lodopab_ground_truth(data_loc: Union[Path, str],
 
 
 def load_cache_files(cache_files: dict,
-                     indicator: str):
+                     indicator: str
+                     ) -> np.ndarray:
     """
     :param cache_files: input location - dict of form {'train': <path to .npy file>,
                                                        'validation': <path to .npy file>,
@@ -64,7 +73,8 @@ def load_cache_files(cache_files: dict,
 
 def get_ground_truth(idx: Union[int, List[int]],
                      indicator: str,
-                     path: Union[Path, str]) -> np.ndarray:
+                     path: Union[Path, str]
+                     ) -> np.ndarray:
     if type(idx) is int:
         gt_idx = idx % 128
         gt_file_num = math.floor(idx / 128)
@@ -171,10 +181,10 @@ class LodopabDataset(Dataset):
         self.inputs = F.pad(input=torch.from_numpy(self.inputs), pad=(0, 6, 6, 0), mode='constant', value=0)
         self.targets = F.pad(input=torch.from_numpy(self.targets), pad=(0, 6, 6, 0), mode='constant', value=0)
 
-    def __len__(self):
+    def __len__(self) -> tuple:
         return self.inputs.shape[0]
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> [torch.Tensor, torch.Tensor]:
         input = self.transform(self.inputs[idx]) if self.transform else self.inputs[idx]
         target = self.target_transform(self.targets[idx]) if self.target_transform else self.targets[idx]
         return input, target
diff --git a/data_inspect.py b/data_inspect.py
index d38cbf7fb946519e8b9aa57020024798d2cf311b..06c681c59ffc095175ec4705c9c291e9fe2bf20a 100644
--- a/data_inspect.py
+++ b/data_inspect.py
@@ -1,3 +1,7 @@
+"""
+script to inspect single data samples (FBP and ground truth)
+"""
+
 import torch
 from data import get_samples_from_cache, get_im
 from torchvision import transforms
diff --git a/eval.py b/eval.py
index 3929e43d28d23566a8ee088fe0f962d4e5d7a6e1..f2eda7ced98dbbc1777a8a0792d5b8b014dc0d7e 100644
--- a/eval.py
+++ b/eval.py
@@ -1,3 +1,7 @@
+"""
+script to evaluate data with a trained model
+"""
+
 from pathlib import Path
 import torch
 import numpy as np
diff --git a/losses.py b/losses.py
index 68e2b9ae43777ceb3a2da6014f342fc930336235..452d03a6df992c91c16f39f5ca0d6e7caee26c04 100644
--- a/losses.py
+++ b/losses.py
@@ -1,3 +1,7 @@
+"""
+script to inspect and plot losses of a trained model
+"""
+
 import numpy as np
 import pickle
 import matplotlib.pyplot as plt
diff --git a/model.py b/model.py
index 6736804d7f21a9c1f6ac515633b2f52915b2d619..2f50e0ed5b54ea7a7017ac24c679c08b08e75d36 100644
--- a/model.py
+++ b/model.py
@@ -1,3 +1,7 @@
+"""
+UNet class and functions
+"""
+
 from turtle import forward
 import torch
 from torch import nn
diff --git a/orig_model.py b/orig_model.py
index 606a5f5003cce528b21845b02678b0ec8fde710c..ddbd9852467a6b84765aafdd63e02b002bf0b0e9 100644
--- a/orig_model.py
+++ b/orig_model.py
@@ -1,3 +1,7 @@
+"""
+FBPConvNet class and functions
+"""
+
 import torch
 import torch.nn as nn
 
diff --git a/reco.py b/reco.py
deleted file mode 100644
index b22b7464ea01191b9c723a5e6ac957c518449623..0000000000000000000000000000000000000000
--- a/reco.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from pathlib import Path
-import torch
-import numpy as np
-import matplotlib.pyplot as plt
-from model import UNet
-from torchinfo import summary
-from data import get_samples_from_cache
-import torch.nn.functional as fun
-import functools
-
-
-
-# data locations
-cache_files = {
-    'train':
-        '../data/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
-    'validation':
-        '../data/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',
-    'test':
-        '../data/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}
-
-model_path = Path("models") / "2022-05-12_unet.pt"
-sample = get_samples_from_cache(cache_files=cache_files, transform=functools.partial(torch.unsqueeze, dim=0))
-transform = functools.partial(torch.unsqueeze, dim=0)
-input = torch.unsqueeze(sample, dim=0)
-input2 = transform(sample)
-
-# idx = np.random.randint(cache_data.shape[0])
-# sample = torch.from_numpy(cache_data[idx])[None, None, :, :]
-# input = fun.pad(input=sample, pad=(0, 6, 6, 0), mode='constant', value=0)
-
-model = UNet()
-model.load_state_dict(torch.load(model_path))
-# model.eval()
-
-img = torch.rand((1, 1, 362, 362))
-img = torch.rand((1, 1, 358, 358))
-img = torch.rand((1, 1, 368, 368))
-
-model(img)
-out = model(input)
-summary(model, input_data=input)
-
-plt.figure()
-plt.imshow(out.detach().numpy()[0][0])
-plt.savefig('temp.png')
\ No newline at end of file
diff --git a/train.ipynb b/train.ipynb
index 09af310e62dc27dc48003cc4743fa97e7fd9ccd4..6a924e828f7cca47cb05c3d63539856403962864 100644
--- a/train.ipynb
+++ b/train.ipynb
@@ -1,9 +1,17 @@
 {
  "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fe78fece",
+   "metadata": {},
+   "source": [
+    "Notebook to train a model on the high performance cluster"
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "id": "84110e61",
+   "execution_count": null,
+   "id": "4b7a7e5e",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -469,7 +477,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.13"
+   "version": "3.9.7"
   }
  },
  "nbformat": 4,
diff --git a/train.py b/train.py
index 3968ef167eb68b9e4447ab92ce6de8de28baa3c4..e14545025b8890f83609e50c910f435abfec34ab 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,6 @@
+"""
+script to train a model on the high performance cluster
+"""
 import torch
 from data import LodopabDataset
 from trainer import Trainer
@@ -13,7 +16,7 @@ from utils import check_filepath
 start = datetime.now()
 print(f'Starting training at {start.strftime("%Y-%m-%d  %T")}')
 
-# data locations
+# input data locations
 cache_files = {
     'train':
         '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
@@ -24,12 +27,19 @@ cache_files = {
 
 ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
 
+# output paths
 model_name = f'{date.today().strftime("%Y-%m-%d")}_unet'
 model_dir = Path(f'./models/{model_name}/')
 model_output_path = model_dir / f'{model_name}.pt'
 loss_output_path = model_dir / f'{model_name}_losses.p'
+
+# adjustable parameters
 epochs = 500
 save_after_epochs = 25
+learning_rate = 0.1
+weight_decay = 1e-8
+batch_size = 16
+
 
 # create Dataloaders
 training_dataset = LodopabDataset(cache_files=cache_files,
@@ -39,7 +49,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
                                   target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
-                                                  batch_size=16,
+                                                  batch_size=batch_size,
                                                   shuffle=True)
 
 validation_dataset = LodopabDataset(cache_files=cache_files,
@@ -49,7 +59,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
                                     target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
-                                                    batch_size=16,
+                                                    batch_size=batch_size,
                                                     shuffle=True)
 
 # auto define the correct device
@@ -86,7 +96,7 @@ trainer = Trainer(model=model,
 # start training
 training_losses, validation_losses, lr_rates = trainer.run_trainer()
 
-# save
+# save final model and losses
 torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num'))
 losses = {'training_losses': training_losses,
           'validation_losses': validation_losses,
diff --git a/train_locally.py b/train_locally.py
index dd1f2477fb96eafdde49afed46cc1e116f04b322..51034bdc749a07abae4a97ce549d6e6dc3f3eccb 100644
--- a/train_locally.py
+++ b/train_locally.py
@@ -4,10 +4,16 @@ from trainer import Trainer
 from model import UNet
 import functools
 from torchinfo import summary
-from datetime import date
+from datetime import date, datetime
 from pathlib import Path
 import pickle
 
+
+# timer
+start = datetime.now()
+print(f'Starting training at {start.strftime("%Y-%m-%d  %T")}')
+
+# input paths
 cache_files = {
     'train':
         '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
@@ -18,10 +24,16 @@ cache_files = {
 
 ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/'
 
+# output paths
 model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt'
 loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p'
+
+# adjustable parameters
 epochs = 2
 save_after_epochs = 1
+learning_rate = 0.1
+weight_decay = 1e-8
+batch_size = 16
 
 # create Dataloaders
 training_dataset = LodopabDataset(cache_files=cache_files,
@@ -31,7 +43,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
                                   target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
-                                                  batch_size=16,
+                                                  batch_size=batch_size,
                                                   shuffle=True)
 
 validation_dataset = LodopabDataset(cache_files=cache_files,
@@ -41,7 +53,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
                                     target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
-                                                    batch_size=16,
+                                                    batch_size=batch_size,
                                                     shuffle=True)
 
 # auto define the correct device
@@ -50,15 +62,15 @@ print(f"Using {device} device.")
 
 # model defition
 model = UNet()
-print('model initialized:')
-summary(model)
 model.to(device)
+# print('model initialized:')
+# summary(model)
 
 # training parameters
 criterion = torch.nn.MSELoss()
 optimizer = torch.optim.SGD(model.parameters(),
-                            lr=0.01,
-                            weight_decay=1e-8)
+                            lr=learning_rate,
+                            weight_decay=weight_decay)
 
 # initiate Trainer
 trainer = Trainer(model=model,
@@ -77,10 +89,16 @@ trainer = Trainer(model=model,
 # start training
 training_losses, validation_losses, lr_rates = trainer.run_trainer()
 
-# save
+# save final model and losses
 torch.save(model.state_dict(), Path.cwd() / model_output_path)
 losses = {'training_losses': training_losses,
           'validation_losses': validation_losses,
           'lr_rates': lr_rates}
 with open(Path.cwd() / loss_output_path, 'w') as file:
     pickle.dump(training_losses, file)
+
+# timer
+end = datetime.now()
+duration = end - start
+print(f'Ending training at {end.strftime("%Y-%m-%d  %T")}')
+print(f'Duration of training: {str(duration)}')
diff --git a/train_test.py b/train_test.py
index f743c9a6db8df981c46ce074bd45dbbd14b8be77..11b69f33c4f42863ee65ad5da2ecc26aaf3ba37a 100644
--- a/train_test.py
+++ b/train_test.py
@@ -1,3 +1,7 @@
+"""
+script to train a model with small datasets
+"""
+
 import torch
 from data import LodopabDataset
 from trainer import Trainer
@@ -13,7 +17,7 @@ from utils import check_filepath
 start = datetime.now()
 print(f'Starting training at {start.strftime("%Y-%m-%d  %T")}')
 
-# data locations
+# # cluster input data locations
 # cache_files = {
 #     'train':
 #         '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',
@@ -24,7 +28,7 @@ print(f'Starting training at {start.strftime("%Y-%m-%d  %T")}')
 #
 # ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'
 
-# local data locations
+# local input data locations
 cache_files = {
     'train':
         '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',
@@ -35,12 +39,18 @@ cache_files = {
 
 ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/data/'
 
+# output paths
 model_name = f'{date.today().strftime("%Y-%m-%d")}_unet_test'
 model_dir = Path(f'./models/{model_name}/')
 model_output_path = model_dir / f'{model_name}.pt'
 losses_output_path = model_dir / f'{model_name}_losses.p'
+
+# adjustable parameters
 epochs = 2
 save_after_epochs = 1
+learning_rate = 0.1
+weight_decay = 1e-8
+batch_size = 16
 
 # create Dataloaders
 training_dataset = LodopabDataset(cache_files=cache_files,
@@ -50,7 +60,7 @@ training_dataset = LodopabDataset(cache_files=cache_files,
                                   target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,
-                                                  batch_size=16,
+                                                  batch_size=batch_size,
                                                   shuffle=True)
 
 validation_dataset = LodopabDataset(cache_files=cache_files,
@@ -60,7 +70,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files,
                                     target_transform=functools.partial(torch.unsqueeze, dim=0))
 
 validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
-                                                    batch_size=16,
+                                                    batch_size=batch_size,
                                                     shuffle=True)
 
 # auto define the correct device
@@ -76,8 +86,8 @@ model.to(device)
 # training parameters
 criterion = torch.nn.MSELoss()
 optimizer = torch.optim.SGD(model.parameters(),
-                            lr=0.01,
-                            weight_decay=1e-8)
+                            lr=learning_rate,
+                            weight_decay=weight_decay)
 
 # initiate Trainer
 trainer = Trainer(model=model,
@@ -96,7 +106,7 @@ trainer = Trainer(model=model,
 # start training
 training_losses, validation_losses, lr_rates = trainer.run_trainer()
 
-# save
+# save final model and losses
 torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num'))
 losses = {'training_losses': training_losses,
           'validation_losses': validation_losses,
diff --git a/trainer.py b/trainer.py
index 65a0155f694f3653dcb63f8753426e060dc1b964..486e36c322f021f65b48fb2d8623bdecc22c465f 100644
--- a/trainer.py
+++ b/trainer.py
@@ -1,5 +1,7 @@
+"""
+Trainer class and functions
+"""
 import pickle
-
 import numpy as np
 import torch
 from typing import Union
@@ -9,6 +11,9 @@ from utils import check_filepath
 
 
 class Trainer:
+    """
+    for training and validation
+    """
     def __init__(self,
                  model: torch.nn.Module,
                  device: torch.device,
@@ -53,7 +58,7 @@ class Trainer:
             from tqdm import trange
 
         epoch_progress = trange(self.epochs, desc='Epoch')
-        for epoch in epoch_progress:
+        for _ in epoch_progress:
             self.epoch += 1
 
             self._train()
@@ -76,9 +81,7 @@ class Trainer:
         losses = {'training_losses': self.training_loss,
                   'validation_losses': self.validation_loss,
                   'lr_rates': self.learning_rate}
-        save_loss_as = check_filepath(filepath=self.loss_output_path.with_name(
-            self.loss_output_path.name.replace('.p', f'_e{self.epoch:03}.pt')),
-            expand='num')
+        save_loss_as = check_filepath(filepath=self.loss_output_path, replace=True)
         with open(save_loss_as, 'wb') as file:
             pickle.dump(losses, file)
 
@@ -117,7 +120,7 @@ class Trainer:
 
         self.model.eval()
         validation_losses = []
-        batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.training_dataloader),
+        batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.validation_dataloader),
                               leave=False)
 
         for i, (x, y) in batch_progress:
diff --git a/utils.py b/utils.py
index d052c0f48edcab72eecb975ccfcfa2eb25da18f9..091af06283fd9619b4db7a115be6e7ac547f6d87 100644
--- a/utils.py
+++ b/utils.py
@@ -1,3 +1,6 @@
+"""
+additional functions
+"""
 import os
 from pathlib import Path
 from datetime import datetime
@@ -5,14 +8,14 @@ from datetime import datetime
 
 def check_filepath(filepath: (str, Path),
                    makedir: bool = True,
-                   overwrite: bool = False,
+                   replace: bool = False,
                    expand: str = 'date',
                    **kwargs) -> Path:
     """
     Checks a filepath existence and possibly creates directories
     :param filepath: filepath to the target
     :param makedir: create nonexistent directories in the filepath
-    :param overwrite: toggles allowance to overwrite existing data in the filepath
+    :param replace: toggles allowance to overwrite existing data in the filepath
     :param expand: 'date' (default): expands existing files with datetime (up to minute),
         'num': expands with running digit
     :return filepath: The correct path to the existing or created folder
@@ -23,7 +26,7 @@ def check_filepath(filepath: (str, Path),
             os.makedirs(Path(filepath.parent))
         else:
             raise ValueError('Filepath does not exist. To create, rerun with makedir = True.')
-    if not overwrite:
+    if not replace:
         if expand == 'date':
             if filepath.exists():
                 filepath = Path(filepath.as_posix().replace(filepath.stem + filepath.suffix,
diff --git a/validate.py b/validate.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000