-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
restructuring neural models + addition of OT-FM and GENOT #468
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #468 +/- ##
==========================================
+ Coverage 90.64% 90.71% +0.07%
==========================================
Files 60 68 +8
Lines 6682 7046 +364
Branches 956 996 +40
==========================================
+ Hits 6057 6392 +335
- Misses 477 494 +17
- Partials 148 160 +12
|
co-authored by @lucaeyring |
BaseNeuralSolver
and UnbalancednessMixin
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
src/ott/neural/flow_models/flows.py
Outdated
return jnp.full_like(t, fill_value=self.sigma) | ||
|
||
|
||
class BrownianNoiseFlow(StraightFlow): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"BrownianNoiseFlow" is ambiguous to me, it could correspond to Brownian motion and associated flow e.g. VE Flow in eqn 16 of https://arxiv.org/pdf/2210.02747.pdf
TODOs left:
|
tests/neural/conftest.py
Outdated
|
||
|
||
@pytest.fixture() | ||
def lin_cond_dl() -> DataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a "conditional OT data loader".
a0fed3c
to
7b61e05
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MUCDK and everyone, LGTM, merging this!
* draft of BaseSolver and UnbalancedMixin * draft of BaseSolver and UnbalancedMixin * [ci skip] continue flow matching implementation * [ci skip] continue flow matching implementation * [ci skip] add neural networks * [ci skip] add test * [ci skip] resolve import errors * [ci skip] MRO not working * [ci skip] basic test for flow matching passes * [ci skip] add tests for FM with conditions and conditional OT with FM * [ci skip] add genot outline * [ci skip] restructure genot * [ci skip] restructure genot * [ci skip] fix transport * [ci skip] flow matching tests passing * [ci skip] add more tests genot * [ci skip] add more tests genot * [ci skip] add TimeSampler * [ci skip] add docs for TimeSampler and Flow * [ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.Array * [ci skip] change init arguments of GENOT and add docstrings to GENOT * [ci skip] split nets into base_models and models * [ci skip] add references * add tests for learning the rescaling factors * [ci skip] partially fix rescaling factor learning * [ci skip] fix rescaling factor learning * [ci skip] all tests passing but k_samples_per_x in genot * k_samples_per_x working in GENOT * [ci skip] changed dataloaders to numpy and dict return * [ci skip] changed dataloaders to numpy and dict return * revert jax.Array to jnp.ndarray * move dataloader from tests to module * add docstrings to neurcal networks * [ci skip] adapt type of scale_cost and cost_fn * [ci skip] clean code * [ci skip] fix genot tests * [ci skip] fix otfm tests * [ci skip] fix otfm tests * add scale cost to otfm * incorporate feedback partially * resolve circular import errors * resolve a few pre-commit errors * resolve pre-commit errors * resolve pre-commit errors * fix rng bug * Update pre-commit * fix import error * Run linter * replace rng jnp.ndarray type by jax.array * replace rng jnp.ndarray type by jax.array * fix import error * [ci skip] start to incorporate feedback * restructure neural module * fix import errors * incorporate feedback partially * make time encoder a layer * make conditions Optional and minor feedback * revert faulty jax.array / jnp.ndarray conversions * make formatting in neural nets nicer * add description to Velocity Field * replace time sampler class by function * add citations * add more references * rename keys_model to rng * fix tests regarding time sampling * fix typo in tests * rename neural_vector_field to velocity_field everywhere * fix OTFlowMatching.transport * fix rescaling networks * Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <[email protected]> * Update src/ott/neural/flows/flows.py Co-authored-by: nvesseron <[email protected]> * test for scale_cost * update test for scale_cost * fix bug for scale_cost * fix bug for scale_cost * jit solve_ode in genot * incorporate changes partially * [ci skip] intermediate save * [ci skip] neural base solver update * make resamlpemixin a class * incorporate more changes * move noise sampling to flows * fix bug in passing rngs in otfm * introduce otmatcher in otfm * [ci skip] split GENOT into GENOTLin and GENOTQuad * remove dictionaries in OTFM and GENOT classes * change logic in match_latent_to_data in genot * change data loaders / data sets * finish data loader refactoring * Update linter * fix bug in _resample_data` * incorporate more changes * add docs * incorporate more changes * problem with custom type * fix scale cost bug * fix bugs * fux bug in unbalancedness/rescalingMlp * unify unbalancedness step in GENOT * change OTDataSet and OTFlowMatching to 4 data loaderes * Fix bug in the `ConditionalOTDataset` * Polish docs in the `flows.py` * Update `OTFM` * Fix small bugs in `OTFM` * Polish layers * Fix typo in citation * More polish for the docs * remove print statements and unbalancednesshandler * remove tests * make genot training loops more similar to otfm training loop * adapt tests to the extent possible * Add weights to sampling * Start cleaning matchers * Add conditional sampling + resampling * Add initial quad matcher * Improve typing * Remove `base_solver.py` * Add TODO * Update datasets, fix OTFM tests * Start cleaning GENOT * Update GENOT * Remove old GENOTLin/GENOTQuad * Remove axis swapping * Remove old todo * Fix OTFM tests * Remove `MLPBlock` and `RescalingMLP` * Add forgotten license * Remove `__post_init__` from `VF` * Move cyclical time encoder * Move more stuff to `utils` * Remove `samplers.py` * Rename `cond_dim` -> `condition_dim` * Nicer formatting * Fix bug when sampling from the target * Fix another bug when sampling from the data * Add initial test for GW * Remove old GENOT tests * Remove old dataloaders * Add more todos * add docs to dataloader * expose args in GENOT * add docs and adapt data_match_fn * fix linting * fix data loading and add genot fused tests * genot tests passing * adapt docs * adapt docs * add error message * clean docs * comprise genot tests * change reference for GENOT * add missing docstring * Modify behaviour of `ConditionalLoader` * Update docstring * Clean GENOT docs * Improve VF * Simplify GENOT test * Better metadata wrapper in tests * Fix condition in GENOT test * Add quad cond dl * Add conf fused DL * Polish docs * Remove conditional loader * Fix link in the docs * Improve VF * Fix GENOT test * Polish docs * Remove `uniform_marginals` argument * Fix undefined variable * Update `GENOT.transport` docs * Add `diffrax` to `conf.py` * Restructure files * Fix neural init tests import * Update `docs/` * Update Monge Gap * Update MetaOT and NeuralDual * Update ICNN inits * Fix links to neural in the docs * Check for condition dim in VF * Don't use activation fn in the last layer of VF * Update assertions * Try skipping OTFM/GENOT tests temporarily * Be extra verbose when intalling packages * Remove `torch` dependency * Remove `torch` from tests in `pyproject.toml` * [ci skip] Update docstrings --------- Co-authored-by: lucaeyring <[email protected]> Co-authored-by: Michal Klein <[email protected]> Co-authored-by: nvesseron <[email protected]> Co-authored-by: Dominik Klein <[email protected]>
This is a PR for
OTFlowMatching
and, related to this, classes for flows and time samplersGENOT
(with extension to conditional GENOT)Following this PR, the implementations of ICNN-based solvers and the Monge Gap model should be adapted and extended to the unbalanced setting.
Moreover, wrt typing, I replacedjnp.ndarray
byjax.Array
What remains to be done, but I would prefer to do in a separate PR
OTFM
andGENOT
, i.e. functions which compute batch-wise graphs, and compute costs, e.g. geodesic Sinkhorn or convolutional Wasserstein from this.