{ "cells": [ { "cell_type": "markdown", "id": "rocky-imaging", "metadata": {}, "source": [ "# Results for the wine quality dataset\n", "\n", "\n", "\n", "This notebook produces the results of EiV and non-EiV models for \n", "the wine quality dataset (taken from https://archive.ics.uci.edu/ml/datasets/wine+quality) as presented in the preprint \"Errors-in-Variables for deep learning: rethinking aleatoric uncertainty\" submitted to NeurIPS 2021.\n", "\n", "\n", "This notebook produces Figures 5a and 6 of the preprint. \n", "\n", "How to use this notebook: \n", "\n", "+ This notebook assumes that the corresponding trained networks exist in `saved_networks`. To achieve this, either run the training scripts described in the `README` or load the pre-trained networks from the link in the `README` into the `saved_networks` folder. \n", "\n", "+ To run this notebook, click \"Run\" in the menu above. \n", "\n", "+ To run this notebook with a GPU, set `use_gpu` to `True` in cell [2] (default is `False`)\n", "\n", "+ Plots will be displayed inline and, in addition, saved to `saved_images`\n" ] }, { "cell_type": "code", "execution_count": 31, "id": "attractive-punch", "metadata": {}, "outputs": [], "source": [ "import random\n", "import os\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader, TensorDataset\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from tqdm.notebook import tqdm\n", "\n", "from EIVArchitectures import Networks\n", "from generate_wine_data import test_x, test_y, train_x, train_y\n", "from EIVTrainingRoutines import train_and_store\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "92190849", "metadata": {}, "source": [ "## Fix relevant hyperparameters" ] }, { "cell_type": "markdown", "id": "7f0337bf", "metadata": {}, "source": [ "### Values that can be changed" ] }, { "cell_type": "code", "execution_count": 32, "id": "2a8a4fbc", "metadata": {}, "outputs": [], "source": [ "# the Deming factor to use for the scatter plots (Figure 5)\n", "# Choose one of 0.01,0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6 and 2.0\n", "deming = 1.0\n", "\n", "# Switch to True if GPU should be used\n", "use_gpu = False\n", "\n", "# Uncertainty coverage factor (1.96 taken from the standard normal)\n", "k=1.96" ] }, { "cell_type": "code", "execution_count": 33, "id": "copyrighted-taxation", "metadata": {}, "outputs": [], "source": [ "# graphics\n", "fontsize=15\n", "matplotlib.rcParams.update({'font.size': fontsize})" ] }, { "cell_type": "code", "execution_count": 34, "id": "3f5515e7", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "if not use_gpu:\n", " device = torch.device('cpu')\n", "else:\n", " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "id": "d6ee9d43", "metadata": {}, "source": [ "### Values to keep fixed\n", "The following values assume the settings from the training scripts. To change the following values, these scripts must be adapted and rerun." ] }, { "cell_type": "code", "execution_count": 35, "id": "564904dc", "metadata": {}, "outputs": [], "source": [ "from train_eiv_wine import dim, init_std_y_list, precision_prior_zeta, deming_factor_list, seed_list\n", "init_std_y = init_std_y_list[0]" ] }, { "cell_type": "code", "execution_count": 36, "id": "auburn-immune", "metadata": {}, "outputs": [], "source": [ "# function to fix seeds (for reproducability)\n", "def set_seeds(seed):\n", " torch.backends.cudnn.benchmark = False \n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)" ] }, { "cell_type": "markdown", "id": "extraordinary-suspension", "metadata": {}, "source": [ "## Comparison EiV and non-EiV for one deming factor\n", "Produces Figure 6 of the preprint" ] }, { "cell_type": "code", "execution_count": 37, "id": "daily-draft", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNNEIV(\n", " (main): Sequential(\n", " (0): EIVInput()\n", " (1): Linear(in_features=11, out_features=200, bias=True)\n", " (2): LeakyReLU(negative_slope=0.01)\n", " (3): EIVDropout()\n", " (4): Linear(in_features=200, out_features=100, bias=True)\n", " (5): LeakyReLU(negative_slope=0.01)\n", " (6): EIVDropout()\n", " (7): Linear(in_features=100, out_features=50, bias=True)\n", " (8): LeakyReLU(negative_slope=0.01)\n", " (9): EIVDropout()\n", " (10): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load EiV model\n", "net = Networks.FNNEIV(p=0.5, init_std_y=init_std_y,\n", " precision_prior_zeta=precision_prior_zeta, deming=deming,\n", " h=[dim, 200,100,50,1])\n", "saved_file = os.path.join('saved_networks', \n", " 'eiv_wine_init_std_y_%.3f_deming_factor_%.3f_seed_%i.pkl'\n", " % (init_std_y, deming, 0))\n", "train_loss, test_loss, stored_std_x, stored_std_y, state_dict, extra_list = train_and_store.open_stored_training(saved_file, \n", " net=net, extra_keys=['rmse'], device=device)\n", "rmse = extra_list[0]\n", "net.to(device)" ] }, { "cell_type": "code", "execution_count": 38, "id": "swiss-paris", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNNBer(\n", " (main): Sequential(\n", " (0): Linear(in_features=11, out_features=200, bias=True)\n", " (1): LeakyReLU(negative_slope=0.01)\n", " (2): Dropout(p=0.5, inplace=False)\n", " (3): Linear(in_features=200, out_features=100, bias=True)\n", " (4): LeakyReLU(negative_slope=0.01)\n", " (5): Dropout(p=0.5, inplace=False)\n", " (6): Linear(in_features=100, out_features=50, bias=True)\n", " (7): LeakyReLU(negative_slope=0.01)\n", " (8): Dropout(p=0.5, inplace=False)\n", " (9): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load non-EiV model\n", "ber_net = Networks.FNNBer(p=0.5, init_std_y=init_std_y,\n", " h=[dim, 200,100,50,1])\n", "ber_saved_file = os.path.join('saved_networks', \n", " 'noneiv_wine_init_std_y_%.3f_seed_%i.pkl'\n", " % (init_std_y, 0))\n", "ber_train_loss, ber_test_loss, ber_stored_std_x,\\\n", " ber_stored_std_y, ber_state_dict, ber_extra_list\\\n", " = train_and_store.open_stored_training(ber_saved_file, net=ber_net, extra_keys=['rmse'], device=device)\n", "ber_rmse = ber_extra_list[0]\n", "ber_net.to(device)" ] }, { "cell_type": "code", "execution_count": 39, "id": "91a110fd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "std_y for EiV: 0.593\n", "std_y for non-EiV: 0.587\n" ] } ], "source": [ "# print std_y for EiV and non-EiV\n", "print('std_y for EiV: %.3f' % (stored_std_y[-1]))\n", "print('std_y for non-EiV: %.3f' % (ber_stored_std_y[-1]))" ] }, { "cell_type": "code", "execution_count": 40, "id": "brave-fiction", "metadata": {}, "outputs": [], "source": [ "# Collect predictions for test data\n", "\n", "set_seeds(0)\n", "\n", "# Save network state\n", "ber_net_train_state = ber_net.training\n", "net_train_state = net.training\n", "net_noise_state = net.noise_is_on\n", "\n", "# Switch on Dropout and input noise\n", "ber_net.train()\n", "net.train()\n", "net.noise_on()\n", "\n", "# Collect predictions and compute errors and uncertainties\n", "pred, _ = [t.cpu().detach().numpy()\n", " for t in net.predict(test_x.to(device), number_of_draws=1000,\n", " take_average_of_prediction=False)]\n", "ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(test_x.to(device), number_of_draws=1000,\n", " take_average_of_prediction=False)]\n", "val_y = test_y.detach().cpu().numpy()\n", "err = np.mean(pred, axis=1).flatten()-val_y\n", "unc = np.std(pred, axis=1).flatten()\n", "ber_err = np.mean(ber_pred, axis=1).flatten()-val_y\n", "ber_unc = np.std(ber_pred, axis=1).flatten()\n", "\n", "# Restore networks\n", "if net_train_state:\n", " net.train()\n", "else:\n", " net.eval()\n", "if net_noise_state:\n", " net.noise_on()\n", "else:\n", " net.noise_off()\n", "if ber_net_train_state:\n", " ber_net.train()\n", "else:\n", " ber_net.eval()" ] }, { "cell_type": "code", "execution_count": 41, "id": "descending-newark", "metadata": {}, "outputs": [], "source": [ "def diagonal(x_array, eps=0.1):\n", " min_x = np.min(x_array)\n", " max_x = np.max(x_array)\n", " x = np.linspace(min_x-eps, max_x+eps)\n", " return x,x" ] }, { "cell_type": "markdown", "id": "e770c860", "metadata": {}, "source": [ "Figure 6 in the preprint" ] }, { "cell_type": "code", "execution_count": 42, "id": "mobile-consumption", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot and save figures\n", "plt.figure(1)\n", "plt.clf()\n", "plt.scatter(ber_unc, unc, color='r')\n", "plt.plot(*diagonal(unc), color='k', linestyle='dashed')\n", "plt.xlabel('uncertainty - non-EiV')\n", "plt.ylabel('uncertainty - EiV')\n", "plt.axis('square')\n", "plt.gca().set_aspect('equal','box')\n", "plt.tight_layout()\n", "plt.savefig(os.path.join('saved_images', 'wine_unc_scatter.pdf'))\n", "\n", "plt.figure(2)\n", "plt.clf()\n", "plt.scatter(ber_err, err, color='r')\n", "plt.plot(*diagonal(ber_err), color='k', linestyle='dashed')\n", "plt.xlabel('error - non-EiV')\n", "plt.ylabel('error - EiV')\n", "plt.axis('square')\n", "plt.gca().set_aspect('equal','box')\n", "plt.tight_layout()\n", "plt.savefig(os.path.join('saved_images', 'wine_error_scatter.pdf'))" ] }, { "cell_type": "code", "execution_count": 83, "id": "4c5e40a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.95228493\n", "1.0843716\n", "1.6982132\n", "2.5368009\n" ] } ], "source": [ "last_std_y = stored_std_y[-1].item()\n", "ber_last_std_y = ber_stored_std_y[-1].item()\n", "print(np.sqrt(np.mean(np.abs(err)**2/(unc**2 + last_std_y**2))))\n", "print(np.sqrt(np.mean(np.abs(ber_err)**2/(ber_unc**2 + ber_last_std_y**2))))\n", "print(np.sqrt(np.mean(np.abs(err)**2/np.abs(unc)**2)))\n", "print(np.sqrt(np.mean(np.abs(ber_err)**2/np.abs(ber_unc)**2)))\n", "#print(np.corrcoef(np.abs(err),unc)[0,1])\n", "# print(np.corrcoef(np.abs(ber_err),ber_unc)[0,1])\n" ] }, { "cell_type": "markdown", "id": "statewide-aurora", "metadata": {}, "source": [ "## Deming vs RMSE\n", "Produces Figure 5a of the preprint" ] }, { "cell_type": "code", "execution_count": 84, "id": "smart-polls", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b23ed94b293e4a7da02cb3ab583524a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/11 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Cycle through seeds and deming factors\n", "# and collect the RMSE\n", "\n", "set_seeds(0)\n", "\n", "# This deming factor will be considered for error/unc quotient and coverage\n", "deming_for_quotient = 1.0\n", "\n", "rmse_list = []\n", "\n", "# quotient error vs uncertainty/uncertainty+noise\n", "unc_q = []\n", "full_q = []\n", "ber_unc_q = []\n", "ber_full_q = []\n", "ber_rmse_fixed_seed_list = []\n", "\n", "# coverage of noisy labels by uncertainty/uncertainty+noise\n", "unc_cov = []\n", "full_cov = []\n", "ber_unc_cov = []\n", "ber_full_cov = []\n", "def compute_coverage(errors, uncertainties, k=1.96):\n", " return np.mean(np.abs(errors) <= k*uncertainties)\n", "\n", "\n", "for i, deming_scale_test in enumerate(tqdm(deming_factor_list)):\n", " rmse_fixed_seed_list = []\n", " for seed in seed_list:\n", " net = Networks.FNNEIV(p=0.5, init_std_y=init_std_y, \n", " precision_prior_zeta=precision_prior_zeta,\n", " h=[dim, 200,100,50,1],\n", " deming=deming_scale_test)\n", " saved_file = os.path.join('saved_networks', \n", " 'eiv_wine_init_std_y_%.3f_deming_factor_%.3f_seed_%i.pkl'\n", " % (init_std_y, deming_scale_test, seed))\n", " net.to(device)\n", " train_and_store.open_stored_training(saved_file, net=net, device=device)\n", " net_train_state = net.training\n", " net_noise_state = net.noise_is_on\n", " net.train()\n", " net.noise_on()\n", " pred, _ = [t.cpu().detach().numpy()\n", " for t in net.predict(test_x.to(device), number_of_draws=300,\n", " take_average_of_prediction=False)]\n", " val_y = test_y.detach().cpu().numpy()\n", " if deming_scale_test == deming_for_quotient:\n", " err = np.abs(np.mean(pred, axis=1).flatten()-val_y.flatten())\n", " unc = np.std(pred, axis=1).flatten()\n", " last_std_y = stored_std_y[-1].item()\n", " unc_q.append(np.sqrt(np.mean(err**2/unc**2)))\n", " full_q.append(np.sqrt(np.mean(err**2/(unc**2+last_std_y**2))))\n", " unc_cov.append(compute_coverage(err, unc))\n", " full_cov.append(compute_coverage(err, np.sqrt(unc**2 + last_std_y**2)))\n", " r = np.sqrt(np.mean((np.mean(pred, axis=1).flatten()-val_y.flatten())**2))\n", " rmse_fixed_seed_list.append(r)\n", " if i==0:\n", " # non-EiV\n", " ber_net = Networks.FNNBer(p=0.5, init_std_y=init_std_y, h=[dim, 200,100,50,1])\n", " ber_saved_file = os.path.join('saved_networks',\n", " 'noneiv_wine_init_std_y_%.3f_seed_%i.pkl' % (init_std_y, seed))\n", " ber_net.to(device)\n", " train_and_store.open_stored_training(ber_saved_file, net=ber_net, device=device)\n", " ber_net_train_state = ber_net.training\n", " ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(test_x.to(device), number_of_draws=300,\n", " take_average_of_prediction=False)]\n", " ber_r = np.sqrt(np.mean((np.mean(ber_pred, axis=1).flatten()-val_y.flatten())**2))\n", " ber_err = np.abs(np.mean(ber_pred, axis=1).flatten()-val_y.flatten())\n", " ber_unc = np.std(ber_pred, axis=1).flatten()\n", " ber_last_std_y = ber_stored_std_y[-1].item()\n", " ber_rmse_fixed_seed_list.append(ber_r)\n", " ber_unc_q.append(np.sqrt(np.mean(ber_err**2/ber_unc**2)))\n", " ber_full_q.append(np.sqrt(np.mean(ber_err**2/(ber_unc**2+ber_last_std_y**2))))\n", " ber_unc_cov.append(compute_coverage(ber_err, ber_unc))\n", " ber_full_cov.append(compute_coverage(ber_err, np.sqrt(ber_unc**2 + ber_last_std_y**2)))\n", " rmse_list.append(np.array(rmse_fixed_seed_list))\n", "\n", "# convert list to an array \n", "rmse_list = np.stack(rmse_list, axis=0)\n", "ber_rmse_list = np.stack(ber_rmse_fixed_seed_list, axis=0)\n", "\n", "# restore settings of last loaded networks\n", "if net_train_state:\n", " net.train()\n", "else:\n", " net.eval()\n", "if net_noise_state:\n", " net.noise_on()\n", "else:\n", " net.noise_off()\n", "if ber_net_train_state:\n", " ber_net.train()\n", "else:\n", " ber_net.eval()" ] }, { "cell_type": "code", "execution_count": 85, "id": "amino-expression", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot and save figure\n", "plt.figure()\n", "plt.plot(deming_factor_list, np.mean(rmse_list,axis=1), linewidth=2, color='orange')\n", "plt.errorbar(deming_factor_list, np.mean(rmse_list,axis=1), np.std(rmse_list,axis=1)/np.sqrt(rmse_list.shape[1]), color='r', linewidth=3,alpha=0.9,ecolor='r',fmt='o',linestyle='None')\n", "plt.fill_between(deming_factor_list, np.mean(ber_rmse_list)+ np.std(ber_rmse_list)/np.sqrt(len(ber_rmse_list)), np.mean(ber_rmse_list)- np.std(ber_rmse_list)/np.sqrt(len(ber_rmse_list)), color='k', alpha=0.4)\n", "plt.xlabel(r'$\\delta$')\n", "plt.ylabel('RMSE')\n", "plt.tight_layout()\n", "plt.savefig(os.path.join('saved_images','wine_deming_rmse.pdf'))" ] }, { "cell_type": "code", "execution_count": 88, "id": "d502a44f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EiV: 1.716 (0.009)\n", "non-EiV: 2.545 (0.026)\n", "EiV: 0.960 (0.009)\n", "non-EiV: 1.085 (0.002)\n" ] } ], "source": [ "# Error vs. Uncertainty\n", "print('EiV: %.3f (%.3f)' % (np.mean(unc_q), np.std(unc_q)/np.sqrt(len(unc_q))))\n", "print('non-EiV: %.3f (%.3f)' % (np.mean(ber_unc_q), np.std(ber_unc_q)/np.sqrt(len(ber_unc_q))))\n", "print('EiV: %.3f (%.3f)' % (np.mean(full_q), np.std(unc_q)/np.sqrt(len(full_q))))\n", "print('non-EiV: %.3f (%.3f)' % (np.mean(ber_full_q), np.std(ber_full_q)/np.sqrt(len(ber_full_q))))" ] }, { "cell_type": "code", "execution_count": 89, "id": "df5e5040", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Uncertainty only \n", "--------\n", "EiV: 0.842 (0.002)\n", "non-EiV: 0.674 (0.003)\n", "Uncertainty + Noise \n", "--------\n", "EiV: 0.958 (0.002)\n", "non-EiV: 0.931 (0.001)\n" ] } ], "source": [ "# Coverage\n", "print('Uncertainty only \\n--------')\n", "print('EiV: %.3f (%.3f)' % (np.mean(unc_cov), np.std(unc_cov)/np.sqrt(len(unc_cov))))\n", "print('non-EiV: %.3f (%.3f)' % (np.mean(ber_unc_cov), np.std(ber_unc_cov)/np.sqrt(len(ber_unc_cov))))\n", "print('Uncertainty + Noise \\n--------')\n", "print('EiV: %.3f (%.3f)' % (np.mean(full_cov), np.std(unc_cov)/np.sqrt(len(full_cov))))\n", "print('non-EiV: %.3f (%.3f)' % (np.mean(ber_full_cov), np.std(ber_full_cov)/np.sqrt(len(ber_full_cov))))" ] }, { "cell_type": "code", "execution_count": null, "id": "607b92ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }