Skip to content
Snippets Groups Projects
Commit 8b1adc0f authored by Jörg Martin's avatar Jörg Martin
Browse files

simulated datasets revised to show EiV effect

parent c2a8a3c4
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,8 @@ total_number_of_datapoints = 1000 ...@@ -9,8 +9,8 @@ total_number_of_datapoints = 1000
input_range = [-1,1] input_range = [-1,1]
slope = 1.0 slope = 1.0
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.1 x_noise_strength = 0.2
y_noise_strength = 0.1 y_noise_strength = 0.05
func = lambda true_x: slope * true_x**3 + intercept func = lambda true_x: slope * true_x**3 + intercept
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
......
...@@ -10,7 +10,7 @@ input_range = [-1,1] ...@@ -10,7 +10,7 @@ input_range = [-1,1]
slope = 1.0 slope = 1.0
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.1 x_noise_strength = 0.1
y_noise_strength = 0.1 y_noise_strength = 0.2
func = lambda true_x: slope * true_x + intercept func = lambda true_x: slope * true_x + intercept
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
......
...@@ -8,8 +8,8 @@ from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ ...@@ -8,8 +8,8 @@ from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
total_number_of_datapoints = 2000 total_number_of_datapoints = 2000
input_range = [-0.2,0.8] input_range = [-0.2,0.8]
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.02 x_noise_strength = 0.04
y_noise_strength = 0.02 y_noise_strength = 0.01
func = lambda true_x: true_x +\ func = lambda true_x: true_x +\
torch.sin(2 * torch.pi * true_x) +\ torch.sin(2 * torch.pi * true_x) +\
torch.sin(4 * torch.pi * true_x) torch.sin(4 * torch.pi * true_x)
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
"std_y_update_points": [1,40], "std_y_update_points": [1,40],
"eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.05],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.10, "fixed_std_x": 0.20,
"seed_range": [0,10], "seed_range": [0,10],
"gpu_number": 1 "gpu_number": 1
} }
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40], "std_y_update_points": [1,40],
"eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.1],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.10, "fixed_std_x": 0.10,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40], "std_y_update_points": [1,40],
"eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.1],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.10, "fixed_std_x": 0.10,
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
"eiv_number_of_forward_draws": 10, "eiv_number_of_forward_draws": 10,
"eiv_prediction_number_of_draws": [100,5], "eiv_prediction_number_of_draws": [100,5],
"eiv_prediction_number_of_batches": 10, "eiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.1], "init_std_y_list": [0.01],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.02, "fixed_std_x": 0.04,
"seed_range": [0,10], "seed_range": [0,10],
"gpu_number": 1 "gpu_number": 1
} }
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40] , "std_y_update_points": [1,40] ,
"noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10, "noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.05],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"seed_range": [0,10], "seed_range": [0,10],
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40] , "std_y_update_points": [1,40] ,
"noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10, "noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.1],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"seed_range": [0,10], "seed_range": [0,10],
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40] , "std_y_update_points": [1,40] ,
"noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10, "noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.5], "init_std_y_list": [0.1],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"seed_range": [0,10], "seed_range": [0,10],
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"std_y_update_points": [1,40] , "std_y_update_points": [1,40] ,
"noneiv_prediction_number_of_draws": 100, "noneiv_prediction_number_of_draws": 100,
"noneiv_prediction_number_of_batches": 10, "noneiv_prediction_number_of_batches": 10,
"init_std_y_list": [0.1], "init_std_y_list": [0.01],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"seed_range": [0,10], "seed_range": [0,10],
......
...@@ -13,6 +13,7 @@ import torch.backends.cudnn ...@@ -13,6 +13,7 @@ import torch.backends.cudnn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from matplotlib.pyplot import cm from matplotlib.pyplot import cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from tqdm import tqdm
from EIVArchitectures import Networks from EIVArchitectures import Networks
from EIVTrainingRoutines import train_and_store from EIVTrainingRoutines import train_and_store
...@@ -94,7 +95,7 @@ plt.ylabel('coverage') ...@@ -94,7 +95,7 @@ plt.ylabel('coverage')
# datasets to plot and their coloring # datasets to plot and their coloring
datasets = ['linear', 'quadratic','cubic','sine'] datasets = ['linear', 'quadratic','cubic','sine']
colors = ['cyan', 'magenta', 'yellow', 'green'] colors = ['#084519', '#7D098D', '#77050C', '#09017F']
def compute_coverages(data, eiv, number_of_draws): def compute_coverages(data, eiv, number_of_draws):
""" """
...@@ -148,7 +149,6 @@ def compute_coverages(data, eiv, number_of_draws): ...@@ -148,7 +149,6 @@ def compute_coverages(data, eiv, number_of_draws):
= load_data(seed=0, return_ground_truth=True, = load_data(seed=0, return_ground_truth=True,
normalize=normalize) normalize=normalize)
print(f"Computing {'EiV' if eiv else 'non-EiV'} coverage for {long_dataname}")
# train_data only used for finding dimensions # train_data only used for finding dimensions
input_dim = train_data[0][0].numel() input_dim = train_data[0][0].numel()
...@@ -233,14 +233,13 @@ def compute_coverages(data, eiv, number_of_draws): ...@@ -233,14 +233,13 @@ def compute_coverages(data, eiv, number_of_draws):
return numerical_coverage, theoretical_coverage return numerical_coverage, theoretical_coverage
# loop through data # loop through data
for data, color in zip(datasets, colors): for data, color in tqdm(zip(datasets, colors)):
# compute coverages # compute coverages
eiv_coverages = compute_coverages(data=data, eiv=True, eiv_coverages = compute_coverages(data=data, eiv=True,
number_of_draws=[100,5]) number_of_draws=[100,5])
noneiv_coverages = compute_coverages(data=data, eiv=False, noneiv_coverages = compute_coverages(data=data, eiv=False,
number_of_draws=100) number_of_draws=100)
# create plots # create plots
plt.figure(1)
coverage_diagonal_plot(eiv_coverages, noneiv_coverages, coverage_diagonal_plot(eiv_coverages, noneiv_coverages,
color=color, against_theoretical=False, label=data) color=color, against_theoretical=False, label=data)
......
...@@ -80,7 +80,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -80,7 +80,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
bottom = noneiv_metric_mean - k* noneiv_metric_std, bottom = noneiv_metric_mean - k* noneiv_metric_std,
color=colors[1], color=colors[1],
alpha=0.5) alpha=0.5)
plt.ylim(bottom=0, top=y_max) plt.ylim(bottom=0, top=ymax)
ax = plt.gca() ax = plt.gca()
ax.set_xticks(np.arange(1,len(data_list)+1)) ax.set_xticks(np.arange(1,len(data_list)+1))
ax.set_xticklabels(data_list, rotation='vertical') ax.set_xticklabels(data_list, rotation='vertical')
...@@ -127,7 +127,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -127,7 +127,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
color=colors[1], color=colors[1],
alpha=0.5) alpha=0.5)
plt.axhline(0.95,0.0,1.0,color='k', linestyle='dashed') plt.axhline(0.95,0.0,1.0,color='k', linestyle='dashed')
plt.ylim(bottom=0, top=y_max) plt.ylim(bottom=0, top=ymax)
ax = plt.gca() ax = plt.gca()
ax.set_xticks(np.arange(1,len(data_list)+1)) ax.set_xticks(np.arange(1,len(data_list)+1))
ax.set_xticklabels(data_list, rotation='vertical') ax.set_xticklabels(data_list, rotation='vertical')
...@@ -173,7 +173,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std), ...@@ -173,7 +173,7 @@ for i, ([(eiv_metric_mean, eiv_metric_std),
bottom = noneiv_metric_mean - k* noneiv_metric_std, bottom = noneiv_metric_mean - k* noneiv_metric_std,
color=colors[1], color=colors[1],
alpha=0.5) alpha=0.5)
plt.ylim(bottom=0, top=y_max) plt.ylim(bottom=0, top=ymax)
ax = plt.gca() ax = plt.gca()
ax.set_xticks(np.arange(1,len(data_list)+1)) ax.set_xticks(np.arange(1,len(data_list)+1))
ax.set_xticklabels(data_list, rotation='vertical') ax.set_xticklabels(data_list, rotation='vertical')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment