Skip to content

Commit

Permalink
Use value_vars directly
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Nov 15, 2021
1 parent 13211ab commit 5ead708
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def sample_numpyro_nuts(
var_names = model.unobserved_value_vars

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
inputs = [model.rvs_to_values[i] for i in model.free_RVs]

tic1 = pd.Timestamp.now()
print("Compiling...", file=sys.stdout)
Expand Down Expand Up @@ -164,7 +163,7 @@ def sample_numpyro_nuts(
print("Transforming variables...", file=sys.stdout)
mcmc_samples = {}
for v in vars_to_sample:
fgraph = FunctionGraph(inputs, [v], clone=False)
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
mcmc_samples[v.name] = result
Expand Down

0 comments on commit 5ead708

Please sign in to comment.