Newer
Older
"id": "84110e61",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from data import LodopabDataset\n",
"from trainer import Trainer\n",
"from model import UNet\n",
"from torchinfo import summary\n",
"import functools\n",
"from datetime import date\n",
"from pathlib import Path"
"id": "15aaa025",
"metadata": {},
"outputs": [],
"source": [
"# data locations\n",
"cache_files = {\n",
" 'train':\n",
" '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_train_fbp.npy',\n",
" '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_validation_fbp.npy',\n",
" '/gpfs001/scratch/kaspar01/lodopab-ct/fbp/cache_lodopab_test_fbp.npy'}\n",
"ground_truth_data_loc = '/gpfs001/scratch/kaspar01/lodopab-ct/data'\n",
"\n",
"model_path = Path('./model') / f'{date.today().strftime(\"%Y-%m-%d\")}_unet.pt'"
"name": "stderr",
"output_type": "stream",
"text": [
"/home/kaspar01/projects/unet/data.py:80: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1607370131125/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
" self.inputs = torch.from_numpy(self.inputs)\n"
"training_dataset = LodopabDataset(cache_files=small_cache_files,\n",
" split='train',\n",
" transform=functools.partial(torch.unsqueeze, dim=0),\n",
" target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
"\n",
"training_dataloader = torch.utils.data.DataLoader(dataset=training_dataset,\n",
" batch_size=16,\n",
" shuffle=True)\n",
"\n",
"validation_dataset = LodopabDataset(cache_files=small_cache_files,\n",
" split='validation',\n",
" transform=functools.partial(torch.unsqueeze, dim=0),\n",
" target_transform=functools.partial(torch.unsqueeze, dim=0))\n",
"\n",
"validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,\n",
" batch_size=16,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"id": "ada1e3df",
"metadata": {},
"outputs": [],
"source": [
"# model defition\n",
"model = UNet()\n",
"print('model initialized:')\n",
"summary(model)"
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device.\n"
]
}
],
"source": [
"# training parameters\n",
"criterion = torch.nn.MSELoss()\n",
"optimizer = torch.optim.SGD(model.parameters(),\n",
" lr=0.01,\n",
" weight_decay=1e-8)\n",
"\n",
"# auto define the correct device\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using {device} device.\")"
]
},
{
"cell_type": "code",
"id": "cadd871b",
"metadata": {},
"outputs": [],
"source": [
"# initiate Trainer\n",
"trainer = Trainer(model=model,\n",
" device=torch.device(device),\n",
" criterion=criterion,\n",
" optimizer=optimizer,\n",
" training_dataloader=training_dataloader,\n",
" validation_dataloader=validation_dataloader,\n",
" lr_scheduler=None,\n",
" epochs=2,\n",
" epoch=0,\n",
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
" notebook=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4c1bd4e1",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "65d276a3bd7b46ec90de8b74d13176c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e5a3cceed1a1401390c92a470c27a1ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0%| | 0/2239 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "IndexError",
"evalue": "index 29756 is out of bounds for dimension 0 with size 128",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# start training\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m training_losses, validation_losses, lr_rates \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/projects/unet/trainer.py:45\u001b[0m, in \u001b[0;36mTrainer.run_trainer\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m epoch_progress:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidation_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate()\n",
"File \u001b[0;32m~/projects/unet/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer._train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 61\u001b[0m train_losses \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 62\u001b[0m batch_progress \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m'\u001b[39m, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_dataloader))\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, (x, y) \u001b[38;5;129;01min\u001b[39;00m batch_progress:\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28minput\u001b[39m, target \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice), y\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/notebook.py:257\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 257\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28msuper\u001b[39m(tqdm_notebook, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__iter__\u001b[39m():\n\u001b[1;32m 258\u001b[0m \u001b[38;5;66;03m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:435\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()\n\u001b[0;32m--> 435\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 438\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/dataloader.py:475\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 474\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 475\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 477\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data)\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
"File \u001b[0;32m~/anaconda3/yes/envs/dival/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:44\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfetch\u001b[39m(\u001b[38;5;28mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_collation:\n\u001b[0;32m---> 44\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
"File \u001b[0;32m~/projects/unet/data.py:88\u001b[0m, in \u001b[0;36mLodopabDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minputs[idx]\n\u001b[0;32m---> 88\u001b[0m target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtargets[idx]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_transform \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m, target\n",
"\u001b[0;31mIndexError\u001b[0m: index 29756 is out of bounds for dimension 0 with size 128"
]
}
],
"source": [
"# start training\n",
"training_losses, validation_losses, lr_rates = trainer.run_trainer()"
{
"cell_type": "code",
"execution_count": null,
"id": "bf6c39f5",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), Path.cwd() / model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b69efe6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
}
},
"nbformat": 4,
"nbformat_minor": 5
}