Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding StatsBase.predict to the API #466

Open
sethaxen opened this issue Feb 20, 2023 · 7 comments
Open

Adding StatsBase.predict to the API #466

sethaxen opened this issue Feb 20, 2023 · 7 comments

Comments

@sethaxen
Copy link
Member

In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).

It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand for a conditioned model:

StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)
@devmotion
Copy link
Member

Maybe this could even be part of AbstractPPL and be defined on AbstractPPL.AbstractProbabilisticProgram: condition is part of its API, only rand is not clearly specified there yet (probably should be done anyway).

@sethaxen
Copy link
Member Author

Yeah, makes sense.

@torfjelde
Copy link
Member

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

And regarding adding to APPL; we need to propagate that change back to v0.5 too then, because v0.6 is currently not compatible with DPPL (see #440).

@sethaxen
Copy link
Member Author

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

@torfjelde
Copy link
Member

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

For maximal model-compat, yes. But you do of course take a performance hit as a result 😕

@sethaxen
Copy link
Member Author

Hrm. Maybe then predict should use a NamedTuple if x is a NamedTuple (imperfect because you can have few parameters but many data points). Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

@devmotion
Copy link
Member

Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

Adding T to predict (with some default) would be in line with our API for rand though - there type T can be specified already.

bors bot pushed a commit to TuringLang/AbstractPPL.jl that referenced this issue Feb 25, 2023
This PR adds a 3-arg form of `rand` (suggested by @devmotion in TuringLang/DynamicPPL.jl#466 (comment)) to the interface for `AbstractProbabilisticProgram` and implements the default 1- and 2-arg methods that dispatch to this.

Currently tests fail because this breaks the fallbacks for `GraphPPL.Model`, which expects `rand` to forward to its `rand!` method. I'm not certain how we want to define the interface for this `Model`.

Co-authored-by: Xianda Sun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants