Skip to content
Snippets Groups Projects
Verified Commit 6d197965 authored by Björn Ludwig's avatar Björn Ludwig
Browse files

refactor(test_dataset): adapt test suite to new implementation

parent 0f7345fe
No related branches found
No related tags found
No related merge requests found
Pipeline #15632 failed
"""Common strategies"""
from typing import cast
import numpy as np
from hypothesis import strategies as hst
from hypothesis.extra import numpy as hnp
from hypothesis.strategies import composite, DrawFn, SearchStrategy
from numpy._typing import NDArray
from zema_emc_annotated.data_types import UncertainArray
@composite
def uncertain_arrays(
draw: DrawFn,
greater_than: float = -1e2,
less_than: float = 1e2,
samples: int | None = None,
scaler: int | None = None,
) -> SearchStrategy[UncertainArray]:
if samples is None:
samples = draw(hst.integers(min_value=1, max_value=10))
if scaler is None:
scaler = draw(hst.integers(min_value=1, max_value=10))
values: NDArray[np.float64] = cast(
NDArray[np.float64],
draw(
hnp.arrays(
dtype=np.float64,
shape=(samples, scaler * 11),
elements=hst.floats(min_value=greater_than, max_value=less_than),
)
),
)
std_uncertainties = cast(
NDArray[np.float64],
draw(
hnp.arrays(
dtype=np.float64,
shape=values.shape,
elements=hst.floats(
min_value=np.abs(values).min() * 1e-3,
max_value=np.abs(values).min() * 1e2,
),
)
),
)
return cast(
SearchStrategy[UncertainArray],
UncertainArray(values, std_uncertainties),
)
...@@ -4,7 +4,7 @@ from pathlib import Path ...@@ -4,7 +4,7 @@ from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
from hypothesis import assume, given, settings, strategies as hst from hypothesis import given, settings, strategies as hst
from zema_emc_annotated import dataset from zema_emc_annotated import dataset
from zema_emc_annotated.data_types import UncertainArray from zema_emc_annotated.data_types import UncertainArray
...@@ -16,7 +16,6 @@ from zema_emc_annotated.dataset import ( ...@@ -16,7 +16,6 @@ from zema_emc_annotated.dataset import (
ZEMA_QUANTITIES, ZEMA_QUANTITIES,
ZeMASamples, ZeMASamples,
) )
from .conftest import uncertain_arrays
def test_dataset_has_docstring() -> None: def test_dataset_has_docstring() -> None:
...@@ -163,28 +162,8 @@ def test_check_and_load_cache_has_docstring() -> None: ...@@ -163,28 +162,8 @@ def test_check_and_load_cache_has_docstring() -> None:
assert ZeMASamples._check_and_load_cache.__doc__ is not None assert ZeMASamples._check_and_load_cache.__doc__ is not None
def test_check_and_load_cache_expects_parameter_n_samples() -> None: def test_check_and_load_cache_expects_parameter_normalize() -> None:
assert "n_samples" in signature(ZeMASamples._check_and_load_cache).parameters assert "normalize" in signature(ZeMASamples._check_and_load_cache).parameters
def test_check_and_load_cache_expects_parameter_n_samples_as_int() -> None:
assert (
signature(ZeMASamples._check_and_load_cache).parameters["n_samples"].annotation
is int
)
def test_check_and_load_cache_expects_parameter_size_scaler() -> None:
assert "size_scaler" in signature(ZeMASamples._check_and_load_cache).parameters
def test_check_and_load_cache_expects_parameter_size_scaler_as_int() -> None:
assert (
signature(ZeMASamples._check_and_load_cache)
.parameters["size_scaler"]
.annotation
is int
)
def test_zema_samples_has_attribute_cache_path() -> None: def test_zema_samples_has_attribute_cache_path() -> None:
...@@ -195,33 +174,30 @@ def test_dataset_cache_path_is_callable() -> None: ...@@ -195,33 +174,30 @@ def test_dataset_cache_path_is_callable() -> None:
assert callable(ZeMASamples._cache_path) assert callable(ZeMASamples._cache_path)
def test_cache_path_has_docstring() -> None: def test_cache_path_expects_parameter_normalize() -> None:
assert ZeMASamples._cache_path.__doc__ is not None assert "normalize" in signature(ZeMASamples._cache_path).parameters
def test_cache_path_expects_parameter_n_samples() -> None:
assert "n_samples" in signature(ZeMASamples._cache_path).parameters
def test_check_and_load_cache_expects_parameter_normalize_as_bool() -> None:
def test_cache_path_expects_parameter_size_scaler() -> None: assert (
assert "size_scaler" in signature(ZeMASamples._cache_path).parameters signature(ZeMASamples._check_and_load_cache).parameters["normalize"].annotation
is bool
)
def test_cache_path_expects_parameter_n_samples_as_int() -> None: def test_cache_path_has_docstring() -> None:
assert signature(ZeMASamples._cache_path).parameters["n_samples"].annotation is int assert ZeMASamples._cache_path.__doc__ is not None
def test_cache_path_expects_parameter_size_scaler_as_int() -> None: def test_cache_path_actually_returns_path() -> None:
assert ( assert isinstance(
signature(ZeMASamples._cache_path).parameters["size_scaler"].annotation is int ZeMASamples()._cache_path(
signature(ZeMASamples).parameters["normalize"].default
),
Path,
) )
@given(hst.integers(), hst.integers())
def test_cache_path_actually_returns_path(n_samples: int, size_scaler: int) -> None:
assert isinstance(ZeMASamples._cache_path(n_samples, size_scaler), Path)
def test_zema_samples_has_attribute_store_cache() -> None: def test_zema_samples_has_attribute_store_cache() -> None:
assert hasattr(ZeMASamples, "_store_cache") assert hasattr(ZeMASamples, "_store_cache")
...@@ -234,47 +210,57 @@ def test_store_cache_has_docstring() -> None: ...@@ -234,47 +210,57 @@ def test_store_cache_has_docstring() -> None:
assert ZeMASamples._store_cache.__doc__ is not None assert ZeMASamples._store_cache.__doc__ is not None
def test_store_cache_expects_parameter_uncertain_values() -> None: def test_store_cache_expects_parameter_normalize() -> None:
assert "uncertain_values" in signature(ZeMASamples._store_cache).parameters assert "normalize" in signature(ZeMASamples._store_cache).parameters
@given(uncertain_arrays(samples=11)) @pytest.mark.webtest
@given(hst.integers(min_value=1, max_value=10))
@settings(deadline=None) @settings(deadline=None)
def test_store_cache_runs_for_random_uncertain_values( def test_store_cache_stores_pickle_file_for_random_input(size_scaler: int) -> None:
uncertain_array: UncertainArray, zema_samples = ZeMASamples(11, size_scaler)
) -> None:
ZeMASamples._store_cache(uncertain_array)
assert os.path.exists( assert os.path.exists(
ZeMASamples._cache_path(11, int(uncertain_array.values.shape[1] / 11)) zema_samples._cache_path(signature(ZeMASamples).parameters["normalize"].default)
) )
@given(hst.integers(), hst.integers()) @pytest.mark.webtest
@given(hst.integers(min_value=1, max_value=10), hst.integers(min_value=1, max_value=10))
@settings(deadline=None)
def test_check_and_load_cache_runs_for_random_uncertain_values_and_returns( def test_check_and_load_cache_runs_for_random_uncertain_values_and_returns(
n_samples: int, size_scaler: int n_samples: int, size_scaler: int
) -> None: ) -> None:
result = ZeMASamples._check_and_load_cache(n_samples, size_scaler) result = ZeMASamples(n_samples, size_scaler)._check_and_load_cache(
signature(ZeMASamples).parameters["normalize"].default
)
assert result is None or isinstance(result, UncertainArray) assert result is None or isinstance(result, UncertainArray)
@given(uncertain_arrays(samples=12)) @pytest.mark.webtest
@given(hst.integers(min_value=1, max_value=10))
@settings(deadline=None)
def test_check_and_load_cache_returns_something_for_existing_file( def test_check_and_load_cache_returns_something_for_existing_file(
uncertain_array: UncertainArray, size_scaler: int,
) -> None: ) -> None:
ZeMASamples._store_cache(uncertain_array) zema_samples = ZeMASamples(12, size_scaler)
assert ( assert (
ZeMASamples._check_and_load_cache(12, int(uncertain_array.values.shape[1] / 11)) zema_samples._check_and_load_cache(
signature(ZeMASamples).parameters["normalize"].default
)
is not None is not None
) )
def test_store_cache_expects_parameter_uncertain_values_as_uncertain_array() -> None: def test_store_cache_expects_parameter_normalize_as_bool() -> None:
assert ( assert (
signature(ZeMASamples._store_cache).parameters["uncertain_values"].annotation signature(ZeMASamples._store_cache).parameters["normalize"].annotation is bool
is UncertainArray
) )
def test_cache_path_expects_parameter_normalize_as_bool() -> None:
assert signature(ZeMASamples._cache_path).parameters["normalize"].annotation is bool
def test_cache_path_expects_stats_to_return_path() -> None: def test_cache_path_expects_stats_to_return_path() -> None:
assert signature(ZeMASamples._cache_path).return_annotation is Path assert signature(ZeMASamples._cache_path).return_annotation is Path
...@@ -385,11 +371,39 @@ def test_extract_samples_returns_values_and_uncertainties_which_are_not_similar( ...@@ -385,11 +371,39 @@ def test_extract_samples_returns_values_and_uncertainties_which_are_not_similar(
@pytest.mark.webtest @pytest.mark.webtest
# @given(hst.integers(min_value=2, max_value=10), hst.integers(min_value=2, max_value=10)) def test_zema_samples_fails_for_more_than_4766_samples() -> None:
# @settings(deadline=None) with pytest.raises(
def test_extract_samples_returns_normalized_values( ValueError,
# n_samples: int, size_scaler: int match=r"all the input array dimensions except for the concatenation axis must "
r"match exactly.*",
):
ZeMASamples(4767)
@pytest.mark.webtest
def test_zema_samples_creates_pickle_files() -> None:
for size_scaler in (1, 10, 100, 1000, 2000):
for normalize in (True, False):
assert ZeMASamples(size_scaler=size_scaler, normalize=normalize)
@pytest.mark.webtest
@given(hst.integers(min_value=1, max_value=10), hst.integers(min_value=1, max_value=10))
@settings(deadline=None)
def test_zema_samples_normalized_mean_is_smaller_or_equal(
n_samples: int, size_scaler: int
) -> None:
normalized_result = ZeMASamples(n_samples, size_scaler, True)
not_normalized_result = ZeMASamples(n_samples, size_scaler)
assert not_normalized_result.values.mean() >= normalized_result.values.mean()
@pytest.mark.webtest
@given(hst.integers(min_value=1, max_value=10), hst.integers(min_value=1, max_value=10))
@settings(deadline=None)
def test_zema_samples_normalized_std_is_smaller_or_equal(
n_samples: int, size_scaler: int
) -> None: ) -> None:
# result = ZeMASamples(n_samples, size_scaler, True) normalized_result = ZeMASamples(n_samples, size_scaler, True)
result = ZeMASamples(2, 5, True) not_normalized_result = ZeMASamples(n_samples, size_scaler)
assert result.values.shape[1] == 11 * 5 assert not_normalized_result.values.std() >= normalized_result.values.std()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment