Skip to content
Snippets Groups Projects
Verified Commit dc7d7c9f authored by Björn Ludwig's avatar Björn Ludwig
Browse files

feat(dataset): leave storage location specification to pooch to share data across local projects

parent 3c852929
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
__all__ = [ __all__ = [
"ExtractionDataType", "ExtractionDataType",
"LOCAL_ZEMA_DATASET_PATH",
"ZeMASamples", "ZeMASamples",
"ZEMA_DATASET_URL", "ZEMA_DATASET_URL",
"ZEMA_QUANTITIES", "ZEMA_QUANTITIES",
...@@ -13,7 +12,7 @@ import os ...@@ -13,7 +12,7 @@ import os
import pickle import pickle
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from os.path import dirname, exists from os.path import exists
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
...@@ -21,11 +20,10 @@ import h5py ...@@ -21,11 +20,10 @@ import h5py
import numpy as np import numpy as np
from h5py import Dataset from h5py import Dataset
from numpy._typing import NDArray from numpy._typing import NDArray
from pooch import retrieve from pooch import os_cache, retrieve
from zema_emc_annotated.data_types import RealMatrix, RealVector, UncertainArray from zema_emc_annotated.data_types import RealMatrix, RealVector, UncertainArray
LOCAL_ZEMA_DATASET_PATH = Path(dirname(__file__), "datasets")
ZEMA_DATASET_URL = "https://zenodo.org/record/5185953/files/axis11_2kHz_ZeMA_PTB_SI.h5" ZEMA_DATASET_URL = "https://zenodo.org/record/5185953/files/axis11_2kHz_ZeMA_PTB_SI.h5"
ZEMA_QUANTITIES = ( ZEMA_QUANTITIES = (
"Acceleration", "Acceleration",
...@@ -109,7 +107,6 @@ class ZeMASamples: ...@@ -109,7 +107,6 @@ class ZeMASamples:
dataset_full_path = retrieve( dataset_full_path = retrieve(
url=ZEMA_DATASET_URL, url=ZEMA_DATASET_URL,
known_hash=None, known_hash=None,
path=LOCAL_ZEMA_DATASET_PATH,
progressbar=True, progressbar=True,
) )
assert exists(dataset_full_path) assert exists(dataset_full_path)
...@@ -261,11 +258,13 @@ class ZeMASamples: ...@@ -261,11 +258,13 @@ class ZeMASamples:
if self.samples_slice.start is not None # pylint: disable=no-member if self.samples_slice.start is not None # pylint: disable=no-member
else self.samples_slice.stop # pylint: disable=no-member else self.samples_slice.stop # pylint: disable=no-member
) )
return LOCAL_ZEMA_DATASET_PATH.joinpath( return Path(
f"{str(n_samples)}_samples" os_cache("pooch").joinpath(
f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_" f"{str(n_samples)}_samples"
f"{str(self.size_scaler)}_values_per_sensor" f"{'_starting_from_' + str(idx_start) if idx_start else ''}_with_"
f"{'_normalized' if normalize else ''}.pickle" f"{str(self.size_scaler)}_values_per_sensor"
f"{'_normalized' if normalize else ''}.pickle"
)
) )
def _store_cache(self, normalize: bool) -> None: def _store_cache(self, normalize: bool) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment