Skip to content
Snippets Groups Projects
Commit 20359f86 authored by Jörg Martin's avatar Jörg Martin
Browse files

Added zoom prediction plot

parent 5406914b
No related branches found
No related tags found
No related merge requests found
......@@ -11,11 +11,20 @@ import json
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store
font = {'family' : 'DejaVu Sans',
'weight' : 'normal',
'size' : 16}
matplotlib.rc('font', **font)
linewidth = 2.0
# coverage factor
k = 1.96
......@@ -107,6 +116,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
return_ground_truth=False,
return_normalized_func=True,
normalize=normalize)
plotting_dictionary['func'] = normalized_func
input_dim = test_data[0][0].numel()
output_dim = test_data[0][1].numel()
assert output_dim == 1
......@@ -223,8 +233,19 @@ list_x_range = [torch.linspace(-1.0,1.0, 50),
torch.linspace(-0.2,0.8, 50)]
list_color = [('red','blue')] * len(data_list)
list_number_of_draws = [((100,5), 100)] * len(data_list)
for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
list_x_range, list_color, list_number_of_draws)):
# create an extra zoom plot for zoom_example
zoom_example = 'linear'
# where to zoom
zoom_point = 14
# size of the zoom plot
x_zoom_radius = 0.4
y_zoom_radius = x_zoom_radius
fignum = 0
for data, x_range, color, number_of_draws in zip(data_list,
list_x_range, list_color, list_number_of_draws):
fignum += 1
eiv_plotting_dictionary = compute_predictions_and_uncertainties(
data=data,
x_range=x_range,
......@@ -237,25 +258,62 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
number_of_draws=number_of_draws[1])
input_dim = eiv_plotting_dictionary['input_dim']
if input_dim == 1:
plt.figure(i+1)
plt.figure(fignum)
plt.clf()
x_values, y_values = eiv_plotting_dictionary['range_points']
plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k')
noisy_x_values, _ = eiv_plotting_dictionary['noisy_range_points']
plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth)
eiv_pred = eiv_plotting_dictionary['prediction']
eiv_unc = eiv_plotting_dictionary['uncertainty']
plt.plot(x_values, eiv_pred,'-',
color=color[0])
color=color[0], linewidth=linewidth)
plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc,
eiv_pred + k * eiv_unc,
color=color[0], alpha=0.5)
noneiv_pred = noneiv_plotting_dictionary['prediction']
noneiv_unc = noneiv_plotting_dictionary['uncertainty']
plt.plot(x_values.flatten(), noneiv_pred,'-',
color=color[1])
color=color[1], linewidth=linewidth)
plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc,
noneiv_pred + k * noneiv_unc,
color=color[1], alpha=0.5)
plt.tight_layout()
plt.savefig(f'results/figures/prediction_{data}.pdf')
if data == zoom_example:
fignum += 1
func = eiv_plotting_dictionary['func']
x_point = x_values[zoom_point]
y_point = func(x_point)
noisy_x_point = noisy_x_values[zoom_point]
func_noisy_x_point = func(noisy_x_point)
plt.figure(fignum)
plt.clf()
plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth)
plt.plot(x_values, eiv_pred,'-',
color=color[0], linewidth=linewidth)
plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc,
eiv_pred + k * eiv_unc,
color=color[0], alpha=0.5)
plt.plot(x_values.flatten(), noneiv_pred,'-',
color=color[1], linewidth=linewidth)
plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc,
noneiv_pred + k * noneiv_unc,
color=color[1], alpha=0.5)
plt.axvline(x_point, color='black', linestyle='dotted')
plt.axhline(y_point, color='black', linestyle='dotted')
plt.axvline(noisy_x_point, color='gray', linestyle='dashed')
plt.axhline(func_noisy_x_point, color='gray', linestyle='dashed')
plt.text(x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$\zeta$', color='k')
plt.text(noisy_x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$x$', color='gray')
plt.text(x_point - 0.92 * x_zoom_radius,y_point-0.13 * y_zoom_radius, r'$g(\zeta)$', color='k')
plt.text(x_point - 0.92 * x_zoom_radius,func_noisy_x_point-0.13 * y_zoom_radius, r'$g(x)$', color='gray')
plt.gca().set_xlim(left=x_point - x_zoom_radius,
right=x_point + x_zoom_radius)
plt.gca().set_ylim(bottom=y_point - y_zoom_radius,
top=y_point + y_zoom_radius)
plt.gca().set_aspect('equal', adjustable='box')
plt.tight_layout()
plt.savefig(f'results/figures/prediction_{data}_zoom.pdf')
else:
# multidimensional handling not included yet
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment