Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find_MAP with close JAX integration and fix bug with Laplace fit #385

Merged
merged 21 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6aa20f7
Add JAX-based `find_MAP`
jessegrabowski Oct 27, 2024
7ed3b2f
add `better_optimize` to CI envs
jessegrabowski Oct 27, 2024
e412f6f
Fix relative import
jessegrabowski Oct 27, 2024
f9b6258
Remove `find_MAP` import from module-level `__init__.py`
jessegrabowski Oct 27, 2024
ad3abd9
Update docstring
jessegrabowski Oct 27, 2024
be1d790
Allow calling `find_MAP` inside model context without model argument
jessegrabowski Oct 27, 2024
923eb26
Required patched better_optimize
jessegrabowski Oct 27, 2024
f705d43
in-progress refactor
jessegrabowski Nov 30, 2024
a23762b
More refactor
jessegrabowski Dec 3, 2024
2d21403
Generalize code to use any pytensor backend
jessegrabowski Dec 3, 2024
4c2529d
Reconcile the two laplace approximation functions
jessegrabowski Dec 3, 2024
07ebe40
Use absolute import in doctest
jessegrabowski Dec 3, 2024
b40e101
Fix imports
jessegrabowski Dec 4, 2024
bc340c2
Fix unrelated statespace test
jessegrabowski Dec 4, 2024
da338bf
- Rename argument `use_jax_gradients` -> `gradient_backend`
jessegrabowski Dec 4, 2024
3ebbf20
Fix typo introduced by rename refactor
jessegrabowski Dec 4, 2024
2035202
use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP opti…
jessegrabowski Dec 4, 2024
f2504e9
Rename `test_jax_find_map.py` -> `test_find_map.py`
jessegrabowski Dec 4, 2024
a81079b
Improve docstring for `fit_laplace`
jessegrabowski Dec 4, 2024
4d88343
Update tests to match new signature
jessegrabowski Dec 4, 2024
9b1cd0e
Update docstring
jessegrabowski Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions pymc_experimental/inference/find_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from collections.abc import Callable
from typing import cast
from typing import Literal, cast, get_args

import jax
import numpy as np
Expand All @@ -17,11 +17,15 @@
from pymc.pytensorf import join_nonshared_inputs
from pymc.util import get_default_varnames
from pytensor.compile import Function
from pytensor.compile.mode import Mode
from pytensor.tensor import TensorVariable
from scipy.optimize import OptimizeResult

_log = logging.getLogger(__name__)

GradientBackend = Literal["pytensor", "jax"]
VALID_BACKENDS = get_args(GradientBackend)


def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
method_info = MINIMIZE_MODE_KWARGS[method].copy()
Expand Down Expand Up @@ -85,7 +89,11 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,

out.append(untransformed_X)

f_untransform = pytensor.function([X], out, mode="JAX")
f_untransform = pytensor.function(
inputs=[pytensor.In(X, borrow=True)],
outputs=pytensor.Out(out, borrow=True),
mode=Mode(linker="py", optimizer=None),
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
)
return f_untransform(posterior_draws)


