From 03c02b1896070c2dd223e70da2027299620f39c7 Mon Sep 17 00:00:00 2001 From: Joerg Martin <joerg.martin@ptb.de> Date: Mon, 13 Dec 2021 13:44:44 +0100 Subject: [PATCH] Argument loading in training scripts --- Experiments/train_eiv.py | 12 +++++++++++- Experiments/train_noneiv.py | 10 +++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py index 5c9032b..78f6d61 100644 --- a/Experiments/train_eiv.py +++ b/Experiments/train_eiv.py @@ -4,6 +4,7 @@ Train EiV model using different seeds import random import importlib import os +import argparse import json import numpy as np @@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVTrainingRoutines import train_and_store, loss_functions -data = 'california' + +# read in data via --data option +parser = argparse.ArgumentParser() +parser.add_argument("--data", help="Loads data", default='california') +parser.add_argument("--no-autoindent", help="", + action="store_true") # to avoid conflics in IPython +args = parser.parse_args() +data = args.data # load hyperparameters from JSON file with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file: @@ -42,6 +50,8 @@ gamma = conf_dict["gamma"] hidden_layers = conf_dict["hidden_layers"] seed_range = conf_dict['seed_range'] +print(f"Training on {long_dataname} data") + try: gpu_number = conf_dict["gpu_number"] device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py index 3d741b3..24d051f 100644 --- a/Experiments/train_noneiv.py +++ b/Experiments/train_noneiv.py @@ -4,6 +4,7 @@ Train non-EiV model using different seeds import random import importlib import os +import argparse import json import numpy as np @@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVTrainingRoutines import train_and_store, loss_functions -data = 'california' + +# read in data via --data option +parser = argparse.ArgumentParser() +parser.add_argument("--data", help="Loads data", default='california') +parser.add_argument("--no-autoindent", help="", + action="store_true") # to avoid conflics in IPython +args = parser.parse_args() +data = args.data # load hyperparameters from JSON file with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file: -- GitLab