Skip to content

Commit

Permalink
fixing broken kl and bmd tests
Browse files Browse the repository at this point in the history
  • Loading branch information
htjb committed Nov 8, 2023
1 parent 82bfb77 commit 17dc588
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
11 changes: 7 additions & 4 deletions tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ def test_maf_clustering():
assert_equal(bij.flow[f].mades[i].get_weights(),
loaded_bijector.flow[f].mades[i].get_weights())

def check_stats(i):
if i ==0:
def check_stats(label):
if label == "KL Divergence":
value = samples_kl
assert_allclose(stats[label], value, rtol=1, atol=1)
else:
value = samples_d
assert_allclose(stats['Value'][i], value, rtol=1, atol=1)
assert_allclose(stats[label], value, rtol=1, atol=1)

stats_label = ["KL Divergence", "BMD"]

stats = calculate(bij).statistics()
[check_stats(i) for i in range(2)]
[check_stats(l) for l in stats_label]

equal_weight_theta = samples.compress(100)[names].values
x = bij.sample(len(equal_weight_theta))
Expand Down
22 changes: 14 additions & 8 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ def d_g(logL, weights):


def test_maf():
def check_stats(i):
if i == 0:
def check_stats(label):
if label == "KL Divergence":
value = samples_kl
assert_allclose(stats[label], value, rtol=1, atol=1)
else:
value = samples_d
assert_allclose(stats["Value"][i], value, rtol=1, atol=1)
assert_allclose(stats[label], value, rtol=1, atol=1)

bij = MAF(theta, weights=weights)
bij.train(10000, early_stop=True)

stats_label = ["KL Divergence", "BMD"]

stats = calculate(bij).statistics()
[check_stats(i) for i in range(2)]
[check_stats(l) for l in stats_label]

equal_weight_theta = mcmc_samples.compress(50)[names].values
x = bij.sample(len(equal_weight_theta))
Expand Down Expand Up @@ -124,18 +127,21 @@ def test_maf_save_load():

def test_kde():

def check_stats(i):
if i ==0:
def check_stats(label):
if label == "KL Divergence":
value = samples_kl
assert_allclose(stats[label], value, rtol=1, atol=1)
else:
value = samples_d
assert_allclose(stats['Value'][i], value, rtol=1, atol=1)
assert_allclose(stats[label], value, rtol=1, atol=1)

kde = KDE(theta, weights=weights)
kde.generate_kde()

stats_label = ["KL Divergence", "BMD"]

stats = calculate(kde).statistics()
[check_stats(i) for i in range(2)]
[check_stats(l) for l in stats_label]

equal_weight_theta = mcmc_samples.compress(50)[names].values
x = kde.sample(len(equal_weight_theta))
Expand Down

0 comments on commit 17dc588

Please sign in to comment.