Commit 355ddd5f authored by Manuel Marschall's avatar Manuel Marschall
Browse files

latpush solver implemented

parent 667578c6
......@@ -35,7 +35,7 @@ class LinearInverseProblem():
noise_sigma: float = 0.1):
from utils import reconstruction_problem
A, x, y = reconstruction_problem(sigma=blur_sigma)
return LinearInverseProblem(A, y.reshape(-1), noise_sigma, x.numpy().reshape(-1))
return LinearInverseProblem(A.real, y.reshape(-1), noise_sigma, x.numpy().reshape(-1))
if __name__ == "__main__":
......
from numpy import ndarray as array_type
from numpy.linalg import solve
from numpy import eye as npeye
from numpy import logspace as nplogspace
from numpy import argmin as npargmin
from linear_inverse_problem import LinearInverseProblem
from generative_model import GenerativeModel
from utils import build_neighbour_matrix
class LIPSolver:
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel) -> None:
self.lip = lip
self.gm = gm
def solve(self) -> array_type:
# (A'A) x = A' y --> invert to solve
return solve(self.lip.operator.T.dot(self.lip.operator), self.lip.operator.T.dot(self.lip.data))
class LIPThikonovSolver(LIPSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
llambda: float) -> None:
super().__init__(lip, gm)
self.llambda = llambda
def solve(self) -> array_type:
return solve(self.lip.operator.T.dot(self.lip.operator) + self.llambda*npeye(self.gm.dim_x),
self.lip.operator.T.dot(self.lip.data))
def solve_oracle(self,
logspace_min: int = -6,
logspace_max: int = 2,
lospace_num: int = 1000) -> array_type:
lambda_list = nplogspace(logspace_min, logspace_max, num=logspace_max)
x_list = []
res_list = []
for ell in lambda_list:
self.llambda = ell
_x, _res = self.solve()
x_list.append(_x)
res_list.append(_res)
return x_list[npargmin(res_list)]
class LIPGMRFSolver(LIPThikonovSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
llambda: float) -> None:
super().__init__(lip, gm, llambda)
self.gmrf = build_neighbour_matrix(gm.dim_x).toarray()
def solve(self) -> array_type:
return solve(self.lip.operator.T.dot(self.lip.operator) + self.llambda*self.gmrf,
self.lip.operator.T.dot(self.lip.data))
\ No newline at end of file
......@@ -122,6 +122,7 @@ if __name__ == '__main__':
z0 = tf.Variable(tf.random.normal([10, 20, 1, 1]))
grad = tf_vae.J_z0_already_variable(z0)
assert grad.shape == (10, 1, 28, 28, 20, 1, 1)
# previously: assert grad.shape == (10, 1, 28, 28, 20, 1, 1)
assert grad.shape == (10, 28, 28, 20, 1, 1)
print("DET and STO VAE decoder checked")
from typing import Any
from numpy.lib.arraysetops import isin
import tensorflow as tf
from lip_solver import LIPSolver
from linear_inverse_problem import LinearInverseProblem
from generative_model import GenerativeModel, ProbabilisticGenerator, DeterministicGenerator
from utils import plot_recon, savefig, plot_z_coverage, plot_convergence
def matvecmul(A, b):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return tf.reshape(tf.matmul(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def matvecsolve(A, b):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return tf.reshape(tf.linalg.solve(tf.cast(A, dtype=tf.double),
tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
class VAELIPSolver(LIPSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
export_path: str) -> None:
super().__init__(lip, gm)
self.path = export_path
self.path_suffix = "generic/"
def get_initial_value(self, iteration: int = 10) -> Any:
z0 = tf.zeros([1, 20, 1, 1], dtype=tf.float64)
A = self.lip.operator
sigma = self.lip.sigma
counter = 0
while True:
if counter >= iteration:
break
counter += 1
if isinstance(self.gm, ProbabilisticGenerator):
g, gamma = self.gm(z0)
g = tf.reshape(g, [-1])
gamma = tf.reshape(gamma, [-1])
elif isinstance(self.gm, DeterministicGenerator):
g = self.gm(z0)
g = tf.reshape(g, [-1])
gamma = tf.ones(self.gm.dim_x, dtype=tf.double)
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
lhs = tf.matmul(tf.transpose(A), A)/(sigma*sigma) + tf.linalg.diag(1/gamma)
rhs = matvecmul(tf.transpose(A), tf.reshape(self.lip.data, [-1]))/(sigma*sigma) + \
matvecmul(tf.linalg.diag(1/gamma), g)
x0 = matvecsolve(lhs, rhs)
z0 = tf.reshape(self.gm.vae.encoder(tf.reshape(x0, [28, 28]))[:, :20], [20])
return z0
class VAELatpushSolver(VAELIPSolver):
def __init__(self, lip: LinearInverseProblem, gm: GenerativeModel, export_path: str) -> None:
super().__init__(lip, gm, export_path)
self.path_suffix = "latpush/"
def solve(self,
x0_iteration: int = 1,
num_iteration: int = 10000,
loss_tolerance: float = 1e-1) -> dict:
oracle_z0 = tf.squeeze(self.gm.vae.encoder(tf.reshape(self.lip.ground_truth, [28, 28]))[0, :self.gm.dim_z]).numpy()
z0 = tf.Variable(self.get_initial_value(iteration=x0_iteration))
# z0 = tf.Variable(tf.zeros(20))
y = tf.reshape(self.lip.data, [-1])
loss_tolerance = 1e-1
mse_list = []
psnr_list = []
ssim_list = []
loss_list = []
def neg_log_post():
if isinstance(self.gm, ProbabilisticGenerator):
g, _ = self.gm(z0)
elif isinstance(self.gm, DeterministicGenerator):
g = self.gm(z0)
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
g = tf.reshape(g, [-1])
s2 = (1/(self.lip.sigma*self.lip.sigma))
loss1 = s2*tf.square(tf.linalg.norm(matvecmul(self.lip.operator, g) - y))
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
loss2 = tf.cast(tf.reduce_sum(z0*z0), dtype=tf.double)
retval = (loss1 + loss2)
return retval
opt = tf.keras.optimizers.Adam()
for lia in range(num_iteration):
opt.minimize(neg_log_post, var_list=[z0])
if lia % 100 == 0 or lia == 0:
if isinstance(self.gm, ProbabilisticGenerator):
x0 = tf.reshape(self.gm(z0)[0], [-1])
elif isinstance(self.gm, DeterministicGenerator):
x0 = tf.reshape(self.gm(z0), [-1])
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
fig, mse, psnr, ssim = plot_recon(tf.reshape(x0, [28, 28]),
tf.cast(tf.reshape(self.lip.ground_truth, [28, 28]), dtype=tf.double),
tf.reshape(y, [28, 28]), return_stats=True)
mse_list.append(mse)
psnr_list.append(psnr)
ssim_list.append(ssim)
loss_list.append(neg_log_post())
print(f"Iteration {lia}/{num_iteration}, loss: {loss_list[-1]}")
savefig(fig, self.path + self.path_suffix, f"image_{lia}.png")
fig = plot_z_coverage(self.gm.vae, x0, oracle_z0)
savefig(fig, self.path + self.path_suffix, f"z0values_{lia}.png")
if lia > 100 and tf.abs(loss_list[-2] - loss_list[-1]) < loss_tolerance:
break
for ll, label in zip([mse_list, psnr_list, ssim_list, loss_list], ["mse", "psnr", "ssim", "loss"]):
fig = plot_convergence(ll, title=label)
savefig(fig, self.path + self.path_suffix, f"{label}_conv.png")
retval = {
"MSE": mse_list,
"PSNR": psnr_list,
"SSIM": ssim_list,
"LOSS": loss_list
}
return retval
if __name__ == "__main__":
from onnx_vae import ONNX_VAE_STO, ONNX_VAE_DET
path_det = "onnx_models/deterministic/"
path_prob = "onnx_models/stochastic_good/"
path_add = "probabilistic_vae_comparison/"
onnx_vae1 = ONNX_VAE_DET(path_add + path_det + "encoder.onnx",
path_add + path_det + "decoder.onnx")
tf_vae_det = onnx_vae1.to_tensorflow()
det_generator = DeterministicGenerator.from_vae(tf_vae_det, int(28*28), 20)
onnx_vae2 = ONNX_VAE_STO(path_add + path_prob + "good_probVAE_encoder.onnx",
path_add + path_prob + "good_probVAE_decoder.onnx")
tf_vae_good = onnx_vae2.to_tensorflow()
prop_generator = ProbabilisticGenerator.from_vae(tf_vae_good, int(28*28), 20)
lip = LinearInverseProblem.mnist_recon_problem()
solver = VAELatpushSolver(lip, det_generator, "test")
# solver = VAELatpushSolver(lip, prop_generator, "test")
solver.solve()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment