Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fast tests #101

Merged
merged 53 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
3e96ba4
Enable tests coverage
michalk8 Jul 4, 2022
0f36469
Add fast test CI job
michalk8 Jul 4, 2022
742e7db
Add coverage upload in CI
michalk8 Jul 4, 2022
4327531
Add custom fast mark handling
michalk8 Jul 4, 2022
38b5417
Fix typo in tests.yml
michalk8 Jul 4, 2022
7dd9526
Fix wrong test_mark identifier
michalk8 Jul 4, 2022
7c3eeb1
Update continuous barycenter test
michalk8 Jul 6, 2022
04a7314
Update discrete barycenter
michalk8 Jul 6, 2022
e4273dd
Update FGW test, fix fixture handling
michalk8 Jul 6, 2022
7a02e2a
Update GW tests
michalk8 Jul 6, 2022
5352c86
Update/move GW unbalanced test
michalk8 Jul 6, 2022
644c502
Update ICNN test
michalk8 Jul 6, 2022
a2c2a17
Update NeuralDualSolver tests
michalk8 Jul 6, 2022
3d949fd
Update Sinkhorn Anderson accel. test
michalk8 Jul 6, 2022
23bf570
Refactor Sinkhorn Bures test
michalk8 Jul 6, 2022
acd0121
Refactor Sinkhorn differentiability tests
michalk8 Jul 6, 2022
0510a8c
Move and refactor grid diff + precond tests
michalk8 Jul 6, 2022
678143d
Refactor and move Hessian tests
michalk8 Jul 6, 2022
0f3f054
Move SinkhornJacobianApply test
michalk8 Jul 6, 2022
3643f35
Move implicit vs. autodiff tests
michalk8 Jul 6, 2022
9e02779
Update SinkhornGrid test
michalk8 Jul 6, 2022
5fdc509
Update Sinkhorn jitting tests
michalk8 Jul 6, 2022
b141445
Update LRSinkhorn test
michalk8 Jul 6, 2022
f72faa6
Merge branch 'main' into feature/fast-tests
michalk8 Jul 6, 2022
476e96f
Move SinkhornOnline tests
michalk8 Jul 6, 2022
8a6fb7c
Move jacobian test
michalk8 Jul 6, 2022
046ef88
Move sinkhorn unbalanced tests
michalk8 Jul 6, 2022
e3c083b
Move Sinkhorn jitting tests, add chex as dep
michalk8 Jul 7, 2022
7262eb0
Update Sinkhorn tests
michalk8 Jul 7, 2022
5aa2036
Update TestCostFn
michalk8 Jul 7, 2022
40aef0a
Update LRGeom test
michalk8 Jul 7, 2022
73e5e76
Update Geom LSE and PC apply tests
michalk8 Jul 7, 2022
f9e58d3
Update PC apply test
michalk8 Jul 7, 2022
3f93ca5
Update Matrix square root tests
michalk8 Jul 7, 2022
2ba3dca
Update ScaleCost tests
michalk8 Jul 7, 2022
0a7dfa4
Update SinkhornDivGrad test
michalk8 Jul 7, 2022
66a9400
Update SinkhornDivergece output and tests
michalk8 Jul 7, 2022
2a2f697
Update soft sort and transport tests
michalk8 Jul 7, 2022
fcaecbb
Update GMM pair test
michalk8 Jul 7, 2022
7a12beb
Fix typo, update FitGMM tests
michalk8 Jul 7, 2022
dfccacf
Update GMM pair test
michalk8 Jul 7, 2022
5f0fed3
Update Gaussian Mixture test
michalk8 Jul 7, 2022
0d271da
Update TestGaussian
michalk8 Jul 7, 2022
ea5f74b
Update lingalg tests
michalk8 Jul 7, 2022
3963503
Update probabilities test
michalk8 Jul 7, 2022
d3bcf47
Refactor ScaleTriL tests
michalk8 Jul 7, 2022
4327d6d
Remove absl-py from requirements
michalk8 Jul 7, 2022
b6a9f47
Add .codecov.yml
michalk8 Jul 7, 2022
9cc34b5
Fix typo in tests.yml
michalk8 Jul 7, 2022
010bf74
Add coverage badge
michalk8 Jul 7, 2022
5171e77
Fix type hint in 3.8
michalk8 Jul 7, 2022
e6a1aa7
Try using pytest-xdist
michalk8 Jul 7, 2022
03f14d3
Update pytest config, add --cov-append
michalk8 Jul 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
codecov:
require_ci_to_pass: no
strict_yaml_branch: main

coverage:
range: 75..100
status:
project:
default:
target: 1
patch: off

comment:
layout: reach, diff, files
behavior: default
require_changes: true
branches:
- main
30 changes: 24 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,42 @@ jobs:
matrix:
python-version: ['3.8']
os: [ubuntu-latest]
test_mark: [fast, all]

steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Build
run: |
set -xe
python -VV
pip install --upgrade pip setuptools wheel
pip install --upgrade pip setuptools
pip install -e '.[test]'
shell: bash
- name: Run tests

- name: Print versions
run: |
set -xe
python -VV
python -c "import jax; print('jax', jax.__version__)"
python -c "import jaxlib; print('jaxlib', jaxlib.__version__)"
pytest tests -v --memray
shell: bash

- name: Run fast tests
if: ${{ matrix.test_mark == 'fast' }}
run: |
pytest -m fast --cov=ott --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray -v tests/

- name: Run all tests
if: ${{ matrix.test_mark == 'all' }}
run: |
pytest --cov=ott --cov-report=xml --cov-report=term-missing --cov-config=setup.cfg --memray -v tests/

- name: Upload coverage
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
flags: unittests-${{ matrix.test_mark }}
env_vars: OS,PYTHON
fail_ci_if_error: true
verbose: true
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

# Optimal Transport Tools (OTT).

![Tests](https://github.com/ott-jax/ott/actions/workflows/tests.yml/badge.svg)
![Tests](https://img.shields.io/github/workflow/status/ott-jax/ott/tests/main)
![Coverage](https://img.shields.io/codecov/c/github/ott-jax/ott/main)

**See [full documentation](https://ott-jax.readthedocs.io/en/latest/).**

Expand Down
30 changes: 18 additions & 12 deletions ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
# limitations under the License.
"""Implements the sinkhorn divergence."""

import collections
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple

import jax
from jax import numpy as jnp

from ott.core import segment, sinkhorn
from ott.geometry import geometry, pointcloud

SinkhornDivergenceOutput = collections.namedtuple(
'SinkhornDivergenceOutput',
['divergence', 'potentials', 'geoms', 'errors', 'converged']
)

class SinkhornDivergenceOutput(NamedTuple):
divergence: float
potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]]
geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry]
errors: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]]
converged: Tuple[bool, bool, bool]


def sinkhorn_divergence(
Expand All @@ -37,7 +40,7 @@ def sinkhorn_divergence(
static_b: bool = False,
share_epsilon: bool = True,
**kwargs: Any,
):
) -> SinkhornDivergenceOutput:
"""Compute Sinkhorn divergence defined by a geometry, weights, parameters.

Args:
Expand Down Expand Up @@ -82,10 +85,13 @@ def sinkhorn_divergence(


def _sinkhorn_divergence(
geometry_xy: geometry.Geometry, geometry_xx: geometry.Geometry,
geometry_yy: Optional[geometry.Geometry], a: jnp.ndarray, b: jnp.ndarray,
**kwargs
):
geometry_xy: geometry.Geometry,
geometry_xx: geometry.Geometry,
geometry_yy: Optional[geometry.Geometry],
a: jnp.ndarray,
b: jnp.ndarray,
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute the (unbalanced) sinkhorn divergence for the wrapper function.

This definition includes a correction depending on the total masses of each
Expand Down Expand Up @@ -159,7 +165,7 @@ def segment_sinkhorn_divergence(
sinkhorn_kwargs: Optional[Mapping[str, Any]] = None,
static_b: bool = False,
share_epsilon: bool = True,
**kwargs
**kwargs: Any
) -> jnp.ndarray:
"""Compute Sinkhorn divergence between subsets of data with point cloud.

Expand Down
24 changes: 23 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ test =
pytest
pytest-xdist
pytest-memray
absl-py
pytest-cov
chex
docs =
sphinx>=4.0
nbsphinx>=0.8.0
Expand All @@ -63,3 +64,24 @@ docs =
sphinx-book-theme
dev =
pre-commit

[coverage:run]
branch = true
source = ott
omit = */__init__.py

[coverage:report]
exclude_lines =
\#.*pragma:\s*no.?cover
^if __name__ == .__main__.:$
^\s*raise AssertionError\b
^\s*raise NotImplementedError\b
^\s*return NotImplemented\b
precision = 2
show_missing = True
skip_empty = True
sort = Miss

[tool:pytest]
markers =
fast: Mark tests as fast.
68 changes: 68 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import collections.abc
import itertools
from typing import Any, Mapping, Optional, Sequence

import jax
import pytest
from _pytest.python import Metafunc


def pytest_generate_tests(metafunc: Metafunc) -> None:
if not hasattr(metafunc.function, "pytestmark"):
# no annotation
return

fast_marks = [m for m in metafunc.function.pytestmark if m.name == "fast"]
if fast_marks:
mark, = fast_marks
selected: Optional[Mapping[str, Any]] = mark.kwargs.pop("only_fast", None)
ids: Optional[Sequence[str]] = mark.kwargs.pop("ids", None)

if mark.args:
argnames, argvalues = mark.args
else:
argnames = tuple(mark.kwargs.keys())
argvalues = [
(vs,) if not isinstance(vs, (str, collections.abc.Iterable)) else vs
for vs in mark.kwargs.values()
]
argvalues = list(itertools.product(*argvalues))

opt = str(metafunc.config.getoption("-m"))
if "fast" in opt: # filter if `-m fast` was passed
if selected is None:
combinations = argvalues
elif isinstance(selected, dict):
combinations = []
for vs in argvalues:
if selected == dict(zip(argnames, vs)):
combinations.append(vs)
elif isinstance(selected, (tuple, list)):
# TODO(michalk8): support passing ids?
combinations = [argvalues[s] for s in selected]
ids = None if ids is None else [ids[s] for s in selected]
elif isinstance(selected, int):
combinations = [argvalues[selected]]
ids = None if ids is None else [ids[selected]]
else:
raise TypeError(f"Invalid fast selection type `{type(selected)}`.")
else:
combinations = argvalues

if argnames:
metafunc.parametrize(argnames, combinations, ids=ids)


@pytest.fixture(scope="session")
def rng():
return jax.random.PRNGKey(0)


@pytest.fixture()
def enable_x64():
previous_value = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", True)
try:
yield
finally:
jax.config.update("jax_enable_x64", previous_value)
Loading