Commit 52b21c8c authored by Manuel Marschall's avatar Manuel Marschall
Browse files

initial commit

parent eb8c5742
from abc import abstractmethod
from onnx_vae import ONNX_VAE_DET, ONNX_VAE_STO
from tf_vae import TF_VAE
from numpy import ndarray as array_type
from typing import Optional
class GenerativeModel():
@abstractmethod
def __init__(self, dim_x: int, dim_z: int) -> None:
if dim_x < dim_z:
raise ValueError("Is this really a generative model if dim Z > dim X?")
self.dim_x = dim_x
self.dim_z = dim_z
self.vae = None # type: Optional[TF_VAE]
def __call__(self, z: array_type) -> array_type:
if self.vae is None:
raise NotImplementedError("This function is not implemented without knowledge "
"about the generative model")
return self.vae.decoder(z)
class DeterministicGenerator(GenerativeModel):
def __init__(self, dim_x: int, dim_z: int) -> None:
super().__init__(dim_x, dim_z)
@staticmethod
def from_vae(vae: TF_VAE,
dim_x: int,
dim_z: int) -> GenerativeModel:
retval = DeterministicGenerator(dim_x, dim_z)
retval.vae = vae
return retval
class ProbabilisticGenerator(GenerativeModel):
def __init__(self, dim_x: int, dim_z: int) -> None:
super().__init__(dim_x, dim_z)
@staticmethod
def from_vae(vae: TF_VAE,
dim_x: int,
dim_z: int) -> GenerativeModel:
retval = ProbabilisticGenerator(dim_x, dim_z)
retval.vae = vae
return retval
if __name__ == "__main__":
path_det = "onnx_models/deterministic/"
path_prob = "onnx_models/stochastic_good/"
path_add = "probabilistic_vae_comparison/"
onnx_vae1 = ONNX_VAE_DET(path_add + path_det + "encoder.onnx",
path_add + path_det + "decoder.onnx")
tf_vae_det = onnx_vae1.to_tensorflow()
det_generator = DeterministicGenerator.from_vae(tf_vae_det, int(28*28), 20)
onnx_vae2 = ONNX_VAE_STO(path_add + path_prob + "good_probVAE_encoder.onnx",
path_add + path_prob + "good_probVAE_decoder.onnx")
tf_vae_good = onnx_vae2.to_tensorflow()
prop_generator = ProbabilisticGenerator.from_vae(tf_vae_good, int(28*28), 20)
import numpy as np
z = np.random.randn(10, 20)
x = det_generator(z)
assert(x.shape == (10, 28, 28))
m, s = prop_generator(z)
assert(m.shape == (10, 28, 28))
assert(s.shape == (10, 28, 28))
print("Successfully tested: Generative models")
from numpy import ndarray as array_type
from typing import Optional
from numpy.linalg import norm
from numpy.testing._private.utils import assert_almost_equal
class LinearInverseProblem():
def __init__(self,
operator: array_type,
data: array_type,
sigma: float,
ground_truth: Optional[array_type] = None,
) -> None:
assert len(operator.shape) == 2 # A is a matrix
assert len(data.shape) == 1 # y is a vector
if ground_truth is not None:
assert operator.shape[1] == ground_truth.shape[0] # A times x is defined
assert operator.shape[0] == data.shape[0] # A times x is of dimension of y
assert sigma > 0
self.operator = operator
self.ground_truth = ground_truth
self.data = data
self.sigma = sigma
def apply_operator(self, item: array_type) -> array_type:
return self.operator.dot(item)
def res(self, item: array_type) -> float:
return self.sigma**(-2) * norm(self.apply_operator(item) - self.data)**2
@staticmethod
def mnist_recon_problem(blur_sigma: int = 4,
noise_sigma: float = 0.1):
from utils import reconstruction_problem
A, x, y = reconstruction_problem(sigma=blur_sigma)
return LinearInverseProblem(A, y.reshape(-1), noise_sigma, x.numpy().reshape(-1))
if __name__ == "__main__":
from utils import reconstruction_problem
A, x, y = reconstruction_problem()
Lip = LinearInverseProblem(A, y.ravel(), 0.1, x.numpy().ravel())
import numpy as np
item = np.random.randn(28*28)
lhs = Lip.apply_operator(item)
assert_almost_equal(np.dot(A, item), lhs, err_msg="Operator application failed")
lhs = Lip.res(item)
assert_almost_equal(100 * np.linalg.norm(np.dot(A, item) - y.reshape(-1))**2, lhs,
err_msg="Residual computation failed")
Lip2 = LinearInverseProblem.mnist_recon_problem()
assert_almost_equal(Lip.operator, Lip2.operator)
assert_almost_equal(Lip.data, Lip2.data)
assert_almost_equal(Lip.sigma, Lip2.sigma)
assert_almost_equal(Lip.ground_truth, Lip2.ground_truth)
print("Sucessfully tested: LinearInverseProblem")
from abc import ABC, abstractmethod
import os
import numpy as np
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from tf_vae import TF_VAE, TF_VAE_DET, TF_VAE_BAD
class ONNX_VAE(ABC):
def __init__(self,
path_encoder: str,
path_decoder: str) -> None:
self.path_encoder = path_encoder
self.path_decoder = path_decoder
encoder = onnx.load(path_encoder)
decoder = onnx.load(path_decoder)
self.tf_encoder = prepare(encoder)
self.tf_decoder = prepare(decoder)
self.tf_enc_inp = self.tf_encoder.inputs[0]
self.tf_dec_inp = self.tf_decoder.inputs[0]
@staticmethod
def parse_z_sample(z) -> np.ndarray:
z = np.array(z)
if len(z.shape) == 1:
curr_z = np.reshape(z, [1, z.shape[0], 1, 1]).astype("float32")
elif len(z.shape) == 2:
curr_z = np.reshape(z, [z.shape[0], z.shape[1], 1, 1]).astype("float32")
else:
raise ValueError(f"Unknown shape of latent variable: {z.shape}")
return curr_z
@abstractmethod
def parse_decoder_output(self,
x_dec: np.ndarray) -> np.ndarray:
pass
def decoder(self,
z: np.ndarray) -> np.ndarray:
curr_z = self.parse_z_sample(z)
x_dec = self.tf_decoder.run({self.tf_dec_inp: curr_z}) # returns a tuple
x_dec = self.parse_decoder_output(x_dec[0])
return x_dec
def encoder_sample(self,
x: np.ndarray) -> np.ndarray:
pass
def to_tensorflow_graph(self) -> None:
self.tf_encoder.export_graph(self.path_encoder + ".tf")
self.tf_decoder.export_graph(self.path_decoder + ".tf")
def to_tensorflow(self) -> TF_VAE:
if not os.path.exists(self.path_encoder + ".tf"):
self.to_tensorflow_graph()
return TF_VAE(self.path_encoder + ".tf", self.path_decoder + ".tf")
class ONNX_VAE_DET(ONNX_VAE):
def __init__(self,
path_encoder: str,
path_decoder: str) -> None:
super().__init__(path_encoder, path_decoder)
def parse_decoder_output(self,
x: np.ndarray) -> np.ndarray:
assert len(x.shape) == 4
return np.squeeze(x[:, 0, :, :])
def to_tensorflow(self) -> TF_VAE:
if not os.path.exists(self.path_encoder + ".tf"):
self.to_tensorflow_graph()
return TF_VAE_DET(self.path_encoder + ".tf", self.path_decoder + ".tf")
class ONNX_VAE_STO(ONNX_VAE):
def __init__(self,
path_encoder: str,
path_decoder: str) -> None:
super().__init__(path_encoder, path_decoder)
def parse_decoder_output(self,
x: np.ndarray) -> np.ndarray:
assert len(x.shape) == 4
assert x.shape[1] == 2
return np.squeeze(x[:, 0, :, :]), np.squeeze(x[:, 1, :, :])
class ONNX_VAE_BAD(ONNX_VAE_STO):
def __init__(self,
path_encoder: str,
path_decoder: str) -> None:
super().__init__(path_encoder, path_decoder)
def to_tensorflow(self) -> TF_VAE_BAD:
if not os.path.exists(self.path_encoder + ".tf"):
self.to_tensorflow_graph()
return TF_VAE_BAD(self.path_encoder + ".tf", self.path_decoder + ".tf")
if __name__ == '__main__':
path = "probabilistic_vae_comparison/onnx_models/deterministic/"
z = np.random.randn(10, 20)
vae_det = ONNX_VAE_DET(path + "encoder.onnx", path + "decoder.onnx")
x = vae_det.decoder(z)
assert(x.shape == (10, 28, 28))
vae_sto = ONNX_VAE_STO(path + "encoder.onnx", path + "decoder.onnx")
m, s = vae_sto.decoder(z)
assert(m.shape == (10, 28, 28))
assert(s.shape == (10, 28, 28))
tf_vae = vae_sto.to_tensorflow()
z0 = tf.Variable(tf.random.normal([10, 20, 1, 1]))
grad = tf_vae.J_z0_already_variable(z0)
assert grad.shape == (10, 1, 28, 28, 20, 1, 1)
print("DET and STO VAE decoder checked")
# %%
import tensorflow as tf
from utils import (reconstruction_problem, load_vae_models, plot_recon, savefig, plot_z_coverage,
MSE, SSIM, PSNR)
import math
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# tf.config.run_functions_eagerly(True)
vae = None
dec = None
def matvecmul(A, b):
return tf.reshape(tf.matmul(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def matvecsolve(A, b):
return tf.reshape(tf.linalg.solve(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def tf_cond(x):
x = tf.convert_to_tensor(x)
s = tf.linalg.svd(x, compute_uv=False)
r = s[..., 0] / s[..., -1]
# Replace NaNs in r with infinite unless there were NaNs before
x_nan = tf.reduce_any(tf.math.is_nan(x), axis=(-2, -1))
r_nan = tf.math.is_nan(r)
r_inf = tf.fill(tf.shape(r), tf.constant(math.inf, r.dtype))
tf.where(x_nan, r, tf.where(r_nan, r_inf, r))
return r
# %%
@tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),))
def get_decoder_properties(z0):
# z0 : (20)
global vae
g, gamma = vae.decoder(z0) # (20) -> (1, 28, 28), (1, 28, 28)
g = tf.reshape(g, [-1]) # (1, 28, 28) -> (28*28)
gamma = tf.reshape(gamma, [-1]) # (1, 28, 28) -> (28*28)
# gamma = gamma*gamma
jac = vae.J_z0_already_variable(z0) # (20) -> (1, 28, 28, 20, 1, 1)
jac = tf.reshape(tf.squeeze(jac), [28*28, 20]) # (28, 28, 20) -> (28*28, 20)
# # Sigma_z^-1 -> (20, 20)
sigma_z_inv = tf.eye(20, dtype=tf.double) + tf.matmul(tf.transpose(jac), tf.reshape(1/gamma, [-1, 1])*jac)
retval = {
"g": g,
"gamma": gamma,
"jac": jac,
"sigma_z_inv": sigma_z_inv
}
return retval
# @tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),
# tf.TensorSpec(shape=[28*28], dtype=tf.float32)))
def update_z0(z0, x0):
# z0 : (20)
# x : (28*28)
global dec
# print(f"Global iteration in decoder: {dec['it']}")
rhs = x0 - dec["g"] + matvecmul(dec["jac"], z0) # x-g+Jz0 : (28*28)
rhs = (1/dec["gamma"])*rhs
rhs = matvecmul(tf.transpose(dec["jac"]), rhs)
# cond = tf_cond(dec["sigma_z_inv"])
# print(f"Condition number of sigma z inv: {cond}")
retval = matvecsolve(dec["sigma_z_inv"], rhs)
return tf.cast(retval, dtype=tf.float32)
# This influences the reconstruction quality... I do not know why
# @tf.function(input_signature=(tf.TensorSpec(shape=[28*28, 28*28], dtype=tf.double),
# tf.TensorSpec(shape=[28*28], dtype=tf.double),
# tf.TensorSpec(shape=[20], dtype=tf.float32)))
def update_x0(A, y, z0):
global dec
regularization = 0
S1 = tf.linalg.diag(1/dec["gamma"]) # G^-1
S2 = (1/(dec["sigma"]*dec["sigma"]))*tf.matmul(tf.transpose(A), A) # s^-2 A'A
# # S'^-1 J'
S3 = tf.linalg.solve(tf.transpose(dec["sigma_z_inv"]), tf.transpose(dec["jac"]))
S3 = tf.matmul(dec["jac"], S3) # J S'^-1 J'
S3 = tf.matmul(tf.matmul(tf.transpose(S1), S3), S1) # G^-T J S'^-1 J' G^-1
# TODO: regularize this matrix
hatS_inv = S1 + S2 + S3 + regularization*tf.eye(28*28, dtype=tf.double)
hatx = hatS_inv - S2
hatx = matvecmul(hatx, dec["g"] - matvecmul(dec["jac"], z0))
hatx = (1/(dec["sigma"]*dec["sigma"]))*matvecmul(tf.transpose(A), y) + hatx
# cond = tf_cond(hatS_inv)
# print(f"Condition number of hat S inv: {cond}")
return matvecsolve(hatS_inv, hatx)
@tf.function(input_signature=(tf.TensorSpec(shape=[20], dtype=tf.float32),
tf.TensorSpec(shape=[28*28], dtype=tf.double)))
def log_prod_prior(z0, x0):
global dec
ztz = tf.cast(tf.reduce_sum(z0*z0), dtype=tf.double)
xgtxg = tf.reduce_sum((x0 - dec["g"])*(1/dec["gamma"])*(x0 - dec["g"]))
# p = 20
# n = 28*28
# return -0.5*(p+n)*log(2*pi) + 0.5*n*log(lambda_d[0]) - 0.5*(ztz + lambda_d[0]*xgtxg)
return (-0.5*ztz - 0.5*tf.reduce_sum(tf.math.log(dec["gamma"])) - 0.5*xgtxg)
@tf.function(input_signature=(tf.TensorSpec(shape=[28*28, 28*28], dtype=tf.double),
tf.TensorSpec(shape=[20], dtype=tf.float32),
tf.TensorSpec(shape=[28*28], dtype=tf.double),
tf.TensorSpec(shape=[28*28], dtype=tf.double)))
def log_post(A, z0, y, x0):
hatz = update_z0(z0, x0)
Ax_y = matvecmul(A, x0) - y
T1 = (1/(dec["sigma"]*dec["sigma"]))*tf.reduce_sum(Ax_y*Ax_y)
# print(f"T1: {T1}")
x_g_Jz = x0 - dec["g"] - matvecmul(dec["jac"], hatz-z0)
T2 = tf.reduce_sum(x_g_Jz*(1/dec["gamma"])*x_g_Jz)
# print(f"T2: {T2}")
T3 = tf.cast(tf.reduce_sum(hatz*hatz), dtype=tf.double)
# print(f"T3: {T3}")
return -0.5*(T1+T2+T3)
def analytic_posterior(A, x, y, sigma, path) -> None:
# ## define optimization problem
# A, x, y, sigma
# ## input:
# Encoder x -> z (bs, 28 28) ->(bs, 10, 40) 40 = mean + sd
# Decoder z -> x (bs, 20) -> (bs, 2, 28, 28) 2 = mean x sd
# Jacobian z -> Jx (bs, 20) -> (bs, 28, 28, 1, 20)
global vae
global dec
num_x0_iteration = 20
oracle_z0 = tf.squeeze(vae.encoder(x)[0, :20]).numpy()
# Initial value of z0
# z0 = vae.encoder(x)[0, :20] # (28, 28) -> (1, 40) -> (20)
# z0 = vae.encoder(y)[0, :20] # (28, 28) -> (1, 40) -> (20)
z0 = tf.zeros(20)
# Initial value of x0 ~ N(g(z0), Gamma(z0))
# We can also take the mean g(z0) here but sampling yields different initial values each time
x0 = vae.sample_decoder(z0)
dec = get_decoder_properties(z0)
dec["sigma"] = sigma
y = tf.reshape(y, [-1])
# Plot the initial configuration
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, "image_0.png")
fig = plot_z_coverage(vae, x0, oracle_z0)
savefig(fig, path, f"z0values_{0}.png")
# #######
for it in range(num_x0_iteration):
x0 = update_x0(A, y, z0)
z0 = vae.encoder(tf.reshape(x0, [28, 28]))[0, :20]
dec = get_decoder_properties(z0)
dec["sigma"] = sigma
# plot results of current iteration
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, f"image_{it+1}.png")
fig = plot_z_coverage(vae, x0, oracle_z0)
savefig(fig, path, f"z0values_{it+1}.png")
# elif optimization == "numeric":
# opt = tf.keras.optimizers.Nadam()
# # x0 = tf.Variable(update_x0(A, y, best_z0))
# for lia in range(num_x0_iteration):
# def neg_log_post():
# _retval = -log_post(A, best_z0, y, x0)
# print(f"loss: {_retval}")
# return _retval
# opt.minimize(neg_log_post, var_list=[x0])
# if lia % 1000 == 0:
# z0 = vae.encoder(tf.reshape(x0, [28, 28]))[0, :20]
# fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
# savefig(fig, "probabilistic_vae_comparison/vae_quality/good/numerically_recon/",
# f"z{it}_image{lia}.png")
retval = {
"MSE": MSE(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28])),
"PSNR": PSNR(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28])),
"SSIM": SSIM(tf.reshape(x, [28, 28]), tf.reshape(x0, [28, 28]))
}
return retval
def reference(A, x, y, sigma, path):
def ell_loss(ellh):
ret = (1/(sigma*sigma))*tf.linalg.norm(matvecmul(A, tf.reshape(x, [-1])) - tf.reshape(y, [-1]))**2 + tf.abs(ellh)*tf.linalg.norm(tf.reshape(x, [-1]))**2
return ret
ell = tf.Variable(1.0, dtype=tf.double)
loss = lambda: ell_loss(ell)
opt = tf.keras.optimizers.Adam()
best_loss = 1e18
best_ell = ell
iteration = 0
while True:
iteration += 1
# Create an optimizer with the desired parameters.
# In eager mode, simply call minimize to update the list of variables.
opt.minimize(loss, var_list=[ell])
if loss() < best_loss and ell > 0:
print(f"it: {iteration} best ell: {ell.numpy():.2g} with loss: {loss():.2f}")
best_loss = loss().numpy()
best_ell = ell.numpy()
if iteration > 2000:
break
# ########
print(f"best lambda 'oracle': {best_ell}")
x0 = matvecsolve(tf.matmul(tf.transpose(A), A) + best_ell*tf.eye(28*28, dtype=tf.double), matvecmul(tf.transpose(A), tf.reshape(y, [-1])))
fig = plot_recon(tf.reshape(x0, [28, 28]), x, tf.reshape(y, [28, 28]))
savefig(fig, path, "reference_L2.png")
# %%
A, x, y = reconstruction_problem(plot=False, sigma=4)
sigma = 1e-2
y = tf.squeeze(y + tf.random.normal(y.shape, mean=0, stddev=sigma))
_, vae_good, _ = load_vae_models()
vae = vae_good
input_dict = {
"A": tf.cast(A.real, dtype=tf.double),
"x": tf.cast(tf.reshape(x, [28, 28]), dtype=tf.double),
"y": tf.cast(y, dtype=tf.double),
"sigma": sigma,
"path": "probabilistic_vae_comparison/vae_quality/good/analytic_recon/"
}
# reference(**input_dict)
analytic_posterior(**input_dict)
# %%
This diff is collapsed.
# %%
import tensorflow as tf
from tensorflow.python.keras.backend import arange
from utils import display_imgs, load_vae_models, savefig, load_mnist
import matplotlib.pyplot as plt
# %%
def test_decoder_quality():
vae_det, vae_good, vae_bad = load_vae_models()
z = tf.random.normal([10, 20, 1, 1])
x_det = vae_det.decoder(z)
x_good, s_good = vae_good.decoder(z)
x_bad, s_bad = vae_bad.decoder(z)
print("x det")
fig = display_imgs(x_det, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/det/", "random_z.png")
print("x bad")
fig = display_imgs(x_bad, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/bad/", "random_z.png")
print("x good")
fig = display_imgs(x_good, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/good/", "random_z.png")
print("s bad")
fig = display_imgs(s_bad, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/bad/", "random_z_sd.png")
print("s good")
fig = display_imgs(s_good, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/good/", "random_z_sd.png")
def test_encoder_decoder_quality():
_, eval_dataset = load_mnist(shuffle=False)
# Take some random x
x = tf.reshape(next(iter(eval_dataset))[0][:8], (8, 28, 28))
fig = display_imgs(x, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/det/", "original_x.png")
savefig(fig, "probabilistic_vae_comparison/vae_quality/bad/", "original_x.png")
savefig(fig, "probabilistic_vae_comparison/vae_quality/good/", "original_x.png")
vae_det, vae_good, vae_bad = load_vae_models()
z_det = vae_det.encoder(x)
z_good = vae_good.encoder(x)
z_bad = vae_bad.encoder(x)
def plot_z(z):
fig, axes = plt.subplots(2, 4, sharex=True, sharey=True)
ax = axes.ravel()
for lia in range(z.shape[0]):
x_range = arange(20)
ax[lia].plot(x_range, z[lia, :20], label=f"{lia+1}")
ax[lia].fill_between(x_range, z[lia, :20]-z[lia, 20:], z[lia, :20]+z[lia, 20:], alpha=0.2)
ax[lia].set_title(f"z[{lia+1}]")
return fig
fig = plot_z(z_det)
savefig(fig, "probabilistic_vae_comparison/vae_quality/det/", "x_to_z.png")
fig = plot_z(z_good)
savefig(fig, "probabilistic_vae_comparison/vae_quality/good/", "x_to_z.png")
fig = plot_z(z_bad)
savefig(fig, "probabilistic_vae_comparison/vae_quality/bad/", "x_to_z.png")
x_det = vae_det.decoder(z_det[:, :20])
x_good, s_good = vae_good.decoder(z_good[:, :20])
x_bad, s_bad = vae_bad.decoder(z_bad[:, :20])
print("x det")
fig = display_imgs(x_det, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/det/", "x_to_x.png")
print("x bad")
fig = display_imgs(x_bad, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/bad/", "x_to_x.png")
print("x good")
fig = display_imgs(x_good, return_fig=True)
savefig(fig, "probabilistic_vae_comparison/vae_quality/good/", "x_to_x.png")