From 03c02b1896070c2dd223e70da2027299620f39c7 Mon Sep 17 00:00:00 2001
From: Joerg Martin <joerg.martin@ptb.de>
Date: Mon, 13 Dec 2021 13:44:44 +0100
Subject: [PATCH] Argument loading in training scripts

---
 Experiments/train_eiv.py    | 12 +++++++++++-
 Experiments/train_noneiv.py | 10 +++++++++-
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py
index 5c9032b..78f6d61 100644
--- a/Experiments/train_eiv.py
+++ b/Experiments/train_eiv.py
@@ -4,6 +4,7 @@ Train EiV model using different seeds
 import random
 import importlib
 import os
+import argparse
 import json
 
 import numpy as np
@@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter
 from EIVArchitectures import Networks, initialize_weights
 from EIVTrainingRoutines import train_and_store, loss_functions
 
-data = 'california'
+
+# read in data via --data option
+parser = argparse.ArgumentParser()
+parser.add_argument("--data", help="Loads data", default='california')
+parser.add_argument("--no-autoindent", help="",
+        action="store_true") # to avoid conflics in IPython
+args = parser.parse_args()
+data = args.data
 
 # load hyperparameters from JSON file
 with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file:
@@ -42,6 +50,8 @@ gamma = conf_dict["gamma"]
 hidden_layers = conf_dict["hidden_layers"]
 seed_range = conf_dict['seed_range']
 
+print(f"Training on {long_dataname} data")
+
 try:
     gpu_number = conf_dict["gpu_number"]
     device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu')
diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py
index 3d741b3..24d051f 100644
--- a/Experiments/train_noneiv.py
+++ b/Experiments/train_noneiv.py
@@ -4,6 +4,7 @@ Train non-EiV model using different seeds
 import random
 import importlib
 import os
+import argparse
 import json
 
 import numpy as np
@@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter
 from EIVArchitectures import Networks, initialize_weights
 from EIVTrainingRoutines import train_and_store, loss_functions
 
-data = 'california'
+
+# read in data via --data option
+parser = argparse.ArgumentParser()
+parser.add_argument("--data", help="Loads data", default='california')
+parser.add_argument("--no-autoindent", help="",
+        action="store_true") # to avoid conflics in IPython
+args = parser.parse_args()
+data = args.data
 
 # load hyperparameters from JSON file
 with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file:
-- 
GitLab