Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 14, 2020
1 parent a473028 commit 5a7b552
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_trace_report(self, step_cls, discard):
assert isinstance(trace.report.t_sampling, float)
pass

def test_trace_report_bart():
def test_trace_report_bart(self):
X = np.random.normal(0, 1, size=(3, 250)).T
Y = np.random.normal(0, 1, size=250)
X[:, 0] = np.random.normal(Y, 0.1)
Expand All @@ -175,7 +175,7 @@ def test_trace_report_bart():
mu = pm.BART("mu", X, Y, m=20)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
trace = pm.sample(500, tune=100, chains=1, random_seed=3415)
trace = pm.sample(500, tune=100, random_seed=3415)
var_imp = trace.report.variable_importance
assert var_imp[0] > var_imp[1:].sum()
npt.assert_almost_equal(var_imp.sum(), 1)
Expand Down

0 comments on commit 5a7b552

Please sign in to comment.