From f0b8dfa85711e2fc3153361d147d6d70d66074b9 Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Tue, 7 Jun 2022 13:21:48 +0000
Subject: [PATCH] Added evaluation of coverage via linear model

No plotting included yet
---
 EIVPackage/EIVData/cubic.py                |  8 +-
 EIVPackage/EIVData/linear.py               | 10 ++-
 EIVPackage/EIVData/quadratic.py            |  8 +-
 EIVPackage/EIVData/sine.py                 | 10 ++-
 EIVPackage/EIVGeneral/linear_evaluation.py | 52 +++++++++++++
 Experiments/evaluate_metrics.py            | 91 ++++++++++++++++++----
 6 files changed, 155 insertions(+), 24 deletions(-)
 create mode 100644 EIVPackage/EIVGeneral/linear_evaluation.py

diff --git a/EIVPackage/EIVData/cubic.py b/EIVPackage/EIVData/cubic.py
index 9d5e397..ca4d7c4 100644
--- a/EIVPackage/EIVData/cubic.py
+++ b/EIVPackage/EIVData/cubic.py
@@ -12,6 +12,12 @@ intercept = 0.0
 x_noise_strength = 0.2
 y_noise_strength = 0.05
 func = lambda true_x: slope * true_x**3 + intercept 
+def design_matrix(x, device):
+    x = x.to(device)
+    assert len(x.shape) == 2 and x.shape[1] == 1
+    return torch.cat((torch.ones(x.shape).to(device), x, x**2, x**3), dim=1)
+
+
 
 def load_data(seed=0, splitting_part=0.8, normalize=True,
         return_ground_truth=False,
@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
     true_cubic_testset = TensorDataset(true_test_x, true_test_y,
             noisy_test_x, noisy_test_y)
 
-
     # return different objects, depending on Booleans
     if not return_ground_truth:
         if not return_normalized_func:
@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
         else: 
             return cubic_trainset, cubic_testset, true_cubic_trainset,\
                 true_cubic_testset, normalized_func
-
diff --git a/EIVPackage/EIVData/linear.py b/EIVPackage/EIVData/linear.py
index 2a1b0d7..790102d 100644
--- a/EIVPackage/EIVData/linear.py
+++ b/EIVPackage/EIVData/linear.py
@@ -12,6 +12,12 @@ intercept = 0.0
 x_noise_strength = 0.1
 y_noise_strength = 0.2
 func = lambda true_x: slope * true_x + intercept 
+def design_matrix(x, device):
+    x = x.to(device)
+    assert len(x.shape) == 2 and x.shape[1] == 1
+    return torch.cat((torch.ones(x.shape).to(device), x), dim=1)
+
+
 
 def load_data(seed=0, splitting_part=0.8, normalize=True,
         return_ground_truth=False,
@@ -85,8 +91,7 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
             noisy_train_x, noisy_train_y)
     true_linear_testset = TensorDataset(true_test_x, true_test_y,
             noisy_test_x, noisy_test_y)
-
-
+    
     # return different objects, depending on Booleans
     if not return_ground_truth:
         if not return_normalized_func:
@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
         else: 
             return linear_trainset, linear_testset, true_linear_trainset,\
                 true_linear_testset, normalized_func
-
diff --git a/EIVPackage/EIVData/quadratic.py b/EIVPackage/EIVData/quadratic.py
index 655e6bf..304e43e 100644
--- a/EIVPackage/EIVData/quadratic.py
+++ b/EIVPackage/EIVData/quadratic.py
@@ -12,6 +12,12 @@ intercept = 0.0
 x_noise_strength = 0.1
 y_noise_strength = 0.1
 func = lambda true_x: slope * true_x**2 + intercept 
+def design_matrix(x, device):
+    x = x.to(device)
+    assert len(x.shape) == 2 and x.shape[1] == 1
+    return torch.cat((torch.ones(x.shape).to(device), x, x**2), dim=1)
+
+
 
 def load_data(seed=0, splitting_part=0.8, normalize=True,
         return_ground_truth=False,
@@ -86,7 +92,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
     true_quadratic_testset = TensorDataset(true_test_x, true_test_y,
             noisy_test_x, noisy_test_y)
 
-
     # return different objects, depending on Booleans
     if not return_ground_truth:
         if not return_normalized_func:
@@ -100,4 +105,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
         else: 
             return quadratic_trainset, quadratic_testset, true_quadratic_trainset,\
                 true_quadratic_testset, normalized_func
-
diff --git a/EIVPackage/EIVData/sine.py b/EIVPackage/EIVData/sine.py
index 2e9eabc..73a35a0 100644
--- a/EIVPackage/EIVData/sine.py
+++ b/EIVPackage/EIVData/sine.py
@@ -13,6 +13,14 @@ y_noise_strength = 0.01
 func = lambda true_x: true_x +\
             torch.sin(2 * torch.pi * true_x) +\
             torch.sin(4 * torch.pi * true_x)
+def design_matrix(x, device):
+    x = x.to(device)
+    assert len(x.shape) == 2 and x.shape[1] == 1
+    return torch.cat((x, 
+        torch.sin(2 * torch.pi * x), torch.sin(4 * torch.pi * x)), dim=1)
+
+
+
 
 def load_data(seed=0, splitting_part=0.8, normalize=True,
         return_ground_truth=False,
@@ -87,7 +95,6 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
     true_sine_testset = TensorDataset(true_test_x, true_test_y,
             noisy_test_x, noisy_test_y)
 
-
     # return different objects, depending on Booleans
     if not return_ground_truth:
         if not return_normalized_func:
@@ -101,4 +108,3 @@ def load_data(seed=0, splitting_part=0.8, normalize=True,
         else: 
             return sine_trainset, sine_testset, true_sine_trainset,\
                 true_sine_testset, normalized_func
-
diff --git a/EIVPackage/EIVGeneral/linear_evaluation.py b/EIVPackage/EIVGeneral/linear_evaluation.py
new file mode 100644
index 0000000..974b9fd
--- /dev/null
+++ b/EIVPackage/EIVGeneral/linear_evaluation.py
@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+import numpy.linalg as LA
+from EIVGeneral.coverage_metrics import logical_and_along_dimension, multivariate_interval_length
+
+
+def compute_par_est_var(x_train, y_train, sigma_y, design_matrix, device):
+    X = design_matrix(x_train, device=device)
+    y = y_train
+    assert len(X.shape) == 2 and len(y_train.shape) == 2
+    X_T = torch.transpose(X, 0,1)
+    inv_X_T_X = torch.linalg.inv(torch.matmul(X_T, X))
+    par_est = torch.matmul(inv_X_T_X,
+            torch.matmul( X_T, y))
+    par_var = inv_X_T_X * sigma_y**2
+    return par_est, par_var
+
+def linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x_test, device):
+    par_est, par_var = compute_par_est_var(x_train, y_train, sigma_y, design_matrix, device=device)
+    X = design_matrix(x_test, device=device)
+    X_T = torch.transpose(X, 0,1)
+    pred = torch.matmul(X, par_est)
+    # assume univariate, for simplicity
+    if len(pred.shape) > 1:
+        assert len(pred.shape) == 2
+        assert pred.shape[1] == 1
+    else:
+        pred = pred.view((-1,1))
+    cov = torch.matmul(X, torch.matmul(par_var, X_T))
+    unc = torch.sqrt(torch.diag(cov)).view((-1,1))
+    assert pred.shape == unc.shape
+    return pred, unc
+
+
+def linear_coverage(pred, epis_unc, true_y, q=0.95):
+    out_dim = true_y.shape[1]
+    interval_length = multivariate_interval_length(dim=out_dim, q=q) \
+                * epis_unc
+    # numerical computation
+    errors = pred - true_y
+    # assume univariate for simplicity
+    assert len(errors.shape) == 2
+    assert errors.shape[1] == 1
+    check_if_in_interval = logical_and_along_dimension(
+            torch.abs(errors) <= interval_length, dim=1)
+    numerical_coverage = torch.mean(check_if_in_interval.to(torch.float32)
+            ).cpu().detach().item()
+    return numerical_coverage
+    
+
+
+
diff --git a/Experiments/evaluate_metrics.py b/Experiments/evaluate_metrics.py
index 06f1b37..60172e4 100644
--- a/Experiments/evaluate_metrics.py
+++ b/Experiments/evaluate_metrics.py
@@ -18,6 +18,7 @@ from EIVTrainingRoutines import train_and_store
 from EIVGeneral.coverage_metrics import epistemic_coverage, normalized_std,\
         total_coverage
 from EIVData.repeated_sampling import repeated_sampling
+from EIVGeneral.linear_evaluation import linear_pred_unc, linear_coverage, compute_par_est_var
 
 # read in data via --data option
 parser = argparse.ArgumentParser()
@@ -46,6 +47,11 @@ print(f"Evaluating {long_dataname}")
 
 scale_outputs = False 
 load_data = importlib.import_module(f'EIVData.{long_dataname}').load_data
+try:
+    sigma_y = importlib.import_module(f'EIVData.{long_dataname}').y_noise_strength
+    design_matrix = importlib.import_module(f'EIVData.{long_dataname}').design_matrix
+except ImportError:
+    sigma_y = None
 
 train_data, test_data = load_data(normalize=normalize)
 input_dim = train_data[0][0].numel()
@@ -70,7 +76,8 @@ except KeyError:
 def collect_metrics(x_y_pairs, seed=0,
     noneiv_number_of_draws=100, eiv_number_of_draws=[100,5],
     decouple_dimensions=False, device=device,
-    scale_outputs=scale_outputs):
+    scale_outputs=scale_outputs,
+    train_data = None):
     """
     Compute various metrics for EiV and non-EiV for single seeds. Will be
     returned as dictionaries.
@@ -250,7 +257,20 @@ def collect_metrics(x_y_pairs, seed=0,
         net.noise_on()
     else:
         net.noise_off()
-    return noneiv_metrics, eiv_metrics
+    metrics = {'eiv': eiv_metrics, 'noneiv': noneiv_metrics}
+    if train_data is not None and sigma_y is not None:
+        assert design_matrix is not None
+        # design_matrix not normalized in data, consistency check
+        assert not normalize
+        x_train, y_train = train_data.tensors
+        x_train, y_train = x_train.to(device), y_train.to(device)
+        lin_pred, lin_unc = linear_pred_unc(x_train, y_train, sigma_y, design_matrix, x, device=device)
+        assert true_y is not None
+        lin_coverage = linear_coverage(lin_pred, lin_unc, true_y)
+        metrics['lin'] = {'coverage': lin_coverage}
+        
+    return metrics
+
 
 
 
@@ -463,11 +483,13 @@ def collect_full_seed_range_metrics(load_data,
         bias_per_x = torch.mean(eiv_residual_collection, dim=-1)
         avg_bias = torch.mean(torch.abs(bias_per_x))
         eiv_metrics['avg_bias'] = avg_bias
-    return noneiv_metrics, eiv_metrics
+    metrics = {'eiv': eiv_metrics, 'noneiv': noneiv_metrics}
+    return metrics
 
 # single seed metrics
 noneiv_metrics_collection = {}
 eiv_metrics_collection = {}
+lin_metrics_collection = {}
 collection_keys = []
 num_test_epochs = 10
 assert noneiv_conf_dict["seed_range"] == eiv_conf_dict["seed_range"]
@@ -492,6 +514,9 @@ for seed in tqdm(seed_list):
     else:
         test_dataloader = DataLoader(true_test_data,
                 batch_size=int(np.min((len(true_test_data), 800))), shuffle=True)
+    per_seed_eiv_metrics_collection = {}
+    per_seed_noneiv_metrics_collection= {}
+    per_seed_lin_metrics_collection = {}
     for i in tqdm(range(num_test_epochs)):
         for j, x_y_pairs in enumerate(test_dataloader):
             if j > number_of_test_samples:
@@ -501,25 +526,49 @@ for seed in tqdm(seed_list):
                 x_y_pairs = (None, None, *x_y_pairs)
             # should contain (true_x,true_y,x,y) or (None,None,x,y)
             assert len(x_y_pairs) == 4
-            noneiv_metrics, eiv_metrics = collect_metrics(x_y_pairs,
-                    seed=seed)
+            metrics = collect_metrics(x_y_pairs,
+                    seed=seed, train_data=train_data)
+            eiv_metrics = metrics['eiv']
+            noneiv_metrics = metrics['noneiv']
+            try:
+                lin_metrics = metrics['lin']
+            except KeyError:
+                lin_metrics = {}
             if i==0 and j==0:
-                # fill collection keys
-                assert eiv_metrics.keys() == noneiv_metrics.keys()
-                collection_keys = list(eiv_metrics.keys())
+                if seed==0:
+                    # fill collection keys
+                    assert eiv_metrics.keys() == noneiv_metrics.keys()
+                    collection_keys = list(eiv_metrics.keys())
+                    lin_keys = list(lin_metrics.keys())
+                    for key in collection_keys:
+                        noneiv_metrics_collection[key] = []
+                        eiv_metrics_collection[key] = []
+                    for key in lin_keys:
+                        lin_metrics_collection[key] = []
                 for key in collection_keys:
-                    noneiv_metrics_collection[key] = []
-                    eiv_metrics_collection[key] = []
+                    per_seed_eiv_metrics_collection[key] = []
+                    per_seed_noneiv_metrics_collection[key]= []
+                for key in lin_keys:
+                    per_seed_lin_metrics_collection[key] = []
             # collect results
             for key in collection_keys:
-                noneiv_metrics_collection[key].append(noneiv_metrics[key])
-                eiv_metrics_collection[key].append(eiv_metrics[key])
+                per_seed_noneiv_metrics_collection[key].append(noneiv_metrics[key])
+                per_seed_eiv_metrics_collection[key].append(eiv_metrics[key])
+            for key in lin_metrics_collection:
+                per_seed_lin_metrics_collection[key].append(lin_metrics[key])
+    for key in collection_keys:
+        noneiv_metrics_collection[key].append(np.mean(per_seed_noneiv_metrics_collection[key]))
+        eiv_metrics_collection[key].append(np.mean(per_seed_eiv_metrics_collection[key]))
+    for key in lin_keys:
+        lin_metrics_collection[key].append(np.mean(per_seed_lin_metrics_collection[key]))
+
 
 # full seed range metrics
 print('Computing metrics that use all seeds at once...')
-noneiv_full_seed_range_metrics, eiv_full_seed_range_metrics =\
-        collect_full_seed_range_metrics(load_data=load_data,\
+full_seed_range_metrics = collect_full_seed_range_metrics(load_data=load_data,
                 seed_range=seed_list)
+eiv_full_seed_range_metrics = full_seed_range_metrics['eiv']
+noneiv_full_seed_range_metrics = full_seed_range_metrics['noneiv']
 # add keys to collection_keys
 assert noneiv_full_seed_range_metrics.keys() ==\
         eiv_full_seed_range_metrics.keys()
@@ -544,6 +593,7 @@ for key in collection_keys:
         results_dict['noneiv'][key] = metric
         print(f'{key}: {metric:.5f} (NaN)')
 
+
 print('\n')
 print('EiV:\n-----')
 results_dict['eiv'] = {}
@@ -552,7 +602,7 @@ for key in collection_keys:
         # per seed metrics
         metric_mean = float(np.mean(eiv_metrics_collection[key]))
         metric_std  = float(np.std(eiv_metrics_collection[key])/\
-                np.sqrt(num_test_epochs*len(seed_list)))
+                np.sqrt(len(seed_list)))
         print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
         results_dict['eiv'][key] = (metric_mean, metric_std)
     else:
@@ -561,6 +611,17 @@ for key in collection_keys:
         results_dict['eiv'][key] = metric
         print(f'{key}: {metric:.5f} (NaN)')
 
+print('\n')
+print('Lin:\n-----')
+results_dict['lin'] = {}
+for key in lin_keys:
+    # exist only as seed metrics
+    metric_mean = float(np.mean(lin_metrics_collection[key]))
+    metric_std  = float(np.std(lin_metrics_collection[key])/\
+            np.sqrt(num_test_epochs*len(seed_list)))
+    results_dict['lin'][key] = (metric_mean, metric_std)
+    print(f'{key}: {metric_mean:.5f} ({metric_std:.5f})')
+
 # write results to a JSON file in the results folder
 with open(os.path.join('results',f'metrics_{short_dataname}.json'), 'w') as f:
     json.dump(results_dict, f)
-- 
GitLab