Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues in pathfinder wrapper #238

Merged
merged 1 commit into from
Nov 1, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 11, 2023

Closes #97
Closes #250

We were doing senseless things, like drawing from an isometric MvNormal with mean two as the initial point of the VI fit and then discarding such fit? This killed my poor RAM :)

Also I am not sure what the inference_loop was doing. I am not familiar with the internals of blackjax (and the documentation references a version that hasn't yet been released), but I think we just want to do init -> sample? Can someone confirm this (@junpenglao @zaxtax?)

Pathfinder doesn't seem to work particularly well with the 8 school example. It's very sensitive to the initial value of mu, and the default converges to something that looks wrong, see gist. Because of this I don't know what better thing to assert in the test.

Also the pymc-examples is broken (it is not using mu nor tau :P): https://www.pymc.io/projects/examples/en/latest/variational_inference/pathfinder.html

@twiecki
Copy link
Member

twiecki commented Sep 11, 2023

Ouch, well these all sounds like fixes to me.

@junpenglao
Copy link
Member

I think it is better to depends on blackjax-nightly (which would soon be release as blackjax v1), and use

pathfinder = blackjax.kernels.pathfinder(logdensity_fn)
state, _ = pathfinder.approximate(rng_key, ...)
_, rng_key = random.split(rng_key)
samples, _ = pathfinder.sample(rng_key, state, 5_000)

per https://blackjax-devs.github.io/sampling-book/algorithms/pathfinder.html

@ricardoV94
Copy link
Member Author

I think it is better to depends on blackjax-nightly (which would soon be release as blackjax v1), and use

How soon is that soon? Would it be too much of a PITA to get another pre v1 release if that's faster?

@junpenglao
Copy link
Member

We just release Blackjax v1

@ricardoV94 ricardoV94 force-pushed the pathfinder_fixes branch 3 times, most recently from 2cb7f7f to 0301d0d Compare September 21, 2023 14:04
@twiecki
Copy link
Member

twiecki commented Oct 8, 2023

Can we merge?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 8, 2023

I would like to have a model where this actually works as a test case, but since we are in experimental maybe that's fine.

@ricardoV94
Copy link
Member Author

Needs to be rebased for tests to pass

@ricardoV94 ricardoV94 force-pushed the pathfinder_fixes branch 2 times, most recently from b86e0ff to c8c601d Compare November 1, 2023 12:28
@ricardoV94 ricardoV94 merged commit ebba64e into pymc-devs:main Nov 1, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

pm.fit("pathfinder") raises AttributeError` Kernel crashes during pathfinder fitting
3 participants