Commit f70e820b authored by Manuel Marschall's avatar Manuel Marschall
Browse files

remove old stuff

parent b7edb10c
# %%
import tensorflow as tf
from utils import (reconstruction_problem, load_vae_models, plot_recon, savefig, plot_z_coverage,
MSE, SSIM, PSNR)
import math
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# tf.config.run_functions_eagerly(True)
vae = None
dec = None
def matvecmul(A, b):
return tf.reshape(tf.matmul(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def matvecsolve(A, b):
return tf.reshape(tf.linalg.solve(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def tf_cond(x):
x = tf.convert_to_tensor(x)
s = tf.linalg.svd(x, compute_uv=False)
r = s[..., 0] / s[..., -1]
# Replace NaNs in r with infinite unless there were NaNs before
x_nan = tf.reduce_any(tf.math.is_nan(x), axis=(-2, -1))
r_nan = tf.math.is_nan(r)
r_inf = tf.fill(tf.shape(r), tf.constant(math.inf, r.dtype))
tf.where(x_nan, r, tf.where(r_nan, r_inf, r))
return r
# %%
@tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),))
def get_decoder_properties(z0):
# z0 : (20)
global vae
g, gamma = vae.decoder(z0) # (20) -> (1, 28, 28), (1, 28, 28)
g = tf.reshape(g, [-1]) # (1, 28, 28) -> (28*28)
gamma = tf.reshape(gamma, [-1]) # (1, 28, 28) -> (28*28)
# gamma = gamma*gamma
jac = vae.J_z0_already_variable(z0) # (20) -> (1, 28, 28, 20, 1, 1)
jac = tf.reshape(tf.squeeze(jac), [28*28, 20]) # (28, 28, 20) -> (28*28, 20)
# # Sigma_z^-1 -> (20, 20)
sigma_z_inv = tf.eye(20, dtype=tf.double) + tf.matmul(tf.transpose(jac), tf.reshape(1/gamma, [-1, 1])*jac)
retval = {
"g": g,
"gamma": gamma,
"jac": jac,
"sigma_z_inv": sigma_z_inv
}
return retval
# @tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),
# tf.TensorSpec(shape=[28*28], dtype=tf.float32)))
def update_z0(z0, x0):
# z0 : (20)
# x : (28*28)
global dec
# print(f"Global iteration in decoder: {dec['it']}")
rhs = x0 - dec["g"] + matvecmul(dec["jac"], z0) # x-g+Jz0 : (28*28)
rhs = (1/dec["gamma"])*rhs
rhs = matvecmul(tf.transpose(dec["jac"]), rhs)
# cond = tf_cond(dec["sigma_z_inv"])
# print(f"Condition number of sigma z inv: {cond}")
retval = matvecsolve(dec["sigma_z_inv"], rhs)
return tf.cast(retval, dtype=tf.float32)
# This influences the reconstruction quality... I do not know why
# @tf.function(input_signature=(tf.TensorSpec(shape=[28*28, 28*28], dtype=tf.double),
# tf.TensorSpec(shape=[28*28], dtype=tf.double),
# tf.TensorSpec(shape=[20], dtype=tf.float32)))
def update_x0(A, y, z0):
global dec
regularization = 0
S1 = tf.linalg.diag(1/dec["gamma"]) # G^-1
S2 = (1/(dec["sigma"]*dec["sigma"]))*tf.matmul(tf.transpose(A), A) # s^-2 A'A
# # S'^-1 J'
S3 = tf.linalg.solve(tf.transpose(dec["sigma_z_inv"]), tf.transpose(dec["jac"]))
S3 = tf.matmul(dec["jac"], S3) # J S'^-1 J'
S3 = tf.matmul(tf.matmul(tf.transpose(S1), S3), S1) # G^-T J S'^-1 J' G^-1
# TODO: regularize this matrix
hatS_inv = S1 + S2 + S3 + regularization*tf.eye(28*28, dtype=tf.double)
hatx = hatS_inv - S2
hatx = matvecmul(hatx, dec["g"] - matvecmul(dec["jac"], z0))
hatx = (1/(dec["sigma"]*dec["sigma"]))*matvecmul(tf.transpose(A), y) + hatx
# cond = tf_cond(hatS_inv)
# print(f"Condition number of hat S inv: {cond}")
return matvecsolve(hatS_inv, hatx)
@tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),
tf.TensorSpec(shape=[28*28], dtype=tf.double)))
def log_prod_prior(z0, x0):
global dec
ztz = tf.cast(tf.reduce_sum(z0*z0), dtype=tf.double)
xgtxg = tf.reduce_sum((x0 - dec["g"])*(1/dec["gamma"])*(x0 - dec["g"]))
# p = 20
# n = 28*28
# return -0.5*(p+n)*log(2*pi) + 0.5*n*log(lambda_d[0]) - 0.5*(ztz + lambda_d[0]*xgtxg)
return (-0.5*ztz - 0.5*tf.reduce_sum(tf.math.log(dec["gamma"])) - 0.5*xgtxg)
@tf.function(input_signature=(tf.TensorSpec(shape=[28*28, 28*28], dtype=tf.double),
tf.TensorSpec(shape=[20], dtype=tf.float32),
tf.TensorSpec(shape=[28*28], dtype=tf.double),
tf.TensorSpec(shape=[28*28], dtype=tf.double)))
def log_post(A, z0, y, x0):
hatz = update_z0(z0, x0)
Ax_y = matvecmul(A, x0) - y
T1 = (1/(dec["sigma"]*dec["sigma"]))*tf.reduce_sum(Ax_y*Ax_y)
# print(f"T1: {T1}")
x_g_Jz = x0 - dec["g"] - matvecmul(dec["jac"], hatz-z0)
T2 = tf.reduce_sum(x_g_Jz*(1/dec["gamma"])*x_g_Jz)
# print(f"T2: {T2}")
T3 = tf.cast(tf.reduce_sum(hatz*hatz), dtype=tf.double)
# print(f"T3: {T3}")
return -0.5*(T1+T2+T3)
def analytic_posterior(A, x, y, sigma, path) -> None:
# ## define optimization problem
# A, x, y, sigma
# ## input:
# Encoder x -> z (bs, 28 28) ->(bs, 10, 40) 40 = mean + sd
# Decoder z -> x (bs, 20) -> (bs, 2, 28, 28) 2 = mean x sd
# Jacobian z -> Jx (bs, 20) -> (bs, 28, 28, 1, 20)
global vae
global dec
num_x0_iteration = 20
oracle_z0 = tf.squeeze(vae.encoder(x)[0, :20]).numpy()
# Initial value of z0
# z0 = vae.encoder(x)[0, :20] # (28, 28) -> (1, 40) -> (20)
# z0 = vae.encoder(y)[0, :20] # (28, 28) -> (1, 40) -> (20)
z0 = tf.zeros(20)
# Initial value of x0 ~ N(g(z0), Gamma(z0))
# We can also take the mean g(z0) here but sampling yields different initial values each time
x0 = vae.sample_decoder(z0)
dec = get_decoder_properties(z0)
dec["sigma"] = sigma
y = tf.reshape(y, [-1])
# Plot the initial configuration
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, "image_0.png")
fig = plot_z_coverage(vae, x0, oracle_z0)
savefig(fig, path, f"z0values_{0}.png")
# #######
for it in range(num_x0_iteration):
x0 = update_x0(A, y, z0)
z0 = vae.encoder(tf.reshape(x0, [28, 28]))[0, :20]
dec = get_decoder_properties(z0)
dec["sigma"] = sigma
# plot results of current iteration
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, f"image_{it+1}.png")
fig = plot_z_coverage(vae, x0, oracle_z0)
savefig(fig, path, f"z0values_{it+1}.png")
# elif optimization == "numeric":
# opt = tf.keras.optimizers.Nadam()
# # x0 = tf.Variable(update_x0(A, y, best_z0))
# for lia in range(num_x0_iteration):
# def neg_log_post():
# _retval = -log_post(A, best_z0, y, x0)
# print(f"loss: {_retval}")
# return _retval
# opt.minimize(neg_log_post, var_list=[x0])
# if lia % 1000 == 0:
# z0 = vae.encoder(tf.reshape(x0, [28, 28]))[0, :20]
# fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
# savefig(fig, "probabilistic_vae_comparison/vae_quality/good/numerically_recon/",
# f"z{it}_image{lia}.png")
retval = {
"MSE": MSE(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28])),
"PSNR": PSNR(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28])),
"SSIM": SSIM(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28]))
}
return retval
def reference(A, x, y, sigma, path):
def ell_loss(ellh):
ret = (1/(sigma*sigma))*tf.linalg.norm(matvecmul(A, tf.reshape(x, [-1])) - tf.reshape(y, [-1]))**2 + tf.abs(ellh)*tf.linalg.norm(tf.reshape(x, [-1]))**2
return ret
ell = tf.Variable(1.0, dtype=tf.double)
loss = lambda: ell_loss(ell)
opt = tf.keras.optimizers.Adam()
best_loss = 1e18
best_ell = ell
iteration = 0
while True:
iteration += 1
# Create an optimizer with the desired parameters.
# In eager mode, simply call minimize to update the list of variables.
opt.minimize(loss, var_list=[ell])
if loss() < best_loss and ell > 0:
print(f"it: {iteration} best ell: {ell.numpy():.2g} with loss: {loss():.2f}")
best_loss = loss().numpy()
best_ell = ell.numpy()
if iteration > 2000:
break
# ########
print(f"best lambda 'oracle': {best_ell}")
x0 = matvecsolve(tf.matmul(tf.transpose(A), A) + best_ell*tf.eye(28*28, dtype=tf.double), matvecmul(tf.transpose(A), tf.reshape(y, [-1])))
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, "reference_L2.png")
# %%
A, x, y = reconstruction_problem(plot=False, sigma=4)
sigma = 1e-2
y = tf.squeeze(y + tf.random.normal(y.shape, mean=0, stddev=sigma))
_, vae_good, _ = load_vae_models()
vae = vae_good
input_dict = {
"A": tf.cast(A.real, dtype=tf.double),
"x": tf.cast(tf.reshape(x, [28, 28]), dtype=tf.double),
"y": tf.cast(y, dtype=tf.double),
"sigma": sigma,
"path": "probabilistic_vae_comparison/vae_quality/good/analytic_recon/"
}
# reference(**input_dict)
analytic_posterior(**input_dict)
# %%
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment