Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restructuring neural models + addition of OT-FM and GENOT #468

Merged
merged 193 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
193 commits
Select commit Hold shift + click to select a range
0f91857
draft of BaseSolver and UnbalancedMixin
MUCDK Nov 22, 2023
3706970
draft of BaseSolver and UnbalancedMixin
MUCDK Nov 22, 2023
42dd2b8
[ci skip] continue flow matching implementation
MUCDK Nov 22, 2023
6583281
[ci skip] continue flow matching implementation
MUCDK Nov 22, 2023
f5a043c
[ci skip] add neural networks
MUCDK Nov 22, 2023
34bb10f
[ci skip] add test
MUCDK Nov 22, 2023
374d051
[ci skip] resolve import errors
MUCDK Nov 22, 2023
a9e9a8c
[ci skip] MRO not working
MUCDK Nov 22, 2023
e4f8991
[ci skip] basic test for flow matching passes
MUCDK Nov 23, 2023
7869e37
[ci skip] add tests for FM with conditions and conditional OT with FM
MUCDK Nov 23, 2023
5a90dc1
[ci skip] add genot outline
MUCDK Nov 23, 2023
c843758
[ci skip] restructure genot
MUCDK Nov 23, 2023
ef86c54
[ci skip] restructure genot
MUCDK Nov 24, 2023
70a6173
[ci skip] fix transport
MUCDK Nov 24, 2023
40570e6
[ci skip] flow matching tests passing
MUCDK Nov 24, 2023
b0910ea
[ci skip] add more tests genot
MUCDK Nov 24, 2023
542f512
[ci skip] add more tests genot
MUCDK Nov 24, 2023
c067f45
[ci skip] add TimeSampler
MUCDK Nov 26, 2023
2546afc
[ci skip] add docs for TimeSampler and Flow
MUCDK Nov 26, 2023
579852f
[ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.…
MUCDK Nov 26, 2023
b075758
[ci skip] change init arguments of GENOT and add docstrings to GENOT
MUCDK Nov 26, 2023
95e8707
[ci skip] split nets into base_models and models
MUCDK Nov 26, 2023
3b1791d
[ci skip] add references
MUCDK Nov 26, 2023
eca77c0
add tests for learning the rescaling factors
MUCDK Nov 26, 2023
62b2666
[ci skip] partially fix rescaling factor learning
MUCDK Nov 26, 2023
2ceceea
[ci skip] fix rescaling factor learning
MUCDK Nov 26, 2023
e8f8171
[ci skip] all tests passing but k_samples_per_x in genot
MUCDK Nov 27, 2023
add1348
k_samples_per_x working in GENOT
MUCDK Nov 27, 2023
993d1de
[ci skip] changed dataloaders to numpy and dict return
lucaeyring Nov 27, 2023
beee22d
[ci skip] changed dataloaders to numpy and dict return
lucaeyring Nov 27, 2023
f26e072
revert jax.Array to jnp.ndarray
MUCDK Nov 28, 2023
8fa3683
move dataloader from tests to module
MUCDK Nov 28, 2023
2e2f9f3
add docstrings to neurcal networks
MUCDK Nov 28, 2023
8c71deb
[ci skip] adapt type of scale_cost and cost_fn
MUCDK Nov 28, 2023
a25b6c2
[ci skip] clean code
MUCDK Nov 28, 2023
75437db
[ci skip] fix genot tests
MUCDK Nov 28, 2023
bfcfcbd
[ci skip] fix otfm tests
MUCDK Nov 28, 2023
f27bc22
[ci skip] fix otfm tests
MUCDK Nov 28, 2023
612c2f1
merge main
MUCDK Nov 28, 2023
384e8fc
add scale cost to otfm
MUCDK Nov 28, 2023
ef204e6
incorporate feedback partially
MUCDK Nov 29, 2023
8894ce2
merge main
MUCDK Nov 29, 2023
2b1ab92
resolve circular import errors
MUCDK Nov 29, 2023
e1be6ca
resolve a few pre-commit errors
MUCDK Nov 29, 2023
a307bf8
resolve pre-commit errors
MUCDK Nov 29, 2023
ffec70c
resolve pre-commit errors
MUCDK Nov 29, 2023
10d70f2
fix rng bug
MUCDK Nov 29, 2023
9fb308b
Update pre-commit
michalk8 Nov 29, 2023
aa0bdc5
fix import error
MUCDK Nov 29, 2023
b48dfdc
Run linter
michalk8 Nov 29, 2023
be94668
Merge branch 'draft/neural_base_solver' of ssh://github.com/MUCDK/ott…
michalk8 Nov 29, 2023
4371e74
replace rng jnp.ndarray type by jax.array
MUCDK Nov 29, 2023
2bc683a
replace rng jnp.ndarray type by jax.array
MUCDK Nov 29, 2023
542dd0a
fix import error
MUCDK Nov 29, 2023
f585c24
[ci skip] start to incorporate feedback
MUCDK Dec 1, 2023
3c07009
restructure neural module
MUCDK Dec 4, 2023
8f404f8
fix import errors
MUCDK Dec 4, 2023
0b81135
incorporate feedback partially
MUCDK Dec 5, 2023
fccdeef
make time encoder a layer
MUCDK Dec 5, 2023
2a279c1
make conditions Optional and minor feedback
MUCDK Dec 5, 2023
e6f0049
revert faulty jax.array / jnp.ndarray conversions
MUCDK Dec 5, 2023
f23497f
make formatting in neural nets nicer
MUCDK Dec 5, 2023
9f96583
add description to Velocity Field
MUCDK Dec 5, 2023
86fe886
replace time sampler class by function
MUCDK Dec 5, 2023
58e3d29
add citations
MUCDK Dec 5, 2023
2f5fa52
add more references
MUCDK Dec 5, 2023
9ad9924
rename keys_model to rng
MUCDK Dec 5, 2023
0addc7a
fix tests regarding time sampling
MUCDK Dec 5, 2023
be68393
fix typo in tests
MUCDK Dec 5, 2023
b5bdc4a
rename neural_vector_field to velocity_field everywhere
MUCDK Dec 5, 2023
bebbbd0
fix OTFlowMatching.transport
MUCDK Dec 5, 2023
f4c05c4
fix rescaling networks
MUCDK Dec 5, 2023
4d9992e
Update src/ott/neural/flows/flows.py
MUCDK Jan 5, 2024
51221dd
Update src/ott/neural/flows/flows.py
MUCDK Jan 5, 2024
6c56dfe
test for scale_cost
MUCDK Jan 8, 2024
cc045fa
update test for scale_cost
MUCDK Jan 9, 2024
f4de339
fix bug for scale_cost
MUCDK Jan 9, 2024
5db4c73
fix bug for scale_cost
MUCDK Jan 9, 2024
72885ac
jit solve_ode in genot
MUCDK Jan 10, 2024
937fffc
incorporate changes partially
MUCDK Feb 7, 2024
a94b585
[ci skip] intermediate save
MUCDK Feb 9, 2024
78b5e10
[ci skip] neural base solver update
MUCDK Feb 9, 2024
592564f
make resamlpemixin a class
MUCDK Feb 9, 2024
5e05bfc
incorporate more changes
MUCDK Feb 9, 2024
831d3ea
move noise sampling to flows
MUCDK Feb 9, 2024
c18c461
fix bug in passing rngs in otfm
MUCDK Feb 9, 2024
8341821
introduce otmatcher in otfm
MUCDK Feb 9, 2024
3cae628
[ci skip] split GENOT into GENOTLin and GENOTQuad
MUCDK Feb 9, 2024
20fbbb8
remove dictionaries in OTFM and GENOT classes
MUCDK Feb 11, 2024
525ef64
change logic in match_latent_to_data in genot
MUCDK Feb 11, 2024
1b30c11
change data loaders / data sets
MUCDK Feb 11, 2024
e2ebb19
finish data loader refactoring
MUCDK Feb 11, 2024
acc782f
Merge remote-tracking branch 'upstream/main' into draft/neural_base_s…
michalk8 Feb 12, 2024
8644fd9
Update linter
michalk8 Feb 12, 2024
460bf90
fix bug in _resample_data`
MUCDK Feb 14, 2024
ce42c1a
incorporate more changes
MUCDK Feb 16, 2024
1e21afb
add docs
MUCDK Feb 16, 2024
1afb922
incorporate more changes
MUCDK Feb 16, 2024
dc436f4
problem with custom type
MUCDK Feb 16, 2024
8bfe1a3
fix scale cost bug
MUCDK Feb 16, 2024
2a1f23a
fix bugs
MUCDK Feb 16, 2024
a46405c
fux bug in unbalancedness/rescalingMlp
MUCDK Feb 18, 2024
7afcac4
unify unbalancedness step in GENOT
MUCDK Feb 18, 2024
4fc8fe6
change OTDataSet and OTFlowMatching to 4 data loaderes
MUCDK Feb 18, 2024
7919051
Merge remote-tracking branch 'upstream/main' into draft/neural_base_s…
michalk8 Feb 19, 2024
43d37f7
Fix bug in the `ConditionalOTDataset`
michalk8 Feb 19, 2024
86f6e7a
Polish docs in the `flows.py`
michalk8 Feb 19, 2024
ae37132
Update `OTFM`
michalk8 Feb 19, 2024
de323d2
Fix small bugs in `OTFM`
michalk8 Feb 19, 2024
4408cc2
Polish layers
michalk8 Feb 19, 2024
451f069
Fix typo in citation
michalk8 Feb 19, 2024
5e10d3a
More polish for the docs
michalk8 Feb 19, 2024
5edc66d
remove print statements and unbalancednesshandler
MUCDK Mar 6, 2024
23eca2c
remove tests
MUCDK Mar 6, 2024
85427ba
make genot training loops more similar to otfm training loop
MUCDK Mar 6, 2024
5a2424a
adapt tests to the extent possible
MUCDK Mar 6, 2024
c4a187e
Add weights to sampling
michalk8 Mar 11, 2024
30f2324
Start cleaning matchers
michalk8 Mar 11, 2024
82bc7e6
Add conditional sampling + resampling
michalk8 Mar 11, 2024
f430c29
Add initial quad matcher
michalk8 Mar 11, 2024
4b41f0c
Improve typing
michalk8 Mar 11, 2024
cc2746b
Remove `base_solver.py`
michalk8 Mar 11, 2024
1068410
Add TODO
michalk8 Mar 11, 2024
e559740
Update datasets, fix OTFM tests
michalk8 Mar 13, 2024
a9fe618
Start cleaning GENOT
michalk8 Mar 14, 2024
abca4f7
Update GENOT
michalk8 Mar 15, 2024
f2c20a4
Remove old GENOTLin/GENOTQuad
michalk8 Mar 15, 2024
693ecc4
Remove axis swapping
michalk8 Mar 15, 2024
3d9c702
Remove old todo
michalk8 Mar 15, 2024
f27d209
Fix OTFM tests
michalk8 Mar 15, 2024
4688998
Remove `MLPBlock` and `RescalingMLP`
michalk8 Mar 15, 2024
52c5de9
Add forgotten license
michalk8 Mar 15, 2024
0b417d7
Remove `__post_init__` from `VF`
michalk8 Mar 15, 2024
fe74a57
Move cyclical time encoder
michalk8 Mar 15, 2024
4affc14
Move more stuff to `utils`
michalk8 Mar 15, 2024
21ce523
Remove `samplers.py`
michalk8 Mar 15, 2024
aa636ef
Rename `cond_dim` -> `condition_dim`
michalk8 Mar 15, 2024
da0ef92
Nicer formatting
michalk8 Mar 15, 2024
de1c264
Fix bug when sampling from the target
michalk8 Mar 15, 2024
ce763f0
Fix another bug when sampling from the data
michalk8 Mar 15, 2024
f9db2db
Add initial test for GW
michalk8 Mar 15, 2024
8bc9b10
Remove old GENOT tests
michalk8 Mar 15, 2024
6f4f864
Remove old dataloaders
michalk8 Mar 15, 2024
11911c4
Add more todos
michalk8 Mar 17, 2024
a8de2ea
add docs to dataloader
MUCDK Mar 19, 2024
dfaf042
expose args in GENOT
MUCDK Mar 19, 2024
2734c60
add docs and adapt data_match_fn
MUCDK Mar 19, 2024
428ad06
Merge branch 'main' into draft/neural_base_solver
MUCDK Mar 19, 2024
08e24d8
fix linting
MUCDK Mar 19, 2024
7d7da3a
fix data loading and add genot fused tests
MUCDK Mar 19, 2024
4c8477a
genot tests passing
MUCDK Mar 19, 2024
001d21d
adapt docs
MUCDK Mar 19, 2024
52d8466
adapt docs
MUCDK Mar 19, 2024
9f230c8
add error message
MUCDK Mar 19, 2024
6c81678
clean docs
MUCDK Mar 19, 2024
e77cc34
comprise genot tests
MUCDK Mar 19, 2024
d8603f7
change reference for GENOT
MUCDK Mar 19, 2024
7813f83
add missing docstring
MUCDK Mar 20, 2024
212ee01
Modify behaviour of `ConditionalLoader`
michalk8 Mar 25, 2024
95c7142
Update docstring
michalk8 Mar 25, 2024
52a54d3
Clean GENOT docs
michalk8 Mar 25, 2024
de2e4ac
Improve VF
michalk8 Mar 26, 2024
9b89fd7
Simplify GENOT test
michalk8 Mar 26, 2024
433da0c
Better metadata wrapper in tests
michalk8 Mar 26, 2024
f8fcba7
Fix condition in GENOT test
michalk8 Mar 26, 2024
49a07a0
Add quad cond dl
michalk8 Mar 26, 2024
d1ae1de
Add conf fused DL
michalk8 Mar 26, 2024
f6c69bd
Polish docs
michalk8 Mar 26, 2024
3b69c0f
Remove conditional loader
michalk8 Mar 26, 2024
0ff3ad6
Fix link in the docs
michalk8 Mar 26, 2024
c3ce786
Improve VF
michalk8 Mar 26, 2024
161dd4a
Fix GENOT test
michalk8 Mar 26, 2024
69c3a4d
Polish docs
michalk8 Mar 26, 2024
65f2ab3
Remove `uniform_marginals` argument
michalk8 Mar 27, 2024
ba64056
Fix undefined variable
michalk8 Mar 27, 2024
80d2924
Update `GENOT.transport` docs
michalk8 Mar 27, 2024
e4aae7f
Add `diffrax` to `conf.py`
michalk8 Mar 27, 2024
128e085
Merge remote-tracking branch 'upstream/main' into draft/neural_base_s…
michalk8 Mar 28, 2024
1d96fac
Restructure files
michalk8 Mar 29, 2024
ef6afd1
Fix neural init tests import
michalk8 Mar 29, 2024
73c2527
Update `docs/`
michalk8 Mar 29, 2024
0418e78
Update Monge Gap
michalk8 Apr 2, 2024
b34b886
Update MetaOT and NeuralDual
michalk8 Apr 2, 2024
67202c2
Update ICNN inits
michalk8 Apr 2, 2024
982d20b
Fix links to neural in the docs
michalk8 Apr 2, 2024
7b61e05
Check for condition dim in VF
michalk8 Apr 2, 2024
8819d5e
Don't use activation fn in the last layer of VF
michalk8 Apr 2, 2024
6f9cbcc
Update assertions
michalk8 Apr 2, 2024
9e1499b
Try skipping OTFM/GENOT tests temporarily
michalk8 Apr 2, 2024
b37da2a
Be extra verbose when intalling packages
michalk8 Apr 2, 2024
9c561a5
Remove `torch` dependency
michalk8 Apr 2, 2024
f227d54
Remove `torch` from tests in `pyproject.toml`
michalk8 Apr 2, 2024
6f9a77c
[ci skip] Update docstrings
michalk8 Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,8 @@ default_stages:
- push
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/google/yapf
rev: v0.40.0
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.10.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: detect-private-key
- id: check-ast
Expand All @@ -37,13 +20,34 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: v0.0.285
rev: v0.1.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.1
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.11.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
- repo: https://github.com/rstcheck/rstcheck
rev: v6.1.2
rev: v6.2.0
hooks:
- id: rstcheck
additional_dependencies: [tomli]
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import logging
from datetime import datetime

import ott
from sphinx.util import logging as sphinx_logging

import ott

# -- Project information -----------------------------------------------------
needs_sphinx = "4.0"

Expand Down
50 changes: 50 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,53 @@ @misc{huguet:2023
title = {Geodesic Sinkhorn for Fast and Accurate Optimal Transport on Manifolds},
year = {2023},
}

MUCDK marked this conversation as resolved.
Show resolved Hide resolved
@misc{eyring:23,
author={Eyring, Luca and Klein, Dominik and Uscidda, Th{\'e}o and Palla, Giovanni and Kilbertus, Niki and Akata, Zeynep and Theis, Fabian},
doi = {10.48550/arXiv.2311.15100},
eprint = {2311.15100},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title={Unbalancedness in Neural Monge Maps Improves Unpaired Domain Translation},
year={2023}
}

@misc{klein_uscidda:23,
author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi},
doi = {10.48550/arXiv.2310.09254},
eprint={2310.09254},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces},
year={2023},
}

@misc{lipman:22,
author={Lipman, Yaron and Chen, Ricky TQ and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt},
doi = {10.48550/arXiv.2210.02747},
eprint={2210.02747},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title={Flow matching for generative modeling},
year={2022},
}

@misc{tong:23,
author={Tong, Alexander and Malkin, Nikolay and Huguet, Guillaume and Zhang, Yanlei and {Rector-Brooks}, Jarrid and Fatras, Kilian and Wolf, Guy and Bengio, Yoshua},
doi={10.48550/arXiv.2302.00482},
eprint={2302.00482},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title={Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport},
year={2023},
}

@misc{pooladian:23,
author={Pooladian, Aram-Alexandre and Ben-Hamu, Heli and Domingo-Enrich, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky},
doi={10.48550/arXiv.2304.14772},
eprint={2304.14772},
eprintclass = {stat.ML},
eprinttype = {arXiv},
title={Multisample flow matching: Straightening flows with minibatch couplings},
year={2023}
}
2 changes: 2 additions & 0 deletions docs/tutorials/MetaOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
"outputs": [],
"source": [
"# Obtain the MNIST dataset and flatten the images into discrete measures.\n",
"\n",
"\n",
"def get_mnist_flat(train):\n",
" dataset = torchvision.datasets.MNIST(\n",
" \"/tmp/mnist/\",\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials/Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import sklearn.datasets\n",
"\n",
"import optax\n",
"from flax import linen as nn\n",
"\n",
"from matplotlib import pyplot as plt\n",
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials/icnn_inits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import optax\n",
"\n",
"import matplotlib.pyplot as plt\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"from torch.utils.data import DataLoader, IterableDataset\n",
"\n",
"import optax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output, display\n",
"\n",
Expand Down
6 changes: 5 additions & 1 deletion docs/tutorials/point_clouds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
},
"outputs": [],
"source": [
"def create_points(rng: jax.random.PRNGKeyArray, n: int, m: int, d: int):\n",
"def create_points(rng: jax.Array, n: int, m: int, d: int):\n",
" rngs = jax.random.split(rng, 3)\n",
" x = jax.random.normal(rngs[0], (n, d)) + 1\n",
" y = jax.random.uniform(rngs[1], (m, d))\n",
Expand Down Expand Up @@ -279,6 +279,8 @@
"outputs": [],
"source": [
"# Helper function to plot successively the optimal transports\n",
"\n",
"\n",
"def plot_ots(ots):\n",
" fig = plt.figure(figsize=(8, 5))\n",
" plott = ott.tools.plot.Plot(fig=fig)\n",
Expand Down Expand Up @@ -366973,6 +366975,8 @@
"outputs": [],
"source": [
"# Plotting utility\n",
"\n",
"\n",
"def plot_map(x, y, z, forward: bool = True):\n",
" plt.figure(figsize=(10, 8))\n",
" marker_t = \"o\" if forward else \"X\"\n",
Expand Down
7 changes: 4 additions & 3 deletions docs/tutorials/soft_sort.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import flax.linen as nn\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"import torchvision\n",
"from flax import struct\n",
"from scipy import ndimage\n",
"from torch.utils import data\n",
"\n",
"import flax.linen as nn\n",
"import optax\n",
"from flax import struct\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from ott.tools import soft_sort"
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/sparse_monge_displacements.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
"outputs": [],
"source": [
"# Plotting utility\n",
"\n",
"\n",
"def plot_map(x, y, x_new=None, z=None, ax=None, title=None):\n",
" if ax is None:\n",
" f, ax = plt.subplots(figsize=(10, 8))\n",
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@
"outputs": [],
"source": [
"# Samples spiral\n",
"\n",
"\n",
"def sample_spiral(\n",
" n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0\n",
"):\n",
Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Changelog = "https://github.com/ott-jax/ott/releases"
neural = [
"flax>=0.6.6",
"optax>=0.1.1",
"diffrax>=0.4.1",
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
]
dev = [
"pre-commit>=2.16.0",
Expand Down Expand Up @@ -102,11 +103,14 @@ include = '\.ipynb$'

[tool.isort]
profile = "black"
line_length = 80
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
include_trailing_comma = true
multi_line_output = 3
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
# also contains what we import in notebooks
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn"]
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TEST", "NUMERIC", "NEURAL", "PLOTTING", "FIRSTPARTY", "LOCALFOLDER"]
# also contains what we import in notebooks/tests
known_neural = ["flax", "optax", "diffrax", "orbax"]
known_numeric = ["numpy", "scipy", "jax", "flax", "optax", "jaxopt", "torch", "ot", "torchvision", "pandas", "sklearn", "tslearn"]
known_test = ["pytest"]
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

[tool.pytest.ini_options]
Expand Down Expand Up @@ -285,7 +289,6 @@ ignore = [
line-length = 80
select = [
"D", # flake8-docstrings
"I", # isort
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
"E", # pycodestyle
"F", # pyflakes
"W", # pycodestyle
Expand All @@ -301,7 +304,7 @@ select = [
"T20", # flake8-print
"RET", # flake8-raise
]
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
unfixable = ["I", "B", "UP", "C4", "BLE", "T20", "RET"]
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
target-version = "py38"
[tool.ruff.per-file-ignores]
# TODO(michalk8): PO004 - remove `self.initialize`
Expand Down
22 changes: 7 additions & 15 deletions src/ott/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class GaussianMixture:
rectangle

batch_size: batch size of the samples
init_rng: initial PRNG key
rng: initial PRNG key
scale: scale of the Gaussian means
std: the standard deviation of the individual Gaussian samples
"""
name: Name_t
batch_size: int
init_rng: jax.Array
rng: jax.Array
scale: float = 5.0
std: float = 0.5

Expand Down Expand Up @@ -96,7 +96,7 @@ def __iter__(self) -> Iterator[jnp.array]:
return self._create_sample_generators()

def _create_sample_generators(self) -> Iterator[jnp.array]:
rng = self.init_rng
rng = self.rng
while True:
rng1, rng2, rng = jax.random.split(rng, 3)
means = jax.random.choice(rng1, self.centers, (self.batch_size,))
Expand Down Expand Up @@ -128,26 +128,18 @@ def create_gaussian_mixture_samplers(
rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)
train_dataset = Dataset(
source_iter=iter(
GaussianMixture(
name_source, batch_size=train_batch_size, init_rng=rng1
)
GaussianMixture(name_source, batch_size=train_batch_size, rng=rng1)
),
target_iter=iter(
GaussianMixture(
name_target, batch_size=train_batch_size, init_rng=rng2
)
GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2)
)
)
valid_dataset = Dataset(
source_iter=iter(
GaussianMixture(
name_source, batch_size=valid_batch_size, init_rng=rng3
)
GaussianMixture(name_source, batch_size=valid_batch_size, rng=rng3)
),
target_iter=iter(
GaussianMixture(
name_target, batch_size=valid_batch_size, init_rng=rng4
)
GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4)
)
)
dim_data = 2
Expand Down
7 changes: 1 addition & 6 deletions src/ott/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
fixed_point_loop,
matrix_square_root,
unbalanced_functions,
utils,
)
from . import fixed_point_loop, matrix_square_root, unbalanced_functions, utils
2 changes: 1 addition & 1 deletion src/ott/neural/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import layers, losses, models, solvers
from . import data, duality, flows, gaps, models
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import conjugate, map_estimator, neuraldual
from . import dataloaders
Loading
Loading