From 20359f86a41d5b884da649c942db4f8e5efc2b6c Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Tue, 8 Feb 2022 14:10:15 +0000 Subject: [PATCH] Added zoom prediction plot --- Experiments/plot_prediction.py | 70 +++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py index a603d78..a574c52 100644 --- a/Experiments/plot_prediction.py +++ b/Experiments/plot_prediction.py @@ -11,11 +11,20 @@ import json import torch import numpy as np +import matplotlib import matplotlib.pyplot as plt from EIVArchitectures import Networks from EIVTrainingRoutines import train_and_store + +font = {'family' : 'DejaVu Sans', + 'weight' : 'normal', + 'size' : 16} + +matplotlib.rc('font', **font) +linewidth = 2.0 + # coverage factor k = 1.96 @@ -107,6 +116,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, return_ground_truth=False, return_normalized_func=True, normalize=normalize) + plotting_dictionary['func'] = normalized_func input_dim = test_data[0][0].numel() output_dim = test_data[0][1].numel() assert output_dim == 1 @@ -223,8 +233,19 @@ list_x_range = [torch.linspace(-1.0,1.0, 50), torch.linspace(-0.2,0.8, 50)] list_color = [('red','blue')] * len(data_list) list_number_of_draws = [((100,5), 100)] * len(data_list) -for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, - list_x_range, list_color, list_number_of_draws)): + +# create an extra zoom plot for zoom_example +zoom_example = 'linear' +# where to zoom +zoom_point = 14 +# size of the zoom plot +x_zoom_radius = 0.4 +y_zoom_radius = x_zoom_radius + +fignum = 0 +for data, x_range, color, number_of_draws in zip(data_list, + list_x_range, list_color, list_number_of_draws): + fignum += 1 eiv_plotting_dictionary = compute_predictions_and_uncertainties( data=data, x_range=x_range, @@ -237,25 +258,62 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, number_of_draws=number_of_draws[1]) input_dim = eiv_plotting_dictionary['input_dim'] if input_dim == 1: - plt.figure(i+1) + plt.figure(fignum) plt.clf() x_values, y_values = eiv_plotting_dictionary['range_points'] - plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k') + noisy_x_values, _ = eiv_plotting_dictionary['noisy_range_points'] + plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth) eiv_pred = eiv_plotting_dictionary['prediction'] eiv_unc = eiv_plotting_dictionary['uncertainty'] plt.plot(x_values, eiv_pred,'-', - color=color[0]) + color=color[0], linewidth=linewidth) plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc, eiv_pred + k * eiv_unc, color=color[0], alpha=0.5) noneiv_pred = noneiv_plotting_dictionary['prediction'] noneiv_unc = noneiv_plotting_dictionary['uncertainty'] plt.plot(x_values.flatten(), noneiv_pred,'-', - color=color[1]) + color=color[1], linewidth=linewidth) plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc, noneiv_pred + k * noneiv_unc, color=color[1], alpha=0.5) + plt.tight_layout() plt.savefig(f'results/figures/prediction_{data}.pdf') + if data == zoom_example: + fignum += 1 + func = eiv_plotting_dictionary['func'] + x_point = x_values[zoom_point] + y_point = func(x_point) + noisy_x_point = noisy_x_values[zoom_point] + func_noisy_x_point = func(noisy_x_point) + plt.figure(fignum) + plt.clf() + plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth) + plt.plot(x_values, eiv_pred,'-', + color=color[0], linewidth=linewidth) + plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc, + eiv_pred + k * eiv_unc, + color=color[0], alpha=0.5) + plt.plot(x_values.flatten(), noneiv_pred,'-', + color=color[1], linewidth=linewidth) + plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc, + noneiv_pred + k * noneiv_unc, + color=color[1], alpha=0.5) + plt.axvline(x_point, color='black', linestyle='dotted') + plt.axhline(y_point, color='black', linestyle='dotted') + plt.axvline(noisy_x_point, color='gray', linestyle='dashed') + plt.axhline(func_noisy_x_point, color='gray', linestyle='dashed') + plt.text(x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$\zeta$', color='k') + plt.text(noisy_x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$x$', color='gray') + plt.text(x_point - 0.92 * x_zoom_radius,y_point-0.13 * y_zoom_radius, r'$g(\zeta)$', color='k') + plt.text(x_point - 0.92 * x_zoom_radius,func_noisy_x_point-0.13 * y_zoom_radius, r'$g(x)$', color='gray') + plt.gca().set_xlim(left=x_point - x_zoom_radius, + right=x_point + x_zoom_radius) + plt.gca().set_ylim(bottom=y_point - y_zoom_radius, + top=y_point + y_zoom_radius) + plt.gca().set_aspect('equal', adjustable='box') + plt.tight_layout() + plt.savefig(f'results/figures/prediction_{data}_zoom.pdf') else: # multidimensional handling not included yet pass -- GitLab