{ "cells": [ { "cell_type": "markdown", "id": "charitable-joshua", "metadata": {}, "source": [ "# Results for a 1D Mexican hat curve (incl. VD)\n", "\n", "This notebook produces the results of EiV and non-EiV models for data following a 1D Mexican hat curve as presented in the preprint \"Errors-in-Variables for deep learning: rethinking aleatoric uncertainty\" submitted to NeurIPS 2021.\n", "\n", "\n", "This notebook produces Figures 1, 2, 3 and part of Table 1 of the preprint. \n", "\n", "How to use this notebook: \n", "\n", "+ This notebook assumes that the corresponding trained networks exist in `saved_networks`. To achieve this, either run the training scripts described in the `README` or load the pre-trained networks from the link in the `README` into the `saved_networks` folder. \n", "\n", "+ To run this notebook, click \"Run\" in the menu above. \n", "\n", "+ To consider different levels of input noise, change `std_x` in cell [2]\n", "\n", "+ To run this notebook with a GPU, set `use_gpu` to `True` in cell [2] (default is `False`)\n", "\n", "+ Plots will be displayed inline and, in addition, saved to `saved_images`\n", "\n", "+ The content of Table 1 is produced under \"Coverage\" below . To get the different columns of Table 1, change `std_x` as explained above." ] }, { "cell_type": "code", "execution_count": 5, "id": "buried-recipe", "metadata": {}, "outputs": [], "source": [ "import random\n", "import os\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "\n", "from train_eiv_mexican import report_point, batch_size, seed_list\n", "from train_noneiv_mexican_ensemble_seed import seed_list as ensemble_seed_list\n", "from EIVArchitectures import Networks\n", "import data_frameworks, generate_mexican_data\n", "from EIVTrainingRoutines import train_and_store\n", "from EIVGeneral.ensemble_handling import create_strings, Ensemble\n", "\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "4043eac9", "metadata": {}, "source": [ "## Fix relevant hyperparameters" ] }, { "cell_type": "markdown", "id": "4dd31b60", "metadata": {}, "source": [ "### Values that can be changed" ] }, { "cell_type": "code", "execution_count": 6, "id": "41ee1a77", "metadata": {}, "outputs": [], "source": [ "# The std_x used for data generation and model loading. \n", "# Pick either 0.05, 0.07 or 0.10.\n", "# For the figures in the preprint 0.07 was used.\n", "std_x = 0.07\n", "\n", "# Switch to True if GPU should be used\n", "use_gpu = False\n", "\n", "# Uncertainty coverage factor (1.96 taken from the standard normal)\n", "k=1.96" ] }, { "cell_type": "code", "execution_count": 7, "id": "bbc99f21", "metadata": {}, "outputs": [], "source": [ "# graphics\n", "fontsize=15\n", "matplotlib.rcParams.update({'font.size': fontsize})" ] }, { "cell_type": "code", "execution_count": 8, "id": "007d7aac", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "if not use_gpu:\n", " device = torch.device('cpu')\n", "else:\n", " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "id": "4c0c5646", "metadata": {}, "source": [ "### Values to keep fixed\n", "The following values assume the settings from the training scripts. To change the following values, these scripts must be adapted and rerun." ] }, { "cell_type": "code", "execution_count": 9, "id": "45130d71", "metadata": {}, "outputs": [], "source": [ "# Set further hyperparameters\n", "from train_eiv_mexican import std_y, init_std_y_list, std_x_list\n", "from train_eiv_mexican_fixed_std_x import std_x_list as fixed_std_x_list\n", "from train_eiv_mexican_fixed_std_x import deming_scale_list\n", "init_std_y = init_std_y_list[0]\n", "fixed_std_x = fixed_std_x_list[0] # used only for the plot of the std_y evolution" ] }, { "cell_type": "code", "execution_count": 10, "id": "29e84920", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Choosing deming factor 0.2\n" ] } ], "source": [ "# Fix the maximal Deming factor below std_x/std_y\n", "def find_nearest(a, x):\n", " idx = (np.abs(a - x)).argmin()\n", " return a[idx]\n", "\n", "def find_min_max(a, x):\n", " idx = np.argwhere(a<x).max()\n", " return a[idx]\n", "\n", "deming = find_min_max(np.array(deming_scale_list), std_x/std_y)\n", "print('Choosing deming factor', deming)" ] }, { "cell_type": "code", "execution_count": 11, "id": "nervous-restoration", "metadata": {}, "outputs": [], "source": [ "# Function to fix seeds (for reproducability)\n", "def set_seeds(seed):\n", " torch.backends.cudnn.benchmark = False \n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)" ] }, { "cell_type": "markdown", "id": "545190ff", "metadata": {}, "source": [ "## Prediction (for a single seed)\n", "Produces Figure 2 from the preprint (if `std_x` is set to 0.07 in cell [2]). The network trained with random seed 0 is used." ] }, { "cell_type": "code", "execution_count": 12, "id": "91413ba9", "metadata": {}, "outputs": [], "source": [ "# Change this to take a different network\n", "# Choose an integer between 0 and 19\n", "single_seed = 0" ] }, { "cell_type": "markdown", "id": "806e87b2", "metadata": {}, "source": [ "### Load networks and data" ] }, { "cell_type": "code", "execution_count": 9, "id": "instrumental-survey", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNNEIV(\n", " (main): Sequential(\n", " (0): EIVInput()\n", " (1): Linear(in_features=1, out_features=50, bias=True)\n", " (2): LeakyReLU(negative_slope=0.01)\n", " (3): EIVDropout()\n", " (4): Linear(in_features=50, out_features=100, bias=True)\n", " (5): LeakyReLU(negative_slope=0.01)\n", " (6): EIVDropout()\n", " (7): Linear(in_features=100, out_features=50, bias=True)\n", " (8): LeakyReLU(negative_slope=0.01)\n", " (9): EIVDropout()\n", " (10): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load EiV model\n", "net = Networks.FNNEIV(p=0.5, deming=deming)\n", "saved_file = os.path.join('saved_networks',\n", " 'eiv_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'\\\n", " %(std_x, std_y, init_std_y, deming, single_seed))\n", "train_loss, test_loss, stored_std_x, stored_std_y, state_dict\\\n", " = train_and_store.open_stored_training(saved_file, net=net, device=device)\n", "net.to(device)" ] }, { "cell_type": "code", "execution_count": 10, "id": "lovely-register", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNNBer(\n", " (main): Sequential(\n", " (0): Linear(in_features=1, out_features=50, bias=True)\n", " (1): LeakyReLU(negative_slope=0.01)\n", " (2): Dropout(p=0.5, inplace=False)\n", " (3): Linear(in_features=50, out_features=100, bias=True)\n", " (4): LeakyReLU(negative_slope=0.01)\n", " (5): Dropout(p=0.5, inplace=False)\n", " (6): Linear(in_features=100, out_features=50, bias=True)\n", " (7): LeakyReLU(negative_slope=0.01)\n", " (8): Dropout(p=0.5, inplace=False)\n", " (9): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load non-EiV model\n", "ber_net = Networks.FNNBer(p=0.5, init_std_y=init_std_y)\n", "ber_saved_file = os.path.join('saved_networks', \n", " 'noneiv_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_seed_%i.pkl'\n", " % (std_x, std_y, init_std_y, single_seed))\n", "ber_train_loss, ber_test_loss, ber_stored_std_x, ber_stored_std_y, ber_state_dict\\\n", " = train_and_store.open_stored_training(ber_saved_file, net=ber_net, device=device)\n", "ber_net.to(device)" ] }, { "cell_type": "code", "execution_count": 13, "id": "08892aaa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " -- Generating Mexican hat data with std_x = 0.07 and std_y = 0.30 --\n", " -- Generating Mexican hat data with std_x = 0.07 and std_y = 0.30 --\n", " -- Generating Mexican hat data with std_x = 0.07 and std_y = 0.30 --\n" ] } ], "source": [ "# Generate data\n", "train_data_pure, train_data,\\\n", " test_data_pure,test_data,\\\n", " val_data_pure,val_data,(func, normean, norstd) = generate_mexican_data.get_data(std_x=std_x, std_y=std_y)\n", "val_x, val_y = val_data[0].numpy().flatten(), val_data[1].numpy().flatten()\n", "val_pure_x, val_pure_y = val_data_pure[0].numpy().flatten(), val_data_pure[1].numpy().flatten()\n", "# Sort according to unnoisy data\n", "val_pure_ind = np.argsort(val_pure_x)\n", "val_pure_x = val_pure_x[val_pure_ind]\n", "val_pure_y = val_pure_y[val_pure_ind]\n", "val_x = val_x[val_pure_ind]\n", "val_y = val_y[val_pure_ind]" ] }, { "cell_type": "markdown", "id": "uniform-toolbox", "metadata": {}, "source": [ "### Predictions for noisy input\n", "Produces Figure 2b from the preprint" ] }, { "cell_type": "code", "execution_count": 12, "id": "reserved-dispatch", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# The ground truth\n", "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]\n", "\n", "# Fix seeds\n", "set_seeds(0)\n", "\n", "\n", "## EiV model\n", "net_train_state = net.training\n", "net_noise_state = net.noise_is_on\n", "net.train()\n", "net.noise_on()\n", "# Collect predictions\n", "pred, _= [t.cpu().detach().numpy()\n", " for t in net.predict(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "pred_mean = np.mean(pred, axis=1).flatten()\n", "pred_std = np.std(pred, axis=1).flatten()\n", "\n", "plt.ylim([-1.5,1.5])\n", "\n", "if net_train_state:\n", " net.train()\n", "else:\n", " net.eval()\n", "if net_noise_state:\n", " net.noise_on()\n", "else:\n", " net.noise_off()\n", "\n", "\n", "## Non-EiV model\n", "ber_net_state = ber_net.training\n", "ber_net.train()\n", "ber_net.to(device)\n", "# Collect predictions\n", "ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "ber_pred_mean = np.mean(ber_pred, axis=1).flatten()\n", "ber_pred_std = np.std(ber_pred, axis=1).flatten()\n", "plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=2)\n", "plt.plot(val_pure_x, ber_pred_mean, color='k', label='No EiV', linewidth=2)\n", "plt.fill_between(val_pure_x, ber_pred_mean-k*ber_pred_std, ber_pred_mean+k*ber_pred_std, color='k', alpha=0.2)\n", "plt.plot(val_pure_x, pred_mean, color='r', label='EiV', linewidth=2)\n", "plt.fill_between(val_pure_x, pred_mean-k*pred_std, pred_mean+k*pred_std, color='r', alpha=0.2)\n", "plt.xlabel(r'$\\zeta$')\n", "if ber_net_state:\n", " ber_net.train()\n", "else:\n", " ber_net.eval()\n", " \n", " \n", "## save result\n", "plt.savefig(os.path.join('saved_images','mexican_noisy_prediction_std_x_%.3f_std_y_%.3f.pdf' % (std_x, std_y)) )" ] }, { "cell_type": "markdown", "id": "corrected-broadcasting", "metadata": {}, "source": [ "### Predictions for unnoisy input\n", "Produces Figure 2a from the preprint" ] }, { "cell_type": "code", "execution_count": 13, "id": "coupled-reputation", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]\n", "net_train_state = net.training\n", "net_noise_state = net.noise_is_on\n", "net.train()\n", "net.noise_off()\n", "net.to(device)\n", "set_seeds(0)\n", "pred, _= [t.cpu().detach().numpy()\n", " for t in net.predict(torch.tensor(plot_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "pred_mean = np.mean(pred, axis=1).flatten()\n", "pred_std = np.std(pred, axis=1).flatten()\n", "\n", "# plt.ylim([-0.5,1.5])\n", "plt.ylim([-1.5,1.5])\n", "#plt.show()\n", "if net_train_state:\n", " net.train()\n", "else:\n", " net.eval()\n", "if net_noise_state:\n", " net.noise_on()\n", "else:\n", " net.noise_off()\n", "\n", "ber_net_state = ber_net.training\n", "ber_net.train()\n", "ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(torch.tensor(plot_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "ber_pred_mean = np.mean(ber_pred, axis=1).flatten()\n", "ber_pred_std = np.std(ber_pred, axis=1).flatten()\n", "#plt.figure()\n", "#plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=1)\n", "plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=2)\n", "plt.plot(plot_x, ber_pred_mean, color='k', label='No EiV', linewidth=2)\n", "plt.fill_between(plot_x, ber_pred_mean-k*ber_pred_std, ber_pred_mean+k*ber_pred_std, color='k', alpha=0.2)\n", "plt.plot(plot_x, pred_mean, color='r', label='EiV', linewidth=2)\n", "plt.fill_between(plot_x, pred_mean-k*pred_std, pred_mean+k*pred_std, color='r', alpha=0.2)\n", "plt.xlabel(r'$\\zeta$')\n", "if ber_net_state:\n", " ber_net.train()\n", "else:\n", " ber_net.eval()\n", "plt.savefig(os.path.join('saved_images','mexican_non_noisy_prediction_std_x_%.3f_std_y_%.3f.pdf' % (std_x, std_y)) )" ] }, { "cell_type": "markdown", "id": "baking-legislature", "metadata": {}, "source": [ "## Coverage\n", "Produces Figure 3 of the preprint (if `std_x` is set to 0.07 in cell [2]) and part of Table 1" ] }, { "cell_type": "code", "execution_count": 14, "id": "de3878d0", "metadata": {}, "outputs": [], "source": [ "## Functions used to compute the RMSE and the coverage\n", "\n", "def inside_uncertainties(predictions, truth, k=1.96):\n", " mean = np.mean(predictions, axis=1).flatten()\n", " std = np.std(predictions, axis=1).flatten()\n", " inside = np.logical_and(truth > mean-k*std, truth < mean+k*std).flatten()\n", " return inside\n", "\n", "def inside_explicit_uncertainties(mean, std, truth, k=1.96):\n", " mean = mean.flatten()\n", " std = std.flatten()\n", " truth = truth.flatten()\n", " inside = np.logical_and(truth > mean-k*std, truth < mean+k*std).flatten()\n", " return inside\n", "\n", "# Use quantiles instead of uncertainties (not used in preprint - for concistency reasons)\n", "def inside_intervals(predictions, truth):\n", " up_quantile = np.quantile(predictions, 0.975, axis=1).flatten()\n", " low_quantile = np.quantile(predictions, 0.025, axis=1).flatten()\n", " inside = np.logical_and(truth > low_quantile, truth < up_quantile).flatten()\n", " return inside\n", "\n", "def compute_mse(predictions, noisy_truth):\n", " pred = np.mean(predictions, axis=1).flatten()\n", " y = noisy_truth.flatten()\n", " assert pred.shape == y.shape\n", " mse = np.mean((pred-y)**2)\n", " return mse\n", "\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 15, "id": "southwest-english", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8365440245794dedadcd5b1221e97286", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "abef68d623154578b34377db92a29b31", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a6aae6c63284cd585172afef38d2886", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2786699400334384a23050dce515ac97", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "911d1d3855bc499c873d98d77536bc16", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5e5710b61394673afd9c64f7e249973", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94d24947982c4690b893e099c0b3b580", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4f3ca928cfae4cf2bca14fcc712e3595", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e96dd712202a44e3bf61ca36ecbf915e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "684395289f6b4be988be7ae9047fc503", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "abe46b872dab4cf094a3646b6adecfae", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e5e490fecd704454a761c623430687be", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7795eff751e24cfab5b1eb151520592c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c5d22b18629d40ed9b5add85aca4de5a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "711dd0806baa45d1a8d96a74b456db0b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b28bec3581864e79b4e08e79ac534b41", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f655d7af203b4c2d98d30cf4944ae7dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f18f1ae342bc457d8d2abf77f0720725", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "888b5761162243158c04ae6bedd8b890", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "51e53f03504a4dd5a46161cfbd6405c0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c17d8ddb5e1443c8563e1a143f5be08", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Collect coverage and RMSE given a seed\n", "def coverage_computation(net, ber_net, seed):\n", " set_seeds(seed)\n", " coverage_x = np.linspace(-1.1,1.1, num=100)\n", " coverage_y = func(coverage_x)[1]\n", " net_train_state = net.training\n", " net_noise_state = net.noise_is_on\n", " ber_net_state = ber_net.training\n", " number_of_repeated_draws = 100#0\n", " net.train()\n", " net.noise_on()\n", " ber_net.train()\n", " inside_map = inside_uncertainties\n", " net_inside_list, ber_net_inside_list = [], []\n", " mse_list, ber_mse_list = [], []\n", " for _ in tqdm(range(number_of_repeated_draws)):\n", " noisy_coverage_x = coverage_x + std_x * np.random.normal(0,1,size=coverage_x.size)\n", " noisy_coverage_y = coverage_y + std_y * np.random.normal(0,1,size=coverage_y.size)\n", " pred, _ = [t.cpu().detach().numpy()\n", " for t in net.predict(torch.tensor(noisy_coverage_x, dtype=torch.float32)[:,None].to(device), number_of_draws=200,\n", " take_average_of_prediction=False)]\n", " ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(torch.tensor(noisy_coverage_x, dtype=torch.float32)[:,None].to(device), number_of_draws=200,\n", " take_average_of_prediction=False)]\n", " net_inside_list.append(inside_map(pred, coverage_y))\n", " ber_net_inside_list.append(inside_map(ber_pred, coverage_y))\n", " mse_list.append(compute_mse(pred, noisy_coverage_y))\n", " ber_mse_list.append(compute_mse(ber_pred, noisy_coverage_y))\n", " net_inside = np.mean(np.stack(net_inside_list), axis=0)\n", " ber_net_inside = np.mean(np.stack(ber_net_inside_list), axis=0)\n", " mse = np.mean(np.array(mse_list))\n", " ber_mse = np.mean(np.array(ber_mse_list))\n", " if net_train_state:\n", " net.train()\n", " else:\n", " net.eval()\n", " if net_noise_state:\n", " net.noise_on()\n", " else:\n", " net.noise_off()\n", " if ber_net_state:\n", " ber_net.train()\n", " else:\n", " ber_net.eval()\n", " return coverage_x, coverage_y, net_inside, ber_net_inside, np.sqrt(mse), np.sqrt(ber_mse)\n", "\n", "# Loop over seeds\n", "net_inside_collection, ber_net_inside_collection, rmse_collection, ber_rmse_collection = [], [], [], []\n", "for seed in tqdm(seed_list):\n", " seed_net = Networks.FNNEIV(p=0.5, deming=deming).to(device)\n", " seed_ber_net = Networks.FNNBer(p=0.5, init_std_y=init_std_y).to(device)\n", " ber_saved_file = os.path.join('saved_networks', \n", " 'noneiv_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_seed_%i.pkl'\n", " % (std_x, std_y, init_std_y, seed))\n", " ber_train_loss, ber_test_loss, ber_stored_std_x, ber_stored_std_y, ber_state_dict\\\n", " = train_and_store.open_stored_training(ber_saved_file, net=seed_ber_net, device=device)\n", " saved_file = os.path.join('saved_networks', 'eiv_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'% (std_x, std_y, init_std_y, deming, seed))\n", " train_loss, test_loss, stored_std_x, stored_std_y, state_dict\\\n", " = train_and_store.open_stored_training(saved_file, net=seed_net, device=device)\n", " coverage_x, coverage_y, net_inside, ber_net_inside, rmse, ber_rmse = coverage_computation(seed=seed, net=seed_net, ber_net=seed_ber_net)\n", " net_inside_collection.append(net_inside)\n", " ber_net_inside_collection.append(ber_net_inside)\n", " rmse_collection.append(rmse)\n", " ber_rmse_collection.append(ber_rmse)\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "opening-thanks", "metadata": {}, "outputs": [], "source": [ "# Reshape and process results\n", "net_inside_collection = np.stack(net_inside_collection)\n", "rmse_collection = np.stack(rmse_collection)\n", "ber_net_inside_collection= np.stack(ber_net_inside_collection)\n", "number_of_draws = net_inside_collection.shape[0]\n", "net_inside_mean = np.mean(net_inside_collection, axis=0)\n", "net_inside_std = np.std(net_inside_collection, axis=0)/np.sqrt(number_of_draws)\n", "ber_net_inside_mean = np.mean(ber_net_inside_collection, axis=0)\n", "ber_net_inside_std = np.std(ber_net_inside_collection, axis=0)/np.sqrt(number_of_draws)" ] }, { "cell_type": "code", "execution_count": 17, "id": "connected-rwanda", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Coverage plot (Figure 3 in preprint)\n", "plt.plot(coverage_x, net_inside_mean, color='orange', linewidth=2,alpha=0.9)\n", "plt.errorbar(coverage_x, net_inside_mean, net_inside_std, color='r', linewidth=3, alpha=0.9, ecolor='red',fmt='o', linestyle='None')\n", "plt.axhline(0.95, color='b', linestyle='dashed',linewidth=2,alpha=0.9)\n", "plt.plot(coverage_x, ber_net_inside_mean, color='gray',linewidth=2,alpha=0.9)\n", "plt.errorbar(coverage_x, ber_net_inside_mean, ber_net_inside_std, color='k', linewidth=3,alpha=0.9,ecolor='k',fmt='o',linestyle='None')\n", "plt.xlabel(r'$\\zeta$')\n", "plt.ylabel(r'coverage')\n", "plt.savefig(os.path.join('saved_images','mexican_coverage_std_x_%.3f_std_y_%.3f.pdf' % (std_x, std_y)) )" ] }, { "cell_type": "code", "execution_count": 18, "id": "smart-turkey", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE\n", "===========\n", "EiV: Average 0.353582, Error 0.000571\n", "non-EiV: Average 0.355463, Error 0.000507\n", "\n", "\n", "Coverage\n", "===========\n", "EiV: Average 0.926340, Error 0.000322\n", "non-EiV: Average 0.815800, Error 0.000417\n" ] } ], "source": [ "# Results for Table 1 in preprint\n", "print('RMSE\\n===========')\n", "print('EiV: Average %.6f, Error %.6f' %( np.mean(rmse_collection),\n", " np.std(rmse_collection)/np.sqrt(len(rmse_collection))))\n", "print('non-EiV: Average %.6f, Error %.6f' % (np.mean(ber_rmse_collection), \n", " np.std(ber_rmse_collection)/np.sqrt(len(ber_rmse_collection))))\n", "print('\\n')\n", "\n", "print('Coverage\\n===========')\n", "print('EiV: Average %.6f, Error %.6f' %(net_inside_collection.mean(), \n", " net_inside_collection.mean(axis=1).std()/np.sqrt(net_inside_collection.size)))\n", "print('non-EiV: Average %.6f, Error %.6f' % (ber_net_inside_collection.mean(),\n", " ber_net_inside_collection.mean(axis=1).std()\n", " /np.sqrt(net_inside_collection.size)))\n" ] }, { "cell_type": "markdown", "id": "b489002d", "metadata": {}, "source": [ "# Results for Variational Dropout" ] }, { "cell_type": "code", "execution_count": 19, "id": "a7bd8c98", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNN_VD_EIV(\n", " (main): Sequential(\n", " (0): EIVInput()\n", " (1): Linear(in_features=1, out_features=50, bias=True)\n", " (2): LeakyReLU(negative_slope=0.01)\n", " (3): EIVVariationalDropout()\n", " (4): Linear(in_features=50, out_features=100, bias=True)\n", " (5): LeakyReLU(negative_slope=0.01)\n", " (6): EIVVariationalDropout()\n", " (7): Linear(in_features=100, out_features=50, bias=True)\n", " (8): LeakyReLU(negative_slope=0.01)\n", " (9): EIVVariationalDropout()\n", " (10): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load EiV VD model\n", "vd_net = Networks.FNN_VD_EIV(initial_alpha=0.5, deming=deming)\n", "saved_file = os.path.join('saved_networks',\n", " 'eiv_vd_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'\\\n", " %(std_x, std_y, init_std_y, deming, single_seed))\n", "train_loss, test_loss, stored_std_x, stored_std_y, state_dict\\\n", " = train_and_store.open_stored_training(saved_file, net=vd_net, device=device)\n", "vd_net.to(device)" ] }, { "cell_type": "code", "execution_count": 20, "id": "4443a7ea", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.0070, dtype=torch.float64, grad_fn=<SoftplusBackward>)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vd_net.main[9].alpha()" ] }, { "cell_type": "code", "execution_count": 21, "id": "8d8f11df", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FNN_VD_Ber(\n", " (main): Sequential(\n", " (0): Linear(in_features=1, out_features=50, bias=True)\n", " (1): LeakyReLU(negative_slope=0.01)\n", " (2): EIVVariationalDropout()\n", " (3): Linear(in_features=50, out_features=100, bias=True)\n", " (4): LeakyReLU(negative_slope=0.01)\n", " (5): EIVVariationalDropout()\n", " (6): Linear(in_features=100, out_features=50, bias=True)\n", " (7): LeakyReLU(negative_slope=0.01)\n", " (8): EIVVariationalDropout()\n", " (9): Linear(in_features=50, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load non-EiV VD model\n", "vd_ber_net = Networks.FNN_VD_Ber(initial_alpha=0.5, init_std_y=init_std_y)\n", "vd_ber_saved_file = os.path.join('saved_networks', \n", " 'noneiv_vd_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_seed_%i.pkl'\n", " % (std_x, std_y, init_std_y, single_seed))\n", "vd_ber_train_loss, ber_test_loss, ber_stored_std_x, ber_stored_std_y, ber_state_dict\\\n", " = train_and_store.open_stored_training(vd_ber_saved_file, net=vd_ber_net, device=device)\n", "vd_ber_net.to(device)" ] }, { "cell_type": "markdown", "id": "8fddd747", "metadata": {}, "source": [ "### VD Predictions for noisy input" ] }, { "cell_type": "code", "execution_count": 22, "id": "6696d9e0", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# The ground truth\n", "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]\n", "\n", "# Fix seeds\n", "set_seeds(0)\n", "\n", "\n", "## EiV model\n", "vd_net_train_state = vd_net.training\n", "vd_net_noise_state = vd_net.noise_is_on\n", "vd_net.train()\n", "vd_net.noise_on()\n", "# Collect predictions\n", "vd_pred, _= [t.cpu().detach().numpy()\n", " for t in vd_net.predict(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "vd_pred_mean = np.mean(vd_pred, axis=1).flatten()\n", "vd_pred_std = np.std(vd_pred, axis=1).flatten()\n", "\n", "plt.ylim([-1.5,1.5])\n", "\n", "if vd_net_train_state:\n", " vd_net.train()\n", "else:\n", " vd_net.eval()\n", "if vd_net_noise_state:\n", " vd_net.noise_on()\n", "else:\n", " net.noise_off()\n", "\n", "\n", "## Non-EiV model\n", "vd_ber_net_state = vd_ber_net.training\n", "vd_ber_net.train()\n", "vd_ber_net.to(device)\n", "# Collect predictions\n", "vd_ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in vd_ber_net.predict(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "vd_ber_pred_mean = np.mean(vd_ber_pred, axis=1).flatten()\n", "vd_ber_pred_std = np.std(vd_ber_pred, axis=1).flatten()\n", "plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=2)\n", "plt.plot(val_pure_x, vd_ber_pred_mean, color='k', label='No EiV', linewidth=2)\n", "plt.fill_between(val_pure_x, vd_ber_pred_mean-k*vd_ber_pred_std, vd_ber_pred_mean+k*vd_ber_pred_std, color='k', alpha=0.2)\n", "plt.plot(val_pure_x, vd_pred_mean, color='r', label='EiV', linewidth=2)\n", "plt.fill_between(val_pure_x, vd_pred_mean-k*vd_pred_std, vd_pred_mean+k*vd_pred_std, color='r', alpha=0.2)\n", "plt.xlabel(r'$\\zeta$')\n", "if vd_ber_net_state:\n", " vd_ber_net.train()\n", "else:\n", " vd_ber_net.eval()\n" ] }, { "cell_type": "markdown", "id": "3625ca6c", "metadata": {}, "source": [ "### VD Predictions for unnoisy input" ] }, { "cell_type": "code", "execution_count": 23, "id": "0ba613ef", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEaCAYAAAAL7cBuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABIwElEQVR4nO3dd3wUdfrA8c93N72QhCR06YQQOoSmWLDAcaACiooVzn7n2c5yP8+Ces3Ts3sq6HmKKJ4FrIAHKgLSe0looRPSCKRvsrvP74/ZhBAWCCGbTcjz9jWvyc7Od/eZYd1n59vGiAhKKaVUVTZ/B6CUUqp+0gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsornyUIY0xnY8zbxph1xhiXMeanapRpb4wRL8sMX8WplFLKuwAfvnZ34NfAUiDoNMs+BCyu9Di7toJSSilVPb5MEF+LyJcAxpjPgLjTKLtFRJb6JiyllFLV4bMqJhFx++q1lVJK+V59baR+z9NukW6MedEYE+rvgJRSqrHxZRVTTTiAN4DvgTzgIuBRoBNwpbcCxpg7gDsAwsPD+ycmJtZJoEopdTZYtWpVtojEe3vO1MVkfeVtECJyUQ3K3g38C+grImtPtm9ycrKsXLmyRjEqpVRjZIxZJSLJ3p6rr1VMlX3mWffzaxRKKdXINIQEIVXWSiml6kBDSBBXe9ar/BqFUko1Mj5rpDbGhGENlANoDTQxxpR/2X8nIkXGmO3AAhG51VNmMhCJNUguD7gAeBj4QkTW+ypWpZRSx/NlL6ZmwKdVtpU/7gDs8ry/vdLzqVijqG8DQoE9wPPAX3wYp1JKKS98liBEZBdgTrFP+yqPZwA675JSStUDDaENQimllB9oglBKKeWVJgillFJeaYJQSinllSYIpZRSXmmCUEop5ZUmCKWUUl5pglBKKeWVJgillFJeaYJQSinllSYIpZRSXmmCUEop5ZUmCKWUUl5pglBKKeWVJgilfKSsrMzfISh1RjRBKOUDRUVFrFmzBhG9lbpquDRBKOUD+fn5ZGRkUFRU5O9QlKoxTRBK+UBWVhYlJSUcPnzY36EoVWOaIJSqZSJCZmYmsbGxHDx40N/hKFVjmiCUqmXFxcU4HA4iIyPJycnB5XL5OySlakQThFK1LD8/H2MMNpsNt9tNXl6ev0NSqkY0QShVy7KzswkKCgLAbreTk5Pj54iUqhlNEErVsoyMDLZs2cKuXbsIDw8nPT3d3yEpVSM+SxDGmM7GmLeNMeuMMS5jzE/VLBdljHnPGJNrjDlijJlujIn1VZxK1abi4mJ27drFPffcw7333ktQUBAFBQWUlJT4OzSlTpsvryC6A78GtnqW6voEuAi4DZgIDABm1W5oSvlGfn4+q1evxu12c+DAAbZs2QKg7RCqQfJlgvhaRM4RkfHApuoUMMYMAUYAt4jI5yIyE7gRGGqMudSHsSpVK7Kzs1m/fn3F4wULFhASEkJGRoYfo1KqZnyWIETEXYNiI4EMEfm50ussB3Z6nlOqXktPT2fdunUVjxcuXEh4eDgZGRm43TX5X0Ip/6lvjdSJQKqX7Sme55Sqt4qLi0lJSSEvL4/4+HhCQkJITU0lJyeHsrIyCgoK/B2iUqclwN8BVBEDHPayPRfoWLehKHV68vPzK64e2rS5gOzsfPbu/Z5nn93AOedczTffCE2aQL9+MH68n4NVqhrqW4IA8Db9pTnBdowxdwB3ALRt29aHYSl1cjk5ORXtD2vWjAFKgO9ZsuRnlix55Jh9n/u78Mijps5jVOp01Lcqplwg2sv2aLxfWSAiU0QkWUSS4+PjfReZUqewf/9+1qzZCIDNNoxrrkkGwG6fx+23b2PixK08eL8TY4RH/2j494uHQafhUPVYfbuCSAXO97I9Ee3qquqxkpIS/v3vfZSVFQGJPPlkKaNHu9i4MYnNmzeTmPg13bp1Y1DXADo4Svj9m0nc/nAUsYUbuPLmKGjVCgID/X0YSh2jvl1BzAZaGGOGlm8wxiRjtT/M9ltUSp3CrFkOZszYA0CvXgMZOTKLnJwcLrjgAsDqzRQUFETxf//LPbzOczeux+02XPvnHiz45CD8+CNs3w5Opz8PQ6lj+HIkdZgx5mpjzNVAayC+/LExJsyzz3ZjzLvlZURkCTAX+MAYM84YMwaYDiwSkXm+ilWpM7FkCUyaFInIDwDceGMPCgsLKS0t5dxzzwWsBBFms9HiqafgzTd5+JNklrYdT+/SFVzxdH/WZrSELVtAp+VQ9YgvryCaAZ96lsFAUqXHzTz7BAD2KuWuAxYA/wY+AFYBY30Yp1I1tmkTjBplVTHZbEswxtC/f38cDgfNmzenWbNmNG/enJycHIpmzcJeWoqEhmKcTgbt+YxlDGZ+0WDeu3s5aYdjrKsIHS+h6glfDpTbJSLmBMsuzz7tRWRilXKHRWSSiESLSBMRuV5Esn0Vp1I1lZ4OI0ZAbi4kJn6D211K165diYqKQkRo164dbre7oprJPtuqJS0aNw5mzYKbb0aiokhmFa8U3kbYzePJ3bgfDh3y41EpdVR9a4NQqsF49lnYvx8GDiyje/cvABgwYAAOh4OIiIiKwXJDhgwBICEtDYDMhARo3RruvRfz7bcU/3Ey24O60cJ9gC2vzrWuIpSqBzRBKFUD+/bBu++CMfD443vZuHEtYCWIwsJCWrZsiTGGNm3akJCQwIDgYFq5XDgiIznkduM8eBAcDggJIfTq0RRNfh6A3ts/I2f9ftDJ/VQ9oAlCqRr4+9+htNQaEW3MZrZt24bdbqdv3744nU5iY60Z6ps1a4bNZuPOFi0A2NisGaVNmpAeFQVlZZCZCYcP0+uSeJbHjCCUEla/vhj27vXn4SkFaIJQ6rTt3w9Tp1p/P/KIgxUrVuB2u+nZsychISHYbDYiIyMBaNKkCcHBwQz33A/ic4eD4M6dSS0uxjFoEJx7LrRsCYcPE3ndKAB6b5hOzurdoPeQUH6mCUKp0/Tcc9bVw9VXQ/v2BRXzLw0YMICioiJiY2MJCLDGoBpjaB8WRpuMDBzAmwcOUBwSgoiwa/duiIqCpCS4+GK6PTiS7aE9aEYmP7++Hg4c8ONRKqUJQqnTcuAATJli/f3kk9YEfWvXrgWsBFFcXEwLT3VSuRbLl2OAFaGhHHa7WbJ+PTExMaSlpVFUVGTtFBAA7doh464CIHHFNA6t2a1TcSi/0gSh1Gl47jmrbfmqq6BnT9i2bRu7d+8mODiYnj174na7iY6OPqZM6Ny5ABzsaE1I/PPPP2Oz2QgMDGR75R5LgYF0+dO1ZAW2pJts5uvXdkJWVl0dmlLH0QShVDWlpx979QDWHeMA+vbtizGGoKAgwsPDjxYqLcX89BMAwcnW5H2LFy/G6XQSHR3N/v37j70daYcOFF5mjQttu/AjctfsBPE6kbFSPqcJQqlq+sc/rHbjsWOhVy8oKytj5cqVwNH2h+bNm2NMpWm8f/wRiopwdehAfJcutGvXjkOHDvGf//wHYwyhoaFs2bIFKU8CISG0f+ImimzhDHPP5+PXsuDw4bo/WKXQBKFUtRw8CG+9Zf1dfvVQXFx8TAO1w+GgWbNmxxb8whpAZxs0CHfLljz88MMAvPPOO2zfvp3IyEiysrI4VHn0dI8e5J5r9WiK+d+nHN6gXV6Vf2iCUKoann/eunoYMwb69LG2paamcvDgQSIjI+natStgdWutIALffQeASU4mvnt3unbtylVXXYXT6WTy5Mk4nU4iIyNJSUk5es/qiAhaP3oTLmxc5fyEd1/Og/LGbKXqkCYIpU4hIwPefNP6u/zqAWD+/PkA9OvXD6fTSUREBCEhIUd3SE21hlxHR0NSEvGdO1NWVsa9995LixYtSE1N5YMPPiAsLIz8/HwyMzOPlh06lEM9LySIMvjmaw6nHvT9gSpVhSYIpU7h1VehuBiuuAL69j26vbz9oVevXhQWFh7XvbW8eomBA6FlS5rExBAUFERwcDBPPPEEAFOmTGH79u1ERUWRkpJCWVmZVSY6mvh7rwdgUtlU3n2lQGd5VXVOE4RSJ+F0wnvvWX8/9NDR7SLChg0bAOjevfsx02tU+Ppra92/P7Rqhc1mo3Xr1hQUFDBo0CDGjh2L0+nkmWeewW6343K5WLhwIenp6VZ105gxHDqnF03JJe/L+cjhI3VwxEodpQlCqZP4/nure2uXLjB06NHtubm5pKWlYYwhMTHxmOk1AGvK7hUrwG63EoRnbETz5s0pLS0F4L777qN58+Zs3ryZDz/8kJiYGEJCQli7di1LlizhkDFE3WJ1eR195CMWf5NbV4etFKAJQqmTKr96mDTJmrm13PLly3G5XHTs2BFjzDHTawDwzTdWlVDv3tC2LQQFARAVFUVwcDAlJSVERERUVDW9/fbbpKWlERwcTLNmzXC73SxZupTNA/pRHBjJAFYyZ8oua4I/peqIJgilTiA7G778Emw2uPnmY59bunQpYFUvFRcX07Jly2N3mDnTWicnW/d+8LDZbPTp04e8vDycTieDBw9m7NixlJWV8fTTT1NYWAhAWFgYLVq0ICs0lPTE3gDELJ1Dwb7DPjlWpbzRBKHUCXz0kfWDffjwY77jAesKAqwE4Xa7iYqKOvpkaSnM89xCfeBAiIk5pmzTpk3p1asXOTk5uN3uiqqmTZs2cfnllzN16tSK0dVRTZviumE0ANe6PuK/U/Vuc6ruaIJQ6gTKq5d+85vjn1u/fj0AiYmJx0+v8d57UFAA7dtD164QGnpc+datW9O5c2eys7OJiIjg5Zdfpm/fvuTl5fH2229z+eWX88Ybb5Cbm0v+0KEcimhFG/az6r2VFOj8TKqOaIJQyos1a2DtWmja1OreWllGRgb79+8nODiYli1bHju9httt3U0IYNw4aNPmhO/RpUsXWrVqRU5ODl26dGHq1KlMmTKFgQMHUlhYyHvvvcfll1/Oi2++Se551jxOgw/O4csPFpOlSULVAU0QSnlRfvVw/fUQHHzsc4sXLwasqwen03ns9BoffQS7dkGLFnDRRVC162slxhi6d+9OZGQkhz3zLfXr149//etfvPfee5x//vmUlJQwffp0rlu+EIBxfMHSOU1ZsXw5aWlpR+dwUsoHNEEoVYXDAdOnW397q16q3EANHO3eKgJ/+Yv19w03QESEtZxEQEAAffv2xWazVTRQA/Ts2ZOXXnqJDz/8kF69erEyN5cFQDhFmEXriA0OIzU1lZSUlDM6VqVORhOEUlV8/bU1jKF372NHTpcrb6BOTEwkJCSE0PI2hq++sqbXiI2F88+HxMRj+8aeQEhICMnJyRQXF5OXl3fMVUFiYiJTpkzh7rvvZrrntcaW/JGv/7ubZs2asWvXrmOn6FCqFmmCUKqKf//bWk+adPxzIlLRQN2+ffuj02uIwNNPW39fd53Vc6l582q/Z2RkJEOGDCE6OprMzEyys7NxOp2AdZVx6623csmLL1KMYRhFTJ96O9M++IDIyEjWrVtHid6/WvmAzxKEMSbJGDPfGFNkjDlgjHnGGGM/RZn2xhjxsszwVZxKVbZ/P8ydC4GBVi1RVbt27SI3N5fo6GhiY2OJi4uznpg3z2rZjoqCiy+27jNtO73/vZo0aULfvn0ZNmwYXbp0qZjAr/y2pJ2HDiW350AAbsDJq6+9xt/+9jeMMWzevFnbI1St80mCMMbEAPMAAa4EngH+ADxdzZd4CBhSaXncB2EqdZxp06yOSFdcAeXf/ZUtWrQIsNofjDFH2x/Krx7Gj4dWrU7aOH0qoaGhdOzYkWHDhtG/f38CAwM5ePAgJQ4HJVf/GoCbaUWQPZTZs2czb948Dh48yIEDB2r8nkp5E3DqXWrkLiAUGCciecD/jDFNgMnGmH94tp3MFhFZ6qPYVCMnIixbtoygoCBatGhBdHQ0YWFhiJy8egmONlAnJiYSFhZmTe+9eLG1hIdbo+qq2fZwKna7nWbNmtGsWTOys7PZtGkT23v0oFloUxKKD/Cr8Ef4Ku8fvPjiiyQmJrJhwwaio6OPHZOh1BnwVRXTSGBulUQwAytpXOij91SqWkpLS60BaPn5bNiwgQULFvDjjz8yY8Yetm2Dli2FESO8l12xYgUAnTp1onl5G0P51cOYMdCpU8XEfLUpLi6OoUOH0qVvXw4N6gfAyLwjjLhwHGVlZfzpT3+itLSUjRs3Hr3xkFJnyFcJIhFIrbxBRPYARZ7nTuU9Y4zLGJNujHnRGHP8UFSlasjhcGCMISIigri4OJo1a0ZoaCjTpllNZOPGFRLg5dra6XRWTPHdoUMHq/1h1Sr43/8gJARGj4aEBJ/Fbbfb6dChA83+z7pt6bV8QjP3AyQlJXHgwAFeeOEFsrKy2LNnj89iUI2LrxJEDHDYy/Zcz3Mn4gDeAG4FLgHeBu7GuvrwyhhzhzFmpTFmpY4uVdXhrceP2x3Ezz/HA/n07bu+YkruyjZt2kRJSQmtW7cmKiqKyOBgeOop68nRo6FXr1OOe6gNIRdfTGGbrsRwmIAVaTxz3wM0adKEhQsXMnfuXFJSUirmclLqTPiym6u3LhXmBNutAiLpInKPiHwlIj+JyGTgQeAKY0yfE5SZIiLJIpIcHx9fG3Grs1xRURG2Kj2MPvvsAIWF/TCmJcXFq0lLSzu2kNPJLz/8QBwwMTKSPm+8QfCIEfDtt1aXp7FjoUOHujmAgADCbrDuEzGx5B0OLA/nSc+04W+++Sbbtm1jw4YN2qtJnTFfJYhcINrL9ii8X1mczGeedb8ziEepCvn5+QR57s/gdruZNm0aL798FbAJkUJef+010tato2DFCpg6Fe68Ey66iPGPPUYW8OfUVFp++y0sWWK94HXXwYABXifl8xVz910UBsdwAQthcQa94+O56aabcLlcPPvss+zZs6di+g6laspXvZhSqdLWYIw5BwinSttENUiVtVJnJC8vj6CgIA4ePMjkyZMr7i3djuu5K2QWXbduZcjEiURUqYqKAwqBvFatiBk8mJDevaFjR2tAXLt2dXsQbduSd/3dhL/3V67a8jzFeU9yzbhxbNy4kTVr1vDKK6+QlJRETMzJanSVOjlfXUHMBkYYYyrdg5FrgWJgwWm+1tWe9araCEw1biJCQUEBCxYsYMJ117Fy5UrCQ6Lpx0ustX3HH0uKGAvElJTgDAykLCkJrruOkj/+kR5AjDFsu/VW7BMnwpAh1nTevXtX3DGuzhhDyyduY2NwP1rJAVwfL6Z9QAAPP/wwYWFhLFu2jOXLl+sIa3VGfJUg3sJqcP7CGHOpMeYOYDLwYuWur8aY7caYdys9nmyM+acxZpyn3DPAS8AXIrLeR7GqRqS0tJQpU6bw2GOPkV9QwEXdu3NDsxf4gaeIdh8mNyGBv7VtSzdgZJcubLrpJspuuIHl8fFsAtq3bUvw4MEEXnopXHqpVbVUeTbXutSyJatGPIYbQ8elH9N13z7atWjBhRdaPcnnz59PRkaGf2JTZwWfJAgRycXqhWQHvsYaQf0S8FSVXQM8+5RLxRon8R7wHXA98LxnrdQZy8nJYebMmdiM4clbbuH1i8fw4p57iSKP9J7JHLrrLnrccAN7Q0KYt3kzXx06xO727VnqmRepS/fuxCYkWO0NtTAY7oyEhHDR4+fzDrcRIE7MW1Pp17Qpw4cPB+CHH35g27ZtOi5C1ZjPejGJyGYRuVhEQkWkpYg8ISKuKvu0F5GJlR7P8PRIihKRIBHpLCJPiojDV3GqxqV8zqKOzZtzW9OmdH3jH4RTxLcx13DgurGcc+GF9L7rLm72DKV+4/33WZuaWjHFRkJCAk2bNvXnIRyjXXI8s7o/xiFiCFi3mpD//peJ11xj3c86K4tly5ZpY7WqMZ3NVTUqGzduBGBSUBBtXnmFAHcZr/J7lg6/mTbt2hHUvTtt2rThxhtvpGvXrmRmZjJjxoxjpviumH+pPjCGX9/amj/huQ/F1KlE79vHmDFjAPjxxx/ZvXu3/+JTDZomCNWopKSkcAvwxz17MG43f+dRHgn4J6O7byf+oosgOBibzUavXr245557sNlszJw5k4yMDEJCQkhKSiIwMNDfh3GM8dcH8o65nTX0hawsePFFbpswAbAmF9yxYwfFxcV+jlI1RJogVKOSlprKG1gf/Dm9f8v/8XfO7bqDxHN7YGvVqmK/iIgIRo8ezZVXXllRh5+QkEDr1q39E/hJNG8Owy6181vesDZ88QW9MjLo3asXDoeDRYsWaWO1qhFNEKrREBECt20jHMhv2pSH8p4B4OpzdxM1ePBxjc5t27bld7/7XcU9pxMSEurtuIIJEwxLGcKcmAlQVkbgM89w84ABgFXNlJaWpo3V6rRpglCNRmFhIS09v6SPtOrCpp2xhAc7mPDb9tZU3VXYbDYGDhzIww8/zODBgxk5cmT9an+oZOxYayjGLbkv4YpoAhs3cs/8+XQIDGT9+vWkpaVx6NAhf4epGhhNEKrR2Lp1Kz08v6LXuK1f178elE5Mz84nLBMZGcm4ceN49NFHSUpKIsDbNK/1QHQ0jBoFmTTnw/FfQYsWBO3axTKgB/DzvHns2rXLv0GqBkcThGo0Nm3aRF/P3zPTzwfglt81BftJ74RLu3btiI+PP3r/6XrK0y7NvzaeD//6FyQlEV9WxmKg7Ntvydm+veL2pUpVhyYI1Whs2riRPp6/Z+eeT1xMGcPHNjllObvdTv/+/etlA3Vlo0dbs40vX2FjR9dfwxtvUDpkCE2AaQUFhLzyCpnrdUICVX2aIFSjcWjNGqKBQ/YwDtKCa661Ud0eqyEhIfWue2tVoaHWTe0AZswMhqFDCXjtNb5r144A4MKlSwl95BFcR474M0zVgGiCUI1GaEoKAGvpChgm3HDyqqWGqLya6aOPQOwB2Pr1I+Spp5gElAHNFy6k5Ntv/RmiakA0QahGweVy0cLTg2mFawitmxVz7rl+DsoHLrsM4uJg82ZYuxYwhn5jx7I6MZEpnn1sn34K2uVVVYMmCKx7FK9du5b8/Hx/h6J8ZO/evXQrKwNgLecxfkwhtrPw0x8YaN2/CGDaNGvdpEkTRlx+OZ979gn45RfIzfVLfKphOQv/Fzl9TqeTvXv3snDhQnbt2qUDis5CmzZtqmigXktfbrzJn9H41k2eY/voI3A6rfEcEyZMYFlQENlAYGYm/PSTP0NUDYQmCGDfvn0EBgYSGxvL5s2bWbZsGQUFBf4OS9Wi7StW0B4owk5Aqwh6Dzx176WGasAA6z5GGRnwv/9Z2zp16kTX7t350rOPa+ZMKC31W4yqYWj0CSI9PZ0BAwbw9OTJHDx4kObNm+NwOFi4cCG7d+/Wq4mzRLHn/tEbaMXIc3cSUNd3gKtDxhy9iiivZoqMjCQxMbGimolFiyA72x/hqQak0SeItWvXEuhysXTZMq4ZP543X3uNgIAAmjZtyubNm1mxYgWFhYX+DlOdoaDNWwBYRyJXXnH2Dxa74QZrPWsW5OeDMYbBgwczHyiw2bDv3g0LF/ozRNUANPoEMfLcc8kIDeW72Fjal5Xx7vvvc9WVV/LD3LnEx8dTVFTEwoUL2bt3LyLi73BVDTU/mAlATmxX2nX30y1C61D79nDBBVBcDJ97LhsuuOACSoHZ5dOFzJ0LWpWqTqLRJwjmzSMgJ4eROTmkAvNDQ0k4dIjHnnqKO2++mYzdu4mJiWHjxo2sWrVKpypogPLz8+nutO6H0LxXc0JjY/0cUd0or2b64ANr3atXL8LCwvi4vO3hl18gPd0/wakGQRPEVVdRtGgRmcnJSGAgFxcXsxBYabPRKSWFmydN4tVnnyUsKIgjR46wcOFC9u/fr1cTDci3M9eQBLiBroPchNajW4b60vjxEBxsdVjau9fqzZSUlMQcoCwgALZsgRUrdEyEOiFNEIAtMZG9V1zB2qefJmvMGJwREfR3u/kvsBlwffst48eOZencuTSJjGTdunWsWbOGkpISf4euqmHR2z8SBKTZwgho3oQQL1N7n42iouDKK0EEpk+3tvXp04diYH18vLVhwQIdE6FOSBMEEBITQ+LllxPdtCm7+vVj/ZNPkj5xIqXNm9MZmAHMzs/nx+ef5/eTJlGQkUFubi4LFy4kMzPT3+Grk3C7wbVqJQAHopoS0KwZtrNxhNwJVO7NJAL9+/cH4NuwMOuJpUutywulvGg8/6ecjM1GeJ8+dLz1VpImTKBJ8+bs69CBjfffz4FJkyiLjqY/MA/429atvHDrrbz/0ksYp5MVK1awefNmyjyjdFX9smABJDpSAXC2aUp4PZ+RtbaNGAHx8dbUG6tXw6BBgwD4MD8fCQiADRsgJQUcDj9HquojTRCVBQQQ3rkznW+6iW633kpEQgL7OnRg04MPknH11ThDQ7kMWC7CyG+/5U/XXcf2ZcvYt2cPv/zyC0d0lsx6Z9o06MMBANxtz6FJPb+nQ22rOvVGz549CQoKYltmJjlduliXFUuWQE6OfwNV9ZLPEoQxJskYM98YU2SMOWCMecYYc8rpM40xUcaY94wxucaYI8aY6caYuu12YgwRrVuTMG4cSXffTVCvXuzu1YvURx8l51e/whUQwATg57w8Sp99ljcffZQju3ezePFiduzYoYPr6omiIvj0v2X0wep5Fta5c6PpwVTZzTdb648/BgggISEBgFXt21tPLF0KO3f6IzRVz/kkQRhjYrBqZAS4EngG+APwdDWKfwJcBNwGTAQGALN8EGa1RMTFkXTllXS9806kVy92DB3K1ocf5vCQIQQDDwNT16/n51tv5ZcZM9iybh0rVqzQ7rD1wFdfQVzhIqKATGOwnXMOoRER/g6rzvXvD4mJkJkJ338PvXv3BmBOZCRis1l1T/v3WyPqlKrEV1cQdwGhwDgR+Z+IvIWVHB40xpxwEhxjzBBgBHCLiHwuIjOBG4GhxphLfRRrtUQ1a0bPsWPpctttFPXrx5Zf/5rt991HbseOxAEvOJ3cPG0a/7vrLrYvWMDin38mwzO9tPKPDz6APliTEaWFheGKjiY4ONjPUdW9ylNvfPDB0YbqdZmZlCUlgcsFy5drglDH8VWCGAnMFZG8SttmYCWNC09RLkNEfi7fICLLgZ2e5/wupnlz+l11FZ1+8xsODRrEtokT2XPrreRGR5MAvJSVRfwf/8imf/2LNT/+yObNm3E6nf4Ou9E5eND6tdyX5QBkRkcT1qoVxhg/R+Yf5VNvfPkldOs2EIAdaWkcOu8864lfftF2CHWcAB+9biLwQ+UNIrLHGFPkee7rk5RL9bI9xfNcrdu4Ea6+2rpvfUCAta78d2CgNdgoKOjoEhxsCA5uRVjYdbjLcnEcTqPJkDGcn/4lF6z/D8PdTi74+ms++mU5Rff8gUND0+k9aCCRkZG+OATlxdtvWz+MBwVvAQcUtmxJZPPm/g7Lb9q1g0sugfnzYf36/tjtdvbu3UtKr160MAZWrbK6u/bs6e9Qz0oi1qwmhw5ZeTgnxxp+UlAAhYXHLkVFUFJiTbZbvjgc1rqszPpcO53Hr+fMgQ4dajduXyWIGOCwl+25nudqUq6jtwLGmDuAOwDatm17OjEC1j/Ili2nXaz83YGmngX+wWji+Qv/YDwT+Znf5GSw/en/4z7zNMuihtCmnYPOnYNo29bQrh20bWstCQmguaP2OBzw5pvW3z2cnmq+tm0bdYIA+P3vrQTx1lshdOzYkW3btrGmoIALExOxpaTA4sVw6aXWr6DGSMRqqNm5E3bvtqrciouhuJjS/HwKs7MpPnSIMoeDMhGcLhdOEcrcboodbo4UCXklNvIcdo44bOQVGw4XG/JKbOSX2igRgwMbpRgcGEo9S/m20mOes2HVO4hncVdau70+3r37dTp0qN1u3L5KEGBFX5U5wfYalxORKWDdTTE5Ofm057/o1cvqBu4tI5cvlTN5+VJcbGX6ypn/yBEnGfsDeHXvNGZnzeXJI/fRnWK+lcf59PAH/PHwTD5fl+Q1jo4doXfvY5f27a36Y3V6Pv3UuhfCed1yaJ1SRiEQ2a0bIY2wgbqy0aOtX5g7d8KQIf3Ytm0b23bupHTAAEJSUqz/EYqKGkeCcLutOshZs3Bv20bZzp0E7N+P/QT3yAjyLCf7dVvrIQIOoNSzdlR6XFrluVIgNv9OoGEkiFwg2sv2KLxfIVQuF+9le/QpytVYaKjVw6N2BABNyc5ykfJTNIVbHuPL6dO5NDWV8WzlKrqzJLINi3vcz9Yud5JxKIJdu2DrVkhLs5aZM4++Wlyc9YNu+HDrXsNt2tRWnGcvEXjlFevv+y76CVJgkzE0ad2a0NBQv8bmb3a7dRXx4IOQkTEY+IS0tDQKR40i5IMPrASRnw/R0f4OtdaVlpZSUlAAmzbhfvddAmbOJCLPaiK1AeVdFw5hNXjuwvoyKgGKgVJsOGzhFLqjcBGODTt2ArBX/GcjItBFuL2MCLuTMHsZ4TYnYbYywmwugnERiJsAEQLdbuxuNwFu63H533a3G7sIAS4XdhFsIoRiNdxWx5GY2v8B5KsEkUqVNgNjzDlAON7bGCqXO9/L9kT82NX1dMXFxzPwiitIWbeO6Lg45i1fjvOTT7i8qIjz8vdx3pKH2LnuSRy33ELCjOdxBYWTmgrr1h27ZGbCjBnWAtCtm5UsfvUrK3EE+PL6r4FauhRWroTYWOhZ+j0Au8LDaRUbS1Bj+GV8CpMmwRNPQFqap6F6xw7Su3Uj1hjYvh0OHIBzzvFzlLUrb+9etj78MNHz5tG5UkP8bmAasBY4EhFBSVwcEW3aENeiNUVliRzMSiQtLZGDB9sDUeC2LuebN3fTq5fQu7fQvRf07WsjMdHm/f9HEetqpXx9On+XlVlXdEVFVpVFSYlVf1pUZK0djqONFS4XUT161Pq589VXzGzgYWNMpIiU9527FisZLzhFuSeMMUNFZBGAMSYZq/1hto9i9Yng4GB6DxjAgdatSW3WjKg+ffhi9WqKv/mG0Tk5dCgqgjffJGfqVLJ79iQwLo4BcXEMjo/HfnFLAm9oRa5py8od7fh+eQzfLmxCSoqNlBTrF3KLFnDLLdb/8F27+vto64/yq4c77gD3V4sByI6NJaGRtz+Ui46GiRPhjTd6A4Zdu3axz+mkR/v2Vt3TsmXgmY7jbHDoxx85PGoUycXWdO/FwBfAwhYtCOrVi76DBnHZ+ecT1aoV+w/G8vnnUXz6aTDZ2Uc7eDZpYjXwX3aZ9QOtU6fT6PxpjHXp1kAZX0xb7RkotxnYCDyH9QX/IvCyiDxeab/twAIRubXStjlAAvAQVjXcc0CmiHi7sjhGcnKyrFy5sjYPpVYUFhaycf168vfuJejIEQ5u2EDGV19x8e7ddK/ma7iBQnsA+bZQstxN2enqyR46sIe2hHVuxcArWjLs5nMIT2pndb1qhPbts9ptALYvPoDt4i60LSrin4MGMfb99+momRSwOmVY1ardgFReeeUVfvfzz9g//xx+8xt44w0ICfFzlGfI5SLz7bex33svsS4X24Fv27XDPWgQvbt0wd62LXH9+hHbpjtffhnKv/9tDQUp17271btx+HAYOPDsvlo3xqwSkWRvz/nksEUk1xhzCfA6VpfWw8BLwGQv7181vV7n2fffWNWD3wD3+iLOuhIeHs6AQYPIT0qiuLiYZkOHkjdmDCkbNvDN9OmEZGQQ4nAQWlpKeGkpEU4nkS4XkS4XTUSIxmq8iXQ5iXTl04p8erP76Btsx0q/L8KyDtfSafozxA3s1KB/udTEm29aHQyuGVtK8LpviC8qwgXYOncmspHcA6I6una1qinnzOkHpLJjxw5KBgwg/PPPjzZUN+QEUVBA1uOP0+TVVwkR4Se7nSO/+x0Xt2lDTlQUkUlJRLXow+uvRzBlilV7A9aVwvXXWzkyOVk7iIAPezGJyGbg4lPs097LtsPAJM9y1rDZbERFRREVFUWLFi2gWzcGDBtG8e234ygsxF1WhrusDHE6K9ZlJSXs27OHBdu2kXHgAAXp6RRnZbEnNRXnwYO0BTra7HQPbU58aRRdynYwaOcnbD9vNZ+OeJ5rXxpC067xjeKTXlxsjX0AuGvAEuxTphCAdQnbvHPnRt9AXdV995UniI/Yvn0HeaNGEQ6QmgpHjkBDTah795L9298S/803ALwfFETsffcR16oVBb1707pNX6ZMieaNN44mhmHDrKQwbhyUz4KuLGfxhVP9Z4whLCyMsJN8KtsmJzNEhJKSEgoLC8nLyyMzM5Mf5s9n5uef89KaNVB4gAB7Bje278Vj+w/RpWwb7eZczV/nP4W58QbueyqaqHZ12UGv7n38sTX46KIuexn43u2Eb9tGIfCIzcaDHToQ0pB/EfvA8OHQqlVvDhyAdev2cDAujpZRUdborQ0ban/EVR2Qbds4fN11xK1ejRt4NjycwQ88QFjTpgQO/BUzv+7Ka68ZCgut/ceMgaeftrq6qxMQkbNm6d+/vzQWhw4dkkWLFskrr7wiF198sRhjBJAQkP+ERolY/SDkJy6QHmHb5OVHD4iz1OXvsH3C7Rbp1cstHdghh6LbiYAcCgqSfiAJrVrJwq+/9neI9dJf/5opgBgTLHPmzBEZNsz63DzyiHVSG5DilBTJbddOBCQf5NboaPlh8mT54Z+vykMPZkhkpLv8fwkZNUpk5Up/R1x/ACvlBN+pej+IBiomJoYhQ4Zw1VVX8dhjj/HOO+8wfvx47OHhTCw+Yk1qheFCfmZhUT/+99wqBvUpYu2as28q8gULIHT9MpabwcQc3k1aaCi9S0vZFBjIAyNHNrp7QFTX738fjzEdEHGwYEEmZQOtrq9s3mx1n2wgsjZs4OAFFxC9ezf7gBtbtGDSH/7A1iNd+e2UO3jhxWbk5xtGjLC6QX/zjTXDrTo1TRANmM1mo3Xr1lxwwQWcf/75TJo0iY8++ojHHnuMPV260BPhKyCafGZyBa03/5X+/V3ccUcOBw7knh33rXC7WfHAR/zIMOIki6Xh4fQpLiY3NJSpDzxAv4QEops183eU9VJEBLRpY43s//LLHPL69LGe2LyZinqYeszpdDLrP/9h28CBtM/K4gBwb8eO3HfHg7w8bxS/feVaUrcE07GjNWh6zpyzqgdvndAEcRYIDAykS5cuXHDBBXTr1o1hw4bxwiuv8JcXXuCDQYN4yWYjEOFT/savpTVTp75Dv355vPDCalJSUhru7VK3biV/+DgeXnsDoZTwSXgk5xcWEtSkCdPuuYeBHTqQ37Yt4eHh/o603hoxwuponZq6k4WuWMRuh127rPtD1GO7d+9mwpgxBE2axLklJWQawzfXXMPwS55nwmt38tmCvhhjePhhq0nlssv8HXEDdaK6p4a4NKY2iJNxuVySnZ0ta9eulf/NnCnfP/GELOjUSQTEATIKBIIFJsrQofNk9uyf5fDhw/4Ou/pyc0X+8AeR8HARkFIC5E8BcQJI66ZNZc7TT8va2bNl5bJlMnv2bCksLPR3xPXWjBkzPLPBnS+jR++Ski5drIr611/3d2heFRUVybPPPitNIyLkC0+jQl5goCx95GkZ1Wd9RTtDv34iq1f7O9qGgZO0Qfj9S702F00Qx3M4HHIgLU3Wvv66HBgyxEoSxsjIimkig6Rp0+fkjTcWyM6dO8Vdnxsny8pE3n1XpE2bikb4b+krCbQWQDq1bClzp02TlcuXy8qVK2XFihUyd+5ccbnOzsb52pCWlub5HESIMU7ZcO4V1rm95RaRenTenE6nfPnll5KQkCB2kBmef39HSIh8dfPfpE1MjoBIaKhbXnjB+qio6tEEocRVVCTZs2ZJxnnniYA47Xb5v/Y9y+cSFptttNx228+yYsVKKS4u9ne4x3K7RZYtExk8uCIxHAwPl9EmpiL+7gkJMu/772XlypWyZMkSmTdvnnz33XeybNkyf0dfr7lcLomNjfWcx63yUJup1jlOThYpKPB3eOJ2u2Xjxo3y61//WgAJB/ksMFAEpCwkRF648E0JtDsFRHr3dktqqr8jbng0QSiLwyHOBQuk8LLLREBcAQHy6UXDJTAgyvMF0Vr69JkpX375o2RnZ/s7Wsvhw+K6/35xBwWJgBTYbPKwMRLoSQx2e4L8/rf3y3fffSezZ8+W7777Tn744QfZvHmzZGZmisPh8PcR1HtDhw4VQIKCbpdz2CUC4g4LE0lP92tchw8flgcffFAiIiLEgEwyRnI9n4OyoGC5t+usiiql3/1OpL79rmkoNEGooxwOkcWLxTVqlHUlERQkC2++Wdq27uNJEjaJjPyTvPjiAsnLy/NfnGVlUvTFF1LSoUPFVcOHIM1B7DabBNjHCMyTP/1pjcyZM0fWrVsnBw8elKKiIv/F3EBNnz5dAgMDPf/+T8l+08r6bHzzjd9i2rFjh3Tp0kUAOQ8kJSSk4nOQ3SpBRsYsFhCJauKSzz/3W5hnBU0Q6lgOh8iiRSKXXioCUhodLcv/8Ae56spJAsYzeOoSefTR7XU+XsrldErW2rWSPmqUlNlsIiD7QEaDxEZGyt2jR8vIS5YKiPTpky3bt++QkpKSug3yLHPw4EF5/PHHxWazCSCf0FcEJPfuu/0ST1lZmfTr10/agnzpuWIQEEd0tMy86DkJMGUCIoP7l8rOnX4J8ayiCUIdr7RU5OefRbp3FwEp6dBBlj7zjLzwt39KcHCc59fk+3L99Q6pix/lTqdTDuzYIasffVT2NWlS8aUwBSQxJkaevPpqWfjBB/Laqz+LzeYWu90ta9aU+j6wRuDIkSMye/ZsefzxxwWQ+zznfmf3c6W4rtsh3G558YYbZApIsScOZ2CgpF85Xm44b0dFldJD9zqkVP/5a4UmCOVdaanIt9+KtGxpXUkMGSJLX35ZHv7DHzxXEd0EXNKnj8iOHb4KoVT27NolP7//vvyYmCilnm+AHSA3Nm0qz914o6x97TX5/r//lfnzf5ABAxwCIvff75t4GqOysjKZPXu2rFy5Uu655x4Z6Pk32B7QSjYuWVo3QTgcItOnS07XrhU/DgQke8AAWfvnj6XbOXkCIpGRbq1SqmUnSxA6WV9jFhhoTWX55z/DvfcSuGQJ/c85h6DRo/lPfDxZWSnExHzC2rUTSE6G6dNh5MjaeWun08mePXvYuX49mR99RN9Zs0j0DNibFhlJ/oQJ/KlrV3IiIsg65xy6du/OTz+1YcUKG82aweTJtROHgoCAAEJDQykrK+OWW24hc18OJbM+ppPzAA88v4j/e7MDsbGx2H0xfXxBAbz6Krz1FuzdS1OgEPilTRuaX3cdm81V3PG3AeQX2klKgi++MHqDrDrkkxsG+Ut9vWFQvXfkiPU/6eTJ4HbjfvBBnigo4K9TptClSyItWvzCwoUxGANPPQWPP17zW02ICJmZmWzesIF9v/xC2euvc1NWFoHALpuNxaNG0W/ECIoLC8lt04ZzkpNp3749DkcQXbtCejr8+9/WnfRU7dm4cSN79uwhLi4OYwy2C39F/+JDjCaCbg/9nqFDB9GuXTtatmxJVFTUmc+OKwKffmrdINszajsnOJh/OBz80Lw5z//mDj7bNpE3PusIwLXXwjvvWNODqNp1shsG+b1aqDYXrWI6A/v3i9x5p3VpHxgoh554QqKbNBFA/vznv8jjjxeLMdbTw4aJ7Nt3+m+Rn58vK1askC/ff18e6d9fVlaqSljQqZOseeUV2fCPf8jCl16S9cuWVYyAdrlEbrrJ2nXQoHo1fuusUVJSIlu3bpU5c+bI3LlzZc/V14iAPOPpThwUFCTnnnuuPPDAAzJjxgxZuHCh7NixQ3Jzc8XpdJ7em23dKjJ8+NGqpA4d5MfLL5cAkMCAAHnxwRekV4/DAiJ2u8hLLzW4yWUbFLQNQlXLli0io0dbH4voaHn80ksFkL59+8qqVatk7lyRZs2sp2NjRWbNqt7LlpaWytatW+W7WbNk+lNPyfMREeLwfDlkhITIunvvlc1vvSUrnn1WfpkxQ3IqjcFwuUR+8xvxjJIVWbXKR8euRMRKFNu3b5fVTz4pAjKPzgLnVgxIBMRut0u/fv3krrvuknfeeUfmzp0ra9askfT09JN3My4sFHniiYopUiQ4WJx33imbX3lFYsLDBZCRwx+V8PBSAZHWra1+FMq3TpYgtIpJHeV2w6pVcMcdsHYtxT17Erd1K0UOBy+99BKTJk2ipCSKW26BuXOtIr/9LbzwApzohm25ubmsXbsW98GDbP7wQwbMns1gz2cubcAAym69lZIjRyiy2Yi/5BLa9+xJgOcGwG63Fcq771qv/+23VpOJ8r3SvXsJatuW0oBgmrjzcLhzSE6eht3+BStXrsTlclXse8455zBo0CD69u1Ljx49iIqKIi4ujtjYWOuGWHY7AXPnwp/+ZN2xDnAkJ7P3iivIBX7/2WcsW7+euLgLyM7+EbAxZoxVpRQb65fDb1ROVsWkCUIdq6zM+ia++WbIz+eDLl24Zds2zjvvPF5++WWSk5Nxu+GVV+DRR63de/Sw7ujWo8fRlxER9uzZQ8rq1YTt38/mv/+dm/fvJwzICQri0N134+jeneIDB7AnJNB5xAiaVLrNpdsNd95pfUmEhlpz+F980hvYqlrXvj3s3s13D8zkyteuwOm0MXbsLm67LYVVq1ayaNEilixZQn5+fkWR0NBQ+vXrR/fu3Unq0oVzs7LoMnMm0du2AVAaHc32yy9nZVgYqzIy+GX3blauWYPNFoPbvZ6goFa8/DLcdZetMdwpt17QBKFOT1ER/PWv8Je/4A4KoqfLxWaXizfffJMJEyYQFRUFwOrVMGECbN1q3eP+0UetNsfQ0DI2b9hAzrp1FC5cSOy0aZzrdAKwMSGBwAcfxFFYSLHDQauRI2ndqxc229GZ591uuOsumDrVet1vvoFLLvHLmWjcJkyAGTNg2DAWdL2Dy6aOp8xl57pr93PTzeux220EBQWxY8cOFi9ezOLFi9nmSQTDgMc5elP6ErudJV268HJAAAvS0jhSVFTpjYKAT2jf/mI++cTOwIE6PXtd0gShTl9WFlxzDfz0E1sjIkgqKODiyy7jr3/9K8nJRz9LBQVw331WzyKApk2F68duYXy7b3B+No0B69cTCeTYbOy+6SZCL7oIR3o6pU2bkjhuHE3i4495W7fbqrZ6+20rOXz9NVx6aR0etzrq/fdh4sSKh4UxrZl6eDyfyzgG3NSFBx/KImPvHop37SIkI4PIw4dx79tHk0WLaJ2VBcAR4FXgZeBQpZeOjGxDcfFlOJ2XARdz+eXFPPec0K1bw7sXdkOnCULVzNKlcPnlkJ3Nk8bwV2OY+s479O7dm+joaMLDwwkJCSE4OJjly4N55CEHy1YEM5bP+Se30oE8AJY3b07YY48RGBpKQWYm7qQkeowYQWhYWMVblZbCd99ZiWHOHCs5fPWV3ujFr0TgzTdh9mz45Rc4dPQrPpN43MGhxLsysDsdxxUtCwsje/hwdvfvz4q0NJbv3k222+AKuJRly8eQl9cVMAwceISbbtpK376lDB482DdjLdRJaYJQNeN0wuuvwwMP4AQGAx3HjuV3991HaWkpLpcL43QSUFREUF4eUavX0HTmT7Q5uBmAjcCTob+l+WV30jn6AHFNM+gysiMDhg0mMDAQEViyBKZNg//+9+j3T2gozJoFw4f76bjVUenpVl2iMZCWBkuXUjjnZ8Jz9lbskk8kRZHNCG0dS0TbaEo7dyanXz8OZOSzMSeeLSUd2Zvbkvk/tCQjIxiArl0PM3HiVvr3P0JsbCwJCQlE6CAHv/BLgjDG3A48ApwDbAIeEZH5pygzGXjKy1MjRWTOqd5TE4QP5OVZDdZffkkKMDgggHvvv5+LevSglQgBWVmE7NpF5Lp1RP3yC0aEHOBZexCfNZnB/tyxx7ycMUKHDoZu3SAlxfrOKderF9x0E1x/PbRqVadHqU5ExKpHzMuDjAzIzgaXiwMbcvh6ZUumLunOqn0tKnZPaF1Ip5ZFpBxowu70IESObWnu2LGQhx46zFVXBREREU5oaChGW6P9qs4ThDHmOmA6MBlYBEwCxgMDRGTjScpNBu4HflXlqRQROXKq99UE4SNbtuC+9FJs+/bxDtY/6BDgosBAupSVVdzY3Am8AbwTG8vfH3iADmFN+F9qAssyW5NT1IE9e8LYts1QqYckrVrBDTfAjTdaCULVcy6XlTCOHIEjR5CgYNZsj2Tal034+MtQMjKPdjYICICEBOjWDZKSYMAAGDUKKvVHUPWAPxLEFmCxiPzG89gGrAPWiciNJyk3GbhHROJq8r6aIHzE7YaPP8Z9yy3YKn+7A2XAGmAJ8BYQ2rkz/5o0iZiYGPJbtSLTbqdPv3608lwSlJbCtm1Wd/jYWDj//JpP26HqF6cTfvrJuthISoJOnazpvlT9drIEUeuT9RljOgIJwH3l20TEbYz5tPI21YDYbDBuHI4FC2DmTMo6dCCvfXtW2u18np7OvA0bOHjoEKP69+dvN9yAdOpEdlQUeUVFDEhOJr5ST6WgIOje3VrU2SUgQHucnW18MZtromedWmV7CtDUGBMvIlknKR9tjMkGorDaOZ8VkS98EKc6HaGhhDzzDJs6dcIRHExEUBBJLhdJwBPjxpFRXEx0p06UtW5NQVkZJSUlDBkyhOjoaH9HrpSqIV8kiBjP+nCV7bmVnj9RgtiO1bC9FogA7gQ+N8ZcdaIkYYy5A7gDoG3btjUOWp2aadGC9sOHs2fDBg46HJj4eCLi4iA0lLDAQEqB/Px83G43Q4YMITIy0t8hK6XOQLXaIIwxUUDLU+0nIqnGmBuAD4Hoyg3LxpjLgO+BBBHZVq3grO4NvwChItLnVPtrG0TdKSgoYM+ePezZswdjDNHR0RQUFBAQEMCAAQMIqzTGQSlVf9VGG8R4YGp13oujVwrRWAMpqfQYjr+yOCEREWPMF8Bzxhi7iLhOWUjViYiICJKSkujUqRP79+8nLS2N0NBQ+vfvf+b3ClBK1QvVShAi8g7wTjVfs7ztIRHYXWl7InDoFO0PJwyhBmVUHQgODqZjx460a9cOQEfCKnUWqfUeySKSBmzFuuoAKrq5jgdmn85reaqYxmJ1j9Wrh3rMbrdrclDqLOOre1JPBj40xuwCFgO3AF2A68t3MMZcCMwHLhGRBZ5tC4DPsa5CwoHbsWZ4GOOjOJVSSp2ATxKEiHxsjIkAHgWewJpqY3SVUdQGsHvW5bZjjaRuCbiB1cAoETmtKw+llFJnzldXEIjIVE7SsC0iP3FsckBEbvVVPEoppU6PzoqilFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsornyQIY8y1xpgvjDHpxhgxxkw8jbLnGWOWGWOKjTE7jTH3+iJGpZRSJ+erK4irgfbAN6dTyBjTGZgL7ARGAW8DLxpjbqvtAJVSSp1cgI9e91oRcRtjIoDT+XJ/GDgA3CgiTuAHY0xb4CljzLsiIr4IViml1PF8cgUhIu4aFh0JfOFJDuVmAG2AHmccmFJKqWqrN43Uxphw4BwgtcpTKZ51Yt1GpJRSjVu9SRBAtGd9uMr2XM86ps4iUUopVb02CGNMFNDyVPuJSNVf/zVxonYGr9uNMXcAdwC0bdu2Ft5eKaUUVL+RejwwtRr7mTOI5bBnHV1le0yV548hIlOAKQDJycnaiK2UUrWkWlVMIvKOiJhTLWcSiIgUAns5vq2h/HFtXJ0opZSqpvrUBgEwGxhrjLFX2nYtVuLY6J+QlFKqcfLJOAhjTBKQBIR4NiUbYwqALBFZ4NnnQmA+cEn5NuB54AZgmjFmKjAAuBO4W8dAKKVU3fLVQLlrgKcqPf6dZ1kAXOTZZgA7ldotRGS7MeZXwItYVxMHgT+IyDs+ilMppdQJmLPph3lycrKsXLnS32EopVSDYYxZJSLJ3p6rb20QSiml6glNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvNIEoZRSyitNEEoppbzSBKGUUsorTRBKKaW80gShlFLKK00QSimlvPJJgjDGXGuM+cIYk26MEWPMxGqWm+zZv+ryK1/EqZRS6sQCfPS6VwPtgW+A206z7BGgakJIqYWYlFJKnQZfJYhrRcRtjIng9BOEU0SW+iIopZRS1eeTKiYRcfvidZVSStWd+thIHW2MyTbGlBlj1hhjxvk7IKWUaox8VcVUU9uBR4C1QARwJ/C5MeYqEfnCWwFjzB3AHZ6HBcaYLXURaCVxQHYdv2d9p+fkeHpOjqfn5Hj+OCftTvSEEZFTljbGRAEtT7WfiKRWKRcB5AOTROQ/p3yj49/XAL8AoSLS53TL1wVjzEoRSfZ3HPWJnpPj6Tk5np6T49W3c1LdK4jxwNRq7GfOIJbjiIgYY74AnjPG2EXEVZuvr5RS6sSq1QYhIu+IiDnV4sM4T32Zo5RSqlbVx0bqCp4qprHAunp89TDF3wHUQ3pOjqfn5Hh6To5Xr85JtdogTvtFjUkCkoAQYBrwBvATkCUiCzz7XAjMBy6ptG0B8DmQCoQDt2MNmhsjIl/VeqBKKaVOyFe9mK4Bnqr0+HeeZQFwkWebAewc226xHbgfq0HcDawGRonIbB/FqZRS6gR8cgWhlFKq4avXbRD1TU0nIfSUPc8Ys8wYU2yM2WmMudeHodYpY8ztxphtxpgSY8wqY8wl1ShzVkzMaIxJMsbMN8YUGWMOGGOeMcbYq1EuyhjznjEm1xhzxBgz3RgTWxcx+1pNzokxpv0JPg8z6ipuXzLGdDbGvG2MWWeMcRljfqpmOb9+TurbQLn6rkaTEBpjOgNzPeX+DxgIvGiMKRKRd3wQZ50xxlwHvAVMBhYBk4BvjDEDRGTjKYo36IkZjTExwDxgM3Al0An4J9YPr8dPUfwToCvW58gNPAfMAs73Ubh14gzPCcBDwOJKj8+WgXTdgV8DS4Gg0yjn38+JiOhSzQWwedYRWF1vJ1az3NvAViCg0rZ/AXvxVPM11AXYAvy78jkCNgAfnqLcZCDb3/Gf4bH/H5ALNKm07RGgqPI2L+WGeD4/F1TaNtCz7VJ/H5efzkl7z/GP9vcx+Oi82Cr9/RnwUzXK+P1zolVMp0FqPgnhSOALEXFW2jYDaAP0OOPA/MQY0xFIAP5bvs1zjj7FOuaz3UhgrojkVdo2AwgFLjxFuQwR+bl8g4gsB3bS8M9bTc/JWa2G3x1+/5xogvAxY0w4cA5W193KyqtSEus2olpVHru3Y2tqjIk/RfmGPjFjIlWOXUT2YP1aPtm/63HlPFJOUa4hqOk5Kfeep44+3RjzojEm1BdBNhB+/5xogvC9aM/6cJXtuZ51TJ1FUvvKYz9cZXt1jq18YsZrgKuAA1gTMzakJBHD8ccO1vGf7NhrWq4hqOmxObDGS90KXIJVLXs31tVHY+X3z0mjbqSu6SSENXSi/sT1qp9xDc9J1WMwJ9heufyHVd73a6yJGZ8EvM7cW095O0Zzgu21Ua4hOO1jE5F04J5Km34yxmQA/zLG9BGRtbUbYoPh189Jo04Q1M0khIc96+gq20/069vfTueclF8pRGP1SKLSYziNYxNpkBMz5nL8vytAFCc/9lzAW/Vb9CnKNQQ1PSfefIbVmaMf1i0AGhu/f04adRWT1MEkhCJSiNVbqWqd4Ynq7/3qNM9Jeezeju2QiGTVJIQaB1/3Uqly7MaYc7CmiTnZv+tx5TxOVOfckNT0nHgjVdaNjd8/J406QdSh2cDYKoOFrsVKHKcaK1BviUgaVvfd8eXbjDE2z+PTmh6lgUzMWNVsYIQxJrLStmuBYqxpZU5WroUxZmj5BmNMMtCR0zxv9VBNz4k3V3vWq2ojsAbI/58Tf/cPbkgL1gSEVwM3Yv2qed3z+MJK+1wIOKts6wwUAB8Bw7AaZ8uA2/x9TLVwTiYALqxBUMOA/2B9GfQ4xTlZANwLDMdKDN9hDQS6wt/HdBrHHgOkA/8DLsW6s2EB8Ocq+20H3q2ybQ6QBowDxmCNJ1no72Py1znBGhfzT8/5uBR4xvM5+tzfx1RL5yXM811xNbAE2FTpcVh9/Zz4/cQ1pMXzIRYvy0+V9rnIs+2iKmWHAsuBEmAXcK+/j6cWz8vtng+3A2uCxUuqPH/cOQHe9Xzwi4FCYCEw0t/HUoNjTwJ+8BxHOvAsYK+yzy7gP1W2RQPvYdUl52H9eIjz9/H465wA1wErsdqySj2fp2eAYH8fTy2dk/Yn+O4QoH19/ZzoZH1KKaW80jYIpZRSXmmCUEop5ZUmCKWUUl5pglBKKeWVJgillFJeaYJQSinllSYIpZRSXmmCUEop5ZUmCKV8yBjzH2OMeFkW+js2pU5FE4RSvvUPrHsLDwHewppaIQN4zZ9BKVUdOtWGUj5mjIkBpmFNTPgC8BexpoFXql5r7DcMUsqnjDFxWDObRgPnishK/0akVPXpFYRSPuK5N8ZirCmwLxSRDD+HpNRp0SsIpXznfqAH0EeTg2qItJFaKd+5D3hRRHb4OxClakIThFI+YIzpD7QFZvg7FqVqShOEUr7R1bPe79colDoDmiCU8o0iz7qbX6NQ6gxoLyalfMAYE4V1X+Ui4C9YN6nfICJ5fg1MqdOgCUIpHzHG9AT+CpyHNQ7iMJAsIml+DEupatMqJqV8REQ2iMjlItIUaApEAOf7OSylqk0ThFI+ZowJAEZgzcO02M/hKFVtmiCU8r0hwLPABBHZ7u9glKoubYNQSinllV5BKKWU8koThFJKKa80QSillPJKE4RSSimvNEEopZTyShOEUkoprzRBKKWU8ur/Ae8lNtrGlkMaAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]\n", "vd_net_train_state = vd_net.training\n", "vd_net_noise_state = vd_net.noise_is_on\n", "vd_net.train()\n", "vd_net.noise_off()\n", "vd_net.to(device)\n", "set_seeds(0)\n", "vd_pred, _= [t.cpu().detach().numpy()\n", " for t in vd_net.predict(torch.tensor(plot_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "vd_pred_mean = np.mean(vd_pred, axis=1).flatten()\n", "vd_pred_std = np.std(vd_pred, axis=1).flatten()\n", "\n", "# plt.ylim([-0.5,1.5])\n", "plt.ylim([-1.5,1.5])\n", "#plt.show()\n", "if vd_net_train_state:\n", " vd_net.train()\n", "else:\n", " vd_net.eval()\n", "if vd_net_noise_state:\n", " vd_net.noise_on()\n", "else:\n", " vd_net.noise_off()\n", "\n", "vd_ber_net_state = vd_ber_net.training\n", "vd_ber_net.train()\n", "vd_ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in vd_ber_net.predict(torch.tensor(plot_x, dtype=torch.float32)[:,None].to(device), number_of_draws=5000,\n", " take_average_of_prediction=False)]\n", "vd_ber_pred_mean = np.mean(vd_ber_pred, axis=1).flatten()\n", "vd_ber_pred_std = np.std(vd_ber_pred, axis=1).flatten()\n", "#plt.figure()\n", "#plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=1)\n", "plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=2)\n", "plt.plot(plot_x, vd_ber_pred_mean, color='k', label='No EiV', linewidth=2)\n", "plt.fill_between(plot_x, vd_ber_pred_mean-k*vd_ber_pred_std, vd_ber_pred_mean+k*vd_ber_pred_std, color='k', alpha=0.2)\n", "plt.plot(plot_x, vd_pred_mean, color='r', label='EiV', linewidth=2)\n", "plt.fill_between(plot_x, vd_pred_mean-k*vd_pred_std, vd_pred_mean+k*vd_pred_std, color='r', alpha=0.2)\n", "plt.xlabel(r'$\\zeta$')\n", "if vd_ber_net_state:\n", " vd_ber_net.train()\n", "else:\n", " vd_ber_net.eval()" ] }, { "cell_type": "markdown", "id": "c407df3b", "metadata": {}, "source": [ "## VD Coverage" ] }, { "cell_type": "code", "execution_count": 24, "id": "1f9dbd64", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0d091cae337149858b16790a2fcb2660", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ace93cad5bf347db8d510a299cc82f17", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ace349462acf4bf6bad3d25fcf5666fb", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f349825279164961a44f80ea85f62600", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ed7a655c742c4eee841b3f88b7903652", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ea8d787e31d4f8092d22b7775b81e89", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d5ef6677c34499aade5cfd9f2a36c11", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c9463d985ede4065aa71076edaff62cb", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "eb4d5495a35b4aa09db601095bfd5a68", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bf589c77cc36497ba76516bf5349fed2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "072183d014f34618a0e3afc635841eb4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "260028ee55704513a7b93973c052c530", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "97c8118677bd4155bd36d07df82a73d4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "838896a6f85f488e8de2e707a4f9a0de", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9922729bc1ea496085c3a67e517e78fa", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a2db6b50a0ba41059c9c7718177b1137", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4d436aafbfa64dd2a0bd31d59563c865", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b35d684cc3949beb22cda1462c73ec9", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a20ab2f8df264cfe8a44cda08972f6e7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "24916b809c3a4405a5e9706f5ef65528", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "47af9668899b4c7a9349767ac549e19c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Collect coverage and RMSE given a seed\n", "## NOTE: the following function can be deleted later on\n", "\n", "def coverage_computation(net, ber_net, seed):\n", " set_seeds(seed)\n", " coverage_x = np.linspace(-1.1,1.1, num=100)\n", " coverage_y = func(coverage_x)[1]\n", " net_train_state = net.training\n", " net_noise_state = net.noise_is_on\n", " ber_net_state = ber_net.training\n", " number_of_repeated_draws = 100#0\n", " net.train()\n", " net.noise_on()\n", " ber_net.train()\n", " inside_map = inside_uncertainties\n", " net_inside_list, ber_net_inside_list = [], []\n", " mse_list, ber_mse_list = [], []\n", " for _ in tqdm(range(number_of_repeated_draws)):\n", " noisy_coverage_x = coverage_x + std_x * np.random.normal(0,1,size=coverage_x.size)\n", " noisy_coverage_y = coverage_y + std_y * np.random.normal(0,1,size=coverage_y.size)\n", " pred, _ = [t.cpu().detach().numpy()\n", " for t in net.predict(torch.tensor(noisy_coverage_x, dtype=torch.float32)[:,None].to(device), number_of_draws=200,\n", " take_average_of_prediction=False)]\n", " ber_pred, _ = [t.cpu().detach().numpy()\n", " for t in ber_net.predict(torch.tensor(noisy_coverage_x, dtype=torch.float32)[:,None].to(device), number_of_draws=200,\n", " take_average_of_prediction=False)]\n", " net_inside_list.append(inside_map(pred, coverage_y))\n", " ber_net_inside_list.append(inside_map(ber_pred, coverage_y))\n", " mse_list.append(compute_mse(pred, noisy_coverage_y))\n", " ber_mse_list.append(compute_mse(ber_pred, noisy_coverage_y))\n", " net_inside = np.mean(np.stack(net_inside_list), axis=0)\n", " ber_net_inside = np.mean(np.stack(ber_net_inside_list), axis=0)\n", " mse = np.mean(np.array(mse_list))\n", " ber_mse = np.mean(np.array(ber_mse_list))\n", " if net_train_state:\n", " net.train()\n", " else:\n", " net.eval()\n", " if net_noise_state:\n", " net.noise_on()\n", " else:\n", " net.noise_off()\n", " if ber_net_state:\n", " ber_net.train()\n", " else:\n", " ber_net.eval()\n", " return coverage_x, coverage_y, net_inside, ber_net_inside, np.sqrt(mse), np.sqrt(ber_mse)\n", "\n", "# Loop over seeds\n", "vd_net_inside_collection, vd_ber_net_inside_collection, vd_rmse_collection, vd_ber_rmse_collection = [], [], [], []\n", "for seed in tqdm(seed_list):\n", " seed_vd_net = Networks.FNN_VD_EIV(initial_alpha=0.5, deming=deming).to(device)\n", " seed_vd_ber_net = Networks.FNN_VD_Ber(initial_alpha=0.5, init_std_y=init_std_y).to(device)\n", " vd_ber_saved_file = os.path.join('saved_networks', \n", " 'noneiv_vd_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_seed_%i.pkl'\n", " % (std_x, std_y, init_std_y, seed))\n", " vd_ber_train_loss, vd_ber_test_loss, vd_ber_stored_std_x, vd_ber_stored_std_y, vd_ber_state_dict\\\n", " = train_and_store.open_stored_training(vd_ber_saved_file, net=seed_vd_ber_net, device=device)\n", " vd_saved_file = os.path.join('saved_networks', 'eiv_vd_mexican_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'% (std_x, std_y, init_std_y, deming, seed))\n", " vd_train_loss, vd_test_loss, vd_stored_std_x, vd_stored_std_y, vd_state_dict\\\n", " = train_and_store.open_stored_training(vd_saved_file, net=seed_vd_net, device=device)\n", " vd_coverage_x, vd_coverage_y, vd_net_inside, vd_ber_net_inside, vd_rmse, vd_ber_rmse = coverage_computation(seed=seed, net=seed_vd_net, ber_net=seed_vd_ber_net)\n", " vd_net_inside_collection.append(vd_net_inside)\n", " vd_ber_net_inside_collection.append(vd_ber_net_inside)\n", " vd_rmse_collection.append(vd_rmse)\n", " vd_ber_rmse_collection.append(vd_ber_rmse)" ] }, { "cell_type": "code", "execution_count": 25, "id": "9766873f", "metadata": {}, "outputs": [], "source": [ "# Reshape and process results\n", "vd_net_inside_collection = np.stack(vd_net_inside_collection)\n", "vd_rmse_collection = np.stack(vd_rmse_collection)\n", "vd_ber_net_inside_collection= np.stack(vd_ber_net_inside_collection)\n", "number_of_draws = vd_net_inside_collection.shape[0]\n", "vd_net_inside_mean = np.mean(vd_net_inside_collection, axis=0)\n", "vd_net_inside_std = np.std(vd_net_inside_collection, axis=0)/np.sqrt(number_of_draws)\n", "vd_ber_net_inside_mean = np.mean(vd_ber_net_inside_collection, axis=0)\n", "vd_ber_net_inside_std = np.std(vd_ber_net_inside_collection, axis=0)/np.sqrt(number_of_draws)" ] }, { "cell_type": "code", "execution_count": 26, "id": "396260b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'coverage')" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Coverage plot\n", "plt.plot(vd_coverage_x, vd_net_inside_mean, color='orange', linewidth=2,alpha=0.9)\n", "plt.errorbar(vd_coverage_x, vd_net_inside_mean, vd_net_inside_std, color='r', linewidth=3, alpha=0.9, ecolor='red',fmt='o', linestyle='None')\n", "plt.axhline(0.95, color='b', linestyle='dashed',linewidth=2,alpha=0.9)\n", "plt.plot(vd_coverage_x, vd_ber_net_inside_mean, color='gray',linewidth=2,alpha=0.9)\n", "plt.errorbar(vd_coverage_x, vd_ber_net_inside_mean, vd_ber_net_inside_std, color='k', linewidth=3,alpha=0.9,ecolor='k',fmt='o',linestyle='None')\n", "plt.xlabel(r'$\\zeta$')\n", "plt.ylabel(r'coverage')" ] }, { "cell_type": "markdown", "id": "42c60be6", "metadata": {}, "source": [ "# Results for Ensemble" ] }, { "cell_type": "code", "execution_count": 14, "id": "1cdcb410", "metadata": {}, "outputs": [], "source": [ "ensemble_files = create_strings('noneiv_mexican_std_x_%.3f'\\\n", "'_std_y_%.3f_init_std_y_%.3f_ensemble_seed_%i.pkl', ensemble_seed_list, (std_x, std_y, init_std_y,), ())\n", "ensemble_files = [os.path.join('saved_networks',s) for s in ensemble_files]\n", "ensemble_size = 5\n", "assert len(ensemble_files) % ensemble_size == 0\n", "number_of_ensembles = int(len(ensemble_files) / ensemble_size)" ] }, { "cell_type": "code", "execution_count": 15, "id": "981ad23f", "metadata": {}, "outputs": [], "source": [ "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]" ] }, { "cell_type": "markdown", "id": "df2124ef", "metadata": {}, "source": [ "### loop through ensembles and collect prediction, unc,rmse and coverage" ] }, { "cell_type": "code", "execution_count": 22, "id": "3d216d87", "metadata": {}, "outputs": [], "source": [ "# The ground truth\n", "plot_x = np.linspace(-1.1,1.1)\n", "plot_y = func(plot_x)[1]\n", "\n", "# Fix seeds\n", "set_seeds(0)\n", "\n", "#mean_collection = [] \n", "#unc_collection = []\n", "#for i in tqdm(range(number_of_ensembles)):\n", "# file_chunk = ensemble_files[i*ensemble_size: (i+1)*ensemble_size]\n", "# ens = Ensemble(saved_files = file_chunk,\n", "# architecture_class=Networks.FNNBer,\n", "# device=device,\n", "# p=0.5, init_std_y=init_std_y)\n", "# mean, std = ens.mean_and_std(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device))\n", "# mean_collection.append(mean.flatten())\n", "# unc_collection.append(std.flatten())\n", "#mean_collection = np.stack(mean_collection)\n", "\n", "mean, std = ens.mean_and_std(torch.tensor(val_x, dtype=torch.float32)[:,None].to(device))\n", "mean = mean.detach().cpu().numpy().flatten()\n", "std = std.detach().cpu().numpy().flatten()" ] }, { "cell_type": "code", "execution_count": 23, "id": "25303106", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(200,)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean.shape" ] }, { "cell_type": "code", "execution_count": 24, "id": "13bdbecd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.collections.PolyCollection at 0x7fe235fd00d0>" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(plot_x, plot_y, color='b', label='ground truth', linewidth=2)\n", "plt.plot(val_pure_x, mean, color='g', label='Ensemble', linewidth=2)\n", "plt.fill_between(val_pure_x, mean-k*std,\n", " mean+k*std, color='g', alpha=0.2)" ] }, { "cell_type": "code", "execution_count": 29, "id": "1f517745", "metadata": {}, "outputs": [], "source": [ "def ens_coverage_computation(ens, seed):\n", " set_seeds(seed)\n", " coverage_x = np.linspace(-1.1,1.1, num=100)\n", " coverage_y = func(coverage_x)[1]\n", " number_of_repeated_draws = 100#0\n", " ens_inside_list = []\n", " ens_mse_list = []\n", " for _ in tqdm(range(number_of_repeated_draws)):\n", " noisy_coverage_x = coverage_x + std_x * np.random.normal(0,1,size=coverage_x.size)\n", " noisy_coverage_y = coverage_y + std_y * np.random.normal(0,1,size=coverage_y.size)\n", " mean, std = [t.cpu().detach().numpy()\n", " for t in ens.mean_and_std(torch.tensor(noisy_coverage_x, dtype=torch.float32)[:,None])]\n", " ens_inside_list.append(inside_explicit_uncertainties(mean, std, coverage_y))\n", " ens_mse_list.append(np.mean((mean.flatten()-noisy_coverage_y.flatten())**2))\n", " ens_inside = np.mean(np.stack(ens_inside_list), axis=0)\n", " mse = np.mean(np.array(ens_mse_list))\n", " return coverage_x, coverage_y, ens_inside, np.sqrt(mse)" ] }, { "cell_type": "code", "execution_count": 30, "id": "f381570a", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dd4b7a51aec1456f84e27485a040e335", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "07e30e8291d04add90d48f7d503d0354", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f593a76c68ac4f1a8bfd85d71e12098c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ec8e577b7f1417a9bfefd513386c206", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1be1725b580642fa89a0f8210aadd273", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4fb5374806604868ae2045dba4212649", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d3b6d811f1104df5a60573c84a93e732", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "128dbf5eebab427e9ba9365e5695cb21", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c1ef006def724ff4a7d0ee770f5b4bae", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fe8e01ae32f4b27910d5ab95c95d0dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14fc39fbbbaf477d9592fd30cbff87ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8dabab7d3789460bb7030fe372d0cd9c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0519917e872c42fbba452bd9538f2674", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2abcc1e5e1d94710b45aa33fdbb5b7be", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "425d0345fcd94156b2b5a5ff66793843", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "675499fc83ec413bbebe18929ff50d19", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43442a470dc343d38022f33ee2904ea5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c1f11ba3d7544cb9cd503c77b41a933", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cba3c620bdf44234932ffc9710929801", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "243f702210e34bdebe892f1e2ca5bfd7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "db69c709717043daa1223d2f6b609244", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ens_inside_collection = []\n", "rmse_collection = []\n", "for i in tqdm(range(number_of_ensembles)):\n", " file_chunk = ensemble_files[i*ensemble_size: (i+1)*ensemble_size]\n", " ens = Ensemble(saved_files = file_chunk,\n", " architecture_class=Networks.FNNBer,\n", " device=device,\n", " p=0.5, init_std_y=init_std_y)\n", " _,_, ens_inside, rmse = ens_coverage_computation(ens, seed=i*ensemble_size)\n", " ens_inside_collection.append(ens_inside)\n", " rmse_collection.append(rmse)\n", "ens_inside_collection = np.stack(ens_inside_collection)\n", "rmse_collection = np.stack(rmse_collection)" ] }, { "cell_type": "code", "execution_count": 31, "id": "94112e7e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(20, 100)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ens_inside_collection.shape" ] }, { "cell_type": "code", "execution_count": 32, "id": "d2a1c07f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE\n", "===========\n", "Average 0.358710, Error 0.000718\n", "\n", "\n", "Coverage\n", "===========\n", "Average 0.323130, Error 0.011597\n" ] } ], "source": [ "# Results for Table 1 in preprint\n", "print('RMSE\\n===========')\n", "print('Average %.6f, Error %.6f' %( np.mean(rmse_collection),\n", " np.std(rmse_collection)/np.sqrt(len(rmse_collection))))\n", "print(\"\\n\")\n", "\n", "print('Coverage\\n===========')\n", "print('Average %.6f, Error %.6f' %(ens_inside_collection.mean(), \n", " ens_inside_collection.mean(axis=1).std()/np.sqrt(ens_inside_collection.shape[0])))" ] }, { "cell_type": "code", "execution_count": 40, "id": "0670fdf3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'coverage')" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Coverage plot\n", "plt.plot(coverage_x, ens_inside_collection.mean(axis=0), color='cyan', linewidth=2,alpha=0.9)\n", "plt.errorbar(coverage_x, ens_inside_collection.mean(axis=0), ens_inside_collection.std(axis=0)/np.sqrt(ens_inside_collection.shape[0]), color='g', linewidth=3, alpha=0.9, ecolor='green',fmt='o', linestyle='None')\n", "plt.axhline(0.95, color='b', linestyle='dashed',linewidth=2,alpha=0.9)\n", "plt.xlabel(r'$\\zeta$')\n", "plt.ylabel(r'coverage')" ] }, { "cell_type": "markdown", "id": "stable-playing", "metadata": {}, "source": [ "## Evolution of $\\sigma_y$\n", "\n", "Produces Figure 1 from the preprint. In contrast to the plots above, this uses the results of the training scripts `train_eiv_mexican_fixed_std_x.py` and `train_noneiv_mexican_fixed_std_x.py` and `fixed_std_x` instead of `std_x`.\n", "\n", "(This distinction was only done for computational reasons.)" ] }, { "cell_type": "code", "execution_count": 33, "id": "5918ea1f", "metadata": {}, "outputs": [], "source": [ "# For scaling the x-axis\n", "train_len = generate_mexican_data.n_train\n", "epoch_scale = (report_point-1)*batch_size/train_len" ] }, { "cell_type": "code", "execution_count": 34, "id": "forbidden-armstrong", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc69118b2f9a4a2987d21953f3102ffd", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/9 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Load sigma_y evolution from files\n", "stored_std_y_collection, ber_stored_std_y_collection = [], []\n", "for i, deming_scale_test in enumerate(tqdm(deming_scale_list)):\n", " fixed_deming_std_y_collection = []\n", " for seed in seed_list:\n", " saved_file = os.path.join('saved_networks', 'eiv_mexican_fixed_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_deming_scale_%.3f_seed_%i.pkl'% (fixed_std_x, std_y, init_std_y, deming_scale_test, seed))\n", " _, _, _, stored_std_y, _ = train_and_store.open_stored_training(saved_file, net=None, device=device)\n", " fixed_deming_std_y_collection.append(stored_std_y.cpu().numpy())\n", " if i == 0:\n", " ber_saved_file = os.path.join('saved_networks', 'noneiv_mexican_fixed_std_x_%.3f_std_y_%.3f_init_std_y_%.3f_seed_%i.pkl'% (fixed_std_x, std_y, init_std_y, seed))\n", " _, _, _, ber_stored_std_y, _ = train_and_store.open_stored_training(ber_saved_file, net=None, device=device)\n", " ber_stored_std_y_collection.append(ber_stored_std_y.cpu().numpy())\n", " stored_std_y_collection.append(np.stack(fixed_deming_std_y_collection, axis=0))\n", "ber_sy = np.stack(ber_stored_std_y_collection, axis=0)" ] }, { "cell_type": "code", "execution_count": 35, "id": "palestinian-federation", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Create plots with standard errors (Figure 1 in the preprint)\n", "plt.figure()\n", "evol_x = np.arange(0, len(stored_std_y)) * epoch_scale\n", "scale = 1/np.sqrt(len(seed_list))\n", "for i, (sy, dem) in enumerate(zip(stored_std_y_collection, deming_scale_list)):\n", " color = plt.get_cmap(\"tab20\")(i)\n", " plt.plot(evol_x, np.mean(sy,axis=0), label=r'$\\delta$' + '=' + str(dem), linewidth=2, color=color)\n", " plt.fill_between(evol_x, np.mean(sy,axis=0)-scale*np.std(sy, axis=0),\n", " np.mean(sy,axis=0)+scale*np.std(sy, axis=0),\n", " linewidth=2, color=color)\n", " plt.ylabel(r'$\\sigma_y$')\n", " plt.xlabel('epochs')\n", "plt.plot(evol_x, np.mean(ber_sy, axis=0), color='k', label='non-EiV', linewidth=2)\n", "plt.fill_between(evol_x, np.mean(ber_sy, axis=0)-\n", " scale*np.std(ber_sy, axis=0), \n", " np.mean(ber_sy, axis=0)+\n", " scale*np.std(ber_sy, axis=0),\n", " color='k', alpha=0.7)\n", "plt.axhline(std_y, 0,1, color='blue', linestyle='dotted', linewidth=2)\n", "plt.legend(loc=(1.04,0))\n", "plt.tight_layout()\n", "plt.savefig(os.path.join('saved_images','mexican_sigmay_evol_std_x_%.3f_std_y_%.3f.pdf' % (fixed_std_x, std_y)) )" ] }, { "cell_type": "code", "execution_count": null, "id": "fff4d432", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }