Commit 527aa26e authored by Manuel Marschall's avatar Manuel Marschall
Browse files

experiments runner works fine, just as paper_plots

parent 86b49a9a
Pipeline #6063 failed with stages
in 13 minutes and 8 seconds
......@@ -45,6 +45,11 @@ conda create -n GenerativeBayes python==3.7
pip install -r requirements.txt
```
## Generative models as neural networks
We provide trained generative models in the form of `onnx` files. Details on the architecture and training can be made available upon request.
The onnx models are internally transformed to `tensorflow` graphs, which is used to perform linear algebra and automatic differentiation.
# Run the code
Simply call
......@@ -53,4 +58,17 @@ Simply call
python datainformed_prior/experiments_runner.py -n 100
```
Depending on your system, this might take a long time. With computations carried out on a Nvidia Tesla V100, the results are available in approximately 24h. Therefore, reduce the number of runs to `-n 3` for a smaller statistic. The final results are given in `paper_plots/`
\ No newline at end of file
You may have to set the pythonpath accordingly: `PYTHONPATH=/path/to/cloned/repo/datainformed_prior`.
Depending on your system, this might take a long time. With computations carried out on a Nvidia Tesla V100, the results are available in approximately 24h. Therefore, reduce the number of runs to `-n 2` for a smaller statistic. On CPU, even `-n 2` takes around an hour, since optimization in latent space takes a huge amount of time because Jacobian computations are expensive.
The final results are given in `paper_plots/`
## Plots
To generate the paper plots, adjust `datainformed_prior/paper_plots.py` to the number of available experiments (-1) in the method `plot_quality()` and run
```
python datainformed_prior/paper_plots.py
```
The inference results from the Appendix are currently commented out, since they are again computational expensive as the solver are called repeteadly for new data with different noise realizations.
\ No newline at end of file
......@@ -312,7 +312,8 @@ def run_experiments(num_experiments):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num-exp", required=False, type=int, default=3)
parser.add_argument("-n", "--num-exp", required=False, type=int, default=2)
args = parser.parse_args()
print(f"start Experiment: {args}")
run_experiments(args.num_exp)
# run_experiments(args.num_exp+1)
run_oneshot()
......@@ -15,109 +15,8 @@ import seaborn as sns; sns.set_theme(style="whitegrid")
available_methods = ["laplace", "thikonov", "gmrf", "mnist_regu", "generic", "generic_encoder", "latpush", "eb_latpush",
"latpush_linear", "eb_qmin"]
def plot_quality_only_etas():
vae_qualities = [1]
plot_vae_qualities = [1]
noise_sigmas = [0.1, 0.01, 0.001, 0.0001]
blur_sigmas = [2, 3, 4, 5] #, 6]
plot_blur_sigmas = [1.5, 2.3, 3.5, 5.5, 8]
num_experiments = 2
methods = ["laplace", "thikonov", "latpush"]
ticks = ["Laplace", "L2", "Latent"]
metrics = ["mse", "psnr", "ssim"]
plot_dict = {} # type: Dict[str, Dict[str, list]]
for metric in metrics:
plot_dict[metric] = {}
for method in methods:
plot_dict[metric][method] = {}
for vae_quality in vae_qualities:
plot_dict[metric][method][vae_quality] = {}
for noise_sigma in noise_sigmas:
plot_dict[metric][method][vae_quality][noise_sigma] = {}
for blur_sigma in blur_sigmas:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma] = {}
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["value"] = list()
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["bias"] = {}
for crossmethod in methods:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["bias"][crossmethod] = list()
for vae_quality in vae_qualities:
for noise_sigma in noise_sigmas:
for blur_sigma in blur_sigmas:
for lia in range(1, num_experiments):
exp_str = f"experiments/{vae_quality}_noise{noise_sigma}_blur{blur_sigma}_genericz0init/"
with open(exp_str + f"run_{lia}/result.json", "r") as fp:
curr_res = json.load(fp)
for metric in metrics:
for method in methods:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["value"].append(curr_res[method][metric])
with open(exp_str + f"run_{lia}/biaseval2/{method}/result.json", "r") as fp_bias:
curr_res_bias = json.load(fp_bias)
for crossmethod in methods:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["bias"][crossmethod].append(curr_res_bias[crossmethod][metric])
for metric in metrics:
for method in methods:
for vae_quality in vae_qualities:
for noise_sigma in noise_sigmas:
for blur_sigma in blur_sigmas:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["value"] = np.array(plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["value"])
for crossmethod in methods:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["bias"][crossmethod] = np.array(plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["bias"][crossmethod])
for eta in blur_sigmas:
fig = plt.figure(figsize=(8, 6))
bp1 = plt.boxplot([plot_dict["psnr"]["thikonov"][1][x][2]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="whitesmoke"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) - 0.25)
bp2 = plt.boxplot([plot_dict["psnr"]["latpush"][1][x][eta]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="lightgray"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) - 0.1)
bp3 = plt.boxplot([plot_dict["psnr"]["laplace"][1][x][eta]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="dimgray"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) + 0.1)
winner_boxes = []
for x in noise_sigmas:
winner_box = []
for lib, (lap_item, lat_item) in enumerate(zip(plot_dict["psnr"]["laplace"][1][x][eta]["bias"]["laplace"],
plot_dict["psnr"]["latpush"][1][x][eta]["bias"]["latpush"])):
if lap_item >= lat_item:
winner_box.append(plot_dict["psnr"]["laplace"][1][x][eta]["value"][lib])
else:
winner_box.append(plot_dict["psnr"]["latpush"][1][x][eta]["value"][lib])
winner_boxes.append(winner_box)
bp4 = plt.boxplot(winner_boxes,
notch=True, patch_artist=True, boxprops=dict(facecolor="black"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) + 0.25)
winner_boxes_cross = []
for x in noise_sigmas:
winner_box_cross = []
for lib, (lap_item, lat_item) in enumerate(zip(plot_dict["psnr"]["laplace"][1][x][eta]["bias"]["latpush"],
plot_dict["psnr"]["latpush"][1][x][eta]["bias"]["laplace"])):
if lap_item >= lat_item:
winner_box_cross.append(plot_dict["psnr"]["laplace"][1][x][eta]["value"][lib])
else:
winner_box_cross.append(plot_dict["psnr"]["latpush"][1][x][eta]["value"][lib])
winner_boxes_cross.append(winner_box_cross)
bp5 = plt.boxplot(winner_boxes_cross,
notch=True, patch_artist=True, boxprops=dict(facecolor="blue"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) + 0.45)
plt.legend([bp["boxes"][0] for bp in [bp1, bp2, bp3, bp4, bp5]], [r"$L^2$", "Latent", "Laplace", "Guide", "Guidecross"], loc='upper left')
# axes[lia].hlines(0, 0.5, len(noise_sigmas)-0.5, colors="black", linestyles="dashed")
plt.gca().set_xlabel(r"observational noise $-\log_{10}\sigma$")
plt.gca().set(xticks=[1, 2, 3, 4], xticklabels=["1", "2", "3", "4"], ylabel="PSNR difference", title="Blurring precision $\eta=2$")
plt.gca().grid(axis="y")
plt.tight_layout()
fig.savefig(f"paper_plots/quality_eta{eta}.png")
def plot_quality_oneshot():
vae_quality = 1
vae_quality = "good"
sigma = 0.01
blur_sigma = 3
experiment = (3359, 13, 30)
......@@ -172,12 +71,12 @@ def plot_quality_oneshot():
fig.savefig(f"paper_plots/oneshot.pdf", dpi=300)
def plot_quality():
vae_qualities = [1]
vae_qualities = ["good"]
plot_vae_qualities = [1]
noise_sigmas = [0.1, 0.01, 0.001, 0.0001]
blur_sigmas = [2, 3, 4, 5] #, 6]
plot_blur_sigmas = [1.5, 2.3, 3.5, 5.5, 8]
num_experiments = 2
num_experiments = 2 # set this to your available number of runs - 1
methods = ["laplace", "thikonov", "latpush"]
......@@ -208,7 +107,7 @@ def plot_quality():
for metric in metrics:
for method in methods:
plot_dict[metric][method][vae_quality][noise_sigma][blur_sigma]["value"].append(curr_res[method][metric])
with open(exp_str + f"run_{lia}/biaseval/{method}/result.json", "r") as fp_bias:
with open(exp_str + f"run_{lia}/biaseval2/{method}/result.json", "r") as fp_bias:
curr_res_bias = json.load(fp_bias)
# no cross evaluation for all experiments available
# for crossmethod in methods:
......@@ -228,18 +127,17 @@ def plot_quality():
fig, axes = plt.subplots(2, 2, figsize=(8, 6))
axes = axes.ravel()
for lia, eta in enumerate(blur_sigmas):
# bp1 = axes[lia].boxplot([plot_dict["psnr"]["laplace"][1][x][eta] - plot_dict["psnr"]["latpush"][1][x][eta] for x in noise_sigmas],
# notch=True, patch_artist=True, boxprops=dict(facecolor="lightgray"), widths=0.25,
# positions=np.array(list(range(1, len(blur_sigmas)+1))) - 0.15)
bp1 = axes[lia].boxplot([plot_dict["psnr"]["thikonov"][1][x][eta]["value"] for x in noise_sigmas],
bp1 = axes[lia].boxplot([plot_dict["psnr"]["thikonov"]["good"][x][eta]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="whitesmoke"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) - 0.25)
bp2 = axes[lia].boxplot([plot_dict["psnr"]["latpush"][1][x][eta]["value"] for x in noise_sigmas],
bp2 = axes[lia].boxplot([plot_dict["psnr"]["latpush"]["good"][x][eta]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="lightgray"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) - 0.1)
bp3 = axes[lia].boxplot([plot_dict["psnr"]["laplace"][1][x][eta]["value"] for x in noise_sigmas],
bp3 = axes[lia].boxplot([plot_dict["psnr"]["laplace"]["good"][x][eta]["value"] for x in noise_sigmas],
notch=True, patch_artist=True, boxprops=dict(facecolor="dimgray"), widths=0.15,
positions=np.array(list(range(1, len(noise_sigmas)+1))) + 0.1)
winner_boxes = []
......@@ -248,15 +146,15 @@ def plot_quality():
lat_wins = 0
lap_wins = 0
num_wins = 0
for lib, (lap_item, lat_item) in enumerate(zip(plot_dict["psnr"]["laplace"][1][x][eta]["bias"]["laplace"],
plot_dict["psnr"]["latpush"][1][x][eta]["bias"]["latpush"])):
for lib, (lap_item, lat_item) in enumerate(zip(plot_dict["psnr"]["laplace"]["good"][x][eta]["bias"]["laplace"],
plot_dict["psnr"]["latpush"]["good"][x][eta]["bias"]["latpush"])):
num_wins += 1
if lap_item >= lat_item:
lap_wins += 1
winner_box.append(plot_dict["psnr"]["laplace"][1][x][eta]["value"][lib])
winner_box.append(plot_dict["psnr"]["laplace"]["good"][x][eta]["value"][lib])
else:
lat_wins += 1
winner_box.append(plot_dict["psnr"]["latpush"][1][x][eta]["value"][lib])
winner_box.append(plot_dict["psnr"]["latpush"]["good"][x][eta]["value"][lib])
print(f"eta: {eta}, sigma: {x}, Laplace wins: {lap_wins}/{num_wins} Latent wins: {lat_wins}/{num_wins}")
winner_boxes.append(winner_box)
bp4 = axes[lia].boxplot(winner_boxes,
......@@ -336,7 +234,7 @@ def plot_priors():
plt.savefig("paper_plots/prior_mnist.pdf", dpi=300)
from experiments_runner import get_generator
gen = get_generator("good")
gen = get_generator()
np.random.seed(4)
z = np.random.randn(5, 20)*1.5
x, _ = gen(z)
......@@ -421,7 +319,7 @@ def plot_inference_results():
sigmas = [0.1, 0.01, 0.001, 0.0001]
repetitions = 2
chosen_image_idx = 8
prob_generator = get_generator("good")
prob_generator = get_generator()
A = blur_matrix_operator(_sigma=eta).real
_, eval_dataset = load_mnist(shuffle=False)
......@@ -522,9 +420,8 @@ def plot_inference_results():
if __name__ == "__main__":
plot_quality_oneshot()
plot_inference_results()
# # compare_covariances()
# compare_covariances()
plot_blurring_and_noise()
plot_priors()
plot_quality()
# # plot_quality_only_etas()
# plot_inference_results()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment