diff --git a/Experiments/evaluate_metrics.py b/Experiments/evaluate_metrics.py index cfae6f4d25b5000899fe8453500c7dbf56c47d9f..4a59088edac1fa63379fc5e4ed751bf686d3fd73 100644 --- a/Experiments/evaluate_metrics.py +++ b/Experiments/evaluate_metrics.py @@ -267,7 +267,7 @@ def collect_metrics(x_y_pairs, seed=0, lin_pred, lin_unc = linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x, device=device) assert true_y is not None lin_coverage = linear_coverage(lin_pred, lin_unc, true_y) - metrics['lin'] = {'coverage': lin_coverage} + metrics['lin'] = {'true_coverage_numerical': lin_coverage} return metrics diff --git a/Experiments/plot_summary.py b/Experiments/plot_summary.py index 90eacc83d2fa768a36ba8e3f7b3d83e8642ec5d6..9c1253479ac1ed25d935a75905de5c45948158d9 100644 --- a/Experiments/plot_summary.py +++ b/Experiments/plot_summary.py @@ -82,14 +82,14 @@ for i, ([(eiv_metric_mean, eiv_metric_std), width = 1.0, bottom = eiv_metric_mean - eiv_bar_size, color=colors[0], - alpha=1.0) + alpha=0.5) plt.plot(i+1, noneiv_metric_mean, '^', color=colors[1], markersize=16, zorder=0) plt.bar(i+1, height = 2 * k *noneiv_bar_size, width = 1.0, bottom = noneiv_metric_mean - k* noneiv_bar_size, color=colors[1], - alpha=1.0) + alpha=0.5) plt.ylim(bottom=0, top=ymax) ax = plt.gca() ax.set_xticks(np.arange(1,len(data_list)+1)) @@ -102,13 +102,14 @@ plt.savefig('results/figures/RMSE_bar_plot.pdf') metric = 'true_coverage_numerical' data_list = ['linear','quadratic','cubic','sine'] -colors = ['red', 'blue'] +colors = ['red', 'blue','black'] ymax = 1.0 minimal_bar_size = ymax * 1.5e-3 # read out EiV and non-EiV results for all datasets metric_results = [ (save_readout(results[data]['eiv'], metric), - save_readout(results[data]['noneiv'], metric)) + save_readout(results[data]['noneiv'], metric), + save_readout(results[data]['lin'], metric)) for data in data_list] # create figure @@ -118,7 +119,8 @@ plt.gcf().canvas.manager.set_window_title('coverage (ground truth)') # plot bars for i, ([(eiv_metric_mean, eiv_metric_std), - (noneiv_metric_mean, noneiv_metric_std)],\ + (noneiv_metric_mean, noneiv_metric_std), + (lin_metric_mean, lin_metric_std)],\ data) in\ enumerate(zip(metric_results, data_list)): if eiv_metric_mean is not None: @@ -127,6 +129,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std), assert noneiv_metric_std is not None eiv_bar_size = max(eiv_metric_std, minimal_bar_size) noneiv_bar_size = max(noneiv_metric_std, minimal_bar_size) + lin_bar_size = max(lin_metric_std, minimal_bar_size) plt.plot(i+1, eiv_metric_mean, '^', color=colors[0], markersize=16) plt.bar(i+1, height = 2*eiv_bar_size, @@ -141,6 +144,13 @@ for i, ([(eiv_metric_mean, eiv_metric_std), bottom = noneiv_metric_mean - k* noneiv_bar_size, color=colors[1], alpha=0.5) + plt.plot(i+1, lin_metric_mean, '^', color=colors[2], markersize=16) + plt.bar(i+1, + height = 2 * k *lin_bar_size, + width = 0.3, + bottom = lin_metric_mean - k* lin_bar_size, + color=colors[2], + alpha=0.5) plt.axhline(0.95,0.0,1.0,color='k', linestyle='dashed') plt.ylim(bottom=0, top=ymax) ax = plt.gca()