Skip to content
Snippets Groups Projects
Commit feed4ac9 authored by Jörg Martin's avatar Jörg Martin
Browse files

Updated training parameters for simulated data

parent b0744268
Branches
Tags
No related merge requests found
Showing with 21 additions and 23 deletions
...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset ...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
unnormalize_tensor unnormalize_tensor
total_number_of_datapoints = 2000 total_number_of_datapoints = 500
input_range = [-4,4] input_range = [-4,4]
slope = 1.0 slope = 1.0
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.05 * (input_range[1] - input_range[0])/2 x_noise_strength = 0.2
y_noise_strength = 3 y_noise_strength = 3
func = lambda true_x: slope * true_x**3 + intercept func = lambda true_x: slope * true_x**3 + intercept
......
...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset ...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
unnormalize_tensor unnormalize_tensor
total_number_of_datapoints = 2000 total_number_of_datapoints = 500
input_range = [-1,1] input_range = [-1,1]
slope = 1.0 slope = 1.0
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.05 x_noise_strength = 0.1
y_noise_strength = 0.1 y_noise_strength = 0.1
func = lambda true_x: slope * true_x + intercept func = lambda true_x: slope * true_x + intercept
......
...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset ...@@ -5,11 +5,11 @@ from torch.utils.data import TensorDataset
from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\ from EIVGeneral.manipulate_tensors import add_noise, normalize_tensor,\
unnormalize_tensor unnormalize_tensor
total_number_of_datapoints = 2000 total_number_of_datapoints = 500
input_range = [-1,1] input_range = [-1,1]
slope = 1.0 slope = 1.0
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.05 x_noise_strength = 0.1
y_noise_strength = 0.1 y_noise_strength = 0.1
func = lambda true_x: slope * true_x**2 + intercept func = lambda true_x: slope * true_x**2 + intercept
......
...@@ -9,7 +9,7 @@ total_number_of_datapoints = 2000 ...@@ -9,7 +9,7 @@ total_number_of_datapoints = 2000
input_range = [-0.2,0.8] input_range = [-0.2,0.8]
intercept = 0.0 intercept = 0.0
x_noise_strength = 0.02 x_noise_strength = 0.02
y_noise_strength = 0.05 y_noise_strength = 0.02
func = lambda true_x: true_x +\ func = lambda true_x: true_x +\
torch.sin(2 * torch.pi * true_x) +\ torch.sin(2 * torch.pi * true_x) +\
torch.sin(4 * torch.pi * true_x) torch.sin(4 * torch.pi * true_x)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "cubic", "short_dataname": "cubic",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"init_std_y_list": [0.5], "init_std_y_list": [0.5],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.05, "fixed_std_x": 0.20,
"seed_range": [0,10], "seed_range": [0,10],
"gpu_number": 1 "gpu_number": 1
} }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "linear", "short_dataname": "linear",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"init_std_y_list": [0.5], "init_std_y_list": [0.5],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.05, "fixed_std_x": 0.10,
"seed_range": [0,10], "seed_range": [0,10],
"gpu_number": 1 "gpu_number": 1
} }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "quadratic", "short_dataname": "quadratic",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"init_std_y_list": [0.5], "init_std_y_list": [0.5],
"gamma": 0.5, "gamma": 0.5,
"hidden_layers": [128, 128, 128, 128], "hidden_layers": [128, 128, 128, 128],
"fixed_std_x": 0.05, "fixed_std_x": 0.10,
"seed_range": [0,10], "seed_range": [0,10],
"gpu_number": 1 "gpu_number": 1
} }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "sine", "short_dataname": "sine",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "cubic", "short_dataname": "cubic",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "linear", "short_dataname": "linear",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "quadratic", "short_dataname": "quadratic",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"short_dataname": "sine", "short_dataname": "sine",
"normalize": false, "normalize": false,
"lr": 1e-3, "lr": 1e-3,
"batch_size": 64, "batch_size": 16,
"test_batch_size": 800, "test_batch_size": 800,
"number_of_epochs": 100, "number_of_epochs": 100,
"unscaled_reg": 10, "unscaled_reg": 10,
......
...@@ -216,9 +216,8 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws, ...@@ -216,9 +216,8 @@ def compute_predictions_and_uncertainties(data, x_range, eiv, number_of_draws,
return plotting_dictionary return plotting_dictionary
data_list = ['cubic'] # short datanames data_list = ['sine'] # short datanames
#TODO: Check which ranges are "correct" list_x_range = [torch.linspace(-0.3,0.9, 50)]
list_x_range = [torch.linspace(-2.5,2.5, 50)]
list_color = [('red','blue')] list_color = [('red','blue')]
list_number_of_draws = [((100,5), 100)] list_number_of_draws = [((100,5), 100)]
for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
...@@ -258,5 +257,4 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list, ...@@ -258,5 +257,4 @@ for i, (data, x_range, color, number_of_draws) in enumerate(zip(data_list,
pass pass
plt.show() plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment