From 20359f86a41d5b884da649c942db4f8e5efc2b6c Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Tue, 8 Feb 2022 14:10:15 +0000
Subject: [PATCH] Added zoom prediction plot

---
 Experiments/plot_prediction.py | 70 +++++++++++++++++++++++++++++++---
 1 file changed, 64 insertions(+), 6 deletions(-)

diff --git a/Experiments/plot_prediction.py b/Experiments/plot_prediction.py
index a603d78..a574c52 100644
--- a/Experiments/plot_prediction.py
+++ b/Experiments/plot_prediction.py
@@ -11,11 +11,20 @@ import json
 import torch
 
 import numpy as np
+import matplotlib
 import matplotlib.pyplot as plt
 
 from EIVArchitectures import Networks
 from EIVTrainingRoutines import train_and_store
 
+
+font = {'family' : 'DejaVu Sans',
+        'weight' : 'normal',
+        'size'   : 16}
+
+matplotlib.rc('font', **font)
+linewidth = 2.0
+
 # coverage factor
 k = 1.96
 
@@ -107,6 +116,7 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
             return_ground_truth=False,
             return_normalized_func=True,
             normalize=normalize)
+    plotting_dictionary['func'] = normalized_func
     input_dim = test_data[0][0].numel()
     output_dim = test_data[0][1].numel()
     assert output_dim == 1
@@ -223,8 +233,19 @@ list_x_range = [torch.linspace(-1.0,1.0, 50),
         torch.linspace(-0.2,0.8, 50)]
 list_color = [('red','blue')] * len(data_list)
 list_number_of_draws = [((100,5), 100)] * len(data_list)
-for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
-    list_x_range, list_color, list_number_of_draws)):
+
+# create an extra zoom plot for zoom_example
+zoom_example = 'linear'
+# where to zoom
+zoom_point = 14
+# size of the zoom plot
+x_zoom_radius = 0.4
+y_zoom_radius = x_zoom_radius
+
+fignum = 0
+for data, x_range, color, number_of_draws in zip(data_list,
+        list_x_range, list_color, list_number_of_draws):
+    fignum += 1
     eiv_plotting_dictionary = compute_predictions_and_uncertainties(
             data=data,
             x_range=x_range,
@@ -237,25 +258,62 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
             number_of_draws=number_of_draws[1])
     input_dim = eiv_plotting_dictionary['input_dim']
     if input_dim == 1:
-        plt.figure(i+1)
+        plt.figure(fignum)
         plt.clf()
         x_values, y_values = eiv_plotting_dictionary['range_points']
-        plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k')
+        noisy_x_values, _ = eiv_plotting_dictionary['noisy_range_points']
+        plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth)
         eiv_pred = eiv_plotting_dictionary['prediction']
         eiv_unc = eiv_plotting_dictionary['uncertainty']
         plt.plot(x_values, eiv_pred,'-',
-                color=color[0])
+                color=color[0], linewidth=linewidth)
         plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc,
                 eiv_pred + k * eiv_unc,
                 color=color[0], alpha=0.5)
         noneiv_pred = noneiv_plotting_dictionary['prediction']
         noneiv_unc = noneiv_plotting_dictionary['uncertainty']
         plt.plot(x_values.flatten(), noneiv_pred,'-',
-                color=color[1])
+                color=color[1], linewidth=linewidth)
         plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc,
                 noneiv_pred + k * noneiv_unc,
                 color=color[1], alpha=0.5)
+        plt.tight_layout()
         plt.savefig(f'results/figures/prediction_{data}.pdf')
+        if data == zoom_example:
+            fignum += 1
+            func = eiv_plotting_dictionary['func'] 
+            x_point = x_values[zoom_point] 
+            y_point = func(x_point)
+            noisy_x_point = noisy_x_values[zoom_point]
+            func_noisy_x_point = func(noisy_x_point)
+            plt.figure(fignum)
+            plt.clf()
+            plt.plot(x_values.flatten(), y_values.flatten(),'-', color='k', linewidth=linewidth)
+            plt.plot(x_values, eiv_pred,'-',
+                    color=color[0], linewidth=linewidth)
+            plt.fill_between(x_values.flatten(), eiv_pred-k * eiv_unc,
+                    eiv_pred + k * eiv_unc,
+                    color=color[0], alpha=0.5)
+            plt.plot(x_values.flatten(), noneiv_pred,'-',
+                    color=color[1], linewidth=linewidth)
+            plt.fill_between(x_values.flatten(), noneiv_pred-k * noneiv_unc,
+                    noneiv_pred + k * noneiv_unc,
+                    color=color[1], alpha=0.5)
+            plt.axvline(x_point, color='black', linestyle='dotted')
+            plt.axhline(y_point, color='black', linestyle='dotted')
+            plt.axvline(noisy_x_point, color='gray', linestyle='dashed')
+            plt.axhline(func_noisy_x_point, color='gray', linestyle='dashed')
+            plt.text(x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$\zeta$', color='k')
+            plt.text(noisy_x_point - 0.1 * x_zoom_radius,y_point-0.9 * y_zoom_radius, r'$x$', color='gray')
+            plt.text(x_point - 0.92 * x_zoom_radius,y_point-0.13 * y_zoom_radius, r'$g(\zeta)$', color='k')
+            plt.text(x_point - 0.92 * x_zoom_radius,func_noisy_x_point-0.13 * y_zoom_radius, r'$g(x)$', color='gray')
+            plt.gca().set_xlim(left=x_point - x_zoom_radius, 
+                    right=x_point + x_zoom_radius)
+            plt.gca().set_ylim(bottom=y_point - y_zoom_radius,
+                    top=y_point + y_zoom_radius)
+            plt.gca().set_aspect('equal', adjustable='box')
+            plt.tight_layout()
+            plt.savefig(f'results/figures/prediction_{data}_zoom.pdf')
     else:
         # multidimensional handling not included yet
         pass
-- 
GitLab