Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Manuel Marschall
datainformed-prior
Commits
355ddd5f
Commit
355ddd5f
authored
Nov 29, 2021
by
Manuel Marschall
Browse files
latpush solver implemented
parent
667578c6
Changes
4
Hide whitespace changes
Inline
Side-by-side
linear_inverse_problem.py
View file @
355ddd5f
...
...
@@ -35,7 +35,7 @@ class LinearInverseProblem():
noise_sigma
:
float
=
0.1
):
from
utils
import
reconstruction_problem
A
,
x
,
y
=
reconstruction_problem
(
sigma
=
blur_sigma
)
return
LinearInverseProblem
(
A
,
y
.
reshape
(
-
1
),
noise_sigma
,
x
.
numpy
().
reshape
(
-
1
))
return
LinearInverseProblem
(
A
.
real
,
y
.
reshape
(
-
1
),
noise_sigma
,
x
.
numpy
().
reshape
(
-
1
))
if
__name__
==
"__main__"
:
...
...
lip_solver.py
0 → 100644
View file @
355ddd5f
from
numpy
import
ndarray
as
array_type
from
numpy.linalg
import
solve
from
numpy
import
eye
as
npeye
from
numpy
import
logspace
as
nplogspace
from
numpy
import
argmin
as
npargmin
from
linear_inverse_problem
import
LinearInverseProblem
from
generative_model
import
GenerativeModel
from
utils
import
build_neighbour_matrix
class
LIPSolver
:
def
__init__
(
self
,
lip
:
LinearInverseProblem
,
gm
:
GenerativeModel
)
->
None
:
self
.
lip
=
lip
self
.
gm
=
gm
def
solve
(
self
)
->
array_type
:
# (A'A) x = A' y --> invert to solve
return
solve
(
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
operator
),
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
data
))
class
LIPThikonovSolver
(
LIPSolver
):
def
__init__
(
self
,
lip
:
LinearInverseProblem
,
gm
:
GenerativeModel
,
llambda
:
float
)
->
None
:
super
().
__init__
(
lip
,
gm
)
self
.
llambda
=
llambda
def
solve
(
self
)
->
array_type
:
return
solve
(
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
operator
)
+
self
.
llambda
*
npeye
(
self
.
gm
.
dim_x
),
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
data
))
def
solve_oracle
(
self
,
logspace_min
:
int
=
-
6
,
logspace_max
:
int
=
2
,
lospace_num
:
int
=
1000
)
->
array_type
:
lambda_list
=
nplogspace
(
logspace_min
,
logspace_max
,
num
=
logspace_max
)
x_list
=
[]
res_list
=
[]
for
ell
in
lambda_list
:
self
.
llambda
=
ell
_x
,
_res
=
self
.
solve
()
x_list
.
append
(
_x
)
res_list
.
append
(
_res
)
return
x_list
[
npargmin
(
res_list
)]
class
LIPGMRFSolver
(
LIPThikonovSolver
):
def
__init__
(
self
,
lip
:
LinearInverseProblem
,
gm
:
GenerativeModel
,
llambda
:
float
)
->
None
:
super
().
__init__
(
lip
,
gm
,
llambda
)
self
.
gmrf
=
build_neighbour_matrix
(
gm
.
dim_x
).
toarray
()
def
solve
(
self
)
->
array_type
:
return
solve
(
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
operator
)
+
self
.
llambda
*
self
.
gmrf
,
self
.
lip
.
operator
.
T
.
dot
(
self
.
lip
.
data
))
\ No newline at end of file
onnx_vae.py
View file @
355ddd5f
...
...
@@ -122,6 +122,7 @@ if __name__ == '__main__':
z0
=
tf
.
Variable
(
tf
.
random
.
normal
([
10
,
20
,
1
,
1
]))
grad
=
tf_vae
.
J_z0_already_variable
(
z0
)
assert
grad
.
shape
==
(
10
,
1
,
28
,
28
,
20
,
1
,
1
)
# previously: assert grad.shape == (10, 1, 28, 28, 20, 1, 1)
assert
grad
.
shape
==
(
10
,
28
,
28
,
20
,
1
,
1
)
print
(
"DET and STO VAE decoder checked"
)
vae_lip_solver.py
0 → 100644
View file @
355ddd5f
from
typing
import
Any
from
numpy.lib.arraysetops
import
isin
import
tensorflow
as
tf
from
lip_solver
import
LIPSolver
from
linear_inverse_problem
import
LinearInverseProblem
from
generative_model
import
GenerativeModel
,
ProbabilisticGenerator
,
DeterministicGenerator
from
utils
import
plot_recon
,
savefig
,
plot_z_coverage
,
plot_convergence
def
matvecmul
(
A
,
b
):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return
tf
.
reshape
(
tf
.
matmul
(
tf
.
cast
(
A
,
dtype
=
tf
.
double
),
tf
.
expand_dims
(
tf
.
cast
(
b
,
dtype
=
tf
.
double
),
1
)),
[
-
1
])
def
matvecsolve
(
A
,
b
):
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
return
tf
.
reshape
(
tf
.
linalg
.
solve
(
tf
.
cast
(
A
,
dtype
=
tf
.
double
),
tf
.
expand_dims
(
tf
.
cast
(
b
,
dtype
=
tf
.
double
),
1
)),
[
-
1
])
class
VAELIPSolver
(
LIPSolver
):
def
__init__
(
self
,
lip
:
LinearInverseProblem
,
gm
:
GenerativeModel
,
export_path
:
str
)
->
None
:
super
().
__init__
(
lip
,
gm
)
self
.
path
=
export_path
self
.
path_suffix
=
"generic/"
def
get_initial_value
(
self
,
iteration
:
int
=
10
)
->
Any
:
z0
=
tf
.
zeros
([
1
,
20
,
1
,
1
],
dtype
=
tf
.
float64
)
A
=
self
.
lip
.
operator
sigma
=
self
.
lip
.
sigma
counter
=
0
while
True
:
if
counter
>=
iteration
:
break
counter
+=
1
if
isinstance
(
self
.
gm
,
ProbabilisticGenerator
):
g
,
gamma
=
self
.
gm
(
z0
)
g
=
tf
.
reshape
(
g
,
[
-
1
])
gamma
=
tf
.
reshape
(
gamma
,
[
-
1
])
elif
isinstance
(
self
.
gm
,
DeterministicGenerator
):
g
=
self
.
gm
(
z0
)
g
=
tf
.
reshape
(
g
,
[
-
1
])
gamma
=
tf
.
ones
(
self
.
gm
.
dim_x
,
dtype
=
tf
.
double
)
else
:
raise
ValueError
(
f
"Unknown generative model type:
{
type
(
self
.
gm
)
}
"
)
lhs
=
tf
.
matmul
(
tf
.
transpose
(
A
),
A
)
/
(
sigma
*
sigma
)
+
tf
.
linalg
.
diag
(
1
/
gamma
)
rhs
=
matvecmul
(
tf
.
transpose
(
A
),
tf
.
reshape
(
self
.
lip
.
data
,
[
-
1
]))
/
(
sigma
*
sigma
)
+
\
matvecmul
(
tf
.
linalg
.
diag
(
1
/
gamma
),
g
)
x0
=
matvecsolve
(
lhs
,
rhs
)
z0
=
tf
.
reshape
(
self
.
gm
.
vae
.
encoder
(
tf
.
reshape
(
x0
,
[
28
,
28
]))[:,
:
20
],
[
20
])
return
z0
class
VAELatpushSolver
(
VAELIPSolver
):
def
__init__
(
self
,
lip
:
LinearInverseProblem
,
gm
:
GenerativeModel
,
export_path
:
str
)
->
None
:
super
().
__init__
(
lip
,
gm
,
export_path
)
self
.
path_suffix
=
"latpush/"
def
solve
(
self
,
x0_iteration
:
int
=
1
,
num_iteration
:
int
=
10000
,
loss_tolerance
:
float
=
1e-1
)
->
dict
:
oracle_z0
=
tf
.
squeeze
(
self
.
gm
.
vae
.
encoder
(
tf
.
reshape
(
self
.
lip
.
ground_truth
,
[
28
,
28
]))[
0
,
:
self
.
gm
.
dim_z
]).
numpy
()
z0
=
tf
.
Variable
(
self
.
get_initial_value
(
iteration
=
x0_iteration
))
# z0 = tf.Variable(tf.zeros(20))
y
=
tf
.
reshape
(
self
.
lip
.
data
,
[
-
1
])
loss_tolerance
=
1e-1
mse_list
=
[]
psnr_list
=
[]
ssim_list
=
[]
loss_list
=
[]
def
neg_log_post
():
if
isinstance
(
self
.
gm
,
ProbabilisticGenerator
):
g
,
_
=
self
.
gm
(
z0
)
elif
isinstance
(
self
.
gm
,
DeterministicGenerator
):
g
=
self
.
gm
(
z0
)
else
:
raise
ValueError
(
f
"Unknown generative model type:
{
type
(
self
.
gm
)
}
"
)
g
=
tf
.
reshape
(
g
,
[
-
1
])
s2
=
(
1
/
(
self
.
lip
.
sigma
*
self
.
lip
.
sigma
))
loss1
=
s2
*
tf
.
square
(
tf
.
linalg
.
norm
(
matvecmul
(
self
.
lip
.
operator
,
g
)
-
y
))
# pylint: disable=unexpected-keyword-arg, redundant-keyword-arg, no-value-for-parameter
loss2
=
tf
.
cast
(
tf
.
reduce_sum
(
z0
*
z0
),
dtype
=
tf
.
double
)
retval
=
(
loss1
+
loss2
)
return
retval
opt
=
tf
.
keras
.
optimizers
.
Adam
()
for
lia
in
range
(
num_iteration
):
opt
.
minimize
(
neg_log_post
,
var_list
=
[
z0
])
if
lia
%
100
==
0
or
lia
==
0
:
if
isinstance
(
self
.
gm
,
ProbabilisticGenerator
):
x0
=
tf
.
reshape
(
self
.
gm
(
z0
)[
0
],
[
-
1
])
elif
isinstance
(
self
.
gm
,
DeterministicGenerator
):
x0
=
tf
.
reshape
(
self
.
gm
(
z0
),
[
-
1
])
else
:
raise
ValueError
(
f
"Unknown generative model type:
{
type
(
self
.
gm
)
}
"
)
fig
,
mse
,
psnr
,
ssim
=
plot_recon
(
tf
.
reshape
(
x0
,
[
28
,
28
]),
tf
.
cast
(
tf
.
reshape
(
self
.
lip
.
ground_truth
,
[
28
,
28
]),
dtype
=
tf
.
double
),
tf
.
reshape
(
y
,
[
28
,
28
]),
return_stats
=
True
)
mse_list
.
append
(
mse
)
psnr_list
.
append
(
psnr
)
ssim_list
.
append
(
ssim
)
loss_list
.
append
(
neg_log_post
())
print
(
f
"Iteration
{
lia
}
/
{
num_iteration
}
, loss:
{
loss_list
[
-
1
]
}
"
)
savefig
(
fig
,
self
.
path
+
self
.
path_suffix
,
f
"image_
{
lia
}
.png"
)
fig
=
plot_z_coverage
(
self
.
gm
.
vae
,
x0
,
oracle_z0
)
savefig
(
fig
,
self
.
path
+
self
.
path_suffix
,
f
"z0values_
{
lia
}
.png"
)
if
lia
>
100
and
tf
.
abs
(
loss_list
[
-
2
]
-
loss_list
[
-
1
])
<
loss_tolerance
:
break
for
ll
,
label
in
zip
([
mse_list
,
psnr_list
,
ssim_list
,
loss_list
],
[
"mse"
,
"psnr"
,
"ssim"
,
"loss"
]):
fig
=
plot_convergence
(
ll
,
title
=
label
)
savefig
(
fig
,
self
.
path
+
self
.
path_suffix
,
f
"
{
label
}
_conv.png"
)
retval
=
{
"MSE"
:
mse_list
,
"PSNR"
:
psnr_list
,
"SSIM"
:
ssim_list
,
"LOSS"
:
loss_list
}
return
retval
if
__name__
==
"__main__"
:
from
onnx_vae
import
ONNX_VAE_STO
,
ONNX_VAE_DET
path_det
=
"onnx_models/deterministic/"
path_prob
=
"onnx_models/stochastic_good/"
path_add
=
"probabilistic_vae_comparison/"
onnx_vae1
=
ONNX_VAE_DET
(
path_add
+
path_det
+
"encoder.onnx"
,
path_add
+
path_det
+
"decoder.onnx"
)
tf_vae_det
=
onnx_vae1
.
to_tensorflow
()
det_generator
=
DeterministicGenerator
.
from_vae
(
tf_vae_det
,
int
(
28
*
28
),
20
)
onnx_vae2
=
ONNX_VAE_STO
(
path_add
+
path_prob
+
"good_probVAE_encoder.onnx"
,
path_add
+
path_prob
+
"good_probVAE_decoder.onnx"
)
tf_vae_good
=
onnx_vae2
.
to_tensorflow
()
prop_generator
=
ProbabilisticGenerator
.
from_vae
(
tf_vae_good
,
int
(
28
*
28
),
20
)
lip
=
LinearInverseProblem
.
mnist_recon_problem
()
solver
=
VAELatpushSolver
(
lip
,
det_generator
,
"test"
)
# solver = VAELatpushSolver(lip, prop_generator, "test")
solver
.
solve
()
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment