From dc7d7c9fbb61a3d0fe5e55e51d58b03a3d1ab6a5 Mon Sep 17 00:00:00 2001
From: Bjoern Ludwig <bjoern.ludwig@ptb.de>
Date: Mon, 16 Jan 2023 18:10:40 +0100
Subject: [PATCH] feat(dataset): leave storage location specification to pooch
 to share data across local projects

---
 src/zema_emc_annotated/dataset.py | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/src/zema_emc_annotated/dataset.py b/src/zema_emc_annotated/dataset.py
index ec6dddd..b4e940f 100644
--- a/src/zema_emc_annotated/dataset.py
+++ b/src/zema_emc_annotated/dataset.py
@@ -2,7 +2,6 @@
 
 __all__ = [
     "ExtractionDataType",
-    "LOCAL_ZEMA_DATASET_PATH",
     "ZeMASamples",
     "ZEMA_DATASET_URL",
     "ZEMA_QUANTITIES",
@@ -13,7 +12,7 @@ import os
 import pickle
 from enum import Enum
 from functools import reduce
-from os.path import dirname, exists
+from os.path import exists
 from pathlib import Path
 from typing import cast
 
@@ -21,11 +20,10 @@ import h5py
 import numpy as np
 from h5py import Dataset
 from numpy._typing import NDArray
-from pooch import retrieve
+from pooch import os_cache, retrieve
 
 from zema_emc_annotated.data_types import RealMatrix, RealVector, UncertainArray
 
-LOCAL_ZEMA_DATASET_PATH = Path(dirname(__file__), "datasets")
 ZEMA_DATASET_URL = "https://zenodo.org/record/5185953/files/axis11_2kHz_ZeMA_PTB_SI.h5"
 ZEMA_QUANTITIES = (
     "Acceleration",
@@ -109,7 +107,6 @@ class ZeMASamples:
         dataset_full_path = retrieve(
             url=ZEMA_DATASET_URL,
             known_hash=None,
-            path=LOCAL_ZEMA_DATASET_PATH,
             progressbar=True,
         )
         assert exists(dataset_full_path)
@@ -261,11 +258,13 @@ class ZeMASamples:
             if self.samples_slice.start is not None  # pylint: disable=no-member
             else self.samples_slice.stop  # pylint: disable=no-member
         )
-        return LOCAL_ZEMA_DATASET_PATH.joinpath(
-            f"{str(n_samples)}_samples"
-            f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_"
-            f"{str(self.size_scaler)}_values_per_sensor"
-            f"{'_normalized' if normalize else ''}.pickle"
+        return Path(
+            os_cache("pooch").joinpath(
+                f"{str(n_samples)}_samples"
+                f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_"
+                f"{str(self.size_scaler)}_values_per_sensor"
+                f"{'_normalized' if normalize else ''}.pickle"
+            )
         )
 
     def _store_cache(self, normalize: bool) -> None:
-- 
GitLab