From 17dc588830bded229b3591dc591b54d08696ceed Mon Sep 17 00:00:00 2001 From: htjb Date: Wed, 8 Nov 2023 16:37:46 +0000 Subject: [PATCH] fixing broken kl and bmd tests --- tests/test_clusters.py | 11 +++++++---- tests/test_stats.py | 22 ++++++++++++++-------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 8bea60e..ed69671 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -41,15 +41,18 @@ def test_maf_clustering(): assert_equal(bij.flow[f].mades[i].get_weights(), loaded_bijector.flow[f].mades[i].get_weights()) - def check_stats(i): - if i ==0: + def check_stats(label): + if label == "KL Divergence": value = samples_kl + assert_allclose(stats[label], value, rtol=1, atol=1) else: value = samples_d - assert_allclose(stats['Value'][i], value, rtol=1, atol=1) + assert_allclose(stats[label], value, rtol=1, atol=1) + + stats_label = ["KL Divergence", "BMD"] stats = calculate(bij).statistics() - [check_stats(i) for i in range(2)] + [check_stats(l) for l in stats_label] equal_weight_theta = samples.compress(100)[names].values x = bij.sample(len(equal_weight_theta)) diff --git a/tests/test_stats.py b/tests/test_stats.py index 6a5931a..f2512ab 100755 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -50,18 +50,21 @@ def d_g(logL, weights): def test_maf(): - def check_stats(i): - if i == 0: + def check_stats(label): + if label == "KL Divergence": value = samples_kl + assert_allclose(stats[label], value, rtol=1, atol=1) else: value = samples_d - assert_allclose(stats["Value"][i], value, rtol=1, atol=1) + assert_allclose(stats[label], value, rtol=1, atol=1) bij = MAF(theta, weights=weights) bij.train(10000, early_stop=True) + stats_label = ["KL Divergence", "BMD"] + stats = calculate(bij).statistics() - [check_stats(i) for i in range(2)] + [check_stats(l) for l in stats_label] equal_weight_theta = mcmc_samples.compress(50)[names].values x = bij.sample(len(equal_weight_theta)) @@ -124,18 +127,21 @@ def test_maf_save_load(): def test_kde(): - def check_stats(i): - if i ==0: + def check_stats(label): + if label == "KL Divergence": value = samples_kl + assert_allclose(stats[label], value, rtol=1, atol=1) else: value = samples_d - assert_allclose(stats['Value'][i], value, rtol=1, atol=1) + assert_allclose(stats[label], value, rtol=1, atol=1) kde = KDE(theta, weights=weights) kde.generate_kde() + stats_label = ["KL Divergence", "BMD"] + stats = calculate(kde).statistics() - [check_stats(i) for i in range(2)] + [check_stats(l) for l in stats_label] equal_weight_theta = mcmc_samples.compress(50)[names].values x = kde.sample(len(equal_weight_theta))