diff --git a/.idea/workspace.xml b/.idea/workspace.xml index a105d966be4b3e8bcb0d0ca6d9a40dc868c871bc..f9050b1059edda371d437671e8c37ff178eff4c3 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -3,10 +3,22 @@ <component name="ChangeListManager"> <list default="true" id="714ca32e-a67a-44cf-90a6-e7ab36748bed" name="Changes" comment=""> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/.ipynb_checkpoints/train-checkpoint.ipynb" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/data.py" beforeDir="false" afterPath="$PROJECT_DIR$/data.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/data_inspect.py" beforeDir="false" afterPath="$PROJECT_DIR$/data_inspect.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/eval.py" afterDir="false" /> <change beforePath="$PROJECT_DIR$/losses.py" beforeDir="false" afterPath="$PROJECT_DIR$/losses.py" afterDir="false" /> - <change beforePath="$PROJECT_DIR$/temp.png" beforeDir="false" /> - <change beforePath="$PROJECT_DIR$/temp2.png" beforeDir="false" /> + <change beforePath="$PROJECT_DIR$/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/model.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/orig_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/orig_model.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/reco.py" beforeDir="false" /> + <change beforePath="$PROJECT_DIR$/train.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/train.ipynb" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/train_locally.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_locally.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/train_test.py" beforeDir="false" afterPath="$PROJECT_DIR$/train_test.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/trainer.py" beforeDir="false" afterPath="$PROJECT_DIR$/trainer.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" /> + <change beforePath="$PROJECT_DIR$/validate.py" beforeDir="false" /> </list> <option name="SHOW_DIALOG" value="false" /> <option name="HIGHLIGHT_CONFLICTS" value="true" /> @@ -26,7 +38,6 @@ <component name="HighlightingSettingsPerFile"> <setting file="file://$PROJECT_DIR$/trainer.py" root0="FORCE_HIGHLIGHTING" /> <setting file="file://$PROJECT_DIR$/summary.txt" root0="FORCE_HIGHLIGHTING" /> - <setting file="file://$PROJECT_DIR$/reco.py" root0="FORCE_HIGHLIGHTING" /> <setting file="file://$USER_HOME$/.conda/envs/nn/Lib/site-packages/torch/nn/modules/conv.py" root0="SKIP_INSPECTION" /> <setting file="file://$PROJECT_DIR$/model.py" root0="FORCE_HIGHLIGHTING" /> <setting file="file://$PROJECT_DIR$/utils.py" root0="FORCE_HIGHLIGHTING" /> @@ -173,15 +184,4 @@ </map> </option> </component> - <component name="XDebuggerManager"> - <breakpoint-manager> - <breakpoints> - <line-breakpoint enabled="true" suspend="THREAD" type="python-line"> - <url>file://$PROJECT_DIR$/utils.py</url> - <line>19</line> - <option name="timeStamp" value="2" /> - </line-breakpoint> - </breakpoints> - </breakpoint-manager> - </component> </project> \ No newline at end of file diff --git a/.ipynb_checkpoints/train-checkpoint.ipynb b/.ipynb_checkpoints/train-checkpoint.ipynb index 2837960a7e74f5bed1f91751911faffec490981f..6a924e828f7cca47cb05c3d63539856403962864 100644 --- a/.ipynb_checkpoints/train-checkpoint.ipynb +++ b/.ipynb_checkpoints/train-checkpoint.ipynb @@ -1,9 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "fe78fece", + "metadata": {}, + "source": [ + "Notebook to train a model on the high performance cluster" + ] + }, { "cell_type": "code", - "execution_count": 1, - "id": "84110e61", + "execution_count": null, + "id": "4b7a7e5e", "metadata": {}, "outputs": [], "source": [ @@ -33,29 +41,81 @@ " 'test':\n", " '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}\n", "\n", - "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'\n", - "\n", - "model_path = Path('./model') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'" + "ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'" ] }, { "cell_type": "code", "execution_count": 3, + "id": "13b8a348", + "metadata": {}, + "outputs": [], + "source": [ + "# small set data locations\n", + "cache_files = {\n", + " 'train':\n", + " '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy',\n", + " 'validation':\n", + " '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_validation_fbp.npy',\n", + " 'test':\n", + " '/oneFS/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_test_fbp.npy'}\n", + "\n", + "ground_truth_data_loc = '/oneFS/daten841/kaspar01/lodopab-ct/data/'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f02e654d", + "metadata": {}, + "outputs": [], + "source": [ + "# output\n", + "model_output_path = Path('./models') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'\n", + "epochs = 2\n", + "save_after_epochs = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "id": "b12a5345", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ... Loading cached files for dataset: \"train\" ...\r", + "Done loading cached files for dataset: \"train\"\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading Ground Truth for dataset: \"train\": 100%|██████████████████████████████████████████| 3/3 [00:00<00:00, 9.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ... Loading cached files for dataset: \"validation\" ...\r", + "Done loading cached files for dataset: \"validation\"\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/kaspar01/projects/unet/data.py:80: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1607370131125/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n", - " self.inputs = torch.from_numpy(self.inputs)\n" + "Loading Ground Truth for dataset: \"validation\": 100%|█████████████████████████████████████| 3/3 [00:00<00:00, 10.15it/s]\n" ] } ], "source": [ "# create Dataloaders\n", - "training_dataset = LodopabDataset(cache_files=small_cache_files,\n", + "training_dataset = LodopabDataset(cache_files=cache_files,\n", " ground_truth_data_loc=ground_truth_data_loc,\n", " split='train',\n", " transform=functools.partial(torch.unsqueeze, dim=0),\n", @@ -65,7 +125,7 @@ " batch_size=16,\n", " shuffle=True)\n", "\n", - "validation_dataset = LodopabDataset(cache_files=small_cache_files,\n", + "validation_dataset = LodopabDataset(cache_files=cache_files,\n", " ground_truth_data_loc=ground_truth_data_loc,\n", " split='validation',\n", " transform=functools.partial(torch.unsqueeze, dim=0),\n", @@ -78,46 +138,55 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "ada1e3df", + "execution_count": 6, + "id": "c31c73c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device.\n" + ] + } + ], "source": [ - "# model defition\n", - "model = UNet()\n", - "print('model initialized:')\n", - "summary(model)" + "# auto define the correct device\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(f\"Using {device} device.\")" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "235cced9", + "execution_count": 7, + "id": "ada1e3df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using cuda device.\n" + "model initialized:\n" ] } ], "source": [ + "# model defition\n", + "model = UNet()\n", + "print('model initialized:')\n", + "summary(model)\n", + "model.to(device)\n", + "\n", "# training parameters\n", "criterion = torch.nn.MSELoss()\n", "optimizer = torch.optim.SGD(model.parameters(),\n", " lr=0.01,\n", - " weight_decay=1e-8)\n", - "\n", - "# auto define the correct device\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "print(f\"Using {device} device.\")" + " weight_decay=1e-8)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "cadd871b", "metadata": {}, "outputs": [], @@ -130,26 +199,28 @@ " training_dataloader=training_dataloader,\n", " validation_dataloader=validation_dataloader,\n", " lr_scheduler=None,\n", - " epochs=2,\n", + " epochs=10,\n", " epoch=0,\n", - " notebook=True)" + " notebook=True,\n", + " model_output_path=model_output_path,\n", + " save_after_epochs=save_after_epochs)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "4c1bd4e1", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "65d276a3bd7b46ec90de8b74d13176c4", + "model_id": "73279fc9b95e4e09aab2553b86ab0874", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Epoch: 0%| | 0/2 [00:00<?, ?it/s]" + "Epoch: 0%| | 0/10 [00:00<?, ?it/s]" ] }, "metadata": {}, @@ -158,36 +229,212 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e5a3cceed1a1401390c92a470c27a1ed", + "model_id": "d03f9f365e49484db8f2a3baeaf73038", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training: 0%| | 0/2239 [00:00<?, ?it/s]" + "Training: 0%| | 0/24 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { - "ename": "IndexError", - "evalue": "index 29756 is out of bounds for dimension 0 with size 128", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# start training\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m training_losses, validation_losses, lr_rates \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/projects/unet/trainer.py:45\u001b[0m, in \u001b[0;36mTrainer.run_trainer\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m epoch_progress:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidation_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate()\n", - "File \u001b[0;32m~/projects/unet/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer._train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 61\u001b[0m train_losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 62\u001b[0m batch_progress \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m'\u001b[39m, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader))\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, (x, y) \u001b[38;5;129;01min\u001b[39;00m batch_progress:\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28minput\u001b[39m, target \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice), y\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/notebook.py:257\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 257\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m():\n\u001b[1;32m 258\u001b[0m \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:435\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()\n\u001b[0;32m--> 435\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:475\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 474\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 475\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 477\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data)\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/projects/unet/data.py:88\u001b[0m, in \u001b[0;36mLodopabDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]\n\u001b[0;32m---> 88\u001b[0m target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtargets[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m, target\n", - "\u001b[0;31mIndexError\u001b[0m: index 29756 is out of bounds for dimension 0 with size 128" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fff18bcff6554c1ca04c8966e0f96f62", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b33ba398ee734e51b729b1aede972bb5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3df25a49264f4a7e8532fa150bf7089a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "82ed12bfe65947368fcdc770bba9430e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a3a1261be29e44d4813b727277f52219", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c7757599671e44bb8753e2eaa6648ff9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bf6bbee755544529a2c3944e2082efad", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e8d22e916dd942df933fa8dcdcd4fd99", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6e61297656af4c26b28620dd5271a7a6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "78571ab864004405beedee980425b3f2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "705eca833be14c92bb0c0eaa756a8fda", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2159f388e3134a7e8cf2bdd3f76401b7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e17f9b41998e40c4ab0846ed21f5218c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3789b6062f2645f2b817d9ffaf9fb0d3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0%| | 0/24 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ diff --git a/README.md b/README.md index 3adc2763b5b37e1ebb80c6a7623255b8dc7523c6..01d97a43d6c766e496d71538b2cc068dacb4acf1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,51 @@ -My implementation of a U-Net similar to the FBPConvNet after: +### UNet for Loow Dose CT +This repository contains my implementation of a U-Net similar to the FBPConvNet after: K. H. Jin, M. T. McCann, E. Froustey and M. Unser, "Deep Convolutional Neural Network for Inverse Problems in Imaging," in IEEE Transactions on Image Processing, vol. 26, no. 9, pp. 4509-4522, Sept. 2017, https://doi.org/10.1109/TIP.2017.2713099 https://github.com/panakino/FBPConvNet + +### Project structure +```text +Unet +├───mlruns +│ '''trained models and losses''' +├───models +│ '''trained models''' +│ └───losses +│ '''loss files of the trained models''' +├───msub +│ '''scripts for high performance cluster job management''' +│ +│ conda_env_nn.yml +│ '''conda environment file (prerequisites)''' +│ data.py +│ '''functions for data loading and preparation''' +│ data_inspect.py +│ '''script to inspect single data samples (FBP and ground truth)''' +│ eval.py +│ '''script to evaluate data with a trained model''' +│ losses.py +│ '''script to inspect and plot losses of a trained model''' +│ model.py +│ '''UNet class and functions''' +│ orig_model.py +│ '''FBPConvNet class and functions''' +│ README.md +│ '''information on the repository''' +│ train.ipynb +│ '''jupyter notebook to train a model on the high performance cluster''' +│ train.py +│ '''script to train a model on the high performance cluster''' +│ trainer.py +│ '''Trainer class and functions''' +│ train_locally.py +│ '''script to train a model on my local machine''' +│ train_test.py +│ '''script to train a model with small datasets''' +│ utils.py +│ '''additional functions''' + + +``` \ No newline at end of file diff --git a/data.py b/data.py index 9afeec6ddaa6ba60d965a7751f9de49b8ab953f6..9888ffa3caff066c97fbe72074065717388da29e 100644 --- a/data.py +++ b/data.py @@ -1,15 +1,20 @@ +""" +functions for data loading and preparation +""" + import math from pathlib import Path import numpy as np import h5py import torch from torch.utils.data import Dataset -from typing import Union, List, Callable, Tuple +from typing import Union, List, Callable, Tuple, Dict from tqdm import tqdm import torch.nn.functional as F -def get_im(input: Union[torch.Tensor, np.ndarray]): +def get_im(input: Union[torch.Tensor, np.ndarray] + ) -> np.ndarray: """ returns the 2D image of a sample @@ -20,12 +25,15 @@ def get_im(input: Union[torch.Tensor, np.ndarray]): def get_lodopab_ground_truth(data_loc: Union[Path, str], - indicator: Union[str, List[str]] = ['train', 'validation', 'test']): + indicator: Union[str, List[str]] = ['train', 'validation', 'test'] + ) -> Union[Dict[str, np.ndarray], np.ndarray]: """ Function to load the Ground Truth Lodopab data similarly to the cached data :param data_loc: Path to directory with all necessary data files in h5py :param indicator: 'train', 'test', or 'validation', or a list of several indicators to define which data is loaded + + :return data: np. ndarray of data or if more than one indicator dictionary of the form {indicator: np.ndarray} """ data_loc = Path(data_loc) indicator = [indicator] if type(indicator) is not list else indicator @@ -49,7 +57,8 @@ def get_lodopab_ground_truth(data_loc: Union[Path, str], def load_cache_files(cache_files: dict, - indicator: str): + indicator: str + ) -> np.ndarray: """ :param cache_files: input location - dict of form {'train': <path to .npy file>, 'validation': <path to .npy file>, @@ -64,7 +73,8 @@ def load_cache_files(cache_files: dict, def get_ground_truth(idx: Union[int, List[int]], indicator: str, - path: Union[Path, str]) -> np.ndarray: + path: Union[Path, str] + ) -> np.ndarray: if type(idx) is int: gt_idx = idx % 128 gt_file_num = math.floor(idx / 128) @@ -171,10 +181,10 @@ class LodopabDataset(Dataset): self.inputs = F.pad(input=torch.from_numpy(self.inputs), pad=(0, 6, 6, 0), mode='constant', value=0) self.targets = F.pad(input=torch.from_numpy(self.targets), pad=(0, 6, 6, 0), mode='constant', value=0) - def __len__(self): + def __len__(self) -> tuple: return self.inputs.shape[0] - def __getitem__(self, idx): + def __getitem__(self, idx) -> [torch.Tensor, torch.Tensor]: input = self.transform(self.inputs[idx]) if self.transform else self.inputs[idx] target = self.target_transform(self.targets[idx]) if self.target_transform else self.targets[idx] return input, target diff --git a/data_inspect.py b/data_inspect.py index d38cbf7fb946519e8b9aa57020024798d2cf311b..06c681c59ffc095175ec4705c9c291e9fe2bf20a 100644 --- a/data_inspect.py +++ b/data_inspect.py @@ -1,3 +1,7 @@ +""" +script to inspect single data samples (FBP and ground truth) +""" + import torch from data import get_samples_from_cache, get_im from torchvision import transforms diff --git a/eval.py b/eval.py index 3929e43d28d23566a8ee088fe0f962d4e5d7a6e1..f2eda7ced98dbbc1777a8a0792d5b8b014dc0d7e 100644 --- a/eval.py +++ b/eval.py @@ -1,3 +1,7 @@ +""" +script to evaluate data with a trained model +""" + from pathlib import Path import torch import numpy as np diff --git a/losses.py b/losses.py index 68e2b9ae43777ceb3a2da6014f342fc930336235..452d03a6df992c91c16f39f5ca0d6e7caee26c04 100644 --- a/losses.py +++ b/losses.py @@ -1,3 +1,7 @@ +""" +script to inspect and plot losses of a trained model +""" + import numpy as np import pickle import matplotlib.pyplot as plt diff --git a/model.py b/model.py index 6736804d7f21a9c1f6ac515633b2f52915b2d619..2f50e0ed5b54ea7a7017ac24c679c08b08e75d36 100644 --- a/model.py +++ b/model.py @@ -1,3 +1,7 @@ +""" +UNet class and functions +""" + from turtle import forward import torch from torch import nn diff --git a/orig_model.py b/orig_model.py index 606a5f5003cce528b21845b02678b0ec8fde710c..ddbd9852467a6b84765aafdd63e02b002bf0b0e9 100644 --- a/orig_model.py +++ b/orig_model.py @@ -1,3 +1,7 @@ +""" +FBPConvNet class and functions +""" + import torch import torch.nn as nn diff --git a/reco.py b/reco.py deleted file mode 100644 index b22b7464ea01191b9c723a5e6ac957c518449623..0000000000000000000000000000000000000000 --- a/reco.py +++ /dev/null @@ -1,46 +0,0 @@ -from pathlib import Path -import torch -import numpy as np -import matplotlib.pyplot as plt -from model import UNet -from torchinfo import summary -from data import get_samples_from_cache -import torch.nn.functional as fun -import functools - - - -# data locations -cache_files = { - 'train': - '../data/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', - 'validation': - '../data/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy', - 'test': - '../data/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'} - -model_path = Path("models") / "2022-05-12_unet.pt" -sample = get_samples_from_cache(cache_files=cache_files, transform=functools.partial(torch.unsqueeze, dim=0)) -transform = functools.partial(torch.unsqueeze, dim=0) -input = torch.unsqueeze(sample, dim=0) -input2 = transform(sample) - -# idx = np.random.randint(cache_data.shape[0]) -# sample = torch.from_numpy(cache_data[idx])[None, None, :, :] -# input = fun.pad(input=sample, pad=(0, 6, 6, 0), mode='constant', value=0) - -model = UNet() -model.load_state_dict(torch.load(model_path)) -# model.eval() - -img = torch.rand((1, 1, 362, 362)) -img = torch.rand((1, 1, 358, 358)) -img = torch.rand((1, 1, 368, 368)) - -model(img) -out = model(input) -summary(model, input_data=input) - -plt.figure() -plt.imshow(out.detach().numpy()[0][0]) -plt.savefig('temp.png') \ No newline at end of file diff --git a/train.ipynb b/train.ipynb index 09af310e62dc27dc48003cc4743fa97e7fd9ccd4..6a924e828f7cca47cb05c3d63539856403962864 100644 --- a/train.ipynb +++ b/train.ipynb @@ -1,9 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "fe78fece", + "metadata": {}, + "source": [ + "Notebook to train a model on the high performance cluster" + ] + }, { "cell_type": "code", - "execution_count": 1, - "id": "84110e61", + "execution_count": null, + "id": "4b7a7e5e", "metadata": {}, "outputs": [], "source": [ @@ -469,7 +477,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.9.7" } }, "nbformat": 4, diff --git a/train.py b/train.py index 3968ef167eb68b9e4447ab92ce6de8de28baa3c4..e14545025b8890f83609e50c910f435abfec34ab 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,6 @@ +""" +script to train a model on the high performance cluster +""" import torch from data import LodopabDataset from trainer import Trainer @@ -13,7 +16,7 @@ from utils import check_filepath start = datetime.now() print(f'Starting training at {start.strftime("%Y-%m-%d %T")}') -# data locations +# input data locations cache_files = { 'train': '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', @@ -24,12 +27,19 @@ cache_files = { ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data' +# output paths model_name = f'{date.today().strftime("%Y-%m-%d")}_unet' model_dir = Path(f'./models/{model_name}/') model_output_path = model_dir / f'{model_name}.pt' loss_output_path = model_dir / f'{model_name}_losses.p' + +# adjustable parameters epochs = 500 save_after_epochs = 25 +learning_rate = 0.1 +weight_decay = 1e-8 +batch_size = 16 + # create Dataloaders training_dataset = LodopabDataset(cache_files=cache_files, @@ -39,7 +49,7 @@ training_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) validation_dataset = LodopabDataset(cache_files=cache_files, @@ -49,7 +59,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) # auto define the correct device @@ -86,7 +96,7 @@ trainer = Trainer(model=model, # start training training_losses, validation_losses, lr_rates = trainer.run_trainer() -# save +# save final model and losses torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num')) losses = {'training_losses': training_losses, 'validation_losses': validation_losses, diff --git a/train_locally.py b/train_locally.py index dd1f2477fb96eafdde49afed46cc1e116f04b322..51034bdc749a07abae4a97ce549d6e6dc3f3eccb 100644 --- a/train_locally.py +++ b/train_locally.py @@ -4,10 +4,16 @@ from trainer import Trainer from model import UNet import functools from torchinfo import summary -from datetime import date +from datetime import date, datetime from pathlib import Path import pickle + +# timer +start = datetime.now() +print(f'Starting training at {start.strftime("%Y-%m-%d %T")}') + +# input paths cache_files = { 'train': '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', @@ -18,10 +24,16 @@ cache_files = { ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/' +# output paths model_output_path = Path('./models') / f'{date.today().strftime("%Y-%m-%d")}_unet.pt' loss_output_path = Path('./models/losses') / f'{date.today().strftime("%Y-%m-%d")}_unet.p' + +# adjustable parameters epochs = 2 save_after_epochs = 1 +learning_rate = 0.1 +weight_decay = 1e-8 +batch_size = 16 # create Dataloaders training_dataset = LodopabDataset(cache_files=cache_files, @@ -31,7 +43,7 @@ training_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) validation_dataset = LodopabDataset(cache_files=cache_files, @@ -41,7 +53,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) # auto define the correct device @@ -50,15 +62,15 @@ print(f"Using {device} device.") # model defition model = UNet() -print('model initialized:') -summary(model) model.to(device) +# print('model initialized:') +# summary(model) # training parameters criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), - lr=0.01, - weight_decay=1e-8) + lr=learning_rate, + weight_decay=weight_decay) # initiate Trainer trainer = Trainer(model=model, @@ -77,10 +89,16 @@ trainer = Trainer(model=model, # start training training_losses, validation_losses, lr_rates = trainer.run_trainer() -# save +# save final model and losses torch.save(model.state_dict(), Path.cwd() / model_output_path) losses = {'training_losses': training_losses, 'validation_losses': validation_losses, 'lr_rates': lr_rates} with open(Path.cwd() / loss_output_path, 'w') as file: pickle.dump(training_losses, file) + +# timer +end = datetime.now() +duration = end - start +print(f'Ending training at {end.strftime("%Y-%m-%d %T")}') +print(f'Duration of training: {str(duration)}') diff --git a/train_test.py b/train_test.py index f743c9a6db8df981c46ce074bd45dbbd14b8be77..11b69f33c4f42863ee65ad5da2ecc26aaf3ba37a 100644 --- a/train_test.py +++ b/train_test.py @@ -1,3 +1,7 @@ +""" +script to train a model with small datasets +""" + import torch from data import LodopabDataset from trainer import Trainer @@ -13,7 +17,7 @@ from utils import check_filepath start = datetime.now() print(f'Starting training at {start.strftime("%Y-%m-%d %T")}') -# data locations +# # cluster input data locations # cache_files = { # 'train': # '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy', @@ -24,7 +28,7 @@ print(f'Starting training at {start.strftime("%Y-%m-%d %T")}') # # ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data' -# local data locations +# local input data locations cache_files = { 'train': '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/fbp/small_cache_lodopab_train_fbp.npy', @@ -35,12 +39,18 @@ cache_files = { ground_truth_data_loc = '//hpc.isc.pad.ptb.de/hipcl/daten841/kaspar01/lodopab-ct/data/' +# output paths model_name = f'{date.today().strftime("%Y-%m-%d")}_unet_test' model_dir = Path(f'./models/{model_name}/') model_output_path = model_dir / f'{model_name}.pt' losses_output_path = model_dir / f'{model_name}_losses.p' + +# adjustable parameters epochs = 2 save_after_epochs = 1 +learning_rate = 0.1 +weight_decay = 1e-8 +batch_size = 16 # create Dataloaders training_dataset = LodopabDataset(cache_files=cache_files, @@ -50,7 +60,7 @@ training_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) validation_dataset = LodopabDataset(cache_files=cache_files, @@ -60,7 +70,7 @@ validation_dataset = LodopabDataset(cache_files=cache_files, target_transform=functools.partial(torch.unsqueeze, dim=0)) validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset, - batch_size=16, + batch_size=batch_size, shuffle=True) # auto define the correct device @@ -76,8 +86,8 @@ model.to(device) # training parameters criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), - lr=0.01, - weight_decay=1e-8) + lr=learning_rate, + weight_decay=weight_decay) # initiate Trainer trainer = Trainer(model=model, @@ -96,7 +106,7 @@ trainer = Trainer(model=model, # start training training_losses, validation_losses, lr_rates = trainer.run_trainer() -# save +# save final model and losses torch.save(model.state_dict(), check_filepath(filepath=model_output_path, expand='num')) losses = {'training_losses': training_losses, 'validation_losses': validation_losses, diff --git a/trainer.py b/trainer.py index 65a0155f694f3653dcb63f8753426e060dc1b964..486e36c322f021f65b48fb2d8623bdecc22c465f 100644 --- a/trainer.py +++ b/trainer.py @@ -1,5 +1,7 @@ +""" +Trainer class and functions +""" import pickle - import numpy as np import torch from typing import Union @@ -9,6 +11,9 @@ from utils import check_filepath class Trainer: + """ + for training and validation + """ def __init__(self, model: torch.nn.Module, device: torch.device, @@ -53,7 +58,7 @@ class Trainer: from tqdm import trange epoch_progress = trange(self.epochs, desc='Epoch') - for epoch in epoch_progress: + for _ in epoch_progress: self.epoch += 1 self._train() @@ -76,9 +81,7 @@ class Trainer: losses = {'training_losses': self.training_loss, 'validation_losses': self.validation_loss, 'lr_rates': self.learning_rate} - save_loss_as = check_filepath(filepath=self.loss_output_path.with_name( - self.loss_output_path.name.replace('.p', f'_e{self.epoch:03}.pt')), - expand='num') + save_loss_as = check_filepath(filepath=self.loss_output_path, replace=True) with open(save_loss_as, 'wb') as file: pickle.dump(losses, file) @@ -117,7 +120,7 @@ class Trainer: self.model.eval() validation_losses = [] - batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.training_dataloader), + batch_progress = tqdm(enumerate(self.validation_dataloader), 'Validation', total=len(self.validation_dataloader), leave=False) for i, (x, y) in batch_progress: diff --git a/utils.py b/utils.py index d052c0f48edcab72eecb975ccfcfa2eb25da18f9..091af06283fd9619b4db7a115be6e7ac547f6d87 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,6 @@ +""" +additional functions +""" import os from pathlib import Path from datetime import datetime @@ -5,14 +8,14 @@ from datetime import datetime def check_filepath(filepath: (str, Path), makedir: bool = True, - overwrite: bool = False, + replace: bool = False, expand: str = 'date', **kwargs) -> Path: """ Checks a filepath existence and possibly creates directories :param filepath: filepath to the target :param makedir: create nonexistent directories in the filepath - :param overwrite: toggles allowance to overwrite existing data in the filepath + :param replace: toggles allowance to overwrite existing data in the filepath :param expand: 'date' (default): expands existing files with datetime (up to minute), 'num': expands with running digit :return filepath: The correct path to the existing or created folder @@ -23,7 +26,7 @@ def check_filepath(filepath: (str, Path), os.makedirs(Path(filepath.parent)) else: raise ValueError('Filepath does not exist. To create, rerun with makedir = True.') - if not overwrite: + if not replace: if expand == 'date': if filepath.exists(): filepath = Path(filepath.as_posix().replace(filepath.stem + filepath.suffix, diff --git a/validate.py b/validate.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000