Commit 801b0b2f authored by Manuel Marschall's avatar Manuel Marschall
Browse files

some clean up and refactoring

parent 10f75c93
Pipeline #6060 failed with stages
in 14 minutes and 40 seconds
......@@ -5,4 +5,6 @@ experiments_no17/*
logs/*
onnx_models/*
paper/*
train_vae_test_no17/*
\ No newline at end of file
train_vae_test_no17/*
*__pyache__*
......@@ -16,7 +16,7 @@ mypy:
- pip install -r requirements.txt
- apt-get update -qq && apt-get install -y -qq pandoc
- python3 -m pip install --upgrade mypy
- python -m mypy datainformed-prior
- python -m mypy datainformed_prior
flake8:
stage: Static Analysis
......@@ -26,7 +26,7 @@ flake8:
- pip install -r requirements.txt
- apt-get update -qq && apt-get install -y -qq pandoc
- python3 -m pip install --upgrade flake8
- flake8 --max-line-length=120 datainformed-prior/*.py
- flake8 --max-line-length=120 datainformed_prior/*.py
pylint:
stage: Static Analysis
......@@ -36,7 +36,7 @@ pylint:
- pip install -r requirements.txt
- apt-get update -qq && apt-get install -y -qq pandoc
- python3 -m pip install --upgrade pylint
- pylint -d C0301 datainformed-prior/*.py
- pylint -d C0301 datainformed_prior/*.py
unit_test:
stage: Test
......@@ -47,7 +47,7 @@ unit_test:
- apt-get update -qq && apt-get install -y -qq pandoc
- python3 -m pip install --upgrade pytest pytest-cov
- python3 -m pytest
- python3 -m pytest --cov-report term-missing --cov=datainformed-prior tests/
- python3 -m pytest --cov-report term-missing --cov=datainformed_prior tests/
# pypi:
# stage: upload
......
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.linalg import dft
from scipy.fftpack import fftshift
import scipy.stats as st
from scipy.sparse import coo_matrix, lil_matrix
import tensorflow as tf
import tensorflow_datasets as tfds
from skimage.metrics import structural_similarity as ssim
from onnx_vae import ONNX_VAE_DET, ONNX_VAE_STO, ONNX_VAE_BAD
from generative_model import DeterministicGenerator, ProbabilisticGenerator
tfk = tf.keras
tfkl = tf.keras.layers
def load_mnist(shuffle=True, bs=256):
@tf.autograph.experimental.do_not_convert
def _preprocess(sample):
image = tf.cast(sample['image'], tf.float32) / 255. # Scale to unit interval.
# image = 2*image - 1 # Scale to [-1, 1] simpler to learn
# Binarizing is sometimes done for some reason. We do not do this here.
# image = image < tf.random.uniform(tf.shape(image)) # Randomly binarize.
return image, image
# Downloads the MNIST data automatically if not found on your system
datasets = tfds.load(name='mnist', with_info=False, as_supervised=False, shuffle_files=shuffle)
# process and create batches for training and testdata
train_dataset = (datasets['train'].map(_preprocess).batch(bs).prefetch(tf.data.experimental.AUTOTUNE))
eval_dataset = (datasets['test'].map(_preprocess).batch(bs).prefetch(tf.data.experimental.AUTOTUNE))
return train_dataset, eval_dataset
def display_imgs(_x, y=None, return_fig=False):
if not return_fig:
plt.ioff()
if not isinstance(_x, (np.ndarray, np.generic)):
_x = np.array(_x)
if len(_x.shape) == 2:
_x = np.reshape(_x, (-1, 28, 28))
n = _x.shape[0]
fig, axs = plt.subplots(1, n, figsize=(n, 1))
if y is not None:
fig.suptitle(np.argmax(y, axis=1))
if n > 1:
for i in range(n):
axs.flat[i].imshow(_x[i].squeeze(), cmap='gray', vmin=0, vmax=1)
axs.flat[i].axis('off')
else:
axs.imshow(_x.squeeze(), cmap="gray")
if return_fig:
return fig
plt.show()
plt.close()
plt.ion()
def savefig(fig, path, filename):
if not os.path.exists(path):
os.makedirs(path)
fig.savefig(path + filename)
def MSE(x, xh):
return (np.linalg.norm(x-xh)**2)/np.prod(x.shape)
def PSNR(x, xh):
return (20*np.log10(1)-10*np.log10(MSE(x, xh)))
def SSIM(x, xh):
if not isinstance(x, np.ndarray):
x = x.numpy()
if not isinstance(xh, np.ndarray):
xh = xh.numpy()
return ssim(x, xh, data_range=x.max() - x.min())
def plot_z_coverage(vae, x0, opt_z0, _title=None):
fig = plt.figure()
curr_grid = range(20)
# curr_z0 = np.squeeze(z0.numpy())
if _title is not None:
plt.title(_title)
z0_a = vae.encoder(tf.reshape(x0, [28, 28]))
mu = z0_a[0, :20]
rho = z0_a[0, 20:]
curr_mean = tf.squeeze(mu).numpy()
sd = tf.squeeze(tf.math.exp(0.5*rho)).numpy()
plt.plot(curr_grid, curr_mean, 'or', label="Mean")
plt.plot(curr_grid, opt_z0, '*', color="black", label="Dec-Mean(x_true)")
plt.fill_between(curr_grid, curr_mean - 1.98*sd, curr_mean + 1.98*sd, label="+-1.98*sd", alpha=0.5)
# for lia in range(20):
# if curr_mean[lia] - 1.98*sd[lia] <= mean_ast[lia] <= curr_mean[lia] + 1.98*sd[lia]:
# plt.plot(curr_grid[lia], mean_ast[lia], 'xb', label="z0" if lia == 0 else "")
# else:
# plt.plot(curr_grid[lia], curr_z0[lia], 'x', color="orange", label="z0" if lia == 0 else "")
plt.legend()
plt.ylim([-4, 4])
return fig
def plot_recon(loc_x, x_true, yh, _title=None, return_stats=False):
fig, ax = plt.subplots(1, 4, figsize=(16, 3))
im = ax[0].imshow(x_true, cmap="jet", vmin=0, vmax=1)
ax[0].set_title("orig")
plt.colorbar(im, ax=ax[0])
im = ax[1].imshow(yh, cmap="jet", vmin=0, vmax=1)
ax[1].set_title("Observ.")
plt.colorbar(im, ax=ax[1])
im = ax[2].imshow(loc_x, cmap="jet", vmin=0, vmax=1)
plt.colorbar(im, ax=ax[2])
im = ax[3].imshow(np.abs(loc_x - x_true), cmap="jet", vmin=0, vmax=2)
ax[3].set_title("difference")
plt.colorbar(im, ax=ax[3])
mse = MSE(x_true.numpy(), loc_x.numpy())
psnr = PSNR(x_true.numpy(), loc_x.numpy())
ssim = SSIM(x_true.numpy(), loc_x.numpy())
if _title is None:
_title = f"RMSE/PSNR/SSIM: {np.sqrt(mse):.2f}/{psnr:.2f}/{ssim:.2f}"
ax[2].set_title(_title)
print(f"RMSE/PSNR/SSIM: {np.sqrt(mse):.2f}/{psnr:.2f}/{ssim:.2f}")
if return_stats:
return fig, mse, psnr, ssim
return fig
def plot_convergence(y, title=None):
fig = plt.figure()
plt.plot(y)
if title is not None:
plt.title(title)
return fig
def build_neighbour_matrix(n, m):
"""
Generates a matrix of size (n*m x n*m) indicating the number of neighbor
of a pixel by calling neig_metric for every pixel/ vertex combination
Arguments:
n {int} -- number of vertices in x direction
m {int} -- number of vertices in y direction
Returns:
sparse lil-matrix -- neighbor matrix
"""
X = np.arange(0, int(n*m)).reshape([int(n), int(m)], order="F")
# nrc = n*m
ic = np.zeros([n+2, m+2])
ic[1:-1, 1:-1] = X
Ind = np.ones([n+2, m+2], dtype=np.bool)
Ind[0, :] = 0
Ind[-1, :] = 0
Ind[:, 0] = 0
Ind[:, -1] = 0
icd = np.zeros([np.prod(n*m), 8])
icd[:, 0] = ic[np.roll(Ind, 1, axis=1)] # shift right
icd[:, 1] = ic[np.roll(Ind, 1, axis=0)] # shift down
icd[:, 2] = ic[np.roll(Ind, -1, axis=1)] # shift left
icd[:, 3] = ic[np.roll(Ind, -1, axis=0)] # shift up
# shift up and right
icd[:, 4] = ic[np.roll(np.roll(Ind, 1, axis=1), -1, axis=0)]
# shift up and left
icd[:, 5] = ic[np.roll(np.roll(Ind, -1, axis=1), -1, axis=0)]
# shift down and right
icd[:, 6] = ic[np.roll(np.roll(Ind, 1, axis=1), 1, axis=0)]
# shift down and left
icd[:, 7] = ic[np.roll(np.roll(Ind, -1, axis=1), 1, axis=0)]
ic = np.tile(ic[Ind].reshape(-1, order="F"), (8, 1)).ravel(order="C")
icd = icd.reshape(-1, order="F")
data = np.ones([len(icd), 1]).ravel(order="F")
Kcol_A_py = coo_matrix((data, (ic, icd)), shape=[int(n*m), int(n*m)])
su = np.sum(Kcol_A_py, axis=0).T
Kcol_A_py = lil_matrix(-Kcol_A_py)
Kcol_A_py[2:, 0] = 0
if m > 1:
Kcol_A_py[n, 0] = -1
Kcol_A_py[n+1, 0] = -1
Kcol_A_py[1, 0] = -1
Kcol_A_py.setdiag(su, 0)
if m > 1:
Kcol_A_py[0, 0] = 3
else:
Kcol_A_py[0, 0] = 1
dist = Kcol_A_py
return dist
def blur_matrix_operator(kernlen=28, _sigma=4):
"""Returns a matrix which acts as a Gaussian blur operator"""
def gkern(kernlen=28, nsig=9):
"""Returns a 2D Gaussian kernel."""
x = np.linspace(-nsig, nsig, kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kern2d = np.outer(kern1d, kern1d)
return kern2d/kern2d.sum()
_F = dft(28, scale="sqrtn") # size (N, N) - 1D DFT
_F = np.kron(_F, _F) # stacking dimensions, not (N, N, N, N) but (N*N, N*N)
_K = _F.dot(fftshift(gkern(kernlen, _sigma)).ravel()) # revert kernel -> F
return np.linalg.inv(_F).dot(np.diag(_K.ravel())).dot(_F) * np.sqrt(kernlen**2)
def derivatives(_xx, _yy):
d = np.array(np.diff(_yy)/np.diff(_xx))
dd = np.array(np.diff(d)/np.diff(_xx[1:]))
return np.array(_xx), np.insert(d, 0, d[0]), np.insert(dd, 0, [dd[0], dd[0]])
def solve_lsq_simple(_mat, _rhs, noise=None):
nn = 1 if noise is None else 1/noise
return np.linalg.solve(nn*_mat.T.dot(nn*_mat), nn*_mat.T.dot(nn*_rhs)).reshape(28, 28).real
def solve_regumat_lsq(_mat, _rhs, _regumat, _lmbda, noise=None):
nn = 1 if noise is None else 1/noise
B = np.vstack([nn*_mat, _lmbda*_regumat])
yh0 = np.hstack([nn*_rhs, np.zeros(_rhs.shape[0]).ravel()])
return np.linalg.solve(B.T.dot(B), B.T.dot(yh0)).reshape(28, 28).real
def solve_regumat_regumean_lsq(_mat, _rhs, _regumat, _regumean, _lmbda, noise=None):
nn = 1 if noise is None else 1/noise
B = np.vstack([nn*_mat, _lmbda*_regumat])
yh0 = np.hstack([nn*_rhs, _lmbda*_regumat.dot(_regumean)])
return np.linalg.solve(B.T.dot(B), B.T.dot(yh0)).reshape(28, 28).real
def solve_lsq(_mat, _rhs, _regumat=None, _regumean=None, _lmbda=1, noise=None):
if _regumat is None and _regumean is None:
return solve_lsq_simple(_mat, _rhs, noise=noise)
if _regumat is not None and _regumean is None:
return solve_regumat_lsq(_mat, _rhs, _regumat, _lmbda, noise=noise)
return solve_regumat_regumean_lsq(_mat, _rhs, _regumat, _regumean, _lmbda, noise=noise)
def solve(_mat, _rhs, _regumat=None, _regumean=None, verbose=False, lnum=100, noise=None):
if _regumat is None and _regumean is None:
return solve_lsq(_mat, _rhs, noise=noise)
assert _regumat is not None
LogNtoS = 1.1 # logarithmic Noise to Signal ratio
# # guessing for L_Curve Scan Limit
SearchInterval = 3 # logarithmic search interval ratio
alpha1 = -2*LogNtoS-2*SearchInterval+np.log10(np.max(_rhs)**2) # Lower scan Limit of L_Curve scan
alpha2 = -2*LogNtoS+2*SearchInterval+np.log10(np.max(_rhs)**2) # Higher scan Limit of L_Curve scan
lmd_range = np.logspace(alpha1, alpha2, num=lnum)
res = np.zeros(len(lmd_range))
reg = np.zeros(len(lmd_range))
for lia, lmd in enumerate(lmd_range):
if verbose and (lia % 10 == 0):
from IPython import display
display.clear_output(wait=True)
print("Do L-Curve")
print(" {}/{}".format(lia, len(lmd_range)))
xl = solve_lsq(_mat, _rhs, _regumat, _regumean, lmd, noise=noise).ravel()
res[lia] = np.linalg.norm(_mat.dot(xl) - _rhs)
if _regumean is None:
reg[lia] = np.linalg.norm(_regumat.dot(xl))
else:
reg[lia] = np.linalg.norm(_regumat.dot(xl - _regumean))
lognn2 = np.log10(res**2)
logrr2 = np.log10(reg**2)
a, drho, ddrho = derivatives(lmd_range, logrr2)
a, deta, ddeta = derivatives(lmd_range, lognn2)
a = lmd_range
kappa = 2*(drho*ddeta-deta*ddrho)/(drho**2+deta**2)**(3/2)
a = a.astype(float)
p = np.argsort(kappa)[::-1][:len(kappa)]
aopt_ind = p[0]
aopt = a[aopt_ind]
if verbose:
plt.figure()
plt.loglog(res, reg)
plt.plot(res[aopt_ind], reg[aopt_ind], 'xr')
return solve_lsq(_mat, _rhs, _regumat, _regumean, aopt, noise=noise)
def solve_from_mean_std(_mat, _rhs, _mean, _std, lmd=None):
_u, _s, _v = np.linalg.svd(np.diag(_std.ravel()**2), full_matrices=False)
sqrtCov = _v.T.dot(np.diag(_s**-0.5).dot(_u.T))
if lmd is None:
return solve(_mat, _rhs, sqrtCov, _mean.ravel())
else:
return solve_lsq(_mat, _rhs, sqrtCov, _mean.ravel(), lmd)
def solve_from_mean_cov(_mat, _rhs, _mean, _cov, lmd=None):
_u, _s, _v = np.linalg.svd(_cov, full_matrices=False)
sqrtCov = _v.T.dot(np.diag(_s**-0.5).dot(_u.T))
if lmd is None:
return solve(_mat, _rhs, sqrtCov, _mean.ravel())
else:
return solve_lsq(_mat, _rhs, sqrtCov, _mean.ravel(), lmd)
def solve_from_mean_prec_cov(_mat, _rhs, _mean, _cov, lmd=None, noise=None):
_u, _s, _v = np.linalg.svd(_cov, full_matrices=False)
sqrtCov = _u.dot(np.diag(_s**0.5).dot(_v))
if lmd is None:
return solve(_mat, _rhs, sqrtCov, _mean.ravel(), noise=noise)
else:
return solve_lsq(_mat, _rhs, sqrtCov, _mean.ravel(), lmd, noise=noise)
def solve_recon_vae(vae, Amat, yh, encoded_size, regu=1e-10, lmd=None, _iv=None):
if _iv is None:
# iv = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1), reinterpreted_batch_ndims=1)
iv = np.random.normal(loc=np.zeros(encoded_size), scale=1)
prior = vae.decoder(iv)
mean = tf.reduce_mean(np.zeros(encoded_size), axis=0).numpy().squeeze().reshape(28, 28)
stddev = tf.reduce_mean(np.ones(encoded_size), axis=0).numpy().squeeze().reshape(28, 28)+regu*np.ones((28, 28))
else:
iv = vae.encoder(tf.reshape(tf.convert_to_tensor(_iv), (1, 28, 28, 1)))
prior = vae.decoder(iv.sample(1000))
mean = tf.reduce_mean(prior.mean(), axis=0).numpy().squeeze().reshape(28, 28)
stddev = tf.reduce_mean(prior.stddev(), axis=0).numpy().squeeze().reshape(28, 28) + regu*np.ones((28, 28))
return solve_from_mean_std(Amat, yh.ravel(), _mean=mean, _std=(stddev.ravel()), lmd=lmd)
def do_recon_experiment(x, vae, encoded_size):
Amat = blur_matrix_operator()
yh = Amat.dot(x.ravel()).reshape(28, 28).real
x_plain = solve(Amat, yh.ravel())
x_l2 = solve(Amat, yh.ravel(), np.eye(Amat.shape[0]))
gmrf = build_neighbour_matrix(28, 28).toarray()
_u, _s, _v = np.linalg.svd(gmrf, full_matrices=False)
sqrtCov = _u.dot(np.diag(_s**0.5).dot(_v))
x_gmrf = solve(Amat, yh.ravel(), sqrtCov)
x_vae_normalprior = solve_recon_vae(vae, Amat, yh, encoded_size)
x_vae_l2prior = solve_recon_vae(vae, Amat, yh, encoded_size, _iv=x_l2)
retval = {
"x": x,
"Amat": Amat,
"yh": yh,
"x_plain": x_plain,
"x_l2": x_l2,
"x_gmrf": x_gmrf,
"x_vae_normalprior": x_vae_normalprior,
"x_vae_l2prior": x_vae_l2prior
}
return retval
def reconstruction_problem(plot=False, sigma=4):
A = blur_matrix_operator(_sigma=sigma)
_, eval_dataset = load_mnist(shuffle=False)
# Take some random x
x = next(iter(eval_dataset))[0][8]
y = A.dot(x.numpy().ravel()).real.reshape(x.shape)
if plot:
plt.ioff()
plt.figure()
plt.subplot(121)
plt.imshow(x, vmin=0, vmax=1)
plt.subplot(122)
plt.imshow(y, vmin=0, vmax=1)
plt.show()
plt.close()
plt.ion()
return A, x, y
def load_vae_models():
path1 = "onnx_models/deterministic/"
path2 = "onnx_models/stochastic_good/"
path3 = "onnx_models/stochastic_bad/"
path_add = "probabilistic_vae_comparison/"
try:
onnx_vae1 = ONNX_VAE_DET(path1 + "encoder.onnx", path1 + "decoder.onnx")
tf_vae_det = onnx_vae1.to_tensorflow()
tf_vae_det = DeterministicGenerator.from_vae(tf_vae_det, int(28*28), 20)
# onnx_vae2 = ONNX_VAE_STO(path2 + "encoder.onnx", path2 + "decoder.onnx")
onnx_vae2 = ONNX_VAE_STO(path2 + "good_probVAE_encoder.onnx", path2 + "good_probVAE_decoder.onnx")
tf_vae_good = onnx_vae2.to_tensorflow()
tf_vae_good = ProbabilisticGenerator.from_vae(tf_vae_good, int(28*28), 20)
onnx_vae3 = ONNX_VAE_BAD(path3 + "bad_probVAE_encoder.onnx", path3 + "bad_probVAE_decoder.onnx")
tf_vae_bad = onnx_vae3.to_tensorflow()
tf_vae_bad = ProbabilisticGenerator.from_vae(tf_vae_bad, int(28*28), 20)
except IOError:
onnx_vae1 = ONNX_VAE_DET(path_add + path1 + "encoder.onnx", path_add + path1 + "decoder.onnx")
tf_vae_det = onnx_vae1.to_tensorflow()
tf_vae_det = DeterministicGenerator.from_vae(tf_vae_det, int(28*28), 20)
onnx_vae2 = ONNX_VAE_STO(path_add + path2 + "good_probVAE_decoder.onnx",
path_add + path2 + "good_probVAE_decoder.onnx")
tf_vae_good = onnx_vae2.to_tensorflow()
tf_vae_good = ProbabilisticGenerator.from_vae(tf_vae_good, int(28*28), 20)
onnx_vae3 = ONNX_VAE_BAD(path_add + path3 + "bad_probVAE_encoder.onnx",
path_add + path3 + "bad_probVAE_decoder.onnx")
tf_vae_bad = onnx_vae3.to_tensorflow()
tf_vae_bad = ProbabilisticGenerator.from_vae(tf_vae_bad, int(28*28), 20)
return tf_vae_det, tf_vae_good, tf_vae_bad
......@@ -5,7 +5,7 @@ from abc import abstractmethod
from math import sqrt
from typing import Optional
from numpy import ndarray as array_type
from onnx_vae import ONNX_VAE_DET, ONNX_VAE_STO
from onnx_vae import ONNX_VAE_STO
from tf_vae import TF_VAE
......@@ -48,20 +48,29 @@ class GenerativeModel():
"about the generative model")
return self.vae.decoder(z_0)
def J_z0_already_variable(self, z0):
return self.vae.J_z0_already_variable(z0)
def J_z0_already_variable(self, curr_z0): # pylint: disable=invalid-name
"""
compute jacobian of generator with respect to input
Arguments:
curr_z0 {tf.Tensor} -- latent vector
def encoder(self, x):
Returns:
tf.Tensor -- Jacobian
"""
return self.vae.J_z0_already_variable(curr_z0)
def encoder(self, curr_x):
"""
wraps the encoder function of a variational auto-encoder. Not clean design but works atm.
Arguments:
x {array_like} -- given image to decode
curr_x {array_like} -- given image to decode
Returns:
array_like -- encoded latent variable
"""
return self.vae.encoder(x)
return self.vae.encoder(curr_x)
class DeterministicGenerator(GenerativeModel):
......@@ -77,7 +86,6 @@ class DeterministicGenerator(GenerativeModel):
super().__init__(dim_x, dim_z, channels)
self.class_type = "DeterministicGenerator"
@staticmethod
def from_vae(vae: TF_VAE,
dim_x: int,
......@@ -97,7 +105,6 @@ class DeterministicGenerator(GenerativeModel):
return retval
class ProbabilisticGenerator(GenerativeModel):
"""Probabilistic / Variational generative models"""
def __init__(self, dim_x: int, dim_z: int, channels: Optional[int] = 1) -> None:
......@@ -150,7 +157,6 @@ class ConvexProbabilisticGenerator():
self.image_dim_x = (int(sqrt(genm1.dim_x)), int(sqrt(genm1.dim_x)), genm1.channels)
self.dim_z = genm1.dim_z
def __call__(self, z_0: array_type) -> array_type:
"""Call of decoder method of generative model VAEs and convex combine them
......@@ -171,10 +177,9 @@ class ConvexProbabilisticGenerator():
mean2, cov2 = self.genm2(z_0)
new_mean = self.alpha*mean1 + (1-self.alpha)*mean2
new_cov = self.alpha**2*cov1 + (1-self.alpha)**2*cov2
return new_mean, new_cov
return new_mean, new_cov # type: ignore
def encoder(self, x):
def encoder(self, variable_vec):
"""
wraps the encoder function of a variational auto-encoder. Not clean design but works atm.
......@@ -184,21 +189,29 @@ class ConvexProbabilisticGenerator():
Returns:
array_like -- encoded latent variable
"""
z1 = self.genm1.encoder(x)
mean1, cov1 = z1[:, :20], z1[:, 20:]
z2 = self.genm2.encoder(x)
mean2, cov2 = z2[:, :20], z2[:, 20:]
latent_vec1 = self.genm1.encoder(variable_vec)
mean1, cov1 = latent_vec1[:, :20], latent_vec1[:, 20:]
latent_vec2 = self.genm2.encoder(variable_vec)
mean2, cov2 = latent_vec2[:, :20], latent_vec2[:, 20:]
new_mean = self.alpha*mean1 + (1-self.alpha)*mean2
new_cov = self.alpha**2*cov1 + (1-self.alpha)**2*cov2
import numpy as np
retval = np.zeros([new_mean.shape[0], int(2*self.dim_z)])
retval[:, :self.dim_z] = new_mean.numpy()
retval[:, self.dim_z:] = new_cov.numpy()
return retval
def J_z0_already_variable(self, z0):
jac1 = self.genm1.J_z0_already_variable(z0)
jac2 = self.genm2.J_z0_already_variable(z0)
def J_z0_already_variable(self, curr_z0): # pylint: disable=invalid-name
"""
compute jacobian of convex combination of generators with respect to input
Arguments:
curr_z0 {tf.Tensor} -- latent vector
Returns:
tf.Tensor -- Jacobian
"""
jac1 = self.genm1.J_z0_already_variable(curr_z0)
jac2 = self.genm2.J_z0_already_variable(curr_z0)
return self.alpha*jac1 + (1-self.alpha)*jac2
......
"""Wrapper class for a linear inverse problem (LIP)"""
from typing import Optional