From dc7d7c9fbb61a3d0fe5e55e51d58b03a3d1ab6a5 Mon Sep 17 00:00:00 2001 From: Bjoern Ludwig <bjoern.ludwig@ptb.de> Date: Mon, 16 Jan 2023 18:10:40 +0100 Subject: [PATCH] feat(dataset): leave storage location specification to pooch to share data across local projects --- src/zema_emc_annotated/dataset.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/zema_emc_annotated/dataset.py b/src/zema_emc_annotated/dataset.py index ec6dddd..b4e940f 100644 --- a/src/zema_emc_annotated/dataset.py +++ b/src/zema_emc_annotated/dataset.py @@ -2,7 +2,6 @@ __all__ = [ "ExtractionDataType", - "LOCAL_ZEMA_DATASET_PATH", "ZeMASamples", "ZEMA_DATASET_URL", "ZEMA_QUANTITIES", @@ -13,7 +12,7 @@ import os import pickle from enum import Enum from functools import reduce -from os.path import dirname, exists +from os.path import exists from pathlib import Path from typing import cast @@ -21,11 +20,10 @@ import h5py import numpy as np from h5py import Dataset from numpy._typing import NDArray -from pooch import retrieve +from pooch import os_cache, retrieve from zema_emc_annotated.data_types import RealMatrix, RealVector, UncertainArray -LOCAL_ZEMA_DATASET_PATH = Path(dirname(__file__), "datasets") ZEMA_DATASET_URL = "https://zenodo.org/record/5185953/files/axis11_2kHz_ZeMA_PTB_SI.h5" ZEMA_QUANTITIES = ( "Acceleration", @@ -109,7 +107,6 @@ class ZeMASamples: dataset_full_path = retrieve( url=ZEMA_DATASET_URL, known_hash=None, - path=LOCAL_ZEMA_DATASET_PATH, progressbar=True, ) assert exists(dataset_full_path) @@ -261,11 +258,13 @@ class ZeMASamples: if self.samples_slice.start is not None # pylint: disable=no-member else self.samples_slice.stop # pylint: disable=no-member ) - return LOCAL_ZEMA_DATASET_PATH.joinpath( - f"{str(n_samples)}_samples" - f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_" - f"{str(self.size_scaler)}_values_per_sensor" - f"{'_normalized' if normalize else ''}.pickle" + return Path( + os_cache("pooch").joinpath( + f"{str(n_samples)}_samples" + f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_" + f"{str(self.size_scaler)}_values_per_sensor" + f"{'_normalized' if normalize else ''}.pickle" + ) ) def _store_cache(self, normalize: bool) -> None: -- GitLab