diff --git a/Experiments/train_eiv.py b/Experiments/train_eiv.py index 5c9032bbc5af5515a06453548370b7fbb9c713da..78f6d61d5f0af1d0e29de0f10f99275db40f42c4 100644 --- a/Experiments/train_eiv.py +++ b/Experiments/train_eiv.py @@ -4,6 +4,7 @@ Train EiV model using different seeds import random import importlib import os +import argparse import json import numpy as np @@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVTrainingRoutines import train_and_store, loss_functions -data = 'california' + +# read in data via --data option +parser = argparse.ArgumentParser() +parser.add_argument("--data", help="Loads data", default='california') +parser.add_argument("--no-autoindent", help="", + action="store_true") # to avoid conflics in IPython +args = parser.parse_args() +data = args.data # load hyperparameters from JSON file with open(os.path.join('configurations',f'eiv_{data}.json'),'r') as conf_file: @@ -42,6 +50,8 @@ gamma = conf_dict["gamma"] hidden_layers = conf_dict["hidden_layers"] seed_range = conf_dict['seed_range'] +print(f"Training on {long_dataname} data") + try: gpu_number = conf_dict["gpu_number"] device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu') diff --git a/Experiments/train_noneiv.py b/Experiments/train_noneiv.py index 3d741b30e632f8c7232da70c3eb91dbfb4210734..24d051fc8f72971c5b3501b1221e732894b0c1c7 100644 --- a/Experiments/train_noneiv.py +++ b/Experiments/train_noneiv.py @@ -4,6 +4,7 @@ Train non-EiV model using different seeds import random import importlib import os +import argparse import json import numpy as np @@ -15,7 +16,14 @@ from torch.utils.tensorboard.writer import SummaryWriter from EIVArchitectures import Networks, initialize_weights from EIVTrainingRoutines import train_and_store, loss_functions -data = 'california' + +# read in data via --data option +parser = argparse.ArgumentParser() +parser.add_argument("--data", help="Loads data", default='california') +parser.add_argument("--no-autoindent", help="", + action="store_true") # to avoid conflics in IPython +args = parser.parse_args() +data = args.data # load hyperparameters from JSON file with open(os.path.join('configurations',f'noneiv_{data}.json'),'r') as conf_file: