Skip to content
Snippets Groups Projects
Verified Commit 706a3ec6 authored by Björn Ludwig's avatar Björn Ludwig
Browse files

refactor(solve_one_instance): rename optimize and include refactor to function

parent 21afadbf
No related branches found
No related tags found
No related merge requests found
Optimize Optimize
======== ========
.. literalinclude:: examples/optimize.py .. literalinclude:: examples/solve_one_instance.py
Solve chosen instances for several network architectures Solve chosen instances for several network architectures
======================================================== ========================================================
......
...@@ -29,13 +29,15 @@ from lp_nn_robustness_verification.data_types import UncertainArray ...@@ -29,13 +29,15 @@ from lp_nn_robustness_verification.data_types import UncertainArray
from lp_nn_robustness_verification.linear_program import RobustVerifier from lp_nn_robustness_verification.linear_program import RobustVerifier
from lp_nn_robustness_verification.pre_processing import LinearInclusion from lp_nn_robustness_verification.pre_processing import LinearInclusion
if __name__ == "__main__":
SAMPLES_PER_SENSOR = 10 def optimize() -> None:
DEPTH = 1 """Solve one specific hard coded instance and time the process"""
zema_data = ZeMASamples(size_scaler=SAMPLES_PER_SENSOR, normalize=True) samples_per_sensor = 10
depth = 1
zema_data = ZeMASamples(size_scaler=samples_per_sensor, normalize=True)
nn_params = generate_weights_and_biases( nn_params = generate_weights_and_biases(
len(zema_data.values[0]), len(zema_data.values[0]),
construct_out_features_counts(len(zema_data.values[0]), depth=DEPTH), construct_out_features_counts(len(zema_data.values[0]), depth=depth),
seed=0, seed=0,
) )
for (values, uncertainties) in zip(zema_data.values, zema_data.uncertainties): for (values, uncertainties) in zip(zema_data.values, zema_data.uncertainties):
...@@ -48,15 +50,26 @@ if __name__ == "__main__": ...@@ -48,15 +50,26 @@ if __name__ == "__main__":
optimization = RobustVerifier(linear_inclusion) optimization = RobustVerifier(linear_inclusion)
optimization.solve() optimization.solve()
yappi.stop() yappi.stop()
with open("timings.txt", "a", encoding="utf-8") as timings_file: with open(
(
f"{samples_per_sensor * 11}_inputs_and_"
f"{depth}_layers_with_sample_0_and_seed_0_timings.txt"
),
"a",
encoding="utf-8",
) as timings_file:
timings_file.write( timings_file.write(
f"\n===================================================================" f"\n==================================================================="
f"===================\n" f"===================\n"
f"Timings for {SAMPLES_PER_SENSOR * 11} inputs and {DEPTH} " f"Timings for {samples_per_sensor * 11} inputs and "
f"{'layers' if DEPTH > 1 else 'layer'}" f"{depth} {'layers' if depth > 1 else 'layer'} with sample 0 and seed 0"
f"\n===================================================================" f"\n==================================================================="
f"===================\n" f"===================\n"
) )
yappi.get_func_stats().print_all( yappi.get_func_stats().print_all( # type: ignore[import]
out=timings_file, columns={0: ("name", 180), 3: ("ttot", 8)} out=timings_file, columns={0: ("name", 180), 3: ("ttot", 8)}
) )
if __name__ == "__main__":
optimize()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment