Skip to content
Snippets Groups Projects
Verified Commit 37f02739 authored by Björn Ludwig's avatar Björn Ludwig
Browse files

refactor(propagate): extract main logic into function to make it testable

parent 2ee3d860
No related branches found
No related tags found
No related merge requests found
...@@ -78,10 +78,13 @@ def _construct_out_features_counts( ...@@ -78,10 +78,13 @@ def _construct_out_features_counts(
return list(sorted(partition, reverse=True)) return list(sorted(partition, reverse=True))
if __name__ == "__main__": def iterate_over_activations_and_architectures(
depths: tuple[int, ...], size_scalers: tuple[int, ...]
) -> None:
"""Iterate over GUM modules for hard coded architectures and ZeMA sample sizes"""
for MLPModule in (GUMSoftplusMLP, GUMQuadLUMLP, GUMSigmoidMLP): for MLPModule in (GUMSoftplusMLP, GUMQuadLUMLP, GUMSigmoidMLP):
for layers_additional_to_input in (1, 3, 5, 8): for layers_additional_to_input in depths:
for samples_per_sensor in (1, 10, 100, 1000, 2000): for samples_per_sensor in size_scalers:
for set_to_none in (False, True): for set_to_none in (False, True):
with open("timings.txt", "a", encoding="utf-8") as timings_file: with open("timings.txt", "a", encoding="utf-8") as timings_file:
timings_file.write( timings_file.write(
...@@ -113,3 +116,7 @@ if __name__ == "__main__": ...@@ -113,3 +116,7 @@ if __name__ == "__main__":
f"{MLPModule.__name__}_{samples_per_sensor * 11}_inputs_" f"{MLPModule.__name__}_{samples_per_sensor * 11}_inputs_"
f"{layers_additional_to_input}_layers_trace.json" f"{layers_additional_to_input}_layers_trace.json"
) )
if __name__ == "__main__":
iterate_over_activations_and_architectures((1, 3, 5, 8), (1, 10, 100, 1000, 2000))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment