Skip to content
Snippets Groups Projects
Verified Commit f790e44a authored by Björn Ludwig's avatar Björn Ludwig
Browse files

test(propagate): adapt test suite to new module design

parent 00afcda5
Branches
Tags v0.12.0
No related merge requests found
import os
from glob import glob
from inspect import signature
from typing import Any, Type
from itertools import chain
from typing import Any, Callable, Generator, Type
import pytest
from hypothesis import given, settings, strategies as hst
......@@ -10,6 +13,7 @@ from pytorch_gum_uncertainty_propagation.examples import propagate
from pytorch_gum_uncertainty_propagation.examples.propagate import (
_construct_out_features_counts,
assemble_pipeline,
iterate_over_activations_and_architectures,
)
from pytorch_gum_uncertainty_propagation.modules import (
GUMQuadLUMLP,
......@@ -18,6 +22,26 @@ from pytorch_gum_uncertainty_propagation.modules import (
)
@pytest.fixture(scope="module")
def file_deleter() -> Callable[[tuple[str, ...]], None]:
def deleter(endings: tuple[str, ...]) -> None:
for file in chain(*(glob(f"*{ending}") for ending in endings)):
try:
os.remove(file)
except FileNotFoundError:
pass
return deleter
@pytest.fixture
def cleanup_traces_after_run(
file_deleter: Callable[[tuple[str, ...]], None]
) -> Generator[None, None, None]:
yield
file_deleter(("*_layers_trace.json",))
def test_propagate_has_docstring() -> None:
assert propagate.__doc__ is not None
......@@ -86,14 +110,16 @@ def test_construct_out_features_counts_parameter_in_features_is_of_type_int() ->
def test_construct_out_features_counts_parameter_states_to_return_int_list() -> None:
assert signature(_construct_out_features_counts).return_annotation == list[int]
assert (
signature(_construct_out_features_counts).return_annotation == tuple[int, ...]
)
@given(hst.integers(min_value=1, max_value=10))
def test_construct_out_features_counts_actually_returns_int_list(
in_features: int,
) -> None:
assert isinstance(_construct_out_features_counts(in_features), list)
assert isinstance(_construct_out_features_counts(in_features), tuple)
@given(hst.integers(min_value=1, max_value=100))
......@@ -102,7 +128,7 @@ def test_construct_out_features_counts_returns_non_empty_list(in_features: int)
def test_construct_out_features_counts_returns_correct_small_example() -> None:
assert _construct_out_features_counts(89, 8, 6) == [76, 63, 50, 36, 22, 8]
assert _construct_out_features_counts(89, 8, 6) == (76, 63, 50, 36, 22, 8)
@given(hst.integers(min_value=1, max_value=100))
......@@ -112,7 +138,7 @@ def test_construct_out_features_counts_is_descending(in_features: int) -> None:
def test_construct_out_features_counts_returns_correct_large_example() -> None:
assert _construct_out_features_counts(99, 3, 59) == [
assert _construct_out_features_counts(99, 3, 59) == (
98,
97,
96,
......@@ -172,7 +198,7 @@ def test_construct_out_features_counts_returns_correct_large_example() -> None:
7,
5,
3,
]
)
@pytest.mark.webtest
......@@ -259,3 +285,9 @@ def test_assemble_pipeline_actually_returns_profiler(
activation_module: Type[Module],
) -> None:
assert isinstance(assemble_pipeline(activation_module), profile)
def test_iterate_over_activations_and_architectures_runs(
cleanup_traces_after_run: Generator[None, None, None],
) -> None:
iterate_over_activations_and_architectures((1, 2), (1, 2, 10))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment