diff --git a/Experiments/evaluate_tabular.py b/Experiments/evaluate_tabular.py index 7493e13b37fd417bb00fec04873d740a79781950..ba0654f071cdb96cfa72145fe048da4b088bc16b 100644 --- a/Experiments/evaluate_tabular.py +++ b/Experiments/evaluate_tabular.py @@ -144,7 +144,7 @@ def collect_metrics(x,y, seed=0, noise_state = net.noise_is_on net.train() net.noise_on() - not_averaged_predictions = net.predict(x, number_of_draws=noneiv_number_of_draws, + not_averaged_predictions = net.predict(x, number_of_draws=eiv_number_of_draws, take_average_of_prediction=False) eiv_mean = torch.mean(not_averaged_predictions[0], dim=1) if len(y.shape) <= 1: