Skip to content
Snippets Groups Projects
Commit f0b8dfa8 authored by Jörg Martin's avatar Jörg Martin
Browse files

Added evaluation of coverage via linear model

No plotting included yet
parent bbae08f1
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,12 @@ intercept = 0.0 ...@@ -12,6 +12,12 @@ intercept = 0.0
x_noise_strength = 0.2 x_noise_strength = 0.2
y_noise_strength = 0.05 y_noise_strength = 0.05
func = lambda true_x: slope * true_x**3 + intercept func = lambda true_x: slope * true_x**3 + intercept
def design_matrix(x, device):
x = x.to(device)
assert len(x.shape) == 2 and x.shape[1] == 1
return torch.cat((torch.ones(x.shape).to(device), x, x**2, x**3), dim=1)
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False, return_ground_truth=False,
...@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
true_cubic_testset = TensorDataset(true_test_x, true_test_y, true_cubic_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y) noisy_test_x, noisy_test_y)
# return different objects, depending on Booleans # return different objects, depending on Booleans
if not return_ground_truth: if not return_ground_truth:
if not return_normalized_func: if not return_normalized_func:
...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
else: else:
return cubic_trainset, cubic_testset, true_cubic_trainset,\ return cubic_trainset, cubic_testset, true_cubic_trainset,\
true_cubic_testset, normalized_func true_cubic_testset, normalized_func
...@@ -12,6 +12,12 @@ intercept = 0.0 ...@@ -12,6 +12,12 @@ intercept = 0.0
x_noise_strength = 0.1 x_noise_strength = 0.1
y_noise_strength = 0.2 y_noise_strength = 0.2
func = lambda true_x: slope * true_x + intercept func = lambda true_x: slope * true_x + intercept
def design_matrix(x, device):
x = x.to(device)
assert len(x.shape) == 2 and x.shape[1] == 1
return torch.cat((torch.ones(x.shape).to(device), x), dim=1)
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False, return_ground_truth=False,
...@@ -85,8 +91,7 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -85,8 +91,7 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
noisy_train_x, noisy_train_y) noisy_train_x, noisy_train_y)
true_linear_testset = TensorDataset(true_test_x, true_test_y, true_linear_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y) noisy_test_x, noisy_test_y)
# return different objects, depending on Booleans # return different objects, depending on Booleans
if not return_ground_truth: if not return_ground_truth:
if not return_normalized_func: if not return_normalized_func:
...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
else: else:
return linear_trainset, linear_testset, true_linear_trainset,\ return linear_trainset, linear_testset, true_linear_trainset,\
true_linear_testset, normalized_func true_linear_testset, normalized_func
...@@ -12,6 +12,12 @@ intercept = 0.0 ...@@ -12,6 +12,12 @@ intercept = 0.0
x_noise_strength = 0.1 x_noise_strength = 0.1
y_noise_strength = 0.1 y_noise_strength = 0.1
func = lambda true_x: slope * true_x**2 + intercept func = lambda true_x: slope * true_x**2 + intercept
def design_matrix(x, device):
x = x.to(device)
assert len(x.shape) == 2 and x.shape[1] == 1
return torch.cat((torch.ones(x.shape).to(device), x, x**2), dim=1)
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False, return_ground_truth=False,
...@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
true_quadratic_testset = TensorDataset(true_test_x, true_test_y, true_quadratic_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y) noisy_test_x, noisy_test_y)
# return different objects, depending on Booleans # return different objects, depending on Booleans
if not return_ground_truth: if not return_ground_truth:
if not return_normalized_func: if not return_normalized_func:
...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
else: else:
return quadratic_trainset, quadratic_testset, true_quadratic_trainset,\ return quadratic_trainset, quadratic_testset, true_quadratic_trainset,\
true_quadratic_testset, normalized_func true_quadratic_testset, normalized_func
...@@ -13,6 +13,14 @@ y_noise_strength = 0.01 ...@@ -13,6 +13,14 @@ y_noise_strength = 0.01
func = lambda true_x: true_x +\ func = lambda true_x: true_x +\
torch.sin(2 * torch.pi * true_x) +\ torch.sin(2 * torch.pi * true_x) +\
torch.sin(4 * torch.pi * true_x) torch.sin(4 * torch.pi * true_x)
def design_matrix(x, device):
x = x.to(device)
assert len(x.shape) == 2 and x.shape[1] == 1
return torch.cat((x,
torch.sin(2 * torch.pi * x), torch.sin(4 * torch.pi * x)), dim=1)
def load_data(seed=0, splitting_part=0.8, normalize=True, def load_data(seed=0, splitting_part=0.8, normalize=True,
return_ground_truth=False, return_ground_truth=False,
...@@ -87,7 +95,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -87,7 +95,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
true_sine_testset = TensorDataset(true_test_x, true_test_y, true_sine_testset = TensorDataset(true_test_x, true_test_y,
noisy_test_x, noisy_test_y) noisy_test_x, noisy_test_y)
# return different objects, depending on Booleans # return different objects, depending on Booleans
if not return_ground_truth: if not return_ground_truth:
if not return_normalized_func: if not return_normalized_func:
...@@ -101,4 +108,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True, ...@@ -101,4 +108,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
else: else:
return sine_trainset, sine_testset, true_sine_trainset,\ return sine_trainset, sine_testset, true_sine_trainset,\
true_sine_testset, normalized_func true_sine_testset, normalized_func
import torch
import numpy as np
import numpy.linalg as LA
from EIVGeneral.coverage_metrics import logical_and_along_dimension, multivariate_interval_length
def compute_par_est_var(x_train, y_train, sigma_y, design_matrix, device):
X = design_matrix(x_train, device=device)
y = y_train
assert len(X.shape) == 2 and len(y_train.shape) == 2
X_T = torch.transpose(X, 0,1)
inv_X_T_X = torch.linalg.inv(torch.matmul(X_T, X))
par_est = torch.matmul(inv_X_T_X,
torch.matmul( X_T, y))
par_var = inv_X_T_X * sigma_y**2
return par_est, par_var
def linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x_test, device):
par_est, par_var = compute_par_est_var(x_train, y_train, sigma_y, design_matrix, device=device)
X = design_matrix(x_test, device=device)
X_T = torch.transpose(X, 0,1)
pred = torch.matmul(X, par_est)
# assume univariate, for simplicity
if len(pred.shape) > 1:
assert len(pred.shape) == 2
assert pred.shape[1] == 1
else:
pred = pred.view((-1,1))
cov = torch.matmul(X, torch.matmul(par_var, X_T))
unc = torch.sqrt(torch.diag(cov)).view((-1,1))
assert pred.shape == unc.shape
return pred, unc
def linear_coverage(pred, epis_unc, true_y, q=0.95):
out_dim = true_y.shape[1]
interval_length = multivariate_interval_length(dim=out_dim, q=q) \
* epis_unc
# numerical computation
errors = pred - true_y
# assume univariate for simplicity
assert len(errors.shape) == 2
assert errors.shape[1] == 1
check_if_in_interval = logical_and_along_dimension(
torch.abs(errors) <= interval_length, dim=1)
numerical_coverage = torch.mean(check_if_in_interval.to(torch.float32)
).cpu().detach().item()
return numerical_coverage
...@@ -18,6 +18,7 @@ from EIVTrainingRoutines import train_and_store ...@@ -18,6 +18,7 @@ from EIVTrainingRoutines import train_and_store
from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std,\ from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std,\
total_coverage total_coverage
from EIVData.repeated_sampling import repeated_sampling from EIVData.repeated_sampling import repeated_sampling
from EIVGeneral.linear_evaluation import linear_pred_unc, linear_coverage, compute_par_est_var
# read in data via --data option # read in data via --data option
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -46,6 +47,11 @@ print(f"Evaluating {long_dataname}") ...@@ -46,6 +47,11 @@ print(f"Evaluating {long_dataname}")
scale_outputs = False scale_outputs = False
load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
try:
sigma_y = importlib.import_module(f'EIVData.{long_dataname}').y_noise_strength
design_matrix = importlib.import_module(f'EIVData.{long_dataname}').design_matrix
except ImportError:
sigma_y = None
train_data, test_data = load_data(normalize=normalize) train_data, test_data = load_data(normalize=normalize)
input_dim = train_data[0][0].numel() input_dim = train_data[0][0].numel()
...@@ -70,7 +76,8 @@ except KeyError: ...@@ -70,7 +76,8 @@ except KeyError:
def collect_metrics(x_y_pairs, seed=0, def collect_metrics(x_y_pairs, seed=0,
noneiv_number_of_draws=100, eiv_number_of_draws=[100,5], noneiv_number_of_draws=100, eiv_number_of_draws=[100,5],
decouple_dimensions=False, device=device, decouple_dimensions=False, device=device,
scale_outputs=scale_outputs): scale_outputs=scale_outputs,
train_data = None):
""" """
Compute various metrics for EiV and non-EiV for single seeds. Will be Compute various metrics for EiV and non-EiV for single seeds. Will be
returned as dictionaries. returned as dictionaries.
...@@ -250,7 +257,20 @@ def collect_metrics(x_y_pairs, seed=0, ...@@ -250,7 +257,20 @@ def collect_metrics(x_y_pairs, seed=0,
net.noise_on() net.noise_on()
else: else:
net.noise_off() net.noise_off()
return noneiv_metrics, eiv_metrics metrics = {'eiv': eiv_metrics, 'noneiv': noneiv_metrics}
if train_data is not None and sigma_y is not None:
assert design_matrix is not None
# design_matrix not normalized in data, consistency check
assert not normalize
x_train, y_train = train_data.tensors
x_train, y_train = x_train.to(device), y_train.to(device)
lin_pred, lin_unc = linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x, device=device)
assert true_y is not None
lin_coverage = linear_coverage(lin_pred, lin_unc, true_y)
metrics['lin'] = {'coverage': lin_coverage}
return metrics
...@@ -463,11 +483,13 @@ def collect_full_seed_range_metrics(load_data, ...@@ -463,11 +483,13 @@ def collect_full_seed_range_metrics(load_data,
bias_per_x = torch.mean(eiv_residual_collection, dim=-1) bias_per_x = torch.mean(eiv_residual_collection, dim=-1)
avg_bias = torch.mean(torch.abs(bias_per_x)) avg_bias = torch.mean(torch.abs(bias_per_x))
eiv_metrics['avg_bias'] = avg_bias eiv_metrics['avg_bias'] = avg_bias
return noneiv_metrics, eiv_metrics metrics = {'eiv': eiv_metrics, 'noneiv': noneiv_metrics}
return metrics
# single seed metrics # single seed metrics
noneiv_metrics_collection = {} noneiv_metrics_collection = {}
eiv_metrics_collection = {} eiv_metrics_collection = {}
lin_metrics_collection = {}
collection_keys = [] collection_keys = []
num_test_epochs = 10 num_test_epochs = 10
assert noneiv_conf_dict["seed_range"] == eiv_conf_dict["seed_range"] assert noneiv_conf_dict["seed_range"] == eiv_conf_dict["seed_range"]
...@@ -492,6 +514,9 @@ for seed in tqdm(seed_list): ...@@ -492,6 +514,9 @@ for seed in tqdm(seed_list):
else: else:
test_dataloader = DataLoader(true_test_data, test_dataloader = DataLoader(true_test_data,
batch_size=int(np.min((len(true_test_data), 800))), shuffle=True) batch_size=int(np.min((len(true_test_data), 800))), shuffle=True)
per_seed_eiv_metrics_collection = {}
per_seed_noneiv_metrics_collection= {}
per_seed_lin_metrics_collection = {}
for i in tqdm(range(num_test_epochs)): for i in tqdm(range(num_test_epochs)):
for j, x_y_pairs in enumerate(test_dataloader): for j, x_y_pairs in enumerate(test_dataloader):
if j > number_of_test_samples: if j > number_of_test_samples:
...@@ -501,25 +526,49 @@ for seed in tqdm(seed_list): ...@@ -501,25 +526,49 @@ for seed in tqdm(seed_list):
x_y_pairs = (None, None, *x_y_pairs) x_y_pairs = (None, None, *x_y_pairs)
# should contain (true_x,true_y,x,y) or (None,None,x,y) # should contain (true_x,true_y,x,y) or (None,None,x,y)
assert len(x_y_pairs) == 4 assert len(x_y_pairs) == 4
noneiv_metrics, eiv_metrics = collect_metrics(x_y_pairs, metrics = collect_metrics(x_y_pairs,
seed=seed) seed=seed, train_data=train_data)
eiv_metrics = metrics['eiv']
noneiv_metrics = metrics['noneiv']
try:
lin_metrics = metrics['lin']
except KeyError:
lin_metrics = {}
if i==0 and j==0: if i==0 and j==0:
# fill collection keys if seed==0:
assert eiv_metrics.keys() == noneiv_metrics.keys() # fill collection keys
collection_keys = list(eiv_metrics.keys()) assert eiv_metrics.keys() == noneiv_metrics.keys()
collection_keys = list(eiv_metrics.keys())
lin_keys = list(lin_metrics.keys())
for key in collection_keys:
noneiv_metrics_collection[key] = []
eiv_metrics_collection[key] = []
for key in lin_keys:
lin_metrics_collection[key] = []
for key in collection_keys: for key in collection_keys:
noneiv_metrics_collection[key] = [] per_seed_eiv_metrics_collection[key] = []
eiv_metrics_collection[key] = [] per_seed_noneiv_metrics_collection[key]= []
for key in lin_keys:
per_seed_lin_metrics_collection[key] = []
# collect results # collect results
for key in collection_keys: for key in collection_keys:
noneiv_metrics_collection[key].append(noneiv_metrics[key]) per_seed_noneiv_metrics_collection[key].append(noneiv_metrics[key])
eiv_metrics_collection[key].append(eiv_metrics[key]) per_seed_eiv_metrics_collection[key].append(eiv_metrics[key])
for key in lin_metrics_collection:
per_seed_lin_metrics_collection[key].append(lin_metrics[key])
for key in collection_keys:
noneiv_metrics_collection[key].append(np.mean(per_seed_noneiv_metrics_collection[key]))
eiv_metrics_collection[key].append(np.mean(per_seed_eiv_metrics_collection[key]))
for key in lin_keys:
lin_metrics_collection[key].append(np.mean(per_seed_lin_metrics_collection[key]))
# full seed range metrics # full seed range metrics
print('Computing metrics that use all seeds at once...') print('Computing metrics that use all seeds at once...')
noneiv_full_seed_range_metrics, eiv_full_seed_range_metrics =\ full_seed_range_metrics = collect_full_seed_range_metrics(load_data=load_data,
collect_full_seed_range_metrics(load_data=load_data,\
seed_range=seed_list) seed_range=seed_list)
eiv_full_seed_range_metrics = full_seed_range_metrics['eiv']
noneiv_full_seed_range_metrics = full_seed_range_metrics['noneiv']
# add keys to collection_keys # add keys to collection_keys
assert noneiv_full_seed_range_metrics.keys() ==\ assert noneiv_full_seed_range_metrics.keys() ==\
eiv_full_seed_range_metrics.keys() eiv_full_seed_range_metrics.keys()
...@@ -544,6 +593,7 @@ for key in collection_keys: ...@@ -544,6 +593,7 @@ for key in collection_keys:
results_dict['noneiv'][key] = metric results_dict['noneiv'][key] = metric
print(f'{key}: {metric:.5f} (NaN)') print(f'{key}: {metric:.5f} (NaN)')
print('\n') print('\n')
print('EiV:\n-----') print('EiV:\n-----')
results_dict['eiv'] = {} results_dict['eiv'] = {}
...@@ -552,7 +602,7 @@ for key in collection_keys: ...@@ -552,7 +602,7 @@ for key in collection_keys:
# per seed metrics # per seed metrics
metric_mean = float(np.mean(eiv_metrics_collection[key])) metric_mean = float(np.mean(eiv_metrics_collection[key]))
metric_std = float(np.std(eiv_metrics_collection[key])/\ metric_std = float(np.std(eiv_metrics_collection[key])/\
np.sqrt(num_test_epochs*len(seed_list))) np.sqrt(len(seed_list)))
print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})') print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
results_dict['eiv'][key] = (metric_mean, metric_std) results_dict['eiv'][key] = (metric_mean, metric_std)
else: else:
...@@ -561,6 +611,17 @@ for key in collection_keys: ...@@ -561,6 +611,17 @@ for key in collection_keys:
results_dict['eiv'][key] = metric results_dict['eiv'][key] = metric
print(f'{key}: {metric:.5f} (NaN)') print(f'{key}: {metric:.5f} (NaN)')
print('\n')
print('Lin:\n-----')
results_dict['lin'] = {}
for key in lin_keys:
# exist only as seed metrics
metric_mean = float(np.mean(lin_metrics_collection[key]))
metric_std = float(np.std(lin_metrics_collection[key])/\
np.sqrt(num_test_epochs*len(seed_list)))
results_dict['lin'][key] = (metric_mean, metric_std)
print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
# write results to a JSON file in the results folder # write results to a JSON file in the results folder
with open(os.path.join('results',f'metrics_{short_dataname}.json'), 'w') as f: with open(os.path.join('results',f'metrics_{short_dataname}.json'), 'w') as f:
json.dump(results_dict, f) json.dump(results_dict, f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment