Skip to content
Snippets Groups Projects
Commit 2a01272a authored by Nando Farchmin's avatar Nando Farchmin
Browse files

Add script to draw images

parent 5184039b
No related branches found
No related tags found
No related merge requests found
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
from scipy.stats import multivariate_normal
import neural_networks_101.src as src
def relu(x, slope=0):
ret = np.zeros(x.size)
idx = np.where(x >= 0)[0]
ret[idx] = x[idx]
idx = np.where(x < 0)[0]
ret[idx] = slope * x[idx]
return ret
def softmax(z):
return np.exp(z) / np.sum(np.exp(z))
def argmax(z):
ret = np.zeros(z.size)
idx = np.argmax(z)
ret[idx] = z[idx]
return ret
def maxpool(z):
ret = np.zeros((2, 2))
ret[0, 0] = np.max(z[:2, :2])
ret[0, 1] = np.max(z[:2, 2:])
ret[1, 0] = np.max(z[2:, :2])
ret[1, 1] = np.max(z[2:, 2:])
return ret
def spline_interp(x, y, res):
tck = interpolate.splrep(x, y, s=0, k=3)
x_new = np.linspace(np.min(x), np.max(x), res)
y_fit = interpolate.BSpline(*tck)(x_new)
return x_new, y_fit
def plot_relu(file_name):
x = np.linspace(-2, 2, 200)
with plt.xkcd():
fig = plt.figure()
plt.axvline(x=0, ls="--", color="k")
plt.plot(x, np.zeros(x.size), ls="--", color="k")
plt.plot(x, relu(x))
plt.savefig(file_name, dpi=200)
def plot_leaky_relu(file_name):
x = np.linspace(-2, 2, 200)
with plt.xkcd():
fig = plt.figure()
plt.axvline(x=0, ls="--", color="k")
plt.plot(x, np.zeros(x.size), ls="--", color="k")
plt.plot(x, relu(x, 0.1))
plt.savefig(file_name, dpi=200)
def plot_tanh(file_name):
x = np.linspace(-3, 3, 200)
with plt.xkcd():
fig = plt.figure()
plt.axvline(x=0, ls="--", color="k")
plt.plot(x, np.zeros(x.size), ls="--", color="k")
plt.plot(x, np.tanh(x))
plt.savefig(file_name, dpi=200)
def plot_softmax(file_name):
z = np.array([0.229, 0.070, 1.163, 1.826, 1.184,
1.311, 0.189, 0.200, 1.881, 0.738])
val = softmax(z)
with plt.xkcd():
fig = plt.figure()
plt.bar(range(1, z.size+1), val)
x, y = spline_interp(np.arange(1, z.size+1), val, 200)
plt.plot(x, y, c="k")
plt.savefig(file_name, dpi=200)
def plot_argmax(file_name):
z = np.array([0.229, 0.070, 1.163, 1.826, 1.184,
1.311, 0.189, 0.200, 1.881, 0.738])
with plt.xkcd():
fig = plt.figure()
plt.bar(range(1, z.size+1), z)
plt.bar(range(1, z.size+1), argmax(z), color="grey")
plt.savefig(file_name, dpi=200)
def plot_maxpool(file_name):
row1 = np.concatenate(
[np.random.uniform(0, 1, (2, 2)), np.random.uniform(1, 2, (2, 2))],
axis=1)
row2 = np.concatenate(
[np.random.uniform(2, 3, (2, 2)), np.random.uniform(3, 4, (2, 2))],
axis=1)
mat = np.concatenate([row1, row2], axis=0)
with plt.xkcd():
fig, ax = plt.subplot_mosaic([["big", "arrow", "small"]])
ax["big"].matshow(mat, cmap="Blues")
ax["arrow"].arrow(0.5, 1, 0.8, 0, head_width=0.2, width=0.05)
ax["arrow"].set_xlim(0, 2)
ax["arrow"].set_ylim(0, 2)
ax["arrow"].axis("off")
ax["small"].matshow(maxpool(mat), cmap="Blues")
plt.savefig(file_name, dpi=200)
def plot_venn(file_name):
with plt.xkcd():
fig, ax = plt.subplots()
ai = plt.Circle((0.5, 0.5), 0.45, color="k", fill=False)
ml = plt.Circle((0.6, 0.42), 0.3, color="k", fill=False)
nn = plt.Circle((0.5, 0.4), 0.15, color="k", fill=False)
ai_fill = plt.Circle((0.5, 0.5), 0.45, alpha=0.15)
ml_fill = plt.Circle((0.6, 0.42), 0.3, alpha=0.15)
nn_fill = plt.Circle((0.5, 0.4), 0.15, alpha=0.15)
ax.add_patch(ai_fill)
ax.add_patch(ml_fill)
ax.add_patch(nn_fill)
ax.add_patch(ai)
ax.add_patch(ml)
ax.add_patch(nn)
ax.text(0.2, 0.6, "AI", fontsize="xx-large")
ax.text(0.7, 0.5, "ML", fontsize="xx-large")
ax.text(0.4, 0.35, "NN", fontsize="xx-large")
ax.axis("off")
plt.savefig(file_name, dpi=200)
def landscape(xs):
mean1, cov1 = np.array([1, -1]), np.eye(2)
f1 = multivariate_normal(mean1, cov1)
ret = f1.pdf(xs)/f1.pdf(mean1)
mean2, cov2 = np.array([1, 1]), 0.2*np.array([[1, -.1], [.1, 1]])
f2 = multivariate_normal(mean2, cov2)
ret += f2.pdf(xs)/f2.pdf(mean2)
mean3, cov3 = np.array([-3.5, -1]), 2.0*np.array([[1, 0], [0, 5]])
f3 = multivariate_normal(mean3, cov3)
ret += f3.pdf(xs)/f3.pdf(mean3)
mean4, cov4 = np.array([0, 2]), 0.1*np.array([[1, 0], [0, 5]])
f4 = multivariate_normal(mean4, cov4)
ret += f4.pdf(xs)/f4.pdf(mean4)
mean5, cov5 = np.array([-1.5, 2]), 0.5*np.array([[3, -1], [1, 1]])
f5 = multivariate_normal(mean5, cov5)
ret += f5.pdf(xs)/f5.pdf(mean5)
return ret
def plot_sgd(file_name):
x = np.linspace(-3, 2, 50)
y = np.linspace(-1.5, 2.5, 50)
xs = src.misc.cart_prod([x, y])
fig, ax = plt.subplot_mosaic([["sgd"]])
ax["sgd"].contourf(x, x, landscape(xs).reshape(x.size, -1).T, cmap="Blues")
ax["sgd"].axis("off")
start_end = np.array([[-0.1, 1.7],
[-0.8, -1.4],
])
points1 = np.array([start_end[0],
[-0.6, 0.5],
start_end[-1],
])
points2 = np.array([start_end[0],
[-0.6, 1.6],
[-0.8, 1.7],
[-1.1, 1.4],
[-1.05, 1.2],
[-1.6, 0.7],
[-1.7, 0.4],
[-1.3, 0.0],
[-1.8, -0.3],
[-1.3, -0.8],
[-1.2, -1.4],
[-0.7, -2.4],
[-0.2, -1.6],
[-0.4, -1.0],
[-1.0, -1.2],
[-0.95, -1.5],
[-0.7, -1.6],
start_end[-1],
])
ax["sgd"].plot(points1[:, 0], points1[:, 1], "-o",
lw=4, color="k", ms=8)
ax["sgd"].plot(points1[:, 0], points1[:, 1], "-o",
lw=2, color="green")
ax["sgd"].plot(points2[:, 0], points2[:, 1], "-o",
lw=4, color="k", ms=8)
ax["sgd"].plot(points2[:, 0], points2[:, 1], "-o",
lw=2, color="orange")
ax["sgd"].plot(start_end[:, 0], start_end[:, 1], "o",
lw=4, color="k", ms=8)
ax["sgd"].plot(start_end[:, 0], start_end[:, 1], "o",
lw=2, color="w")
# color1, edgecolor1 = "orange", "k"
# ax["sgd"].arrow(x=-0.05, y=1.7, dx=-0.3, dy=-0.5, width=0.04,
# facecolor=color1, edgecolor=edgecolor1)
# ax["sgd"].arrow(x=-0.45, y=0.99, dx=-0.2, dy=-1.2, width=0.04,
# facecolor=color1, edgecolor=edgecolor1)
# color2, edgecolor2 = "green", "k"
# ax["sgd"].arrow(x=-0.05, y=1.7, dx=-0.4, dy=-0.08, width=0.04,
# facecolor=color2, edgecolor=edgecolor2)
# ax["sgd"].arrow(x=-0.05, y=1.7, dx=-0.4, dy=-0.08, width=0.04,
# facecolor=color2, edgecolor=edgecolor2)
plt.savefig(file_name, dpi=200)
def main():
"""Main."""
# plot_venn("./venn.png")
# plot_relu("./relu.png")
# plot_leaky_relu("./leaky_relu.png")
# plot_tanh("./tanh.png")
# plot_argmax("./argmax.png")
# plot_softmax("./softmax.png")
# plot_maxpool("./maxpool_tmp.png")
plot_sgd("./sgd.png")
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment