From 29cd996a2af8172b34032eaed24a16a58abf91ca Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 11 Oct 2024 10:04:55 -0500 Subject: [PATCH] Changed backend arg defaults to strings --- python/nutpie/compile_pymc.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 621ddcd..0679ed2 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -354,8 +354,8 @@ def expand(x, **shared): def compile_pymc_model( model: "pm.Model", *, - backend: Literal["numba", "jax"] | None = None, - gradient_backend: Literal["pytensor", "jax"] | None = None, + backend: Literal["numba", "jax"] = "numba", + gradient_backend: Literal["pytensor", "jax"] = "pytensor", **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -384,10 +384,9 @@ def compile_pymc_model( "and restart your kernel in case you are in an interactive session." ) - if backend is None: - backend = "numba" - if backend.lower() == "numba": + if gradient_backend == "jax": + raise ValueError("Gradient backend cannot be jax when using numba backend") return _compile_pymc_model_numba(model, **kwargs) elif backend.lower() == "jax": return _compile_pymc_model_jax(