From 43360eb405aafc468c3d4bd15794e95873c58ccc Mon Sep 17 00:00:00 2001
From: Bjoern Ludwig <bjoern.ludwig@ptb.de>
Date: Sat, 21 Jan 2023 15:33:42 +0100
Subject: [PATCH] feat(dataset): reintroduce strict hash checking, which can
 optionally be skipped

---
 src/zema_emc_annotated/dataset.py | 55 +++++++++++++++++--------------
 1 file changed, 31 insertions(+), 24 deletions(-)

diff --git a/src/zema_emc_annotated/dataset.py b/src/zema_emc_annotated/dataset.py
index ad9a26e..f7411a7 100644
--- a/src/zema_emc_annotated/dataset.py
+++ b/src/zema_emc_annotated/dataset.py
@@ -23,7 +23,12 @@ from h5py import Dataset
 from numpy._typing import NDArray
 from pooch import os_cache, retrieve
 
-from zema_emc_annotated.data_types import RealMatrix, RealVector, UncertainArray
+from zema_emc_annotated.data_types import (
+    RealMatrix,
+    RealVector,
+    SampleSize,
+    UncertainArray,
+)
 
 ZEMA_DATASET_HASH = (
     "sha256:fb0e80de4e8928ae8b859ad9668a1b6ea6310028a6690bb8d4c1abee31cb8833"
@@ -63,54 +68,56 @@ class ZeMASamples:
 
     Parameters
     ----------
-    n_samples : int, optional
-        number of samples each containing the first ``size_scaler`` readings from each
-        of the eleven sensors for one of the cycles with associated uncertainties,
-        defaults to 1 and must be between 1 and 4766 - idx_start
-    size_scaler : int, optional
-        number of sensor readings from each of the individual sensors per sample/cycle,
-        defaults to 1 and should be between 1 and 2000, as there are only 2000
-        readings per cycle, higher values will be clipped to 2000
+    sample_size : SampleSize, optional
+        tuple containing information about which samples to extract, defaults to
+        default of :class:`~zema_emc_annotated.data_types.SampleSize`
     normalize : bool, optional
         if ``True``, then values are centered around zero and values and
         uncertainties are scaled to values' unit std, defaults to ``False``
-    idx_start : int, optional
-        index of first sample to be extracted, defaults to 0 and must be between 0
-        and 4765
+    skip_hash_check : bool, optional
+        allow to circumvent strict hash checking during the retrieve of dataset file,
+        to speed up concurrent calls as each check for the large file might take
+        several seconds, defaults to ``False``
 
     Attributes
     ----------
     uncertain_values : UncertainArray
         The collection of samples of values with associated uncertainties,
-        will be of shape (n_samples, 11 x size_scaler)
+        will be of shape (``sample_size.n_cycles``, 11 x
+        ``sample_size.datapoints_per_cycle``)
     """
 
     uncertain_values: UncertainArray
 
     def __init__(
         self,
-        n_samples: int = 1,
-        size_scaler: int = 1,
+        sample_size: SampleSize = SampleSize(),
         normalize: bool = False,
-        idx_start: int = 0,
+        skip_hash_check: bool = False,
     ):
 
-        self.samples_slice: slice = np.s_[idx_start : idx_start + n_samples]
-        self.size_scaler = size_scaler
+        self.samples_slice: slice = np.s_[
+            sample_size.idx_first_cycle : sample_size.idx_first_cycle
+            + sample_size.n_cycles
+        ]
+        self.size_scaler = sample_size.datapoints_per_cycle
         if cached_data := self._check_and_load_cache(normalize):
             self.uncertain_values = cached_data
         else:
-            self._uncertainties = np.empty((n_samples, 0))
-            self._values = np.empty((n_samples, 0))
-            self.uncertain_values = self._extract_data(normalize)
+            self._uncertainties = np.empty((sample_size.n_cycles, 0))
+            self._values = np.empty((sample_size.n_cycles, 0))
+            self.uncertain_values = self._extract_data(normalize, skip_hash_check)
             self._store_cache(normalize)
             del self._uncertainties
             del self._values
 
-    def _extract_data(self, normalize: bool) -> UncertainArray:
+    def _extract_data(
+        self, normalize: bool, skip_hash_check: bool = True
+    ) -> UncertainArray:
+        """Extract the data as specified"""
         dataset_full_path = retrieve(
             url=ZEMA_DATASET_URL,
-            known_hash=ZEMA_DATASET_HASH,
+            known_hash=None if skip_hash_check else ZEMA_DATASET_HASH,
             progressbar=True,
         )
         assert exists(dataset_full_path)
@@ -235,7 +242,7 @@ class ZeMASamples:
         return self.uncertain_values.uncertainties
 
     def _check_and_load_cache(self, normalize: bool) -> UncertainArray | None:
-        """Checks if corresponding file for n_samples exists and loads it with pickle"""
+        """Checks if corresponding file for n_cycles exists and loads it with pickle"""
         if os.path.exists(cache_path := self._cache_path(normalize)):
             with open(cache_path, "rb") as cache_file:
                 return cast(UncertainArray, pickle.load(cache_file))
-- 
GitLab