Conditional Flow Matching (CFM): a simulation-free training objective for continuous normalizing flows. We explore a few different flow matching variants and ODE solvers on a simple dataset. This repo was inspired and adapted by the awesome work in TorchCFM and Torchdyn.
Training: consider a smooth time-varying vector field
However, the target distributions
where
such that
where
then we have independent conditional flow matching (Tong et al. 2023) with the resulting conditional probability path and vector field
Alternatively, the variance-preserving stochastic interpolant (Albergo & Vanden-Eijnden 2023) has the form
Sampling: now that we have our vector field, we can sample from our prior
for
Run with default params and save the result in media/*.png
:
python main.py --method vp --solver dopri5
main.py
: training and samplingmodels.py
: neural net definitiondatasets.py
: generate prior and target datacfm.py
: flow matching variantsodeint.py
: adaptive and fixed numerical integratorssolver.py
: solver definition for integrator
Install the dependencies (optimized for Apple silicon; yay for MLX!):
pip install -r requirements.txt