diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 5d318a33d05..6fdbd64e730 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -114,7 +114,6 @@ def sample_numpyro_nuts( var_names = model.unobserved_value_vars vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) - inputs = [model.rvs_to_values[i] for i in model.free_RVs] tic1 = pd.Timestamp.now() print("Compiling...", file=sys.stdout) @@ -164,7 +163,7 @@ def sample_numpyro_nuts( print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: - fgraph = FunctionGraph(inputs, [v], clone=False) + fgraph = FunctionGraph(model.value_vars, [v], clone=False) jax_fn = jax_funcify(fgraph) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result