From 30a5cf99c27ad652c60552e0d8acb2ac251e696c Mon Sep 17 00:00:00 2001
From: Bjoern Ludwig <bjoern.ludwig@ptb.de>
Date: Thu, 29 Dec 2022 15:38:26 -0500
Subject: [PATCH] feat(dataset): introduce scaler parameter to retrieve several
 datapoints from each cycle at once

---
 src/zema_emc_annotated/dataset.py | 98 +++++++++++++++----------------
 1 file changed, 48 insertions(+), 50 deletions(-)

diff --git a/src/zema_emc_annotated/dataset.py b/src/zema_emc_annotated/dataset.py
index 83572a0..ce1b935 100644
--- a/src/zema_emc_annotated/dataset.py
+++ b/src/zema_emc_annotated/dataset.py
@@ -6,20 +6,21 @@ __all__ = [
     "LOCAL_ZEMA_DATASET_PATH",
     "ZEMA_DATASET_HASH",
     "ZEMA_DATASET_URL",
-    "ZEMA_DATATYPES",
     "ZEMA_QUANTITIES",
 ]
 
+import operator
 import os
 import pickle
 from enum import Enum
+from functools import reduce
 from os.path import dirname, exists
 from pathlib import Path
 from typing import cast
 
 import h5py
 import numpy as np
-from h5py import Dataset, File, Group
+from h5py import Dataset
 from numpy._typing import NDArray
 from pooch import retrieve
 
@@ -30,7 +31,6 @@ ZEMA_DATASET_HASH = (
     "sha256:fb0e80de4e8928ae8b859ad9668a1b6ea6310028a6690bb8d4c1abee31cb8833"
 )
 ZEMA_DATASET_URL = "https://zenodo.org/record/5185953/files/axis11_2kHz_ZeMA_PTB_SI.h5"
-ZEMA_DATATYPES = ("qudt:standardUncertainty", "qudt:value")
 ZEMA_QUANTITIES = (
     "Acceleration",
     "Active_Current",
@@ -57,7 +57,9 @@ class ExtractionDataType(Enum):
     VALUES = "qudt:value"
 
 
-def provide_zema_samples(n_samples: int = 1) -> UncertainArray:
+def provide_zema_samples(
+    n_samples: int = 1, size_scaler: int = 1, normalize: bool = False
+) -> UncertainArray:
     """Extracts requested number of samples of values with associated uncertainties
 
     The underlying dataset is the annotated "Sensor data set of one electromechanical
@@ -65,26 +67,28 @@ def provide_zema_samples(n_samples: int = 1) -> UncertainArray:
 
     Parameters
     ----------
-    n_samples : int
-        number of samples each containing one reading from each of the eleven sensors
-        with associated uncertainties
-
+    n_samples : int, optional
+        number of samples each containing size_scaler readings from each of the eleven
+        sensors with associated uncertainties, defaults to 1
+    size_scaler : int, optional
+        number of sensor readings from each of the individual sensors per sample,
+        defaults to 1
+    normalize : bool, optional
+        if ``True``, then data is centered around zero and scaled to unit std,
+        defaults to False
     Returns
     -------
     UncertainArray
-        The collection of samples of values with associated uncertainties
+        The collection of samples of values with associated uncertainties, will be of
+        shape (n_samples, 11 x size_scaler)
     """
 
-    def _hdf5_part(hdf5_file: File, keys: list[str]) -> Group | Dataset:
-        part = hdf5_file
-        for key in keys:
-            part = part[key]
-        return part
-
-    def _extract_sample_from_dataset(
-        data_set: Dataset, ns_samples: tuple[slice, int]
-    ) -> NDArray[np.double]:
-        return np.expand_dims(np.array(data_set[ns_samples]), 1)
+    def _normalize_if_requested(data: Dataset) -> NDArray[np.double]:
+        _potentially_normalized_data = data[np.s_[1 : size_scaler + 1, :n_samples]]
+        if normalize:
+            _potentially_normalized_data -= np.mean(data[:, :n_samples], axis=0)
+            _potentially_normalized_data /= np.std(data[:, :n_samples], axis=0)
+        return _potentially_normalized_data.transpose()
 
     def _append_to_extraction(
         append_to: NDArray[np.double], appendix: NDArray[np.double]
@@ -102,46 +106,40 @@ def provide_zema_samples(n_samples: int = 1) -> UncertainArray:
     assert exists(dataset_full_path)
     uncertainties = np.empty((n_samples, 0))
     values = np.empty((n_samples, 0))
-    indices = np.s_[0:n_samples, 0]
     relevant_datasets = (
-        ["ZeMA_DAQ", quantity, datatype]
+        ["ZeMA_DAQ", quantity, datatype.value]
         for quantity in ZEMA_QUANTITIES
-        for datatype in ZEMA_DATATYPES
+        for datatype in ExtractionDataType
     )
     with h5py.File(dataset_full_path, "r") as h5f:
-        for dataset in relevant_datasets:
-            if ExtractionDataType.UNCERTAINTIES.value in dataset:
+        for dataset_descriptor in relevant_datasets:
+            dataset = cast(Dataset, reduce(operator.getitem, dataset_descriptor, h5f))
+            if ExtractionDataType.UNCERTAINTIES.value in dataset.name:
                 extracted_data = uncertainties
-                print(f"    Extract uncertainties from {dataset}")
-            elif ExtractionDataType.VALUES.value in dataset:
+                print(f"    Extract uncertainties from {dataset.name}")
+            elif ExtractionDataType.VALUES.value in dataset.name:
                 extracted_data = values
-                print(f"    Extract values from {dataset}")
+                print(f"    Extract values from {dataset.name}")
             else:
-                extracted_data = None
-            if extracted_data is not None:
-                if len(_hdf5_part(h5f, dataset).shape) == 3:
-                    for sensor in _hdf5_part(h5f, dataset):
-                        extracted_data = _append_to_extraction(
-                            extracted_data,
-                            _extract_sample_from_dataset(sensor, indices),
-                        )
-                else:
+                raise RuntimeError(
+                    "Somehow there is unexpected data in the dataset to be processed. "
+                    f"Did not expect to find {dataset.name}"
+                )
+            if dataset.shape[0] == 3:
+                for sensor in dataset:
                     extracted_data = _append_to_extraction(
-                        extracted_data,
-                        _extract_sample_from_dataset(
-                            _hdf5_part(h5f, dataset),
-                            indices,
-                        ),
+                        extracted_data, _normalize_if_requested(sensor)
                     )
-                if (
-                    ExtractionDataType.UNCERTAINTIES.value
-                    in _hdf5_part(h5f, dataset).name
-                ):
-                    uncertainties = extracted_data
-                    print("    Uncertainties extracted")
-                elif ExtractionDataType.VALUES.value in _hdf5_part(h5f, dataset).name:
-                    values = extracted_data
-                    print("    Values extracted")
+            else:
+                extracted_data = _append_to_extraction(
+                    extracted_data, _normalize_if_requested(dataset)
+                )
+            if ExtractionDataType.UNCERTAINTIES.value in dataset.name:
+                uncertainties = extracted_data
+                print("    Uncertainties extracted")
+            elif ExtractionDataType.VALUES.value in dataset.name:
+                values = extracted_data
+                print("    Values extracted")
     uncertain_values = UncertainArray(np.array(values), np.array(uncertainties))
     _store_cache(uncertain_values)
     return uncertain_values
-- 
GitLab