Skip to content
Snippets Groups Projects
Commit b26d4fee authored by Nando Farchmin's avatar Nando Farchmin
Browse files

Update function approximation scripts

parent 417dfd1b
Branches
No related tags found
No related merge requests found
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from src.misc import time_stamp, timeit
from src import target_function
import matplotlib.pyplot as plt
import neural_networks_101.src as src
def main() -> None:
print(time_stamp(), "Initialize main file")
with timeit("create x_train data ({:4.2f} s)"):
x_train = np.random.uniform(0, 1, (100000, 2))
with timeit("create y_train data ({:4.2f} s)"):
y_train = target_function.sin2d(x_train)
plt.figure()
plt.hexbin(x_train[:, 0], x_train[:, 1], y_train, gridsize=50)
plt.show()
# Get CPU or GPU device for training
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# generate samples
print(src.misc.time_stamp(), "generate training and test data")
x_train = np.random.uniform(0, 1, (100000, 2))
x_test = np.random.uniform(0, 1, (10000, 2))
y_train = src.target_function.sin2d(x_train).reshape(-1, 1)
y_test = src.target_function.sin2d(x_test).reshape(-1, 1)
# define model, loss and optimization algorithm
model = src.approximation.NeuralNetwork(
x_train.shape[1], y_train.shape[1], width=1024).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-02)
loss_function = torch.nn.MSELoss(reduction="mean")
n_epochs = 5
for epoch in range(n_epochs):
with src.misc.timeit("time: {:4.2f} s"):
src.approximation.train(
model, device, x_train, y_train, loss_function, optimizer, log_interval=100)
test_loss = src.approximation.test(
model, device, x_test, y_test, loss_function)
print(src.misc.time_stamp(), f"test set avg. loss: {test_loss}")
if not os.path.isdir("../img"):
os.makedirs("../img", exist_ok=True)
src.approximation.plot_function_approximation(
model, device, src.target_function.sin2d, figsize=(15, 5))
file_name = "../img/function_approximation.png"
plt.savefig(file_name, dpi=200)
print(src.misc.time_stamp(), f"save to: {file_name}")
if __name__ == "__main__":
......
This diff is collapsed.
......@@ -168,3 +168,29 @@ def evaluate(model: NeuralNetwork,
model.eval()
xs = torch.from_numpy(xs).type(torch.float32).to(device)
return model(xs).detach().numpy()
def plot_function_approximation(model, device, target, **kwargs):
""" Plot function approximation error.
Parameters
----------
model : NeuralNetwork
Neural network model.
device : torch.device
Hardware to train the model on.
target : Callable
Target function.
"""
fig, axes = plt.subplot_mosaic([["true", "approx", "error"]], **kwargs)
x = np.linspace(0, 1, 50)
X, Y = np.meshgrid(x, x)
xx = np.concatenate([X.reshape(-1, 1), Y.reshape(-1, 1)], axis=1)
f_val = target(xx).reshape(x.size, -1)
f_nn = evaluate(model, device, xx).reshape(x.size, -1)
im = axes["true"].contourf(x, x, f_val)
plt.colorbar(im, ax=axes["true"])
im = axes["approx"].contourf(x, x, f_nn)
plt.colorbar(im, ax=axes["approx"])
im = axes["error"].contourf(x, x, f_val-f_nn)
plt.colorbar(im, ax=axes["error"])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment