Implement Variational Inference #143
Replies: 1 comment 1 reply
-
VI was briefly discussed during the last meeting, while we dont have any intermediate plan to work on it, design discussion is certainly welcome. I think the first step for implementing VI will be a base class/function that takes multiple meta functions, for example, functions that takes approximation variables (VI variables) and generate callable that computes the variational loss. |
Beta Was this translation helpful? Give feedback.
-
Start by implementing basic transforms/change of variables: affine, exp/log, sigmoid/logistic, rotations. Either through objects that are created with log pdfs, apply transformations and can be called for its log pdf; or through functions that input a log pdf and output that log pdf with the transformation on its parameters + log abs det jac. These can be used to unconstrain variables as a preprocessing step before running HMC and shift & scale normal distributions for mean field approximations.
The vi.py would mimic the kernel implementation of HMC with VIInfo and a kernel generator that performs a step of gradient ascent, using
jax.example_libraries.optimizers
oroptax
, given a log pdf, a reference pdf and a parametrized transformation. Then a subfolder with base.py including VIState and a state generator and loss.py with ELBO implementation and potentially others.Finally implement other transformations such as normalizing flows, polynomial chaos expansions or iterative gaussianization. Also allowing for these transformations to be stacked sequentially, one on top of the other.
Any thoughts?
Beta Was this translation helpful? Give feedback.
All reactions