Skip to content
Snippets Groups Projects
Commit a69a71fe authored by Jörg Martin's avatar Jörg Martin
Browse files

summary plot updated

parent 18b947ae
No related branches found
No related tags found
No related merge requests found
...@@ -267,7 +267,7 @@ def collect_metrics(x_y_pairs, seed=0, ...@@ -267,7 +267,7 @@ def collect_metrics(x_y_pairs, seed=0,
lin_pred, lin_unc = linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x, device=device) lin_pred, lin_unc = linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x, device=device)
assert true_y is not None assert true_y is not None
lin_coverage = linear_coverage(lin_pred, lin_unc, true_y) lin_coverage = linear_coverage(lin_pred, lin_unc, true_y)
metrics['lin'] = {'coverage': lin_coverage} metrics['lin'] = {'true_coverage_numerical': lin_coverage}
return metrics return metrics
......
...@@ -82,14 +82,14 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -82,14 +82,14 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
width = 1.0, width = 1.0,
bottom = eiv_metric_mean - eiv_bar_size, bottom = eiv_metric_mean - eiv_bar_size,
color=colors[0], color=colors[0],
alpha=1.0) alpha=0.5)
plt.plot(i+1, noneiv_metric_mean, '^', color=colors[1], markersize=16, zorder=0) plt.plot(i+1, noneiv_metric_mean, '^', color=colors[1], markersize=16, zorder=0)
plt.bar(i+1, plt.bar(i+1,
height = 2 * k *noneiv_bar_size, height = 2 * k *noneiv_bar_size,
width = 1.0, width = 1.0,
bottom = noneiv_metric_mean - k* noneiv_bar_size, bottom = noneiv_metric_mean - k* noneiv_bar_size,
color=colors[1], color=colors[1],
alpha=1.0) alpha=0.5)
plt.ylim(bottom=0, top=ymax) plt.ylim(bottom=0, top=ymax)
ax = plt.gca() ax = plt.gca()
ax.set_xticks(np.arange(1,len(data_list)+1)) ax.set_xticks(np.arange(1,len(data_list)+1))
...@@ -102,13 +102,14 @@ plt.savefig('results/figures/RMSE_bar_plot.pdf') ...@@ -102,13 +102,14 @@ plt.savefig('results/figures/RMSE_bar_plot.pdf')
metric = 'true_coverage_numerical' metric = 'true_coverage_numerical'
data_list = ['linear','quadratic','cubic','sine'] data_list = ['linear','quadratic','cubic','sine']
colors = ['red', 'blue'] colors = ['red', 'blue','black']
ymax = 1.0 ymax = 1.0
minimal_bar_size = ymax * 1.5e-3 minimal_bar_size = ymax * 1.5e-3
# read out EiV and non-EiV results for all datasets # read out EiV and non-EiV results for all datasets
metric_results = [ metric_results = [
(save_readout(results[data]['eiv'], metric), (save_readout(results[data]['eiv'], metric),
save_readout(results[data]['noneiv'], metric)) save_readout(results[data]['noneiv'], metric),
save_readout(results[data]['lin'], metric))
for data in data_list] for data in data_list]
# create figure # create figure
...@@ -118,7 +119,8 @@ plt.gcf().canvas.manager.set_window_title('coverage (ground truth)') ...@@ -118,7 +119,8 @@ plt.gcf().canvas.manager.set_window_title('coverage (ground truth)')
# plot bars # plot bars
for i, ([(eiv_metric_mean, eiv_metric_std), for i, ([(eiv_metric_mean, eiv_metric_std),
(noneiv_metric_mean, noneiv_metric_std)],\ (noneiv_metric_mean, noneiv_metric_std),
(lin_metric_mean, lin_metric_std)],\
data) in\ data) in\
enumerate(zip(metric_results, data_list)): enumerate(zip(metric_results, data_list)):
if eiv_metric_mean is not None: if eiv_metric_mean is not None:
...@@ -127,6 +129,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -127,6 +129,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
assert noneiv_metric_std is not None assert noneiv_metric_std is not None
eiv_bar_size = max(eiv_metric_std, minimal_bar_size) eiv_bar_size = max(eiv_metric_std, minimal_bar_size)
noneiv_bar_size = max(noneiv_metric_std, minimal_bar_size) noneiv_bar_size = max(noneiv_metric_std, minimal_bar_size)
lin_bar_size = max(lin_metric_std, minimal_bar_size)
plt.plot(i+1, eiv_metric_mean, '^', color=colors[0], markersize=16) plt.plot(i+1, eiv_metric_mean, '^', color=colors[0], markersize=16)
plt.bar(i+1, plt.bar(i+1,
height = 2*eiv_bar_size, height = 2*eiv_bar_size,
...@@ -141,6 +144,13 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -141,6 +144,13 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
bottom = noneiv_metric_mean - k* noneiv_bar_size, bottom = noneiv_metric_mean - k* noneiv_bar_size,
color=colors[1], color=colors[1],
alpha=0.5) alpha=0.5)
plt.plot(i+1, lin_metric_mean, '^', color=colors[2], markersize=16)
plt.bar(i+1,
height = 2 * k *lin_bar_size,
width = 0.3,
bottom = lin_metric_mean - k* lin_bar_size,
color=colors[2],
alpha=0.5)
plt.axhline(0.95,0.0,1.0,color='k', linestyle='dashed') plt.axhline(0.95,0.0,1.0,color='k', linestyle='dashed')
plt.ylim(bottom=0, top=ymax) plt.ylim(bottom=0, top=ymax)
ax = plt.gca() ax = plt.gca()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment