Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforming adaptation with normalizing flows #154

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Oct 17, 2024

Experimental new algorithm that uses a normalizing flow instead of a mass matrix.

Set up using pixi:

git clone https://github.com/pymc-devs/nutpie
cd nutpie
git fetch origin pull/154/head:transform
git switch transform

pixi run develop
pixi shell

Gives a shell with an appropriate python setup.

Usage with pymc:

import pymc as pm
import nutpie
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)

with pm.Model() as model:
    log_sd = pm.Normal("log_sd")
    pm.Normal("y", sigma=np.exp(log_sd))

compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace_ = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

Usage with stan:

import pymc as pm
import nutpie
import numpy as np
import jax
import os

os.environ["TBB_CXX_TYPE"] = "clang"
jax.config.update("jax_enable_x64", True)

code = """
parameters {
    real log_sigma;
    real x;
}
model {
    log_sigma ~ normal(0, 1);
    x ~ normal(0, exp(log_sigma));
}
"""


compiled = nutpie.compile_stan_model(code=code)

compiled = (
    compiled
    .with_transform_adapt(
        # Neural network width, default is half the number of model parameters
        nn_width=None,
        # Number of normalizing flow layers
        num_layers=8,
        # Depth of the neural network in each flow layer
        nn_depth=1,
        # Print status update of the optimizer.
        verbose=False,
        # Number of gradients to use in each training phase
        window_size=5000,
        # Learning rate of the optimizer
        learning_rate=1e-3,
        # Print progress bars for the optimization. Very spammy...
        show_progress=False,
        # Number of initial windows with a diagonal mass matrix
        num_diag_windows=10,
    )
)

trace = nutpie.sample(
    compiled,
    transform_adapt=True,
    chains=2,
    tune=1000,
    draws=1000,
    cores=1,
    seed=123,
)

The optimization can be quite expensive computationally (but luckily doen't need any extra gradient evaluations). A GPU is very helpful here. (Jax should pick up a cuda device automatically)

@aseyboldt aseyboldt added help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows labels Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed normalizing-flows Needed for adaptation through normalizing-flows
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant