From 800f63085d223b0072b57d5111a7029630ffcb11 Mon Sep 17 00:00:00 2001 From: Nando Farchmin <nando.farchmin@gmail.com> Date: Mon, 27 Jun 2022 08:52:40 +0200 Subject: [PATCH] Add function approximation example --- nbs/function_approximation.ipynb | 206 +++++++++++++++++++++++++++++++ src/__init__.py | 1 + src/approximation.py | 170 +++++++++++++++++++++++++ 3 files changed, 377 insertions(+) create mode 100644 nbs/function_approximation.ipynb create mode 100644 src/approximation.py diff --git a/nbs/function_approximation.ipynb b/nbs/function_approximation.ipynb new file mode 100644 index 0000000..660fb2a --- /dev/null +++ b/nbs/function_approximation.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:51:44.140941Z", + "start_time": "2022-06-27T06:51:43.403583Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import neural_networks_101.src as src\n", + "\n", + "%matplotlib inline\n", + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:51:44.157724Z", + "start_time": "2022-06-27T06:51:44.143028Z" + } + }, + "outputs": [], + "source": [ + "# Get CPU or GPU device for training\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device = torch.device(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:51:44.184037Z", + "start_time": "2022-06-27T06:51:44.164227Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2022-06-27 08:51:44] generate training and test data\n" + ] + } + ], + "source": [ + "# generate samples\n", + "print(src.misc.time_stamp(), \"generate training and test data\")\n", + "x_train = np.random.uniform(0, 1, (10000, 2))\n", + "x_test = np.random.uniform(0, 1, (1000, 2))\n", + "y_train = src.target_function.sin2d(x_train).reshape(-1, 1)\n", + "y_test = src.target_function.sin2d(x_test).reshape(-1, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:51:44.209802Z", + "start_time": "2022-06-27T06:51:44.185625Z" + } + }, + "outputs": [], + "source": [ + "# define model, loss and optimization algorithm\n", + "model = src.approximation.NeuralNetwork(x_train.shape[1], y_train.shape[1], width=1024).to(device)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-02)\n", + "loss_function = torch.nn.MSELoss(reduction=\"mean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:52:01.055372Z", + "start_time": "2022-06-27T06:51:44.211488Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "progress: 0.0 % -- loss: 0.2509300410747528\n", + "progress: 63.7 % -- loss: 0.22826001048088074\n", + "[2022-06-27 08:51:48] time: 3.93 s\n", + "[2022-06-27 08:51:48] test set avg. loss: 0.015558036975562572\n", + "progress: 0.0 % -- loss: 0.30233126878738403\n", + "progress: 63.7 % -- loss: 0.20259657502174377\n", + "[2022-06-27 08:51:52] time: 3.79 s\n", + "[2022-06-27 08:51:52] test set avg. loss: 0.0056471554562449455\n", + "progress: 0.0 % -- loss: 0.0772915855050087\n", + "progress: 63.7 % -- loss: 0.07037431001663208\n", + "[2022-06-27 08:51:56] time: 3.75 s\n", + "[2022-06-27 08:51:56] test set avg. loss: 0.004849676508456469\n", + "progress: 0.0 % -- loss: 0.0791594386100769\n", + "progress: 63.7 % -- loss: 0.025842074304819107\n", + "[2022-06-27 08:52:00] time: 4.19 s\n", + "[2022-06-27 08:52:01] test set avg. loss: 0.0014430878218263388\n" + ] + } + ], + "source": [ + "n_epochs = 4\n", + "for epoch in range(n_epochs):\n", + " with src.misc.timeit(\"time: {:4.2f} s\"):\n", + " src.approximation.train(model, device, x_train, y_train, loss_function, optimizer, log_interval=100)\n", + " test_loss = src.approximation.test(model, device, x_test, y_test, loss_function)\n", + " print(src.misc.time_stamp(), f\"test set avg. loss: {test_loss}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-27T06:52:01.684805Z", + "start_time": "2022-06-27T06:52:01.061488Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAuc0lEQVR4nO2df4xe1Znfv4/HnoFie7DBBNZjg2uBB4NCRSjZRv1BVbEBIi3qilVhu6yCEllUy6oSUrv+p2ylKIqqpFITLbvISlGUqFqUpmhDkbMsWmmbP6JUuChh7bVNbPxrIF4b7B1sB+a1Pad/zNzhzp374/w+z7n3+Uiv7Hnvve97nvO853ue85xzzyWlFARBEIT+syp1AQRBEIQ4iOALgiAMBBF8QRCEgSCCLwiCMBBE8AVBEAaCCL4gCMJA6BR8InqRiM4Q0f6G40RE3yKiI0T0FhHd67+Ygm/Er/1FfCs0oRPhfwfAQy3HHwZw++JrF4A/dS+WEIHvQPzaV74D8a1QQ6fgK6V+DOBcyymPAviuWuCnAK4nolt8FVAIg/i1v4hvhSZWe/iMzQBOlf6eWXzvl9UTiWgXFiIK0MT4Z9bcssnD1zdDIwr6+Wo8zl3KoewYv+lmXD73PojorFKq6gw2ftWxf2xk9plXx7vP8e3fOjvqyj328fLvvXrNyuu6yj/xa1OYe2/masNhK9+O0ZrPXLd6Q/sX1zB/7ZoV71Vtujqx8rqxiabid3N1bmzhM+bqj9fVcfW9glUfXV72d9mesh1lG4qyrx//GABww9jF2s/+4Orapf9/OLpmWdmXlbdkx0dnZt6vaa9a+BD8utZYW3NKqT0A9gDAxLYpdct//gMPX9/MxEmNVu3A3FZDlbEklB2Xz5/Du//j2xidOX2i5jAbv+rYv+6EuThfuLW9I/Ht3zo76so9eXSlSs1uX66IXWUHgLefe/ZywyEr306Of0p97qZ/0/m9ZT66e/OK987vWFkPF7bP116/dtus0fcBwMVjk0v/X3fUbF3KhsNmPi/bUrahKPdvbD0MAHhyw09qr//e+c8t/f8vT+4AsLz8BWU73vrWs3XtVQsfq3RmAGwp/T0F4D0PnyukRfzqGZ8diE0HVyKKb+vE3pQ68cuR753/3DJxL94rKMS+DtNOqw0fEf4rAJ4hopcAfBbArFJqxdBQyI4gfr1t6uyK947PhE3tCSsI3mZNxL4pui+oir5N1K/L+R3jxlF+G395csdSlA9ghehXcRmd6NAp+ET0ZwAeAHAjEc0A+CMAawBAKfUCgL0AHgFwBMCvADzlvZSCd375P7+HXx07gqu/ugQAnyaiLyGAX+sEvumcJuGf2zpqTevYRrvrTiit1EhILtxKrtF6GxMx26yOyNelc0wpi2JV/H2MCHRF34ctQHt075tOwVdKPdFxXAH4fW8lEqJwy28/ufT/t5979i2l1H8vH/fhVx2xr57PKdqfODkebZ4mEG8qpe6rvhmizfpI39hw8dhkkIjfNdIvl6sa5bddUxAiugf8pHSC0SUYnMShiT7YEBNuoq9L1c++bZg8Ordi4pYLLmLflc7RoSmqdxXNNtGvi+7XHV1lbE9ddB9K7AGmWyvcNnXWKBXgwroTqvblCgcbAqYKguKjTmJSV94YNnDwb6rIPhYuaZtyR1Qn7OX3Yk1Os4vwY6UBuhpLcdwmx8vFhuKc1HlqzrjWT26dkw/6LvJVCtEvon3bTkAnVx8yugeYRfi2jcf0OpPIyDSK6oMNHKjWR1MuPaVtXT6rO575nEA2Yh9COM/vGPc2UVsm5tJTVoLPlViiEjJazFH0OTPEyD4F646uCiLeF7bPe5k/6KJLzGPfZ8BG8F0bkO71IYUvlggMRbx91ufk0bnaO1hzIvfym1IW+kL4yy9bYgh9GV1Rr9rk836AAnY5/CoP3nxoxXuvn56OXg6XXG8fbMiVqkjWrXYxrZehRvd1e+LUUU57hBCtLto6g9hiX1C3fLRtGWZRbxsOj7ymkVhE+E0NqE4o297vaohtkXERAfqOotpsqDvG0YZcCVEPQxV7XariVP3bZwfgI9UTswO4eGxy2csG1/KyEPw6moRS97gudQJpIxR1QuCrjF2Y2JBbOqhcr9UJz67Oz+ZYGR8TrKE7CA7+LCYzbSJRV9GuXu8a3bvm9m3siZHKKWAp+LpC6SqoPkShiVg2tCGRvj46winR/XJ0RT7EypYyviZ1y0Ifa1K3SugUGEvB50JVMENFUy6i3yXq3EW/SG2VX65wt9mFPtvmgstEbpuw2wi/STl0zvXZCbATfNMGr3t+VaxDNhxX0dKNJHNv/LpzNCaRtW6d5F53qXGJ2lNM5HaJtunxrvN1VhPFTOUUsBN8bnASBtsRBicbCmLN0eSGja845PG7aOsgQt9d6osi2i/EXjfyrxP+VDYnr2nfeVGdz8tRAIW0SP7+E0Ln5AG/ghgyF2/y2W1Rf6xRD6t1+HWi99j6N1e894MP711xXYp17SZU7ajaYArHTis1pnVSXZMf6j6FXHcANaEQPl2h9r2+3BSbnS2buLB9PptRCutS1ol92/sujB+aWXqZUo7+dDut2DaUxTCHFAAgo54muHX2Pla3VO+qjUEuowifsBb8WHSJvEsDaxN2n6JftcGm4xI+wWQNvq/VRTnSJHS2ou9b7OvKwS0ajzmJzcvyEl1iaCuWVfGuE8Y2sQwZHduKxpDE3fZmqPGDpzyXZIGyz4Ys/EIeUT5bwe8DIdI2gP6Io08dgcukaSH2daKvU5emW38I6bBZRcORa/e/G+Rz2Qi+a+Px3fhiiWWoTqEvuPq1KvK+In0Re/e17SFTGU3fzS2dExuW1uuKoKtY9iEC7oMNfSNFZxBrIt5lZU3KVTlcqev0iug+RJTPUvB9E6sxlBs6l8i93CFwW92RCtcoP6agN5U1F1/GTKvo3P06dKQGWpDoWQhNLktkObDh8CjJtgwxCZW7LxDB1ySXiGpIDOVhLqFWGIWg+pAPX5SFnrPo+xzRhHh+8GAFn1P0ziX9MyTKIiqduTlVYVu7bXZJ7E1EX0e8fQi8pHMWkFoQBCa4djxc00MuUW+T2Fffz3kJZkxE8Jmis+6c0yglJDkugfRR5pxSOSEwiew5RfCcOx8+tWSJpEOEkFQ73lSdD2fxD5W374JzLp8r7AQ/pIBLrlYQ+BFauDlF/6mRmhAGC+eoue+EXn6YGq5pHRF8IRtCP4SkD8s8uU7cmhIrXWMb/ZdXJfmkbilmcYeyj05EBF/oNaM7t3QeKz8EJSckRZkeV9Gv227io7s3B1mDD4jgCz2Ck3CXn8BWftrVxEnZTyZ3qiLfJPoc0zrsBN/10X+CkAt9Sb8I+cBO8EPCKQIU7LF9CAqXzw9NLh1JeeI2VM4+1gqdVEtTTdGqDSJ6iIgOE9ERItpdc3ySiP43Ef2ciA4Q0VP+i1pPX0cEOg+9Hk1POX3HpV8cxLFvfg0A7ubmV8EL4lfPmAq7a1rH95bSnYJPRGMAngfwMICdAJ4gop2V034fwN8qpe4B8ACA/0pEkqxkjJqfx5lXX8bmJ3cBwAEE8KtOp8WRtjx7OTefAUH86oNQk5JCOzoR/v0Ajiil3lFKjQC8BODRyjkKwDoiIgBrAZwDcMVrST3jGh37JMUo5eOZk1iz8UaMb7wBWPBfL/xaR91KnbbVOylY8axlx3sEFtM6LPyqm1bJ/c5Z28lbnSi++AzX1JGOJzYDKP/6ZhbfK/PHAO4E8B6AvwHw75VSK6wkol1EtI+I9l29cMmyyGnoW/7/yoVZrJ68vvzWIP3aRa6jlBLWfgWW+/byKI5vcxd+zugIft3dKNVZoc8D+BmAXwPwjwD8MRGtX3GRUnuUUvcppe4bW3edYVHjw2kU4B1VO7EXza+ZpUaMYGiblV+B5b5dMx6+zfZV7MsPV9clxCMhdQR/BkB5/DuFhcigzFMAXlYLHAFwDACbX32sOyjLDZ3LZHJTp7V6/fW4Mvv35bfY+rVNQKu+1RmJcUvnBIatX13z+FVBtF2R43MlT1fKpUn4Yz3vV8fSNwDcTkTbFid2HgfwSuWckwD+FQAQ0acA7ADwjm2hdMWSi6imxGQUUhbDazZvweVzZ3H5/AfAwiguiF+5pEQKkQ8t9syie69+vXpNfeB08dikUyGFdnzewNUp+EqpKwCeAfAagIMAvq+UOkBETxPR04unfQXA54jobwD8FYA/VEq9b1KQFA2FQ8omVadFY2PY9IXfwsx39wDAXQjkV074EnuT32riDi+JX5s6gBjpmtDr7n11bl1Rfl3E72Ot/2qdk5RSewHsrbz3Qun/7wH4DefSOOC7w+DQGTQxu31Cax+VLhvW3rETa+/Yibefe3a/UuqrQFy/vn56OsuHmzTR9Rs02VZhdOeWZSt1mjqryaNzbWmsaH69eGySzc1H646uMo6Kba5pYu22WeuOIXRqh+2dtl2Rr21kXG0cpsJuMh9gWkbbTotz5+STENGyr9VXIUeobSMTTqvHLh6bXCF0oSJuHWE0/e4h7Jvffws1qApm9e9QjcpnOqfLhlzREVKbiVtXXj89vVQ2G7HX2f5gdOeWoU0wJ6dJ9E0f0O5rtON7AzbWgt8kiCHy3qPpqaWXKW2RZ2wbyv9ygcvEbQiYTdIOghDRvQ6mIs4lxVWGteADC8JYiGP5/1ypE4BquUPa0CT25aiX04M+2gQzdzHtc0cn6BHqQSm2JBf8cqNoa+BtImkqDKnynm0dVi425IJO/XDtBHXh/huoRtmx1prHwMdqnRT75ScXfN/4jqo4NSpbUeJkQx3Vzq6cHy/Iza8S3YeDw924MaL2EB0CO8EPNYyPObHnaoOuWHAX8jJdNhUib1t3dZ1hU/30QexT+74rR950vE9Rfkx8dTDsBN+UkHneWI2qDzbERvchJbPbJ1a8TDERcInsu3ER/fM7xtl0GiZpHS55fJaCryuArkLZ1vjrjpmkVGJNOPZV0MvEFlHuT7zi7nOdFTKuwt2XtE7sPD5LwQfSCqavKFDHBh92Ntng2mlxhuPDwNs6pq7y6v7mOIl9nbCbLoc0Ff0Nh0dLYs9B9HNDa2uF0Byf2YTbps6ueL/p1vsmkeyKBC/cSrU3vITO55vY0EUKG0ygUXOH0uTnNkyi+6a60bnOlSHk7evwsd79/I7x6OJd/b5yx+NzmwUdLmyfj3aXLwvBb4PLWmwXUchos63gmIh+LnXRVU7d0YjuHkl9xFb0NxweWY0ShgrblE6OxBIo284nh3TO8ZlNS68c8CX2BU2pOI7RfUx8TdQ2iX1XJ9CXLaDZRPg2w/3q9TrYDPtjCWVIkYsp9hMnx1snPmPYaZPaMSWkHUMX+Co+onLOkX01rRMqpcQqwrdtQKbXmYifqVByjPJTRPapJ1Uv3ErLXm3nlWnqqKp+1fFz6jrIDS7LLfsMK8EHzAXTVmB1RNBWKG3KZHNNl5gV5/QNGyHVqasuTNJNbWWMMfrgQortA0IROq3TVle+vptNSqeMbnrHNZoOKYaxbAD4inpXaicF5ZRPqHqTyH65eBX/112fX6Reioi/LRWjOyrgkM5xeTCKL1gKPpDPCo02+mADN3yJqYh9fMqdQJv4c0jtxF6aWRD6O9kKvtAPXKL83MRTp7y+0jlNHRaHdJGOaJmsPU+xTp8jPh4jKYIvBKcQwjbh5yDuLumnWGI/xDmbJlKMBFxF1zWt45oSEsEXosFB1EMQy64cxNwkJeFyhymHtE+OsFul4xNuE4ZC/9AVe5fo3scKo9zxLfDX7n+385xQ2x2k3DkzWYTftueKT/og+n2wIRamaaPQaRygXeybbhQbusC34WMrhY/u3uyrOFmRNKUTcigcUyT7Ykff8VmXrmJfFfQcBP5qx82/NitMTCdvTTGZ7LXZl8eWVEs0JYfvSF/z0jkSq3OMkcbhisma+hT4Xs1jszyzLORcHnxSwNNrwqDhLJRDFvsyXO6gLfbHL++Tz4m2KD5FZyARvpCUJmEs3tdNdcSI7kXsm3HpAIq0ju6NWbnhY/28L/pTq0I2rDuhll65IGJfz4Xt816i/epnpBpB1I0SQnY+sTsCifCB2j1vctsWIQcbbMRQZ88b0+i+XFe+dr00sc12grbOzlRzSH2KwH3SlMLxEeX7qHP2gh9quVrXxmbFcV+iGcKO2Da44BL5um501lZPXXUUc7uENpo6teL9FMIfOgqP+ei/ocBW8Ltyu4C9aJo8aOW2qbPWgtkmBKY56rpymZ6bSvhDCmJbdB/Dz6HFXnf0Mrd1JCvGKhQ3V/Vhzb2vjo9l96nbiGwam81TtWyu4WZD7pjWk4866hLQEGI/t3W07GVzbSw4R986d9I2EXu1T8w8PjuPmTYik/NdRCCkyMaa6EvRUaSI7mPY6dsun2IdU/hDij6nDkWnLKn3uteBT432CNvJSR1cxYzr6GDy6Bwmj84lLYNu3Zj4N7VNMSjEcN3RVd5E2uVzXKJ7jvjs+LQ+iYgeIqLDRHSEiHY3nPMAEf2MiA4Q0f+xKYxt1KRznQ+h0/mM1MvyHrz5kPa5H711GO/u/gYA3B3Crzp1URV6V+Hn1KHp2hEwGg/i1zp8ipLJZ4XaCoHjTVw+6Jy0JaIxAM8DeBDADIA3iOgVpdTfls65HsCfAHhIKXWSiG4KVF7vNAnk66enI5ekmyYxq9pQ/rvJDjU/j3Pf+yFu+g9fwnv/8esHADwR06+ukW+dSHJJ5VRtmzw6h9ntHRvRaGI4AR/dr0C6p0Vxx2Zppu+0ls6n3Q/giFLqHaXUCMBLAB6tnPM7AF5WSp0EAKXUGdOCNDWkIuIrv3zRFg2bRMpd6NpgMzroKmfd8dumzmL0zims/tQNWHPTDQCgEMivdej4MEQq5MGbDy17mWIj9j4pd2ianVtUv3Ik1gqdWA81cR156Aj+ZgCnSn/PLL5X5g4AG4jor4no/xHR79V9EBHtIqJ9RLTv6qVLnV/c1Hjq3jcVS50G3ySWJpjYYIpLp3Tl/IdYvXHZD82rX2Ontrr8UldXPjt1IJzY3zZ1ttY+zd+itV+Bim8vdrfZAk4TrpyoinvsiV4dr9QtFK+25tUAPgPgCwA+D+A/EdEdKy5Sao9S6j6l1H1j113X+qVdjcekcVUbhklDNznXVORiTejV2qBqy5rcryHwIexdvg1hV5PQV8/RwMqvQMW3a9t9yw2u6+9TrubREfwZAFtKf08BeK/mnL9QSl1SSr0P4McA7vFTxHwJmbpwFbHVGydx5dyyfCI7v6YeBeniW+x1hL56fgvs/JoLuukTGwG/eGxyxXXVUVH1bx8TyTqC/waA24loGxGNA3gcwCuVc34I4J8R0Woi+gcAPgvgoG2hdBtQ0gg5EE2RpOuEZNWG8W1TuPJ3H+Dy2XPAwijOm1+5pHNs03YmhBB7j3j1qykp0zoho/uQEXqxtDVU3XWu0lFKXSGiZwC8BmAMwItKqQNE9PTi8ReUUgeJ6C8AvAVgHsC3lVL7g5TYkVjizX39NY2NYePv/ibOfONFALgLwFdC+5V7nZjCyZ6GrSGi+NUnueX+c7jZqozWXjpKqb0A9lbee6Hy99cBfN2mEL4iQttNth5b/+ayv3/w4b1eyqOL6bK9pk6rbEedDQ/efGhpmeZtU2dxHNPYfM80Tnxx936l1FcBd7+OMVm+HLpjTz0Xobls2Jtfc4Jr7r6NWB0du+60qSGNH5rB+KEZ7fN1qYp903s+BKTJBh9Uy1xnw1CImYLzSVs6p+1eC53rU5FbxK5LXQ4+B9h7I6RItmEqmOVRSl0nVLbBtz1NZc1F9McPnmo81tWhpxC52NF9UweWa8eWC7EeaB4T9oJfxUUsqw3EVhBNRSZkh5WLqDeO3BbFvk30TchNBLu2VcjNnjpCR/l9FOYqvrZ6YC34TUIZK+IPKaZVG3SjRlMByKVDAPyJfmrGD55KZgvHtI7AB9aCH5JYQpgiHZULnAXe5mEivoU+xrJSTvQ1388JVjVsmxvltDzOBNfOwKbT4i4QLoKZ0jbfnRdHP41NXLW+VsS8mZh1w9YLXWIokbM+OaV1OJJ6y2uhmyHk8X3AVvB94ZrTNBXL8mhDOqVmOKdzTOmTLakY6gggtt3DrGVLOA6z+4gIqMCRPnRK+VtggaQ44sJhjsWXz6u2cOicJBBZILe0Toonew1S8DkhaZ/w5NbB91HA256A5TNyji36uUX9WnvpCHHR2RMoNxELjc7+Qo+tf1N7n6S5rSOrpZm26Mw1FbbE3uspN6qi7/P5tBwe0u5CXt2T0As4pEFCM7pzS/dJBvS1g88tQs6drGvbJB3Sx2Gy0E5fRLLLDs2dM9kRS+xDp3ly6rRYllTy2kKfaNr6umsfnQJJ4Qi+YCn4ghCCOuHMNToWBBtE8IVewjWdY/OAnjI5RvttK3T6gmlaJ1WdiOALQgb84MN7sxR7LsRYrpnDNtAi+IKwSM0zYaMwxC2Nc5roNMHErhRRPstaH01PeT0PkFxt3+HkX9Mlmak6mpT0Oc3DuTPjWzKht/heo66DSzrENe8uCE3E7vhE8BmiIzCSz23HtX583WXbtCQzFBxHC5wjXg7EFH3xRGJM0lK5Elv0QuJii+5IgVN6SugXIvgCO8opnzqBNb1hKeZoqC1dJQ9SyYeuFTEbDo8a9+ixidhjRfmDFHxbAZDIyx+x8vhtvhZ/2tH3FE1u2yyb0G/PwT2nado5lCPSIaRrhk6540oxGZ0LuXQSLpE9wH/1EVsvdImliKk+OU3w5iqadeU2yfdXAxPd0UfMUcrVubFo38URn9ss1xGjs2Al+LYTYrlOCrp2WjZCzimNkau4V7G1I+Z++4IbdWJfHQ1wj+4BZoIfk1hRr4xEFtDplLl2AKkmWzl1zn2gK13jkrvPQewB5oLfJJaxRDRkp1C1QXeUYioC3NM5hcj7EHuTugktpjFGnbl1COU8fi45fUAvus+FfGrdA9UGYiuGphPBITso7oKuA9fIPgSmo4XXT0/XCnvdexxvuhoCuUT3QAbPtB1NTy17IArX6P7CrdTamMt2cLXBB1fzDHy8YhvdH5/Z1LiRWm7RPFfO7xhnHbFf2D4fdPSTRYQ/mp5aelVxHTrXiaLrgzKaytRkgyu6NsSi6Y5SU1/lOhkv8KYQ9/M7xpdeNtcDeUX3AJMIvxwdz26fwOTROevPsSF1WsSHsOnYUO60chv+l307cXJ86W7balT8+unpzucXh0yHmPqybIsQDy4RfWyyiPB90oehcQ42pNxhsql+mvLhvkg5IsmtAxfscO2otASfiB4iosNEdISIdrec94+J6CoRPeZSKN2GE6uB2YiEbdmahDJEg/7orcN4d/c3AODuGH4F4vqs+soBz36O5tehkHM6B9AQfCIaA/A8gIcB7ATwBBHtbDjvvwB4zXch+4ytAJoKWPV8NT+Pc9/7IW569ikAOIAAfnWJ8nPL35uUN+K6/iB+FZpZu20Wa7fNOn1GyI5EJ8K/H8ARpdQ7SqkRgJcAPFpz3h8A+F8AzvgoWFcDinHbuum5VXza4JvRO6ew+lM3YM1NNwCAQiS/ArzrRYdq+XyX1zTKbzk/ql+HjqvQlwkl+jqCvxnAqdLfM4vvLUFEmwH8awAvtH0QEe0ion1EtO/qpUudX9zUkHw0MFshb2uMdREtBxvqzrty/kOs3jhZfiuaX4F6+2e3T7AX+4KinD7K67LFgkbnYO3XxXM/8e3FS9nfMNWEjsC25c+rYu9D/EOIvo736sbl1THpfwPwh0qpq20fpJTao5S6Tyl139h11y071raUr9y4mhqYTfqgSzB95X2r5fYpajY2HJ/ZBKjatEI0vwKf1EuX0Hf5NtWEZcjO6fjMpk67DOy28itQ8e3a67pOZ0Mh9jqiXwiricDmmL8H9JZlzgAo3wo5BeC9yjn3AXiJiADgRgCPENEVpdSf+ygkEK5xFYJYXsoXaoKvywYdYau7McfGhtUbJ3Hl3LIoJIlfbRjSMkYPnVk2fk2BrnD3ZRmnjuC/AeB2ItoG4F0AjwP4nfIJSqltxf+J6DsAXrX58XTdrdp2XRdtdzECeiKv0/hsbfCBSUc1vm0KV/7uA1w+ew5YGMWx82ssqn7t0S6WQf2aO76idJ+5+yq+77zt/CSl1BUAz2BhNv8ggO8rpQ4Q0dNE9LS3kgwc3ZSUa8RXXE9jY9j4u7+JM994EQDuAlO/mtSLTd30fP06S7/mNA8A5LkNchNad9oqpfYC2Ft5r3bCRyn1RZcCmUaDJrn7rii/61pdQtrgk2vvmcbme6Zx4ou79yulvgqE82sschHwdSdU493DPlgcpfTGr77JWbRdYNnV6gqgjVDGigK52eBynQ9S3nkr8CG36N4FX6ken50T29q/cCs1ikTbMR1MhM9FJLnYYHN+CExt5tJJdJWjsCtVeSdOjvdp3sE71ZU6LgKa+8iAxeZpbYRqRIUANqV4fApkKhvK53CCi5D7oGqLywS1aVqnKvKhRX/Mbk/DwXLx2GT3SZFhL/ih4SiIpvTBBu4UQq7TWYUWfYnmu2na9z5XfK3WGbzgC4IupukoF9EX4jOE+YX+Wyj0ipxuuupT6ipnfN00lXv+HhDBFzIihtj7jq6bJnOlM1iJ7xRMyrtjQ+TvfXQ4ktJxZG7rSIbgDDD1QeyRQiHwnO84Tkl57xufQq37WUNI5wCJBT+n4XkbfbEjJOU6sukg2+rY5vPK1xSfHaPjlsg+PeuOrvKanrl4bHLZmvuQq3Ncy51M8NV4nEgnVCOOKfI5jyDq6kmn7nRt9lE3Pj5DRnr54iO657gEs45ep3RCNkB5+HQ3LvVTjrqbPoeDwJbLFkv069JCMnIwYygpnCq9FnwhHb46Q46daluZfIu+bs5f9x4BjvRpvTx3RPAFQQOTjseH6Oc0uet7opUjfViSCciyTCEAHKNyW+a2jqzssa2DdSeUtdjn1EkIaWAb4etsY5zDlgJdduRggwm5T8aXce24TCN9H4KdKrXjK8rv+0ghNSwFX3fP+tumzjoLZlMjc200HGwAhjWZ1yWYKepiCKt3XEQ6l/y976WcqWAn+KYPKLEVzC5xKI7biAQXG4pzOIm+bt2EGPmU66urTnSj+7I9LmWWdIwQA1Y5fNunUZleZ9K4TBtiLBtM4CImJjaG9Glxvku93DZ1dkUZm8rcpzkNF6rRfPVvSeeEh43guwqe7vU2jTyWYJrUgY3ApcTGvyE7wQIb4W8rFwfRT+3rOspbJxSvMn0S+3VHV7Fd58+zVJkSQ6AA+wbNUQi6qKvTak48pl06PjYV/a7yTx6dq33FZuxju3rOJU/vA65CX8Auh1/lwZsPrXjv9dPT0cvhkgvnYkMqYnWEoUlhR5uwTx6dw+z2iYilsVuN07eHkTTBXewBJhF+U0OqE8q297saZFskFSpyarOh7piLqKSM/kLCocOwmYivwyS1o+PHPvg6l1U+OQh6F2wtaBJK3eO61AmkTSOqa+A6ZTS1o67TMrEhdlontFjHsCfFRHwfhDw0nEYNpp1Bqjw/S8HXFUFX0e8aLlcxERdfHVIXpjYI3ehG4cUozbevffgtZEfISWgB9/LUXW/ymTZib3tteRtmG9jn8FMSK0f64M2HguX0U+R5TQgxv1EnmL7roFruJh/6uLFOCI+uwFdvwGoTbJNzY5G+BBVMoyXd86sRT8jo1zXi000FcI/g2+xoi4xd6q+pTnzNb4Qqtwvcfwd9o0jHmAh4V8cQC3aCzw1Ojcl2mM7JBsBubsPnXEDszp7DpLMQlyKy57YmP3lJyo2hTQgeW/9m47HydTqNi5sAAumiw5xxGbVx+Q3EHHkKZrjODfjee8c1fw9kkMMvC33x/x98eG+q4lhR11m52uAiDDH21zFdatt0Lqf7FUwWE3Aqd9/hepeuTcon9AZtySP8Npqi+rZo35bxQzNLL1N0RyllYtsgkaMdko6ph9tKHW7YpnFCp39YC34sukTeRSzbhN2n6FdtsOm4QuM7bcVpF9AqkqIbJj4i9JCiz1bwu8TQViyr4l0njG1iaTJxalpGW5HgKO6xcR3B2EyIP7b+zaVXDMYPnmo8JiM4PnAWfbaCnxrOIqrbuDnbwIGYIumaGirEvk30h8b5HeOs8vc5PCCFjeC7DoFtrucgiLGiw9T0PcUxFD8KeaMl+ET0EBEdJqIjRLS75vi/JaK3Fl8/IaJ7XAql23ikkbl1Wpd+cRDHvvk1ALjbp1+HMNGp89vz1ck5RPVe/So0k0N0D2gIPhGNAXgewMMAdgJ4goh2Vk47BuBfKKU+DeArAPb4LqgLsTYMKzdwLp1RuUMopzDU/DzOvPoyNj+5CwAOIEO/djF+8NSyV2hcfR4gxcTGr5xSL7kQIo+v84n3AziilHpHKTUC8BKAR8snKKV+opQ6v/jnTwFM+S1mGjikfELx8cxJrNl4I8Y33gAACgH96hLp+kwFDTD/HdSvfaGtM9JZfppLdA/oCf5mAOWWMrP4XhNfAvCjugNEtIuI9hHRvqsXLumXkgF9WwVx5cIsVk9eX34rmV9DrHQZoLg3Ye1XYLlvL4/8t9k+RP7VSJzTVgpVdO60rVvsXJsjIaJ/iYUf0D+tO66U2oPF4ePEtqmkz9vrc/Suhaqt/uh+jZ36Gj94CqM7txhf53OkMbd1tOIxjU146ris/Aos9+26ybRtNiSuT+WyEfni+2J2ejqlnAFQbiFTAN6rnkREnwbwbQCPKqU+8FO8YZAi3796/fW4Mvv35bei+7VqN5d5j57hza+rPrq87G+529aeVHWnI/hvALidiLYR0TiAxwG8Uj6BiLYCeBnAk0qpt/0Xs5m+ioTOSheXUco1m7fg8rmzuHz+A2BhFMfKr7nC7PfIyq/c1s1XSVW2mOLfKfhKqSsAngHwGoCDAL6vlDpARE8T0dOLpz0H4AYAf0JEPyOifcFKLHiBxsaw6Qu/hZnv7gGAuyB+7SMs/Tp00Q/5NK0utHbLVErtBbC38t4Lpf9/GcCXfRQoZITUt4lXV9besRNr79iJt597dr9S6qtAOL9W0fUzl10nXz89nePNY9H92jdidU4bDo+ifBff6WRBEATBKyL4glBBdwWNYA/ntE4qulI9F49NOn+HCL4QFG5pEJslmS5ws1/ops+dkQi+EB1mK1mSwnlPfx1cVphwF9ZQ5UtpNzvBD/n4wtntE8E+WxCA/B6/6YPU6/FDfD/3zsgWdoJvyhAbWF8J7Uuf6ZwQZZWAxJxC7DccHi17cafoUGJ3LNkLviAI6TsLbkLLqSxNpBB9EXymHJ/Z1HnOaHrQmxwaEWOyVkabcYU2B1HXJdZdyIMVfE5iKULhTjXCjb0aR3z4CX0S4r4xWME3JfWQue+EFMwm8dfxadtIy7TMPp4EptuR5b76R6jHdS2+1tYKQ4XTKGCIuG6r4DvKr9teQSJ7IScGEeHHinbKAsVFCFJ3Wlz2whHiEjqto/P5fUkt+dxAjaXg64olF1FNiYmgc0lLVf3GwY+hgoIYoxQufhX4wyal47oboU3DGk1PNe4pHysy5iB2KTC1W2fVUt+R/L3gCssIvy+EEnPdiC51OoczMaJiSWcJ3GAr+F1iaSumK5bvGYpiyOjJViBE2IUYzF+7JnURjOjT9gi+8vhsBT8mVcGs/h0qGvQ5AuiyITe4Rse65ao7zzYtNbpzi9PS0lT0ZdK0T7AW/CZBDJEqGU1PLb1MKTfkakPnYgNnYfCFro2h68KlsxqCn3zQp+g9JmwmbZsohPGx9W82iiTXaLAg1EZbto9sDD2pd3xm07KbjHw+HpD7w0l0f4tzW0dLtly4lbDuhPLy/dwmbGM9uq8K9w7hwvZ5AGapmnVHVy1dZ0vyCL8tOi7jUzQ5RlHcO61cCOFbriuEOP6O6wi1sRp3UW+iLNquAm5KcsH3je/GyalR2UZvnGzQwbXza7O37pivqLit3Fw7jZjEEv3cOoKYos9O8ENFutVGHVIEXW3QFYechNylTmzEsq5ufNVXnS0+f7c65QzZcYUkRLRfFnhXse/7RDM7wTclZCoklqD2wQYddOwMJZwpO3iJ7uOQW2SfApaC77L0zYSQQ/9YOXlOgq7D66enG+smRJ3Nbp8IVkdtthQ0ib3L5HNuPtflwvZ56/SGT7EPGeXHztlXYSn4QFrBtGlQdQ07VkTbZAPnYX8hluUXZ6r+9V1e15QjF7/6ILUopiCWzSwEvykKaosEbW5saWoUhTiGiAZNbegihQ02xEpjuApdSKEMUQccfOuDahRdFjyfu0Oa0Pf8PZDBOnwukZ+LMJjYEFIo+xQFcseHH3XvtRC/DgfXzpBFhA/kExG20QcbfOKjPsqf0ZT3tq2PUPVoa3dXefoS3QvpYCP4rvQhMu6DDangYp+OD00mbEXk41BN56RI78TI47MSfFvBM73ORBxMhYRjlJ9SDI/PbOpNncSwo65cbfMyXDo6X6TK3w8FdrVr2qhchs9djcW2MdmUyeaakDb4Jmad+DjHhtjr7bn41gfrjq6SydoIsJy0rW6+1XaeKyEbf+42+KZsq07d2FLUR92GZH0Re8EPbWLve+O3cspm7bZZAMDFY5PePl8HloIP9KMB9cGGUFTrpugAfNZZdRdKH2Jf15GblFknf6+ze2YunTxndB+EHvMO3gvb54OOdNgKvjAsQnWOMbaCFvLDJI0TUvTXbpuNGuWzy+ELy+G+//sQSTV5q3NMcOPa/e/Wvt+XPL+W4BPRQ0R0mIiOENHumuNERN9aPP4WEYV5erfglUu/OIhj3/waANwtfjUj5uqjMoZi79WvV6/h09EU++647L9TpRD7NtH3IfxF/r7p75DLMzsFn4jGADwP4GEAOwE8QUQ7K6c9DOD2xdcuAH/a+bkjPj8eW3KOvtX8PM68+jI2P7kLAA4gA7/Obc07ykrwe/Hm1wIOm5TVCaKrSFZFvkn0gbyjfZ0I/34AR5RS7yilRgBeAvBo5ZxHAXxXLfBTANcT0S2eyzo4QgrExzMnsWbjjRjfeAMAKHj0a8hy5y76JlSjeYtUjle/FqQU/TZh1xV93e9MKfqhonydSdvNAE6V/p4B8FmNczYD+GX5JCLahYWIAgDmTnxx936j0vLjRgDvpy6EJRsArH/7uWdPANgBj359+7lnc/crkK9vNwBYD8DZr8BK3771rQC+fc37J7Zh5tdmzY9d7jI7bC/UEfy6sKK6ZkznHCil9gDYAwBEtE8pdZ/G97MlZxuI6LcBfF4p9WUi2rf4tvh1kVzt8OlXoH++7YsNttfqpHRmAGwp/T0F4D2LcwReiF/7ifhVaERH8N8AcDsRbSOicQCPA3ilcs4rAH5vcfb/1wHMKqVWDA8FViz5FQsRn/i1H4hfhUY6UzpKqStE9AwWMlZjAF5USh0goqcXj78AYC+ARwAcAfArAE9pfPce61LzIVsbKn69HsA3xa/LyNKOgH4FMq2TCoO2gZRqv4VbEARB6Adyp60gCMJAEMEXBEEYCMEFvw/bMmjY8AARzRLRzxZfz6UoZxtE9CIRnSGi2nXUpn4Qv/JA/LoS8WsLSqlgLyxM8h4F8A8BjAP4OYCdlXMeAfAjLKwo+HUA/zdkmQLZ8ACAV1OXtcOOfw7gXgD7G45r+0H8yuclfhW/mvghdITfh20ZdGxgj1LqxwDOtZxi4gfxKxPErysQv7YQWvCbbuE2PScluuX7J0T0cyL6ERHdFadoXjHxg/g1H8Sv4tclQj8Axdu2DAnRKd+bAG5VSl0kokcA/DkWdiLMCRM/iF/zQfwqfl0idITfh9u8O8unlPpQKXVx8f97AawhohvjFdELJn4Qv+aD+FX8ukRowe/DtgydNhDRzUREi/+/Hwv1+kH0krph4gfxaz6IX8WvSwRN6ahw2zJEQ9OGxwD8OyK6AuAjAI+rxal0LhDRn2FhdcKNRDQD4I8ArAHM/SB+5YP4dTni147PZWanIAiCEAi501YQBGEgiOALgiAMBBF8QRCEgSCCLwiCMBBE8AVBEAaCCL4gCMJAEMEXBEEYCP8feOtm3QJl25IAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 3 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplot_mosaic([[\"true\", \"approx\", \"error\"]])\n", + "x = np.linspace(0, 1, 50)\n", + "X, Y = np.meshgrid(x,x)\n", + "xx = np.concatenate([X.reshape(-1, 1), Y.reshape(-1, 1)], axis=1)\n", + "axes[\"true\"].contourf(x, x, src.target_function.sin2d(xx).reshape(x.size, -1))\n", + "axes[\"approx\"].contourf(x, x, src.approximation.evaluate(model, device, xx).reshape(x.size, -1))\n", + "im = axes[\"error\"].contourf(x, x, src.target_function.sin2d(xx).reshape(x.size, -1)-src.approximation.evaluate(model, device, xx).reshape(x.size, -1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/__init__.py b/src/__init__.py index 3a794ad..b403c80 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,5 +1,6 @@ from . import ( misc, target_function, + approximation, mnist, ) diff --git a/src/approximation.py b/src/approximation.py new file mode 100644 index 0000000..d88b748 --- /dev/null +++ b/src/approximation.py @@ -0,0 +1,170 @@ +"""Utility functions for handling the MNIST data set.""" +from typing import Tuple, List, Optional +import numpy as np +import torch +import matplotlib.pyplot as plt + + +class NeuralNetwork(torch.nn.Module): + """Neural network used for function approximation.""" + + def __init__(self, + input_dim: int, + output_dim: int, + width: int = 1024, + ) -> None: + """Initialize network layers. + + Parameters + ---------- + input_dim : int + Width of the input layer. + output_dim : int + Width of the output layer. + width : int, default=1024 + Width of the hidden layers. + """ + super(NeuralNetwork, self).__init__() + self.input = torch.nn.Linear(input_dim, width) + self.hidden = torch.nn.Linear(width, width) + self.output = torch.nn.Linear(width, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Evalute network in input data. + + This function defines the topology of the network. + + Parameters: + x : torch.Tensor + Two dimensional input data. + + Returns: + : + Output vector with function values. + """ + model = torch.nn.Sequential( + self.input, + torch.nn.ReLU(inplace=True), + self.hidden, + torch.nn.ReLU(inplace=True), + self.hidden, + torch.nn.ReLU(inplace=True), + self.hidden, + torch.nn.ReLU(inplace=True), + self.output) + return model(x) + + +def train(model: NeuralNetwork, + device: torch.device, + x_train: np.ndarray, + y_train: np.ndarray, + loss_function: torch.nn.modules.loss._Loss, + optimizer: torch.optim.Optimizer, + log_interval=100, + ) -> None: + """Train the given model. + + Parameters + ---------- + model : NeuralNetwork + Neural network model. + device : torch.device + Hardware to train the model on. + x_train : np.ndarray + Input training data the model is trained on. + y_train : np.ndarray + Output training data the model is trained on. + loss_function : torch.nn.modules.loss._Loss + Loss function for the optimization algorithm. + optimizer : torch.optim.Optimizer + Optimization employed to train the model. + log_interval : int + Number of steps the progress is printed after. + """ + kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True} + train_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) + ), batch_size=64, **kwargs) + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data = data.type(torch.float32).to(device) + target = target.type(torch.float32).to(device) + optimizer.zero_grad() + loss = loss_function(input=model(data), target=target) + loss.require_grad = True + loss.backward() + optimizer.step() + if batch_idx % log_interval == 0: + percentage = 100. * batch_idx / len(train_loader) + print(f"progress: {percentage:>4.1f} % -- loss: {loss.item()}") + + +def test(model: NeuralNetwork, + device: torch.device, + x_test: np.ndarray, + y_test: np.ndarray, + loss_function: torch.nn.modules.loss._Loss, + ) -> float: + """Test the network on a test data set. + + Parameters + ---------- + model : NeuralNetwork + Neural network model. + device : torch.device + Hardware to train the model on. + x_train : np.ndarray + Input test data the model is evaluated in. + y_train : np.ndarray + Validation data the model is compared against. + loss_function : torch.nn.modules.loss._Loss + Loss function for the optimization algorithm. + + Returns + ------- + : + Loss on the test data set. + """ + kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True} + test_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset( + torch.from_numpy(x_test), torch.from_numpy(y_test) + ), batch_size=16, **kwargs) + model.eval() + loss = 0 + with torch.no_grad(): + for data, target in test_loader: + data = data.type(torch.float32).to(device) + target = target.type(torch.float32).to(device) + # sum up batch loss + loss += loss_function(input=model(data), target=target) + loss /= len(test_loader.dataset) + return loss + + +def evaluate(model: NeuralNetwork, + device: torch.device, + xs: np.ndarray, + ) -> float: + """Evaluate the network. + + Parameters + ---------- + model : NeuralNetwork + Neural network model. + device : torch.device + Hardware to train the model on. + xs : np.ndarray + Input test data the model is evaluated in. + + Returns + ------- + : + Network outputs. + """ + kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True} + model.eval() + xs = torch.from_numpy(xs).type(torch.float32).to(device) + return model(xs).detach().numpy() -- GitLab