Expand Down Expand Up @@ -209,7 +217,7 @@ def scipy_optimize_funcs_from_loss(
use_grad: bool,
use_hess: bool,
use_hessp: bool,
use_jax_gradients: bool = False,
gradient_backend: GradientBackend = "pytensor",
compile_kwargs: dict | None = None,
) -> tuple[Callable, ...]:
"""
Expand All @@ -230,8 +238,8 @@ def scipy_optimize_funcs_from_loss(
Whether to compile a function that computes the Hessian of the loss function.
use_hessp: bool
Whether to compile a function that computes the Hessian-vector product of the loss function.
use_jax_gradients: bool
If True, use JAX to compute gradients. This is only possible when ``compile_kwargs["mode"]`` is set to "JAX".
gradient_backend: str, one of "jax" or "pytensor"
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
Which backend to use to compute gradients.
compile_kwargs:
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.

Expand All @@ -252,7 +260,12 @@ def scipy_optimize_funcs_from_loss(
"Cannot compute hessian or hessian-vector product without also computing the gradient"
)

use_jax_gradients = use_jax_gradients and use_grad
if gradient_backend not in VALID_BACKENDS:
raise ValueError(
f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
)

use_jax_gradients = (gradient_backend == "jax") and use_grad

mode = compile_kwargs.get("mode", None)
if mode is None and use_jax_gradients:
Expand Down Expand Up @@ -307,7 +320,7 @@ def find_MAP(
jitter_rvs: list[TensorVariable] | None = None,
progressbar: bool = True,
include_transformed: bool = True,
use_jax_gradients: bool = False,
gradient_backend: GradientBackend = "pytensor",
compile_kwargs: dict | None = None,
**optimizer_kwargs,
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
Expand Down Expand Up @@ -342,6 +355,10 @@ def find_MAP(
Whether to display a progress bar during optimization. Defaults to True.
include_transformed: bool, optional
Whether to include transformed variable values in the returned dictionary. Defaults to True.
gradient_backend: str, default "pytensor"
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
compile_kwargs: dict, optional
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
**optimizer_kwargs
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.

Expand Down Expand Up @@ -380,7 +397,7 @@ def find_MAP(
use_grad=use_grad,
use_hess=use_hess,
use_hessp=use_hessp,
use_jax_gradients=use_jax_gradients,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
)

Expand Down
23 changes: 14 additions & 9 deletions pymc_experimental/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from scipy import stats

from pymc_experimental.inference.find_map import (
GradientBackend,
_unconstrained_vector_to_constrained_rvs,
find_MAP,
get_nearest_psd,
Expand Down Expand Up @@ -235,7 +236,7 @@ def fit_mvn_to_MAP(
model: pm.Model | None = None,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
transform_samples: bool = False,
use_jax_gradients: bool = False,
gradient_backend: GradientBackend = "pytensor",
zero_tol: float = 1e-8,
diag_jitter: float | None = 1e-8,
compile_kwargs: dict | None = None,
Expand All @@ -256,12 +257,16 @@ def fit_mvn_to_MAP(
If 'error', an error will be raised.
transform_samples : bool
Whether to transform the samples back to the original parameter space. Default is True.
gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
zero_tol: float
Value below which an element of the Hessian matrix is counted as 0.
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
diag_jitter: float | None
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
If None, no jitter is added. Default is 1e-8.
compile_kwargs: dict, optional
Additional keyword arguments to pass to pytensor.function when compiling loss functions

Returns
-------
Expand Down Expand Up @@ -294,7 +299,7 @@ def fit_mvn_to_MAP(
use_grad=True,
use_hess=True,
use_hessp=False,
use_jax_gradients=use_jax_gradients,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
)

Expand Down Expand Up @@ -323,7 +328,7 @@ def stabilize(x, jitter):
return mu, H_inv


def laplace(
def sample_laplace_posterior(
mu: RaveledVars,
H_inv: np.ndarray,
model: pm.Model | None = None,
Expand Down Expand Up @@ -416,7 +421,7 @@ def fit_laplace(
jitter_rvs: list[pt.TensorVariable] | None = None,
progressbar: bool = True,
include_transformed: bool = True,
use_jax_gradients: bool = False,
gradient_backend: GradientBackend = "pytensor",
chains: int = 2,
draws: int = 500,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
Expand Down Expand Up @@ -461,8 +466,8 @@ def fit_laplace(
Whether to display a progress bar during optimization. Defaults to True.
include_transformed: bool, optional
Whether to include transformed variable values in the returned dictionary. Defaults to True.
use_jax_gradients: bool, optional
Whether to use JAX for gradient calculations. Defaults to False.
gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
chains: int, default: 2
The number of sampling chains running in parallel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably add something here reiterating that this isn't a sampling inference method. This is just sampling from the approximated posterior. There was already people in the forum asking about the differences in these methods

draws: int, default: 500
Expand All @@ -489,7 +494,7 @@ def fit_laplace(

Examples
--------
>>> from pymc_experimental.inference.laplace import fit_laplace
>>> from pymc_experimental.inference.sample_laplace_posterior import fit_laplace
>>> import numpy as np
>>> import pymc as pm
>>> import arviz as az
Expand Down Expand Up @@ -526,7 +531,7 @@ def fit_laplace(
jitter_rvs=jitter_rvs,
progressbar=progressbar,
include_transformed=include_transformed,
use_jax_gradients=use_jax_gradients,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
**optimizer_kwargs,
)
Expand All @@ -541,7 +546,7 @@ def fit_laplace(
compile_kwargs=compile_kwargs,
)

return laplace(
return sample_laplace_posterior(
mu=mu,
H_inv=H_inv,
model=model,
Expand Down
13 changes: 7 additions & 6 deletions tests/test_jax_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from pymc_experimental.inference.find_map import (
GradientBackend,
find_MAP,
scipy_optimize_funcs_from_loss,
)
Expand All @@ -17,8 +18,8 @@ def rng():
return np.random.default_rng(seed)


@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"])
def test_jax_functions_from_graph(use_jax_gradients):
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
def test_jax_functions_from_graph(gradient_backend: GradientBackend):
x = pt.tensor("x", shape=(2,))

def compute_z(x):
Expand All @@ -34,7 +35,7 @@ def compute_z(x):
use_grad=True,
use_hess=True,
use_hessp=True,
use_jax_gradients=use_jax_gradients,
gradient_backend=gradient_backend,
compile_kwargs=dict(mode="JAX"),
)

Expand Down Expand Up @@ -69,8 +70,8 @@ def compute_z(x):
("trust-constr", True, True),
],
)
@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"])
def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng):
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
extra_kwargs = {}
if method == "dogleg":
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
Expand All @@ -88,7 +89,7 @@ def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng):
use_grad=use_grad,
use_hess=use_hess,
progressbar=False,
use_jax_gradients=use_jax_gradients,
gradient_backend=gradient_backend,
compile_kwargs={"mode": "JAX"},
)
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
Expand Down
16 changes: 11 additions & 5 deletions tests/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import pymc_experimental as pmx

from pymc_experimental.inference.find_map import find_MAP
from pymc_experimental.inference.laplace import fit_laplace, fit_mvn_to_MAP, laplace
from pymc_experimental.inference.laplace import (
fit_laplace,
fit_mvn_to_MAP,
sample_laplace_posterior,
)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -86,7 +90,7 @@ def test_laplace_only_fit():
method="laplace",
optimize_method="BFGS",
progressbar=True,
use_jax_gradients=True,
gradient_backend="jax",
compile_kwargs={"mode": "JAX"},
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
random_seed=173300,
Expand Down Expand Up @@ -127,7 +131,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
use_hessp=True,
progressbar=False,
compile_kwargs=dict(mode=mode),
use_jax_gradients=mode == "JAX",
gradient_backend="jax" if mode == "JAX" else "pytensor",
)

for value in optimized_point.values():
Expand All @@ -139,7 +143,9 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
transform_samples=transform_samples,
)

idata = laplace(mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples)
idata = sample_laplace_posterior(
mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples
)

np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5)
np.testing.assert_allclose(
Expand Down Expand Up @@ -182,7 +188,7 @@ def test_fit_laplace_ragged_coords(rng):
progressbar=False,
use_grad=True,
use_hessp=True,
use_jax_gradients=True,
gradient_backend="jax",
compile_kwargs={"mode": "JAX"},
)

Expand Down
Loading