Skip to content
Snippets Groups Projects
Commit 355ddd5f authored by Manuel Marschall's avatar Manuel Marschall
Browse files

latpush solver implemented

parent 667578c6
Branches
No related tags found
No related merge requests found
...@@ -35,7 +35,7 @@ class LinearInverseProblem(): ...@@ -35,7 +35,7 @@ class LinearInverseProblem():
noise_sigma: float = 0.1): noise_sigma: float = 0.1):
from utils import reconstruction_problem from utils import reconstruction_problem
A, x, y = reconstruction_problem(sigma=blur_sigma) A, x, y = reconstruction_problem(sigma=blur_sigma)
return LinearInverseProblem(A, y.reshape(-1), noise_sigma, x.numpy().reshape(-1)) return LinearInverseProblem(A.real, y.reshape(-1), noise_sigma, x.numpy().reshape(-1))
if __name__ == "__main__": if __name__ == "__main__":
......
from numpy import ndarray as array_type
from numpy.linalg import solve
from numpy import eye as npeye
from numpy import logspace as nplogspace
from numpy import argmin as npargmin
from linear_inverse_problem import LinearInverseProblem
from generative_model import GenerativeModel
from utils import build_neighbour_matrix
class LIPSolver:
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel) -> None:
self.lip = lip
self.gm = gm
def solve(self) -> array_type:
# (A'A) x = A' y --> invert to solve
return solve(self.lip.operator.T.dot(self.lip.operator), self.lip.operator.T.dot(self.lip.data))
class LIPThikonovSolver(LIPSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
llambda: float) -> None:
super().__init__(lip, gm)
self.llambda = llambda
def solve(self) -> array_type:
return solve(self.lip.operator.T.dot(self.lip.operator) + self.llambda*npeye(self.gm.dim_x),
self.lip.operator.T.dot(self.lip.data))
def solve_oracle(self,
logspace_min: int = -6,
logspace_max: int = 2,
lospace_num: int = 1000) -> array_type:
lambda_list = nplogspace(logspace_min, logspace_max, num=logspace_max)
x_list = []
res_list = []
for ell in lambda_list:
self.llambda = ell
_x, _res = self.solve()
x_list.append(_x)
res_list.append(_res)
return x_list[npargmin(res_list)]
class LIPGMRFSolver(LIPThikonovSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
llambda: float) -> None:
super().__init__(lip, gm, llambda)
self.gmrf = build_neighbour_matrix(gm.dim_x).toarray()
def solve(self) -> array_type:
return solve(self.lip.operator.T.dot(self.lip.operator) + self.llambda*self.gmrf,
self.lip.operator.T.dot(self.lip.data))
\ No newline at end of file
...@@ -122,6 +122,7 @@ if __name__ == '__main__': ...@@ -122,6 +122,7 @@ if __name__ == '__main__':
z0 = tf.Variable(tf.random.normal([10, 20, 1, 1])) z0 = tf.Variable(tf.random.normal([10, 20, 1, 1]))
grad = tf_vae.J_z0_already_variable(z0) grad = tf_vae.J_z0_already_variable(z0)
assert grad.shape == (10, 1, 28, 28, 20, 1, 1) # previously: assert grad.shape == (10, 1, 28, 28, 20, 1, 1)
assert grad.shape == (10, 28, 28, 20, 1, 1)
print("DET and STO VAE decoder checked") print("DET and STO VAE decoder checked")
from typing import Any
from numpy.lib.arraysetops import isin
import tensorflow as tf
from lip_solver import LIPSolver
from linear_inverse_problem import LinearInverseProblem
from generative_model import GenerativeModel, ProbabilisticGenerator, DeterministicGenerator
from utils import plot_recon, savefig, plot_z_coverage, plot_convergence
def matvecmul(A, b):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return tf.reshape(tf.matmul(tf.cast(A, dtype=tf.double), tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
def matvecsolve(A, b):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return tf.reshape(tf.linalg.solve(tf.cast(A, dtype=tf.double),
tf.expand_dims(tf.cast(b, dtype=tf.double), 1)), [-1])
class VAELIPSolver(LIPSolver):
def __init__(self,
lip: LinearInverseProblem,
gm: GenerativeModel,
export_path: str) -> None:
super().__init__(lip, gm)
self.path = export_path
self.path_suffix = "generic/"
def get_initial_value(self, iteration: int = 10) -> Any:
z0 = tf.zeros([1, 20, 1, 1], dtype=tf.float64)
A = self.lip.operator
sigma = self.lip.sigma
counter = 0
while True:
if counter >= iteration:
break
counter += 1
if isinstance(self.gm, ProbabilisticGenerator):
g, gamma = self.gm(z0)
g = tf.reshape(g, [-1])
gamma = tf.reshape(gamma, [-1])
elif isinstance(self.gm, DeterministicGenerator):
g = self.gm(z0)
g = tf.reshape(g, [-1])
gamma = tf.ones(self.gm.dim_x, dtype=tf.double)
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
lhs = tf.matmul(tf.transpose(A), A)/(sigma*sigma) + tf.linalg.diag(1/gamma)
rhs = matvecmul(tf.transpose(A), tf.reshape(self.lip.data, [-1]))/(sigma*sigma) + \
matvecmul(tf.linalg.diag(1/gamma), g)
x0 = matvecsolve(lhs, rhs)
z0 = tf.reshape(self.gm.vae.encoder(tf.reshape(x0, [28, 28]))[:, :20], [20])
return z0
class VAELatpushSolver(VAELIPSolver):
def __init__(self, lip: LinearInverseProblem, gm: GenerativeModel, export_path: str) -> None:
super().__init__(lip, gm, export_path)
self.path_suffix = "latpush/"
def solve(self,
x0_iteration: int = 1,
num_iteration: int = 10000,
loss_tolerance: float = 1e-1) -> dict:
oracle_z0 = tf.squeeze(self.gm.vae.encoder(tf.reshape(self.lip.ground_truth, [28, 28]))[0, :self.gm.dim_z]).numpy()
z0 = tf.Variable(self.get_initial_value(iteration=x0_iteration))
# z0 = tf.Variable(tf.zeros(20))
y = tf.reshape(self.lip.data, [-1])
loss_tolerance = 1e-1
mse_list = []
psnr_list = []
ssim_list = []
loss_list = []
def neg_log_post():
if isinstance(self.gm, ProbabilisticGenerator):
g, _ = self.gm(z0)
elif isinstance(self.gm, DeterministicGenerator):
g = self.gm(z0)
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
g = tf.reshape(g, [-1])
s2 = (1/(self.lip.sigma*self.lip.sigma))
loss1 = s2*tf.square(tf.linalg.norm(matvecmul(self.lip.operator, g) - y))
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
loss2 = tf.cast(tf.reduce_sum(z0*z0), dtype=tf.double)
retval = (loss1 + loss2)
return retval
opt = tf.keras.optimizers.Adam()
for lia in range(num_iteration):
opt.minimize(neg_log_post, var_list=[z0])
if lia % 100 == 0 or lia == 0:
if isinstance(self.gm, ProbabilisticGenerator):
x0 = tf.reshape(self.gm(z0)[0], [-1])
elif isinstance(self.gm, DeterministicGenerator):
x0 = tf.reshape(self.gm(z0), [-1])
else:
raise ValueError(f"Unknown generative model type: {type(self.gm)}")
fig, mse, psnr, ssim = plot_recon(tf.reshape(x0, [28, 28]),
tf.cast(tf.reshape(self.lip.ground_truth, [28, 28]), dtype=tf.double),
tf.reshape(y, [28, 28]), return_stats=True)
mse_list.append(mse)
psnr_list.append(psnr)
ssim_list.append(ssim)
loss_list.append(neg_log_post())
print(f"Iteration {lia}/{num_iteration}, loss: {loss_list[-1]}")
savefig(fig, self.path + self.path_suffix, f"image_{lia}.png")
fig = plot_z_coverage(self.gm.vae, x0, oracle_z0)
savefig(fig, self.path + self.path_suffix, f"z0values_{lia}.png")
if lia > 100 and tf.abs(loss_list[-2] - loss_list[-1]) < loss_tolerance:
break
for ll, label in zip([mse_list, psnr_list, ssim_list, loss_list], ["mse", "psnr", "ssim", "loss"]):
fig = plot_convergence(ll, title=label)
savefig(fig, self.path + self.path_suffix, f"{label}_conv.png")
retval = {
"MSE": mse_list,
"PSNR": psnr_list,
"SSIM": ssim_list,
"LOSS": loss_list
}
return retval
if __name__ == "__main__":
from onnx_vae import ONNX_VAE_STO, ONNX_VAE_DET
path_det = "onnx_models/deterministic/"
path_prob = "onnx_models/stochastic_good/"
path_add = "probabilistic_vae_comparison/"
onnx_vae1 = ONNX_VAE_DET(path_add + path_det + "encoder.onnx",
path_add + path_det + "decoder.onnx")
tf_vae_det = onnx_vae1.to_tensorflow()
det_generator = DeterministicGenerator.from_vae(tf_vae_det, int(28*28), 20)
onnx_vae2 = ONNX_VAE_STO(path_add + path_prob + "good_probVAE_encoder.onnx",
path_add + path_prob + "good_probVAE_decoder.onnx")
tf_vae_good = onnx_vae2.to_tensorflow()
prop_generator = ProbabilisticGenerator.from_vae(tf_vae_good, int(28*28), 20)
lip = LinearInverseProblem.mnist_recon_problem()
solver = VAELatpushSolver(lip, det_generator, "test")
# solver = VAELatpushSolver(lip, prop_generator, "test")
solver.solve()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment