From 38922c0bd72df3c332d65b5c40b6f6a14c4fc53e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 31 Oct 2024 11:06:08 +0100 Subject: [PATCH] [MRG] pre-commit with ruff,codespell,yamlint (#681) * add file and lint * add ignore words * update tests * codespell commit * test should pass * fix demo flow * upate release file and documentaion for conribution * try to fix doctests * fix tests * remove tets on 3.8 * try other mlegacy option * test correctly --- .github/CONTRIBUTING.md | 31 +- .github/workflows/build_tests.yml | 46 +- .pre-commit-config.yaml | 51 + .yamllint.yml | 10 + RELEASES.md | 1 + benchmarks/benchmark.py | 12 +- benchmarks/emd.py | 20 +- benchmarks/sinkhorn_knopp.py | 20 +- docs/nb_run_conv | 46 +- docs/rtd/conf.py | 6 +- docs/source/conf.py | 203 ++- examples/backends/plot_dual_ot_pytorch.py | 34 +- .../backends/plot_optim_gromov_pytorch.py | 49 +- .../plot_sliced_wass_grad_flow_pytorch.py | 58 +- examples/backends/plot_ssw_unif_torch.py | 51 +- .../plot_stoch_continuous_ot_pytorch.py | 84 +- examples/backends/plot_unmix_optim_torch.py | 32 +- examples/backends/plot_wass1d_torch.py | 30 +- examples/backends/plot_wass2_gan_torch.py | 48 +- examples/barycenters/plot_barycenter_1D.py | 45 +- .../plot_barycenter_lp_vs_entropic.py | 78 +- .../plot_convolutional_barycenter.py | 24 +- .../barycenters/plot_debiased_barycenter.py | 34 +- .../plot_free_support_barycenter.py | 30 +- .../plot_free_support_sinkhorn_barycenter.py | 65 +- .../barycenters/plot_gaussian_barycenter.py | 42 +- ...lot_generalized_free_support_barycenter.py | 58 +- .../domain-adaptation/plot_otda_classes.py | 85 +- .../plot_otda_color_images.py | 52 +- examples/domain-adaptation/plot_otda_d2.py | 101 +- examples/domain-adaptation/plot_otda_jcpot.py | 168 +- .../domain-adaptation/plot_otda_laplacian.py | 67 +- .../plot_otda_linear_mapping.py | 76 +- .../domain-adaptation/plot_otda_mapping.py | 71 +- .../plot_otda_mapping_colors_images.py | 58 +- .../plot_otda_semi_supervised.py | 60 +- examples/gromov/plot_barycenter_fgw.py | 67 +- .../gromov/plot_entropic_semirelaxed_fgw.py | 285 +++- examples/gromov/plot_fgw.py | 52 +- examples/gromov/plot_fgw_solvers.py | 379 +++-- examples/gromov/plot_gnn_TFGW.py | 122 +- examples/gromov/plot_gromov.py | 106 +- examples/gromov/plot_gromov_barycenter.py | 104 +- ..._gromov_wasserstein_dictionary_learning.py | 253 ++- .../plot_quantized_gromov_wasserstein.py | 376 +++-- examples/gromov/plot_semirelaxed_fgw.py | 289 +++- ...mirelaxed_gromov_wasserstein_barycenter.py | 193 ++- examples/others/plot_COOT.py | 34 +- examples/others/plot_EWCA.py | 4 +- examples/others/plot_GMMOT_plan.py | 64 +- examples/others/plot_GMM_flow.py | 106 +- examples/others/plot_SSNB.py | 69 +- examples/others/plot_WDA.py | 40 +- examples/others/plot_WeakOT_VS_OT.py | 44 +- examples/others/plot_dmmot.py | 68 +- examples/others/plot_factored_coupling.py | 44 +- examples/others/plot_logo.py | 107 +- examples/others/plot_lowrank_GW.py | 45 +- examples/others/plot_lowrank_sinkhorn.py | 71 +- ...detection_with_COOT_and_unbalanced_COOT.py | 70 +- examples/others/plot_screenkhorn_1D.py | 16 +- examples/others/plot_stochastic.py | 28 +- examples/plot_Intro_OT.py | 191 ++- examples/plot_OT_1D.py | 20 +- examples/plot_OT_1D_smooth.py | 31 +- examples/plot_OT_2D_samples.py | 58 +- examples/plot_OT_L1_vs_L2.py | 122 +- examples/plot_compute_emd.py | 70 +- examples/plot_compute_wasserstein_circle.py | 28 +- examples/plot_optim_OTreg.py | 36 +- examples/plot_solve_variants.py | 114 +- examples/sliced-wasserstein/plot_variance.py | 20 +- .../sliced-wasserstein/plot_variance_ssw.py | 22 +- examples/unbalanced-partial/plot_UOT_1D.py | 28 +- .../plot_UOT_barycenter_1D.py | 47 +- .../plot_conv_sinkhorn_ti.py | 37 +- .../plot_partial_wass_and_gromov.py | 73 +- examples/unbalanced-partial/plot_regpath.py | 192 ++- .../unbalanced-partial/plot_unbalanced_OT.py | 37 +- ignore-words.txt | 9 + ot/__init__.py | 106 +- ot/backend.py | 242 +-- ot/bregman/__init__.py | 112 +- ot/bregman/_barycenter.py | 432 +++-- ot/bregman/_convolutional.py | 281 ++-- ot/bregman/_dictionary.py | 37 +- ot/bregman/_empirical.py | 365 +++-- ot/bregman/_geomloss.py | 79 +- ot/bregman/_screenkhorn.py | 125 +- ot/bregman/_sinkhorn.py | 580 ++++--- ot/bregman/_utils.py | 6 +- ot/coot.py | 186 ++- ot/da.py | 690 +++++--- ot/datasets.py | 65 +- ot/dr.py | 132 +- ot/factored.py | 43 +- ot/gaussian.py | 145 +- ot/gmm.py | 72 +- ot/gnn/__init__.py | 12 +- ot/gnn/_layers.py | 65 +- ot/gnn/_utils.py | 131 +- ot/gromov/__init__.py | 255 +-- ot/gromov/_bregman.py | 781 ++++++--- ot/gromov/_dictionary.py | 503 ++++-- ot/gromov/_estimators.py | 135 +- ot/gromov/_gw.py | 636 ++++++-- ot/gromov/_lowrank.py | 51 +- ot/gromov/_partial.py | 249 ++- ot/gromov/_quantized.py | 361 +++-- ot/gromov/_semirelaxed.py | 898 ++++++++--- ot/gromov/_unbalanced.py | 488 ++++-- ot/gromov/_utils.py | 297 ++-- ot/helpers/openmp_helpers.py | 34 +- ot/helpers/pre_build_helpers.py | 28 +- ot/lowrank.py | 55 +- ot/lp/__init__.py | 229 ++- ot/lp/cvx.py | 31 +- ot/lp/dmmot.py | 53 +- ot/lp/solver_1d.py | 227 ++- ot/mapping.py | 259 ++- ot/optim.py | 294 +++- ot/partial.py | 306 ++-- ot/plot.py | 90 +- ot/regpath.py | 92 +- ot/sliced.py | 96 +- ot/smooth.py | 178 ++- ot/solvers.py | 871 +++++++--- ot/stochastic.py | 131 +- ot/unbalanced/__init__.py | 42 +- ot/unbalanced/_lbfgs.py | 132 +- ot/unbalanced/_mm.py | 102 +- ot/unbalanced/_sinkhorn.py | 724 ++++++--- ot/utils.py | 223 +-- ot/weak.py | 16 +- setup.py | 143 +- test/conftest.py | 18 +- test/gromov/test_bregman.py | 1400 +++++++++++++---- test/gromov/test_dictionary.py | 902 ++++++++--- test/gromov/test_estimators.py | 74 +- test/gromov/test_fugw.py | 834 +++++++--- test/gromov/test_gw.py | 988 +++++++++--- test/gromov/test_lowrank.py | 48 +- test/gromov/test_partial.py | 175 ++- test/gromov/test_quantized.py | 242 ++- test/gromov/test_semirelaxed.py | 1330 ++++++++++++---- test/gromov/test_utils.py | 68 +- test/test_1d_solver.py | 69 +- test/test_backend.py | 199 ++- test/test_bregman.py | 542 ++++--- test/test_coot.py | 60 +- test/test_da.py | 304 ++-- test/test_dmmot.py | 57 +- test/test_dr.py | 44 +- test/test_factored.py | 4 +- test/test_gaussian.py | 82 +- test/test_gmm.py | 28 +- test/test_gnn.py | 60 +- test/test_helpers.py | 3 +- test/test_lowrank.py | 24 +- test/test_mapping.py | 23 +- test/test_optim.py | 42 +- test/test_ot.py | 117 +- test/test_partial.py | 82 +- test/test_plot.py | 22 +- test/test_regpath.py | 16 +- test/test_sliced.py | 61 +- test/test_smooth.py | 69 +- test/test_solvers.py | 261 ++- test/test_stochastic.py | 99 +- test/test_ucoot.py | 1179 ++++++++++---- test/test_utils.py | 177 +-- test/test_weak.py | 2 +- test/unbalanced/test_lbfgs.py | 266 +++- test/unbalanced/test_mm.py | 102 +- test/unbalanced/test_sinkhorn.py | 531 +++++-- 175 files changed, 20385 insertions(+), 9252 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 .yamllint.yml create mode 100644 ignore-words.txt diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index c66ab3e61..94486046f 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -23,6 +23,15 @@ GitHub, clone, and develop on a branch. Steps: $ cd POT ``` +2. Install pre-commit hooks to ensure that your code is properly formatted: + + ```bash + $ pip install pre-commit + $ pre-commit install + ``` + + This will install the pre-commit hooks that will run on every commit. If the hooks fail, the commit will be aborted. + 3. Create a ``feature`` branch to hold your development changes: ```bash @@ -56,7 +65,7 @@ Pull Request Checklist We recommended that your contribution complies with the following rules before you submit a pull request: -- Follow the PEP8 Guidelines. +- Follow the PEP8 Guidelines which should be handles automatically by pre-commit. - If your pull request addresses an issue, please use the pull request title to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is @@ -101,27 +110,19 @@ following rules before you submit a pull request: You can also check for common programming errors with the following tools: - -- No pyflakes warnings, check with: +- All lint checks pass. You can run the following command to check: ```bash - $ pip install pyflakes - $ pyflakes path/to/module.py + $ pre-commit run --all-files ``` -- No PEP8 warnings, check with: + This will run the pre-commit checks on all files in the repository. - ```bash - $ pip install pep8 - $ pep8 path/to/module.py - ``` - -- AutoPEP8 can help you fix some of the easy redundant errors: +- All tests pass. You can run the following command to check: ```bash - $ pip install autopep8 - $ autopep8 path/to/pep8.py - ``` + $ pytest --durations=20 -v test/ --doctest-modules + ``` Bonus points for contributions that include a performance analysis with a benchmark script and profiling output (please report on the mailing diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index a265c79ec..c06994a1a 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -15,6 +15,30 @@ on: - '**' jobs: + + Lint: + runs-on: ubuntu-latest + strategy: + fail-fast: false + defaults: + run: + shell: bash -l {0} + steps: + + + - name: Checking Out Repository + uses: actions/checkout@v2 + # Install Python & Packages + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - run: which python + - name: Lint with pre-commit + run: | + pip install pre-commit + pre-commit install --install-hooks + pre-commit run --all-files + linux: runs-on: ubuntu-latest @@ -22,7 +46,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -44,26 +68,6 @@ jobs: - name: Upload coverage reports to Codecov with GitHub Action uses: codecov/codecov-action@v3 - pep8: - runs-on: ubuntu-latest - if: "!contains(github.event.head_commit.message, 'no pep8')" - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.x" - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools - pip install flake8 - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics - linux-minimal-deps: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..cd5a35594 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,51 @@ +repos: + # Ruff skada + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.2 + hooks: + - id: ruff + name: ruff lint + args: ["--fix"] + files: ^ot/ + - id: ruff + name: ruff lint preview + args: ["--fix", "--preview", "--select=NPY201"] + files: ^ot/ + - id: ruff + name: ruff lint doc, tutorials, tests and examples + # D103: missing docstring in public function + # D400: docstring first line must end with period + args: ["--ignore=D103,D400", "--fix"] + files: ^docs/|^examples/^test/ + - id: ruff-format + files: ^ot/|^docs/|^examples/| + + # Codespell + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + additional_dependencies: + - tomli + files: ^ot/|^docs/|^examples/ + types_or: [python, bib, rst, inc] + args: [ + "--ignore-words", + "ignore-words.txt", + ] + + # yamllint + - repo: https://github.com/adrienverge/yamllint.git + rev: v1.35.1 + hooks: + - id: yamllint + # args: [--strict] + +# # rstcheck + # - repo: https://github.com/rstcheck/rstcheck.git + # rev: v6.2.0 + # hooks: + # - id: rstcheck + # additional_dependencies: + # - tomli + # files: ^docs/source/.*\.(rst|inc)$ diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 000000000..e7c255105 --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,10 @@ +extends: default + +ignore: | + .github/workflows/*.yml + .circleci/config.yml + codecov.yml + +rules: + line-length: disable + document-start: disable diff --git a/RELEASES.md b/RELEASES.md index 821432548..d56f9aaa8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ - Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663). #### New features +- New linter based on pre-commit using ruff, codespell and yamllint (PR #681) - Added feature `mass=True` for `nx.kl_div` (PR #654) - Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649) - Added feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 7973c6b91..6787b54c9 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -8,10 +8,12 @@ def setup_backends(): if jax: from jax.config import config + config.update("jax_enable_x64", True) if tf: from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() @@ -36,10 +38,7 @@ def exec_bench(setup, tested_function, param_list, n_runs, warmup_runs): print(nx, param_list[i]) args = inputs[i] results_nx = nx._bench( - tested_function, - *args, - n_runs=n_runs, - warmup_runs=warmup_runs + tested_function, *args, n_runs=n_runs, warmup_runs=warmup_runs ) gc.collect() results_nx_with_param_in_key = dict() @@ -64,10 +63,11 @@ def convert_to_html_table(results, param_name, main_title=None, comments=None): assert cpus_cols + gpus_cols == len(devices_names) if main_title is not None: - string += f'{str(main_title)}\n' + string += ( + f'{str(main_title)}\n' + ) for i, bitsize in enumerate(bitsizes): - if i != 0: string += f' \n' diff --git a/benchmarks/emd.py b/benchmarks/emd.py index 861dab332..11ef5e322 100644 --- a/benchmarks/emd.py +++ b/benchmarks/emd.py @@ -3,11 +3,7 @@ import numpy as np import ot -from .benchmark import ( - setup_backends, - exec_bench, - convert_to_html_table -) +from .benchmark import setup_backends, exec_bench, convert_to_html_table def setup(n_samples): @@ -31,10 +27,12 @@ def setup(n_samples): tested_function=lambda a, M: ot.emd(a, a, M), param_list=param_list, n_runs=n_runs, - warmup_runs=warmup_runs + warmup_runs=warmup_runs, + ) + print( + convert_to_html_table( + results, + param_name="Sample size", + main_title=f"EMD - Averaged on {n_runs} runs", + ) ) - print(convert_to_html_table( - results, - param_name="Sample size", - main_title=f"EMD - Averaged on {n_runs} runs" - )) diff --git a/benchmarks/sinkhorn_knopp.py b/benchmarks/sinkhorn_knopp.py index ef0f22b90..1eeb87648 100644 --- a/benchmarks/sinkhorn_knopp.py +++ b/benchmarks/sinkhorn_knopp.py @@ -3,11 +3,7 @@ import numpy as np import ot -from .benchmark import ( - setup_backends, - exec_bench, - convert_to_html_table -) +from .benchmark import setup_backends, exec_bench, convert_to_html_table def setup(n_samples): @@ -33,10 +29,12 @@ def setup(n_samples): tested_function=lambda *args: ot.bregman.sinkhorn(*args, reg=1, stopThr=1e-7), param_list=param_list, n_runs=n_runs, - warmup_runs=warmup_runs + warmup_runs=warmup_runs, + ) + print( + convert_to_html_table( + results, + param_name="Sample size", + main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs", + ) ) - print(convert_to_html_table( - results, - param_name="Sample size", - main_title=f"Sinkhorn Knopp - Averaged on {n_runs} runs" - )) diff --git a/docs/nb_run_conv b/docs/nb_run_conv index adb47ace0..8a133c14e 100755 --- a/docs/nb_run_conv +++ b/docs/nb_run_conv @@ -9,7 +9,6 @@ Created on Fri Sep 1 16:43:45 2017 @author: rflamary """ -import sys import json import glob import hashlib @@ -17,10 +16,10 @@ import subprocess import os -cache_file = 'cache_nbrun' +cache_file = "cache_nbrun" -path_doc = 'source/auto_examples/' -path_nb = '../notebooks/' +path_doc = "source/auto_examples/" +path_nb = "../notebooks/" def load_json(fname): @@ -34,7 +33,7 @@ def load_json(fname): def save_json(fname, nb): - f = open(fname, 'w') + f = open(fname, "w") f.write(json.dumps(nb)) f.close() @@ -60,22 +59,45 @@ def to_update(fname, cache): def update(fname, cache): - - # jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte - subprocess.check_call(['cp', path_doc + fname, path_nb]) - print(' '.join(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace'])) - subprocess.check_call(['jupyter', 'nbconvert', '--to', 'notebook', '--ExecutePreprocessor.timeout=600', '--execute', path_nb + fname, '--inplace']) + # jupyter nbconvert --to notebook --execute mynotebook.ipynb --output target + subprocess.check_call(["cp", path_doc + fname, path_nb]) + print( + " ".join( + [ + "jupyter", + "nbconvert", + "--to", + "notebook", + "--ExecutePreprocessor.timeout=600", + "--execute", + path_nb + fname, + "--inplace", + ] + ) + ) + subprocess.check_call( + [ + "jupyter", + "nbconvert", + "--to", + "notebook", + "--ExecutePreprocessor.timeout=600", + "--execute", + path_nb + fname, + "--inplace", + ] + ) cache[fname] = md5(path_doc + fname) cache = load_json(cache_file) -lst_file = glob.glob(path_doc + '*.ipynb') +lst_file = glob.glob(path_doc + "*.ipynb") lst_file = [os.path.basename(name) for name in lst_file] for fname in lst_file: if to_update(fname, cache): - print('Updating file: {}'.format(fname)) + print("Updating file: {}".format(fname)) update(fname, cache) save_json(cache_file, cache) diff --git a/docs/rtd/conf.py b/docs/rtd/conf.py index cf6479bf5..09d04447b 100644 --- a/docs/rtd/conf.py +++ b/docs/rtd/conf.py @@ -1,6 +1,6 @@ from recommonmark.parser import CommonMarkParser -source_parsers = {'.md': CommonMarkParser} +source_parsers = {".md": CommonMarkParser} -source_suffix = ['.md'] -master_doc = 'index' +source_suffix = [".md"] +master_doc = "index" diff --git a/docs/source/conf.py b/docs/source/conf.py index c51b96ec4..f78197c66 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,11 +15,6 @@ import sys import os import re -try: - import sphinx_gallery - -except ImportError: - print("warning sphinx-gallery not installed") # !!!! allow readthedoc compilation @@ -38,7 +33,7 @@ def __getattr__(cls, name): return MagicMock() -MOCK_MODULES = ['cupy'] +MOCK_MODULES = ["cupy"] # 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers', sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # !!!! @@ -46,32 +41,32 @@ def __getattr__(cls, name): # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath(".")) +sys.path.insert(0, os.path.abspath("..")) sys.path.insert(0, os.path.abspath("../..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named #'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx_gallery.gen_gallery', - 'myst_parser', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx_gallery.gen_gallery", + "myst_parser", "sphinxcontrib.jquery", ] @@ -80,23 +75,23 @@ def __getattr__(cls, name): napoleon_numpy_docstring = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # source_suffix = '.rst' # The encoding of source files. -source_encoding = 'utf-8-sig' +source_encoding = "utf-8-sig" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'POT Python Optimal Transport' -copyright = u'2016-2023, POT Contributors' -author = u'Rémi Flamary, POT Contributors' +project = "POT Python Optimal Transport" +copyright = "2016-2023, POT Contributors" +author = "Rémi Flamary, POT Contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -105,7 +100,8 @@ def __getattr__(cls, name): __version__ = re.search( r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # It excludes inline comment too - open('../../ot/__init__.py').read()).group(1) + open("../../ot/__init__.py").read(), +).group(1) # The short X.Y version. version = __version__ # The full version, including alpha/beta/rc tags. @@ -120,9 +116,9 @@ def __getattr__(cls, name): # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -130,27 +126,27 @@ def __getattr__(cls, name): # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'default' +pygments_style = "default" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -160,7 +156,7 @@ def __getattr__(cls, name): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -169,106 +165,103 @@ def __getattr__(cls, name): html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -html_logo = '_static/images/logo_dark.svg' +html_logo = "_static/images/logo_dark.svg" # The name of an image file (relative to this directory) to use as a favicon of # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -#html_css_files = ["css/custom.css"] +html_static_path = ["_static"] +# html_css_files = ["css/custom.css"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'POTdoc' +htmlhelp_basename = "POTdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. #'preamble': '', - # Latex figure (float) alignment #'figure_align': 'htbp', } @@ -277,42 +270,38 @@ def __getattr__(cls, name): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'POT.tex', u'POT Python Optimal Transport library', - author, 'manual'), + (master_doc, "POT.tex", "POT Python Optimal Transport library", author, "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pot', u'POT Python Optimal Transport', - [author], 1) -] +man_pages = [(master_doc, "pot", "POT Python Optimal Transport", [author], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -321,44 +310,50 @@ def __getattr__(cls, name): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'POT', u'POT Python Optimal Transport', - author, 'POT', 'Python Optimal Transport', - 'Miscellaneous'), + ( + master_doc, + "POT", + "POT Python Optimal Transport", + author, + "POT", + "Python Optimal Transport", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False -autodoc_default_options = {'autosummary': True, - 'autosummary_imported_members': True} +autodoc_default_options = {"autosummary": True, "autosummary_imported_members": True} # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'jax': ('https://jax.readthedocs.io/en/latest/', None)} +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("http://docs.scipy.org/doc/scipy/reference/", None), + "matplotlib": ("http://matplotlib.org/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "jax": ("https://jax.readthedocs.io/en/latest/", None), +} sphinx_gallery_conf = { - 'examples_dirs': ['../../examples', '../../examples/da'], - 'gallery_dirs': 'auto_examples', - 'filename_pattern': 'plot_', # (?!barycenter_fgw) - 'nested_sections': False, - 'backreferences_dir': 'gen_modules/backreferences', - 'inspect_global_variables': True, - 'doc_module': ('ot', 'numpy', 'scipy', 'pylab'), - 'matplotlib_animations': True, - 'reference_url': { - 'ot': None} + "examples_dirs": ["../../examples", "../../examples/da"], + "gallery_dirs": "auto_examples", + "filename_pattern": "plot_", # (?!barycenter_fgw) + "nested_sections": False, + "backreferences_dir": "gen_modules/backreferences", + "inspect_global_variables": True, + "doc_module": ("ot", "numpy", "scipy", "pylab"), + "matplotlib_animations": True, + "reference_url": {"ot": None}, } diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py index 67c7077c9..8449f1f60 100644 --- a/examples/backends/plot_dual_ot_pytorch.py +++ b/examples/backends/plot_dual_ot_pytorch.py @@ -30,10 +30,10 @@ theta = 2 * np.pi / 20 noise_level = 0.1 -Xs, ys = ot.datasets.make_data_classif( - 'gaussrot', n_source_samples, nz=noise_level) +Xs, ys = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level) Xt, yt = ot.datasets.make_data_classif( - 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + "gaussrot", n_target_samples, theta=theta, nz=noise_level +) # one of the target mode changes its variance (no linear mapping) Xt[yt == 2] *= 3 @@ -46,10 +46,10 @@ pl.figure(1, (10, 5)) pl.clf() -pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') +pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples") +pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") # %% # Convert data to torch tensors @@ -76,10 +76,9 @@ losses = [] for i in range(n_iter): - # generate noise samples - # minus because we maximize te dual loss + # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) @@ -94,7 +93,7 @@ pl.figure(2) pl.plot(losses) pl.grid() -pl.title('Dual objective (negative)') +pl.title("Dual objective (negative)") pl.xlabel("Iterations") Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) @@ -106,10 +105,10 @@ pl.figure(3, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) -pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) -pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2) pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") # %% @@ -131,10 +130,9 @@ for i in range(n_iter): - # generate noise samples - # minus because we maximize te dual loss + # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) @@ -149,7 +147,7 @@ pl.figure(4) pl.plot(losses) pl.grid() -pl.title('Dual objective (negative)') +pl.title("Dual objective (negative)") pl.xlabel("Iterations") Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) @@ -162,7 +160,7 @@ pl.figure(5, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) -pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) -pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2) pl.legend(loc=0) -pl.title('OT plan with quadratic regularization') +pl.title("OT plan with quadratic regularization") diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py index 0ae28901f..d01cb56c5 100644 --- a/examples/backends/plot_optim_gromov_pytorch.py +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -50,13 +50,13 @@ def get_sbm(n, nc, ratio, P): for c1 in range(nc): for c2 in range(c1 + 1): if c1 == c2: - for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): for j in range(np.sum(nbpc[:c2]), i): if rng.rand() <= P[c1, c2]: C[i, j] = 1 else: - for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): - for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])): + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])): if rng.rand() <= P[c1, c2]: C[i, j] = 1 @@ -65,30 +65,32 @@ def get_sbm(n, nc, ratio, P): n = 100 nc = 3 -ratio = np.array([.5, .3, .2]) +ratio = np.array([0.5, 0.3, 0.2]) P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3))) C1 = get_sbm(n, nc, ratio, P) # get 2d position for nodes -x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1) +x1 = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C1) -def plot_graph(x, C, color='C0', s=None): +def plot_graph(x, C, color="C0", s=None): for j in range(C.shape[0]): for i in range(j): if C[i, j] > 0: - pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') - pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k") + pl.scatter( + x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k", cmap="tab10", vmax=9 + ) pl.figure(1, (10, 5)) pl.clf() pl.subplot(1, 2, 1) -plot_graph(x1, C1, color='C0') +plot_graph(x1, C1, color="C0") pl.title("SBM Graph") pl.axis("off") pl.subplot(1, 2, 2) -pl.imshow(C1, interpolation='nearest') +pl.imshow(C1, interpolation="nearest") pl.title("Adjacency matrix") pl.axis("off") @@ -104,7 +106,7 @@ def plot_graph(x, C, color='C0', s=None): def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): - """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + """solve min_a GW(C1,C2,a, a2) by gradient descent""" # use pyTorch for our data C1_torch = torch.tensor(C1) @@ -118,18 +120,17 @@ def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): loss_iter = [] for i in range(nb_iter_max): - loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() - #print("{:03d} | {}".format(i, loss_iter[-1])) + # print("{:03d} | {}".format(i, loss_iter[-1])) # performs a step of projected gradient descent with torch.no_grad(): grad = a1_torch.grad - a1_torch -= grad * lr # step + a1_torch -= grad * lr # step a1_torch.grad.zero_() a1_torch.data = ot.utils.proj_simplex(a1_torch) @@ -158,7 +159,7 @@ def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): # ------------------------------------------------------- # The GW OT plan can be used to perform a clustering of the nodes of a graph # when computing the GW with a simple template like C0 by labeling nodes in -# the original graph using by the index of the noe in the template receiving +# the original graph using by the index of the node in the template receiving # the most mass. # # We show here the result of such a clustering when using uniform weights on @@ -194,7 +195,7 @@ def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): def graph_compression_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): - """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + """solve min_a GW(C1,C2,a, a2) by gradient descent""" # use pyTorch for our data @@ -210,23 +211,22 @@ def graph_compression_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): loss_iter = [] for i in range(nb_iter_max): - loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() - #print("{:03d} | {}".format(i, loss_iter[-1])) + # print("{:03d} | {}".format(i, loss_iter[-1])) # performs a step of projected gradient descent with torch.no_grad(): grad = a1_torch.grad - a1_torch -= grad * lr # step + a1_torch -= grad * lr # step a1_torch.grad.zero_() a1_torch.data = ot.utils.proj_simplex(a1_torch) grad = C1_torch.grad - C1_torch -= grad * lr # step + C1_torch -= grad * lr # step C1_torch.grad.zero_() C1_torch.data = torch.clamp(C1_torch, 0, 1) @@ -237,8 +237,9 @@ def graph_compression_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): nb_nodes = 3 -a0_est2, C0_est2, loss_iter2 = graph_compression_gw(nb_nodes, C1, ot.unif(n), - nb_iter_max=100, lr=5e-2) +a0_est2, C0_est2, loss_iter2 = graph_compression_gw( + nb_nodes, C1, ot.unif(n), nb_iter_max=100, lr=5e-2 +) pl.figure(4) pl.plot(loss_iter2) @@ -252,8 +253,8 @@ def graph_compression_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): pl.clf() pl.subplot(1, 2, 1) pl.imshow(P, vmin=0, vmax=1) -pl.title('True SBM P matrix') +pl.title("True SBM P matrix") pl.subplot(1, 2, 2) pl.imshow(C0_est2, vmin=0, vmax=1) -pl.title('Estimated C0 matrix') +pl.title("Estimated C0 matrix") pl.colorbar() diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 7cbfd983f..33d8b92be 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -33,15 +33,14 @@ # %% # Loading the data - import numpy as np import matplotlib.pylab as pl import torch import ot import matplotlib.animation as animation -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2] +I1 = pl.imread("../../data/redcross.png").astype(np.float64)[::5, ::5, 2] +I2 = pl.imread("../../data/tooth.png").astype(np.float64)[::5, ::5, 2] sz = I2.shape[0] XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -78,8 +77,9 @@ gen.manual_seed(42) for i in range(nb_iter_max): - - loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + loss = ot.sliced_wasserstein_distance( + x1_torch, x2_torch, n_projections=20, seed=gen + ) loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() @@ -94,10 +94,10 @@ xb = x1_torch.clone().detach().cpu().numpy() pl.figure(2, (8, 4)) -pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') -pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') -pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') -pl.title('Sliced Wasserstein gradient flow') +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label="$\mu^{(0)}$") +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r"$\nu$") +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label="$\mu^{(100)}$") +pl.title("Sliced Wasserstein gradient flow") pl.legend() ax = pl.axis() @@ -110,15 +110,17 @@ def _update_plot(i): pl.clf() - pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') - pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') - pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') - pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label="$\mu^{(0)}$") + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r"$\nu$") + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label="$\mu^{(100)}$") + pl.title("Sliced Wasserstein gradient flow Iter. {}".format(i)) pl.axis(ax) return 1 -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000 +) # %% # Compute the Sliced Wasserstein Barycenter @@ -142,9 +144,11 @@ def _update_plot(i): alpha = 0.5 for i in range(nb_iter_max): - - loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ - + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + loss = alpha * ot.sliced_wasserstein_distance( + xbary_torch, x3_torch, n_projections=50, seed=gen + ) + (1 - alpha) * ot.sliced_wasserstein_distance( + xbary_torch, x1_torch, n_projections=50, seed=gen + ) loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() @@ -159,10 +163,10 @@ def _update_plot(i): xb = xbary_torch.clone().detach().cpu().numpy() pl.figure(4, (8, 4)) -pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') -pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') -pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') -pl.title('Sliced Wasserstein barycenter') +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label="$\mu$") +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r"$\nu$") +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label="Barycenter") +pl.title("Sliced Wasserstein barycenter") pl.legend() ax = pl.axis() @@ -176,12 +180,14 @@ def _update_plot(i): def _update_plot(i): pl.clf() - pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') - pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') - pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') - pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label="$\mu^{(0)}$") + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r"$\nu$") + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label="$\mu^{(100)}$") + pl.title("Sliced Wasserstein barycenter Iter. {}".format(i)) pl.axis(ax) return 1 -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000 +) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py index 7459cf6f2..5420fea97 100644 --- a/examples/backends/plot_ssw_unif_torch.py +++ b/examples/backends/plot_ssw_unif_torch.py @@ -44,6 +44,7 @@ # Plot data # --------- + def plot_sphere(ax): xlist = np.linspace(-1.0, 1.0, 50) ylist = np.linspace(-1.0, 1.0, 50) @@ -52,16 +53,16 @@ def plot_sphere(ax): Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0)) - ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) - ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + ax.plot_wireframe(X, Y, Z, color="gray", alpha=0.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=0.3) # Now plot the bottom half # plot the distributions pl.figure(1) -ax = pl.axes(projection='3d') +ax = pl.axes(projection="3d") plot_sphere(ax) -ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) -ax.set_title('Data distribution') +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label="Data samples", alpha=0.5) +ax.set_title("Data distribution") ax.legend() @@ -94,7 +95,7 @@ def plot_sphere(ax): pl.figure(1) pl.semilogy(losses) pl.grid() -pl.title('SSW') +pl.title("SSW") pl.xlabel("Iterations") @@ -108,11 +109,17 @@ def plot_sphere(ax): for i in range(9): # pl.subplot(3, 3, i + 1) # ax = pl.axes(projection='3d') - ax = fig.add_subplot(3, 3, i + 1, projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection="3d") plot_sphere(ax) - ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) - ax.set_title('Iter. {}'.format(ivisu[i])) - #ax.axis("off") + ax.scatter( + xvisu[ivisu[i], :, 0], + xvisu[ivisu[i], :, 1], + xvisu[ivisu[i], :, 2], + label="Data samples", + alpha=0.5, + ) + ax.set_title("Iter. {}".format(ivisu[i])) + # ax.axis("off") if i == 0: ax.legend() @@ -127,27 +134,37 @@ def plot_sphere(ax): def _update_plot(i): i = 3 * i pl.clf() - ax = pl.axes(projection='3d') + ax = pl.axes(projection="3d") plot_sphere(ax) - ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.scatter( + xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label="Data samples$", alpha=0.5 + ) ax.axis("off") ax.set_xlim((-1.5, 1.5)) ax.set_ylim((-1.5, 1.5)) - ax.set_title('Iter. {}'.format(i)) + ax.set_title("Iter. {}".format(i)) return 1 print(xvisu.shape) i = 0 -ax = pl.axes(projection='3d') +ax = pl.axes(projection="3d") plot_sphere(ax) -ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.scatter( + xvisu[i, :, 0], + xvisu[i, :, 1], + xvisu[i, :, 2], + label="Data samples from $G\#\mu_n$", + alpha=0.5, +) ax.axis("off") ax.set_xlim((-1.5, 1.5)) ax.set_ylim((-1.5, 1.5)) -ax.set_title('Iter. {}'.format(ivisu[i])) +ax.set_title("Iter. {}".format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000 +) # %% diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py index e64298698..d44a86f64 100644 --- a/examples/backends/plot_stoch_continuous_ot_pytorch.py +++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py @@ -45,11 +45,11 @@ nvisu = 300 pl.figure(1, (5, 5)) pl.clf() -pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5) -pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5) +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", label="Source samples", alpha=0.5) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", label="Target samples", alpha=0.5) pl.legend(loc=0) ax_bounds = pl.axis() -pl.title('Source and target distributions') +pl.title("Source and target distributions") # %% # Convert data to torch tensors @@ -86,7 +86,7 @@ def forward(self, x): reg = 1 -optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) +optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=0.005) # number of iteration n_iter = 500 @@ -96,7 +96,6 @@ def forward(self, x): losses = [] for i in range(n_iter): - # generate noise samples iperms = torch.randint(0, n_source_samples, (n_batch,)) @@ -105,7 +104,7 @@ def forward(self, x): xsi = xs[iperms] xti = xt[ipermt] - # minus because we maximize te dual loss + # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) losses.append(float(loss.detach())) @@ -120,7 +119,7 @@ def forward(self, x): pl.figure(2) pl.plot(losses) pl.grid() -pl.title('Dual objective (negative)') +pl.title("Dual objective (negative)") pl.xlabel("Iterations") @@ -137,7 +136,7 @@ def forward(self, x): xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) -wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2)) +wxg = np.exp(-((xg[:, 0] - 4) ** 2 + (xg[:, 1] - 4) ** 2) / (2 * 2)) wxg = wxg / np.sum(wxg) xg = torch.tensor(xg) @@ -149,41 +148,74 @@ def forward(self, x): pl.subplot(1, 3, 1) iv = 2 -Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = ot.stochastic.plan_dual_entropic( + u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg +) Gg = Gg.reshape((nv, nv)).detach().numpy() -pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) -pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) -pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') -pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) +pl.scatter( + Xs[iv : iv + 1, 0], + Xs[iv : iv + 1, 1], + s=100, + marker="+", + label="Source sample", + zorder=2, + alpha=1, + color="C0", +) +pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() -pl.title('Density of transported source sample') +pl.title("Density of transported source sample") pl.subplot(1, 3, 2) iv = 3 -Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = ot.stochastic.plan_dual_entropic( + u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg +) Gg = Gg.reshape((nv, nv)).detach().numpy() -pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) -pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) -pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') -pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) +pl.scatter( + Xs[iv : iv + 1, 0], + Xs[iv : iv + 1, 1], + s=100, + marker="+", + label="Source sample", + zorder=2, + alpha=1, + color="C0", +) +pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() -pl.title('Density of transported source sample') +pl.title("Density of transported source sample") pl.subplot(1, 3, 3) iv = 6 -Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = ot.stochastic.plan_dual_entropic( + u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg +) Gg = Gg.reshape((nv, nv)).detach().numpy() -pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) -pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) -pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') -pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported source sample') +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker="+", zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker="o", zorder=2, alpha=0.05) +pl.scatter( + Xs[iv : iv + 1, 0], + Xs[iv : iv + 1, 1], + s=100, + marker="+", + label="Source sample", + zorder=2, + alpha=1, + color="C0", +) +pl.pcolormesh(XX, YY, Gg, cmap="Greens", label="Density of transported source sample") pl.legend(loc=0) ax_bounds = pl.axis() -pl.title('Density of transported source sample') +pl.title("Density of transported source sample") diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py index e47a5e085..1f48cf214 100644 --- a/examples/backends/plot_unmix_optim_torch.py +++ b/examples/backends/plot_unmix_optim_torch.py @@ -45,7 +45,7 @@ # Generate data # ------------- -#%% Data +# %% Data nt = 100 nt1 = 10 # @@ -80,13 +80,13 @@ # Plot data # --------- -#%% plot the distributions +# %% plot the distributions pl.figure(1) -pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) -pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) -pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) -pl.title('Sources and Target distributions') +pl.scatter(xt[:, 0], xt[:, 1], label="Target $\mu^t$", alpha=0.5) +pl.scatter(xs1[:, 0], xs1[:, 1], label="Source $\mu^s_1$", alpha=0.5) +pl.scatter(xs2[:, 0], xs2[:, 1], label="Source $\mu^s_2$", alpha=0.5) +pl.title("Sources and Target distributions") pl.legend() @@ -95,7 +95,7 @@ # ------------------------------------------------------ -#%% Weights optimization with gradient descent +# %% Weights optimization with gradient descent # convert numpy arrays to torch tensors H2 = torch.tensor(H) @@ -120,7 +120,6 @@ def get_loss(w): for i in range(niter): - loss = get_loss(w) losses.append(float(loss)) @@ -138,12 +137,12 @@ def get_loss(w): # -------------------------------------------------- we = w.detach().numpy() -print('Estimated mixture:', we) +print("Estimated mixture:", we) pl.figure(2) pl.semilogy(losses) pl.grid() -pl.title('Wasserstein distance') +pl.title("Wasserstein distance") pl.xlabel("Iterations") ############################################################################## @@ -155,7 +154,14 @@ def get_loss(w): # compute source weights ws = H.dot(we) -pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) -pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) -pl.title('Target and reweighted source distributions') +pl.scatter(xt[:, 0], xt[:, 1], label="Target $\mu^t$", alpha=0.5) +pl.scatter( + xs[:, 0], + xs[:, 1], + color="C3", + s=ws * 20 * ns, + label="Weighted sources $\sum_{k} w_k\mu^s_k$", + alpha=0.5, +) +pl.title("Target and reweighted source distributions") pl.legend() diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py index 5a8579514..75dece1fb 100644 --- a/examples/backends/plot_wass1d_torch.py +++ b/examples/backends/plot_wass1d_torch.py @@ -30,8 +30,8 @@ from ot.datasets import make_1D_gauss as gauss from ot.utils import proj_simplex -red = np.array(mpl.colors.to_rgb('red')) -blue = np.array(mpl.colors.to_rgb('blue')) +red = np.array(mpl.colors.to_rgb("red")) +blue = np.array(mpl.colors.to_rgb("blue")) n = 100 # nb bins @@ -61,8 +61,8 @@ loss_iter = [] pl.figure(1, figsize=(8, 4)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") for i in range(nb_iter_max): # Compute the Wasserstein 1D with torch backend @@ -81,15 +81,17 @@ # plot one curve every 10 iterations if i % 10 == 0: mix = float(i) / nb_iter_max - pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + pl.plot( + x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red + ) pl.legend() -pl.title('Distribution along the iterations of the projected gradient descent') +pl.title("Distribution along the iterations of the projected gradient descent") pl.show() pl.figure(2) pl.plot(range(nb_iter_max), loss_iter, lw=3) -pl.title('Evolution of the loss along iterations', fontsize=16) +pl.title("Evolution of the loss along iterations", fontsize=16) pl.show() # %% @@ -126,7 +128,9 @@ for i in range(nb_iter_max): # Compute the Wasserstein 1D with torch backend - loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + loss = (1 - t) * wasserstein_1d( + x_torch, x_torch, a_torch.detach(), bary_torch, p=2 + ) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) # record the corresponding loss value loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() @@ -139,14 +143,14 @@ bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex pl.figure(3, figsize=(8, 4)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') -pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c="green", label="W barycenter") pl.legend() -pl.title('Wasserstein barycenter computed by gradient descent') +pl.title("Wasserstein barycenter computed by gradient descent") pl.show() pl.figure(4) pl.plot(range(nb_iter_max), loss_iter, lw=3) -pl.title('Evolution of the loss along iterations', fontsize=16) +pl.title("Evolution of the loss along iterations", fontsize=16) pl.show() diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index f39d186de..806803715 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -81,8 +81,8 @@ def get_data(n_samples): # plot the distributions x = get_data(500) pl.figure(1) -pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5) -pl.title('Data distribution') +pl.scatter(x[:, 0], x[:, 1], label="Data samples from $\mu_d$", alpha=0.5) +pl.title("Data distribution") pl.legend() @@ -90,6 +90,7 @@ def get_data(n_samples): # Generator Model # --------------- + # define the MLP model class Generator(torch.nn.Module): def __init__(self): @@ -107,6 +108,7 @@ def forward(self, x): output = self.fc3(output) return output + # %% # Training the model # ------------------ @@ -129,7 +131,6 @@ def forward(self, x): for i in range(n_iter): - # generate noise samples xn = torch.randn(size_batch, n_features) @@ -139,7 +140,7 @@ def forward(self, x): # generate sample along iterations xvisu[i, :, :] = G(xnvisu).detach() - # generate smaples and compte distance matrix + # generate samples and compte distance matrix xg = G(xn) M = ot.dist(xg, xd) @@ -158,7 +159,7 @@ def forward(self, x): pl.figure(2) pl.semilogy(losses) pl.grid() -pl.title('Wasserstein distance') +pl.title("Wasserstein distance") pl.xlabel("Iterations") @@ -173,11 +174,16 @@ def forward(self, x): for i in range(9): pl.subplot(3, 3, i + 1) - pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) - pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1) + pl.scatter( + xvisu[ivisu[i], :, 0], + xvisu[ivisu[i], :, 1], + label="Data samples from $G\#\mu_n$", + alpha=0.5, + ) pl.xticks(()) pl.yticks(()) - pl.title('Iter. {}'.format(ivisu[i])) + pl.title("Iter. {}".format(ivisu[i])) if i == 0: pl.legend() @@ -190,27 +196,33 @@ def forward(self, x): def _update_plot(i): pl.clf() - pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) - pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1) + pl.scatter( + xvisu[i, :, 0], xvisu[i, :, 1], label="Data samples from $G\#\mu_n$", alpha=0.5 + ) pl.xticks(()) pl.yticks(()) pl.xlim((-1.5, 1.5)) pl.ylim((-1.5, 1.5)) - pl.title('Iter. {}'.format(i)) + pl.title("Iter. {}".format(i)) return 1 i = 0 -pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) -pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.1) +pl.scatter( + xvisu[i, :, 0], xvisu[i, :, 1], label="Data samples from $G\#\mu_n$", alpha=0.5 +) pl.xticks(()) pl.yticks(()) pl.xlim((-1.5, 1.5)) pl.ylim((-1.5, 1.5)) -pl.title('Iter. {}'.format(ivisu[i])) +pl.title("Iter. {}".format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000 +) # %% # Generate and visualize data @@ -222,7 +234,7 @@ def _update_plot(i): x = G(xn).detach().numpy() pl.figure(5) -pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) -pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) -pl.title('Sources and Target distributions') +pl.scatter(xd[:, 0], xd[:, 1], label="Data samples from $\mu_d$", alpha=0.5) +pl.scatter(x[:, 0], x[:, 1], label="Data samples from $G\#\mu_n$", alpha=0.5) +pl.title("Sources and Target distributions") pl.legend() diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 7c17c9b22..c3229b3aa 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -23,6 +23,7 @@ import numpy as np import matplotlib.pyplot as plt import ot + # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection @@ -31,7 +32,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -54,7 +55,7 @@ # Barycenter computation # ---------------------- -#%% barycenter computation +# %% barycenter computation alpha = 0.2 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) @@ -68,11 +69,11 @@ f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1) ax1.plot(x, A, color="black") -ax1.set_title('Distributions') +ax1.set_title("Distributions") -ax2.plot(x, bary_l2, 'r', label='l2') -ax2.plot(x, bary_wass, 'g', label='Wasserstein') -ax2.set_title('Barycenters') +ax2.plot(x, bary_l2, "r", label="l2") +ax2.plot(x, bary_wass, "g", label="Wasserstein") +ax2.set_title("Barycenters") plt.legend() plt.show() @@ -81,7 +82,7 @@ # Barycentric interpolation # ------------------------- -#%% barycenter interpolation +# %% barycenter interpolation n_alpha = 11 alpha_list = np.linspace(0, 1, n_alpha) @@ -97,50 +98,50 @@ B_l2[:, i] = A.dot(weights) B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) -#%% plot interpolation +# %% plot interpolation plt.figure(2) -cmap = plt.get_cmap('viridis') +cmap = plt.get_cmap("viridis") verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = plt.gcf().add_subplot(projection='3d') +ax = plt.gcf().add_subplot(projection="3d") poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') -ax.set_xlabel('x') +ax.add_collection3d(poly, zs=zs, zdir="y") +ax.set_xlabel("x") ax.set_xlim3d(0, n) -ax.set_ylabel('$\\alpha$') +ax.set_ylabel("$\\alpha$") ax.set_ylim3d(0, 1) -ax.set_zlabel('') +ax.set_zlabel("") ax.set_zlim3d(0, B_l2.max() * 1.01) -plt.title('Barycenter interpolation with l2') +plt.title("Barycenter interpolation with l2") plt.tight_layout() plt.figure(3) -cmap = plt.get_cmap('viridis') +cmap = plt.get_cmap("viridis") verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = plt.gcf().add_subplot(projection='3d') +ax = plt.gcf().add_subplot(projection="3d") poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') -ax.set_xlabel('x') +ax.add_collection3d(poly, zs=zs, zdir="y") +ax.set_xlabel("x") ax.set_xlim3d(0, n) -ax.set_ylabel('$\\alpha$') +ax.set_ylabel("$\\alpha$") ax.set_ylim3d(0, 1) -ax.set_zlabel('') +ax.set_zlabel("") ax.set_zlim3d(0, B_l2.max() * 1.01) -plt.title('Barycenter interpolation with Wasserstein') +plt.title("Barycenter interpolation with Wasserstein") plt.tight_layout() plt.show() diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index 9c4ca4551..48a6d11a3 100644 --- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -26,17 +26,18 @@ import numpy as np import matplotlib.pylab as pl import ot + # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection # noqa -#import ot.lp.cvx as cvx +# import ot.lp.cvx as cvx ############################################################################## # Gaussian Data # ------------- -#%% parameters +# %% parameters problems = [] @@ -59,15 +60,15 @@ M /= M.max() -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.tight_layout() -#%% barycenter computation +# %% barycenter computation alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) @@ -91,14 +92,14 @@ pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') -pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.plot(x, bary_l2, "r", label="l2") +pl.plot(x, bary_wass, "g", label="Reg Wasserstein") +pl.plot(x, bary_wass2, "b", label="LP Wasserstein") pl.legend() -pl.title('Barycenters') +pl.title("Barycenters") pl.tight_layout() problems.append([A, [bary_l2, bary_wass, bary_wass2]]) @@ -107,7 +108,7 @@ # Stair Data # ---------- -#%% parameters +# %% parameters a1 = 1.0 * (x > 10) * (x < 50) a2 = 1.0 * (x > 60) * (x < 80) @@ -124,16 +125,16 @@ M /= M.max() -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.tight_layout() -#%% barycenter computation +# %% barycenter computation alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) @@ -160,14 +161,14 @@ pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') -pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.plot(x, bary_l2, "r", label="l2") +pl.plot(x, bary_wass, "g", label="Reg Wasserstein") +pl.plot(x, bary_wass2, "b", label="LP Wasserstein") pl.legend() -pl.title('Barycenters') +pl.title("Barycenters") pl.tight_layout() @@ -175,14 +176,14 @@ # Dirac Data # ---------- -#%% parameters +# %% parameters a1 = np.zeros(n) a2 = np.zeros(n) -a1[10] = .25 -a1[20] = .5 -a1[30] = .25 +a1[10] = 0.25 +a1[20] = 0.5 +a1[30] = 0.25 a2[80] = 1 @@ -198,16 +199,16 @@ M /= M.max() -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.tight_layout() -#%% barycenter computation +# %% barycenter computation alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) @@ -234,14 +235,14 @@ pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') -pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') +pl.plot(x, bary_l2, "r", label="l2") +pl.plot(x, bary_wass, "g", label="Reg Wasserstein") +pl.plot(x, bary_wass2, "b", label="LP Wasserstein") pl.legend() -pl.title('Barycenters') +pl.title("Barycenters") pl.tight_layout() @@ -250,17 +251,16 @@ # ------------ # -#%% plot +# %% plot nbm = len(problems) -nbm2 = (nbm // 2) +nbm2 = nbm // 2 pl.figure(2, (20, 6)) pl.clf() for i in range(nbm): - A = problems[i][0] bary_l2 = problems[i][1][0] bary_wass = problems[i][1][1] @@ -270,19 +270,19 @@ for j in range(n_distributions): pl.plot(x, A[:, j]) if i == nbm2: - pl.title('Distributions') + pl.title("Distributions") pl.xticks(()) pl.yticks(()) pl.subplot(2, nbm, 1 + i + nbm) - pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)') - pl.plot(x, bary_wass, 'g', label='Reg Wasserstein') - pl.plot(x, bary_wass2, 'b', label='LP Wasserstein') + pl.plot(x, bary_l2, "r", label="L2 (Euclidean)") + pl.plot(x, bary_wass, "g", label="Reg Wasserstein") + pl.plot(x, bary_wass2, "b", label="LP Wasserstein") if i == nbm - 1: pl.legend() if i == nbm2: - pl.title('Barycenters') + pl.title("Barycenters") pl.xticks(()) pl.yticks(()) diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index 143b3a6e5..d9da1a01c 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -1,5 +1,4 @@ - -#%% +# %% # -*- coding: utf-8 -*- """ ============================================ @@ -26,13 +25,13 @@ # # The four distributions are constructed from 4 simple images -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") -f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[::2, ::2, 2] -f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[::2, ::2, 2] -f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[::2, ::2, 2] -f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[::2, ::2, 2] +f1 = 1 - plt.imread(os.path.join(data_path, "redcross.png"))[::2, ::2, 2] +f2 = 1 - plt.imread(os.path.join(data_path, "tooth.png"))[::2, ::2, 2] +f3 = 1 - plt.imread(os.path.join(data_path, "heart.png"))[::2, ::2, 2] +f4 = 1 - plt.imread(os.path.join(data_path, "duck.png"))[::2, ::2, 2] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) @@ -56,8 +55,8 @@ # fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) -plt.suptitle('Convolutional Wasserstein Barycenters in POT') -cm = 'Blues' +plt.suptitle("Convolutional Wasserstein Barycenters in POT") +cm = "Blues" # regularization parameter reg = 0.004 for i in range(nb_images): @@ -81,9 +80,8 @@ else: # call to barycenter computation axes[i, j].imshow( - ot.bregman.convolutional_barycenter2d(A, reg, weights), - cmap=cm + ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm ) - axes[i, j].axis('off') + axes[i, j].axis("off") plt.tight_layout() plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py index 2a603dd9c..f8f271947 100644 --- a/examples/barycenters/plot_debiased_barycenter.py +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -24,15 +24,18 @@ import matplotlib.pyplot as plt import ot -from ot.bregman import (barycenter, barycenter_debiased, - convolutional_barycenter2d, - convolutional_barycenter2d_debiased) +from ot.bregman import ( + barycenter, + barycenter_debiased, + convolutional_barycenter2d, + convolutional_barycenter2d_debiased, +) ############################################################################## # Debiased barycenter of 1D Gaussians # ------------------------------------ -#%% parameters +# %% parameters n = 100 # nb bins @@ -51,7 +54,7 @@ M = ot.utils.dist0(n) M /= M.max() -#%% barycenter computation +# %% barycenter computation alpha = 0.2 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) @@ -64,8 +67,9 @@ labels = ["Sinkhorn barycenter", "Debiased barycenter"] colors = ["indianred", "gold"] -f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, - figsize=(12, 4), num=1) +f, axes = plt.subplots( + 1, len(epsilons), tight_layout=True, sharey=True, figsize=(12, 4), num=1 +) for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) @@ -79,10 +83,10 @@ ############################################################################## # Debiased barycenter of 2D images # --------------------------------- -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] -f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") +f1 = 1 - plt.imread(os.path.join(data_path, "heart.png"))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, "duck.png"))[:, :, 2] A = np.asarray([f1, f2]) + 1e-2 A /= A.sum(axis=(1, 2))[:, None, None] @@ -121,10 +125,10 @@ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13) ax.set_xticks([]) ax.set_yticks([]) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['left'].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) if ii == 0: ax.set_ylabel(method, fontsize=15) fig.tight_layout() diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index b6a4a113a..b22b1c6ee 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -29,8 +29,8 @@ N = 2 d = 2 -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] -I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2] +I1 = pl.imread("../../data/redcross.png").astype(np.float64)[::4, ::4, 2] +I2 = pl.imread("../../data/duck.png").astype(np.float64)[::4, ::4, 2] sz = I2.shape[0] XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -45,7 +45,7 @@ pl.figure(1, (12, 4)) pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) -pl.title('Distributions') +pl.title("Distributions") # %% @@ -53,8 +53,10 @@ # ------------------------------------------- k = 200 # number of Diracs of the barycenter -X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations -b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) +X_init = np.random.normal(0.0, 1.0, (k, d)) # initial Dirac locations +b = ( + np.ones((k,)) / k +) # weights of the barycenter (it will not be optimized, only the locations are optimized) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) @@ -65,8 +67,8 @@ pl.figure(2, (8, 3)) pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) -pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') -pl.title('Data measures and their barycenter') +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker="s", label="2-Wasserstein barycenter") +pl.title("Data measures and their barycenter") pl.legend(loc="lower right") pl.show() @@ -74,10 +76,14 @@ # Compute free support Sinkhorn barycenter k = 200 # number of Diracs of the barycenter -X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations -b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) +X_init = np.random.normal(0.0, 1.0, (k, d)) # initial Dirac locations +b = ( + np.ones((k,)) / k +) # weights of the barycenter (it will not be optimized, only the locations are optimized) -X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) +X = ot.bregman.free_support_sinkhorn_barycenter( + measures_locations, measures_weights, X_init, 20, b, numItermax=15 +) # %% # Plot the Wasserstein barycenter @@ -86,7 +92,7 @@ pl.figure(2, (8, 3)) pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) -pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') -pl.title('Data measures and their barycenter') +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker="s", label="2-Wasserstein barycenter") +pl.title("Data measures and their barycenter") pl.legend(loc="lower right") pl.show() diff --git a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py index ebe1f3b75..a8aa50a95 100644 --- a/examples/barycenters/plot_free_support_sinkhorn_barycenter.py +++ b/examples/barycenters/plot_free_support_sinkhorn_barycenter.py @@ -29,15 +29,20 @@ # ------------- X1 = np.random.randn(200, 2) -X2 = 2 * np.concatenate([ - np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), - np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), - np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), - np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), -], axis=0) +X2 = 2 * np.concatenate( + [ + np.concatenate([-np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), + np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), + np.concatenate([np.linspace(1, -1, 50)[:, None], -np.ones([50, 1])], axis=1), + ], + axis=0, +) X3 = np.random.randn(200, 2) X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) -X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) +X4 = np.random.multivariate_normal( + np.array([0, 0]), np.array([[1.0, 0.5], [0.5, 1.0]]), size=200 +) a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) @@ -47,26 +52,26 @@ fig, axes = plt.subplots(1, 4, figsize=(16, 4)) -axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') -axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') -axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') -axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') +axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c="steelblue", edgecolor="k") +axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c="steelblue", edgecolor="k") +axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c="steelblue", edgecolor="k") +axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c="steelblue", edgecolor="k") axes[0].set_xlim([-3, 3]) axes[0].set_ylim([-3, 3]) -axes[0].set_title('Distribution 1') +axes[0].set_title("Distribution 1") axes[1].set_xlim([-3, 3]) axes[1].set_ylim([-3, 3]) -axes[1].set_title('Distribution 2') +axes[1].set_title("Distribution 2") axes[2].set_xlim([-3, 3]) axes[2].set_ylim([-3, 3]) -axes[2].set_title('Distribution 3') +axes[2].set_title("Distribution 3") axes[3].set_xlim([-3, 3]) axes[3].set_ylim([-3, 3]) -axes[3].set_title('Distribution 4') +axes[3].set_title("Distribution 4") plt.tight_layout() plt.show() @@ -77,12 +82,14 @@ fig = plt.figure(figsize=(10, 10)) -weights = np.array([ - [3 / 3, 0 / 3], - [2 / 3, 1 / 3], - [1 / 3, 2 / 3], - [0 / 3, 3 / 3], -]).astype(np.float32) +weights = np.array( + [ + [3 / 3, 0 / 3], + [2 / 3, 1 / 3], + [1 / 3, 2 / 3], + [0 / 3, 3 / 3], + ] +).astype(np.float32) for k in range(4): XB_init = np.random.randn(n_samples, 2) @@ -93,10 +100,10 @@ X_init=XB_init, reg=reg, numItermax=numItermax, - numInnerItermax=numInnerItermax + numInnerItermax=numInnerItermax, ) ax = plt.subplot2grid((4, 4), (0, k)) - ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.scatter(XB[:, 0], XB[:, 1], color="steelblue", edgecolor="k") ax.set_xlim([-3, 3]) ax.set_ylim([-3, 3]) @@ -109,10 +116,10 @@ X_init=XB_init, reg=reg, numItermax=numItermax, - numInnerItermax=numInnerItermax + numInnerItermax=numInnerItermax, ) ax = plt.subplot2grid((4, 4), (k, 0)) - ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.scatter(XB[:, 0], XB[:, 1], color="steelblue", edgecolor="k") ax.set_xlim([-3, 3]) ax.set_ylim([-3, 3]) @@ -125,10 +132,10 @@ X_init=XB_init, reg=reg, numItermax=numItermax, - numInnerItermax=numInnerItermax + numInnerItermax=numInnerItermax, ) ax = plt.subplot2grid((4, 4), (3, k)) - ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.scatter(XB[:, 0], XB[:, 1], color="steelblue", edgecolor="k") ax.set_xlim([-3, 3]) ax.set_ylim([-3, 3]) @@ -141,10 +148,10 @@ X_init=XB_init, reg=reg, numItermax=numItermax, - numInnerItermax=numInnerItermax + numInnerItermax=numInnerItermax, ) ax = plt.subplot2grid((4, 4), (k, 3)) - ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') + ax.scatter(XB[:, 0], XB[:, 1], color="steelblue", edgecolor="k") ax.set_xlim([-3, 3]) ax.set_ylim([-3, 3]) diff --git a/examples/barycenters/plot_gaussian_barycenter.py b/examples/barycenters/plot_gaussian_barycenter.py index c36b5daa9..60e08348a 100644 --- a/examples/barycenters/plot_gaussian_barycenter.py +++ b/examples/barycenters/plot_gaussian_barycenter.py @@ -45,7 +45,6 @@ def draw_cov(mu, C, color=None, label=None, nstd=1): - def eigsorted(cov): vals, vecs = np.linalg.eigh(cov) order = vals.argsort()[::-1] @@ -54,11 +53,19 @@ def eigsorted(cov): vals, vecs = eigsorted(C) theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) w, h = 2 * nstd * np.sqrt(vals) - ell = Ellipse(xy=(mu[0], mu[1]), - width=w, height=h, alpha=0.5, - angle=theta, facecolor=color, edgecolor=color, label=label, fill=True) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=0.5, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) pl.gca().add_artist(ell) - #pl.scatter(mu[0],mu[1],color=color, marker='x') + # pl.scatter(mu[0],mu[1],color=color, marker='x') axis = [-1.5, 5.5, -1.5, 5.5] @@ -67,24 +74,24 @@ def eigsorted(cov): pl.clf() pl.subplot(1, 4, 1) -draw_cov(m1, C1, color='C0') +draw_cov(m1, C1, color="C0") pl.axis(axis) -pl.title('$\mathcal{N}(m_1,\Sigma_1)$') +pl.title("$\mathcal{N}(m_1,\Sigma_1)$") pl.subplot(1, 4, 2) -draw_cov(m2, C2, color='C1') +draw_cov(m2, C2, color="C1") pl.axis(axis) -pl.title('$\mathcal{N}(m_2,\Sigma_2)$') +pl.title("$\mathcal{N}(m_2,\Sigma_2)$") pl.subplot(1, 4, 3) -draw_cov(m3, C3, color='C2') +draw_cov(m3, C3, color="C2") pl.axis(axis) -pl.title('$\mathcal{N}(m_3,\Sigma_3)$') +pl.title("$\mathcal{N}(m_3,\Sigma_3)$") pl.subplot(1, 4, 4) -draw_cov(m4, C4, color='C3') +draw_cov(m4, C4, color="C3") pl.axis(axis) -pl.title('$\mathcal{N}(m_4,\Sigma_4)$') +pl.title("$\mathcal{N}(m_4,\Sigma_4)$") # %% # Compute Bures-Wasserstein barycenters and plot them @@ -97,10 +104,9 @@ def eigsorted(cov): v4 = np.array((0, 0, 0, 1)) -colors = np.stack((colors.to_rgb('C0'), - colors.to_rgb('C1'), - colors.to_rgb('C2'), - colors.to_rgb('C3'))) +colors = np.stack( + (colors.to_rgb("C0"), colors.to_rgb("C1"), colors.to_rgb("C2"), colors.to_rgb("C3")) +) pl.figure(2, (8, 8)) @@ -123,5 +129,5 @@ def eigsorted(cov): draw_cov(mb, Cb, color=color, label=None, nstd=0.3) pl.axis(axis) -pl.axis('off') +pl.axis("off") pl.tight_layout() diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index a4d081b3f..5b3572bd4 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -32,9 +32,15 @@ # Input measures sub_sample_factor = 8 -I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2] -I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] -I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2] +I1 = pl.imread("../../data/redcross.png").astype(np.float64)[ + ::sub_sample_factor, ::sub_sample_factor, 2 +] +I2 = pl.imread("../../data/tooth.png").astype(np.float64)[ + ::-sub_sample_factor, ::sub_sample_factor, 2 +] +I3 = pl.imread("../../data/heart.png").astype(np.float64)[ + ::-sub_sample_factor, ::sub_sample_factor, 2 +] sz = I1.shape[0] UU, VV = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -64,7 +70,7 @@ fig = plt.figure(figsize=(3, 3)) axis = fig.add_subplot(1, 1, 1, projection="3d") for Xi in X_visu: - axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) axis.view_init(azim=45) axis.set_xticks([]) axis.set_yticks([]) @@ -80,8 +86,8 @@ axis = fig.add_subplot(1, 1, 1, projection="3d") for Xi in X_visu: - axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) -axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) +axis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6) axis.view_init(azim=45) axis.set_xticks([]) axis.set_yticks([]) @@ -95,28 +101,28 @@ fig = plt.figure(figsize=(9, 3)) -ax = fig.add_subplot(1, 3, 1, projection='3d') +ax = fig.add_subplot(1, 3, 1, projection="3d") for Xi in X_visu: - ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) -ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6) ax.view_init(elev=0, azim=0) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) -ax = fig.add_subplot(1, 3, 2, projection='3d') +ax = fig.add_subplot(1, 3, 2, projection="3d") for Xi in X_visu: - ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) -ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6) ax.view_init(elev=0, azim=90) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) -ax = fig.add_subplot(1, 3, 3, projection='3d') +ax = fig.add_subplot(1, 3, 3, projection="3d") for Xi in X_visu: - ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) -ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) +ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6) ax.view_init(elev=90, azim=0) ax.set_xticks([]) ax.set_yticks([]) @@ -135,13 +141,13 @@ def _init(): for Xi in X_visu: - ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker='o', alpha=.6) - ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker='o', alpha=.6) + ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker="o", alpha=0.6) + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker="o", alpha=0.6) ax.view_init(elev=0, azim=0) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) - return fig, + return (fig,) def _update_plot(i): @@ -149,7 +155,15 @@ def _update_plot(i): ax.view_init(elev=0, azim=4 * i) else: ax.view_init(elev=i - 45, azim=4 * i) - return fig, - - -ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000) + return (fig,) + + +ani = animation.FuncAnimation( + fig, + _update_plot, + init_func=_init, + frames=136, + interval=50, + blit=True, + repeat_delay=2000, +) diff --git a/examples/domain-adaptation/plot_otda_classes.py b/examples/domain-adaptation/plot_otda_classes.py index f028022ea..29d199bd0 100644 --- a/examples/domain-adaptation/plot_otda_classes.py +++ b/examples/domain-adaptation/plot_otda_classes.py @@ -24,8 +24,8 @@ n_source_samples = 150 n_target_samples = 150 -Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples) -Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples) +Xs, ys = ot.datasets.make_data_classif("3gauss", n_source_samples) +Xt, yt = ot.datasets.make_data_classif("3gauss2", n_target_samples) ############################################################################## @@ -45,8 +45,7 @@ ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) # Sinkhorn Transport with Group lasso regularization l1l2 -ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20, - verbose=True) +ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20, verbose=True) ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt) # transport source samples onto target samples @@ -62,18 +61,18 @@ pl.figure(1, figsize=(10, 5)) pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Source samples') +pl.title("Source samples") pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Target samples') +pl.title("Target samples") pl.tight_layout() @@ -81,69 +80,89 @@ # Fig 2 : plot optimal couplings and transported samples # ------------------------------------------------------ -param_img = {'interpolation': 'nearest'} +param_img = {"interpolation": "nearest"} pl.figure(2, figsize=(15, 8)) pl.subplot(2, 4, 1) pl.imshow(ot_emd.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nEMDTransport') +pl.title("Optimal coupling\nEMDTransport") pl.subplot(2, 4, 2) pl.imshow(ot_sinkhorn.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornTransport') +pl.title("Optimal coupling\nSinkhornTransport") pl.subplot(2, 4, 3) pl.imshow(ot_lpl1.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornLpl1Transport') +pl.title("Optimal coupling\nSinkhornLpl1Transport") pl.subplot(2, 4, 4) pl.imshow(ot_l1l2.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornL1l2Transport') +pl.title("Optimal coupling\nSinkhornL1l2Transport") pl.subplot(2, 4, 5) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_emd[:, 0], + transp_Xs_emd[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nEmdTransport') +pl.title("Transported samples\nEmdTransport") pl.legend(loc="lower left") pl.subplot(2, 4, 6) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_sinkhorn[:, 0], + transp_Xs_sinkhorn[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nSinkhornTransport') +pl.title("Transported samples\nSinkhornTransport") pl.subplot(2, 4, 7) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_lpl1[:, 0], + transp_Xs_lpl1[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nSinkhornLpl1Transport') +pl.title("Transported samples\nSinkhornLpl1Transport") pl.subplot(2, 4, 8) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_l1l2[:, 0], + transp_Xs_l1l2[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nSinkhornL1l2Transport') +pl.title("Transported samples\nSinkhornL1l2Transport") pl.tight_layout() pl.show() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 06dc8ab3d..0a452677d 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -49,11 +49,11 @@ def minmax(img): # ------------- # Loading images -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") -I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 -I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 +I1 = plt.imread(os.path.join(data_path, "ocean_day.jpg")).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, "ocean_sunset.jpg")).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) @@ -75,13 +75,13 @@ def minmax(img): plt.subplot(1, 2, 1) plt.imshow(I1) -plt.axis('off') -plt.title('Image 1') +plt.axis("off") +plt.title("Image 1") plt.subplot(1, 2, 2) plt.imshow(I2) -plt.axis('off') -plt.title('Image 2') +plt.axis("off") +plt.title("Image 2") ############################################################################## @@ -93,16 +93,16 @@ def minmax(img): plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) plt.axis([0, 1, 0, 1]) -plt.xlabel('Red') -plt.ylabel('Blue') -plt.title('Image 1') +plt.xlabel("Red") +plt.ylabel("Blue") +plt.title("Image 1") plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) plt.axis([0, 1, 0, 1]) -plt.xlabel('Red') -plt.ylabel('Blue') -plt.title('Image 2') +plt.xlabel("Red") +plt.ylabel("Blue") +plt.title("Image 2") plt.tight_layout() @@ -140,33 +140,33 @@ def minmax(img): plt.subplot(2, 3, 1) plt.imshow(I1) -plt.axis('off') -plt.title('Image 1') +plt.axis("off") +plt.title("Image 1") plt.subplot(2, 3, 2) plt.imshow(I1t) -plt.axis('off') -plt.title('Image 1 Adapt') +plt.axis("off") +plt.title("Image 1 Adapt") plt.subplot(2, 3, 3) plt.imshow(I1te) -plt.axis('off') -plt.title('Image 1 Adapt (reg)') +plt.axis("off") +plt.title("Image 1 Adapt (reg)") plt.subplot(2, 3, 4) plt.imshow(I2) -plt.axis('off') -plt.title('Image 2') +plt.axis("off") +plt.title("Image 2") plt.subplot(2, 3, 5) plt.imshow(I2t) -plt.axis('off') -plt.title('Image 2 Adapt') +plt.axis("off") +plt.title("Image 2 Adapt") plt.subplot(2, 3, 6) plt.imshow(I2te) -plt.axis('off') -plt.title('Image 2 Adapt (reg)') +plt.axis("off") +plt.title("Image 2 Adapt (reg)") plt.tight_layout() plt.show() diff --git a/examples/domain-adaptation/plot_otda_d2.py b/examples/domain-adaptation/plot_otda_d2.py index d8b2a9350..7b38cf3b5 100644 --- a/examples/domain-adaptation/plot_otda_d2.py +++ b/examples/domain-adaptation/plot_otda_d2.py @@ -4,7 +4,7 @@ OT for domain adaptation on empirical distributions =================================================== -This example introduces a domain adaptation in a 2D setting. It explicits +This example introduces a domain adaptation in a 2D setting. It explicit the problem of domain adaptation and introduces some optimal transport approaches to solve it. @@ -31,11 +31,11 @@ n_samples_source = 150 n_samples_target = 150 -Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source) -Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target) +Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source) +Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target) # Cost matrix -M = ot.dist(Xs, Xt, metric='sqeuclidean') +M = ot.dist(Xs, Xt, metric="sqeuclidean") ############################################################################## @@ -66,24 +66,24 @@ pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Source samples') +pl.title("Source samples") pl.subplot(2, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Target samples') +pl.title("Target samples") pl.subplot(2, 2, 3) -pl.imshow(M, interpolation='nearest') +pl.imshow(M, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Matrix of pairwise distances') +pl.title("Matrix of pairwise distances") pl.tight_layout() @@ -93,46 +93,46 @@ pl.figure(2, figsize=(10, 6)) pl.subplot(2, 3, 1) -pl.imshow(ot_emd.coupling_, interpolation='nearest') +pl.imshow(ot_emd.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nEMDTransport') +pl.title("Optimal coupling\nEMDTransport") pl.subplot(2, 3, 2) -pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest') +pl.imshow(ot_sinkhorn.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornTransport') +pl.title("Optimal coupling\nSinkhornTransport") pl.subplot(2, 3, 3) -pl.imshow(ot_lpl1.coupling_, interpolation='nearest') +pl.imshow(ot_lpl1.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornLpl1Transport') +pl.title("Optimal coupling\nSinkhornLpl1Transport") pl.subplot(2, 3, 4) -ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1]) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[0.5, 0.5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) -pl.title('Main coupling coefficients\nEMDTransport') +pl.title("Main coupling coefficients\nEMDTransport") pl.subplot(2, 3, 5) -ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1]) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[0.5, 0.5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) -pl.title('Main coupling coefficients\nSinkhornTransport') +pl.title("Main coupling coefficients\nSinkhornTransport") pl.subplot(2, 3, 6) -ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1]) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[0.5, 0.5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) -pl.title('Main coupling coefficients\nSinkhornLpl1Transport') +pl.title("Main coupling coefficients\nSinkhornLpl1Transport") pl.tight_layout() @@ -143,30 +143,45 @@ # display transported samples pl.figure(4, figsize=(10, 4)) pl.subplot(1, 3, 1) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.5) -pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Transported samples\nEmdTransport') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) +pl.scatter( + transp_Xs_emd[:, 0], + transp_Xs_emd[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) +pl.title("Transported samples\nEmdTransport") pl.legend(loc=0) pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.5) -pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Transported samples\nSinkhornTransport') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) +pl.scatter( + transp_Xs_sinkhorn[:, 0], + transp_Xs_sinkhorn[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) +pl.title("Transported samples\nSinkhornTransport") pl.xticks([]) pl.yticks([]) pl.subplot(1, 3, 3) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.5) -pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Transported samples\nSinkhornLpl1Transport') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) +pl.scatter( + transp_Xs_lpl1[:, 0], + transp_Xs_lpl1[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) +pl.title("Transported samples\nSinkhornLpl1Transport") pl.xticks([]) pl.yticks([]) diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py index 0d974f455..ddc30fb64 100644 --- a/examples/domain-adaptation/plot_otda_jcpot.py +++ b/examples/domain-adaptation/plot_otda_jcpot.py @@ -25,18 +25,18 @@ sigma = 0.3 np.random.seed(1985) -p1 = .2 +p1 = 0.2 dec1 = [0, 2] -p2 = .9 +p2 = 0.9 dec2 = [0, -2] -pt = .4 +pt = 0.4 dect = [4, 0] -xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) -xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) -xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) +xs1, ys1 = make_data_classif("2gauss_prop", n, nz=sigma, p=p1, bias=dec1) +xs2, ys2 = make_data_classif("2gauss_prop", n + 1, nz=sigma, p=p2, bias=dec2) +xt, yt = make_data_classif("2gauss_prop", n, nz=sigma, p=pt, bias=dect) all_Xr = [xs1, xs2] all_Yr = [ys1, ys2] @@ -46,9 +46,9 @@ def plot_ax(dec, name): - pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) - pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) - pl.text(dec[0] - .5, dec[1] + 2, name) + pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], "k", alpha=0.5) + pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], "k", alpha=0.5) + pl.text(dec[0] - 0.5, dec[1] + 2, name) ############################################################################## @@ -57,25 +57,49 @@ def plot_ax(dec, name): pl.figure(1) pl.clf() -plot_ax(dec1, 'Source 1') -plot_ax(dec2, 'Source 2') -plot_ax(dect, 'Target') -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, - label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, - label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, - label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) -pl.title('Data') +plot_ax(dec1, "Source 1") +plot_ax(dec2, "Source 2") +plot_ax(dect, "Target") +pl.scatter( + xs1[:, 0], + xs1[:, 1], + c=ys1, + s=35, + marker="x", + cmap="Set1", + vmax=9, + label="Source 1 ({:1.2f}, {:1.2f})".format(1 - p1, p1), +) +pl.scatter( + xs2[:, 0], + xs2[:, 1], + c=ys2, + s=35, + marker="+", + cmap="Set1", + vmax=9, + label="Source 2 ({:1.2f}, {:1.2f})".format(1 - p2, p2), +) +pl.scatter( + xt[:, 0], + xt[:, 1], + c=yt, + s=35, + marker="o", + cmap="Set1", + vmax=9, + label="Target ({:1.2f}, {:1.2f})".format(1 - pt, pt), +) +pl.title("Data") pl.legend() -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") ############################################################################## # Instantiate Sinkhorn transport algorithm and fit them for all source domains # ---------------------------------------------------------------------------- -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric="sqeuclidean") def print_G(G, xs, ys, xt): @@ -83,10 +107,10 @@ def print_G(G, xs, ys, xt): for j in range(G.shape[1]): if G[i, j] > 5e-4: if ys[i]: - c = 'b' + c = "b" else: - c = 'r' - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) + c = "r" + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=0.2) ############################################################################## @@ -94,78 +118,84 @@ def print_G(G, xs, ys, xt): # ------------------------------------------------------ pl.figure(2) pl.clf() -plot_ax(dec1, 'Source 1') -plot_ax(dec2, 'Source 2') -plot_ax(dect, 'Target') +plot_ax(dec1, "Source 1") +plot_ax(dec2, "Source 2") +plot_ax(dect, "Target") print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker="x", cmap="Set1", vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker="+", cmap="Set1", vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker="o", cmap="Set1", vmax=9) -pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') -pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') +pl.plot([], [], "r", alpha=0.2, label="Mass from Class 1") +pl.plot([], [], "b", alpha=0.2, label="Mass from Class 2") -pl.title('Independent OT') +pl.title("Independent OT") pl.legend() -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") ############################################################################## # Instantiate JCPOT adaptation algorithm and fit it # ---------------------------------------------------------------------------- -otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) +otda = ot.da.JCPOTTransport( + reg_e=1, max_iter=1000, metric="sqeuclidean", tol=1e-9, verbose=True, log=True +) otda.fit(all_Xr, all_Yr, xt) -ws1 = otda.proportions_.dot(otda.log_['D2'][0]) -ws2 = otda.proportions_.dot(otda.log_['D2'][1]) +ws1 = otda.proportions_.dot(otda.log_["D2"][0]) +ws2 = otda.proportions_.dot(otda.log_["D2"][1]) pl.figure(3) pl.clf() -plot_ax(dec1, 'Source 1') -plot_ax(dec2, 'Source 2') -plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) - -pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') -pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') - -pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) +plot_ax(dec1, "Source 1") +plot_ax(dec2, "Source 2") +plot_ax(dect, "Target") +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_["M"][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_["M"][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker="x", cmap="Set1", vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker="+", cmap="Set1", vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker="o", cmap="Set1", vmax=9) + +pl.plot([], [], "r", alpha=0.2, label="Mass from Class 1") +pl.plot([], [], "b", alpha=0.2, label="Mass from Class 2") + +pl.title( + "OT with prop estimation ({:1.3f},{:1.3f})".format( + otda.proportions_[0], otda.proportions_[1] + ) +) pl.legend() -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") ############################################################################## # Run oracle transport algorithm with known proportions # ---------------------------------------------------------------------------- h_res = np.array([1 - pt, pt]) -ws1 = h_res.dot(otda.log_['D2'][0]) -ws2 = h_res.dot(otda.log_['D2'][1]) +ws1 = h_res.dot(otda.log_["D2"][0]) +ws2 = h_res.dot(otda.log_["D2"][1]) pl.figure(4) pl.clf() -plot_ax(dec1, 'Source 1') -plot_ax(dec2, 'Source 2') -plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) +plot_ax(dec1, "Source 1") +plot_ax(dec2, "Source 2") +plot_ax(dect, "Target") +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_["M"][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_["M"][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker="x", cmap="Set1", vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker="+", cmap="Set1", vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker="o", cmap="Set1", vmax=9) -pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') -pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') +pl.plot([], [], "r", alpha=0.2, label="Mass from Class 1") +pl.plot([], [], "b", alpha=0.2, label="Mass from Class 2") -pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) +pl.title("OT with known proportion ({:1.1f},{:1.1f})".format(h_res[0], h_res[1])) pl.legend() -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") pl.show() diff --git a/examples/domain-adaptation/plot_otda_laplacian.py b/examples/domain-adaptation/plot_otda_laplacian.py index 67c8f6703..755cfd4be 100644 --- a/examples/domain-adaptation/plot_otda_laplacian.py +++ b/examples/domain-adaptation/plot_otda_laplacian.py @@ -23,8 +23,8 @@ n_source_samples = 150 n_target_samples = 150 -Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples) -Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples) +Xs, ys = ot.datasets.make_data_classif("3gauss", n_source_samples) +Xt, yt = ot.datasets.make_data_classif("3gauss2", n_target_samples) ############################################################################## @@ -36,7 +36,7 @@ ot_emd.fit(Xs=Xs, Xt=Xt) # Sinkhorn Transport -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01) +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=0.01) ot_sinkhorn.fit(Xs=Xs, Xt=Xt) # EMD Transport with Laplacian regularization @@ -54,18 +54,18 @@ pl.figure(1, figsize=(10, 5)) pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Source samples') +pl.title("Source samples") pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Target samples') +pl.title("Target samples") pl.tight_layout() @@ -73,55 +73,70 @@ # Fig 2 : plot optimal couplings and transported samples # ------------------------------------------------------ -param_img = {'interpolation': 'nearest'} +param_img = {"interpolation": "nearest"} pl.figure(2, figsize=(15, 8)) pl.subplot(2, 3, 1) pl.imshow(ot_emd.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nEMDTransport') +pl.title("Optimal coupling\nEMDTransport") pl.figure(2, figsize=(15, 8)) pl.subplot(2, 3, 2) pl.imshow(ot_sinkhorn.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSinkhornTransport') +pl.title("Optimal coupling\nSinkhornTransport") pl.subplot(2, 3, 3) pl.imshow(ot_emd_laplace.coupling_, **param_img) pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nEMDLaplaceTransport') +pl.title("Optimal coupling\nEMDLaplaceTransport") pl.subplot(2, 3, 4) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_emd[:, 0], + transp_Xs_emd[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nEmdTransport') +pl.title("Transported samples\nEmdTransport") pl.legend(loc="lower left") pl.subplot(2, 3, 5) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_sinkhorn[:, 0], + transp_Xs_sinkhorn[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nSinkhornTransport') +pl.title("Transported samples\nSinkhornTransport") pl.subplot(2, 3, 6) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, - marker='+', label='Transp samples', s=30) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) +pl.scatter( + transp_Xs_emd_laplace[:, 0], + transp_Xs_emd_laplace[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) pl.xticks([]) pl.yticks([]) -pl.title('Transported samples\nEMDLaplaceTransport') +pl.title("Transported samples\nEMDLaplaceTransport") pl.tight_layout() pl.show() diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index 7e7177d37..4795072f3 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -14,7 +14,7 @@ # sphinx_gallery_thumbnail_number = 2 -#%% +# %% import os from pathlib import Path @@ -28,25 +28,23 @@ n = 1000 d = 2 -sigma = .1 +sigma = 0.1 rng = np.random.RandomState(42) # source samples angles = rng.rand(n, 1) * 2 * np.pi -xs = np.concatenate((np.sin(angles), np.cos(angles)), - axis=1) + sigma * rng.randn(n, 2) -xs[:n // 2, 1] += 2 +xs = np.concatenate((np.sin(angles), np.cos(angles)), axis=1) + sigma * rng.randn(n, 2) +xs[: n // 2, 1] += 2 # target samples anglet = rng.rand(n, 1) * 2 * np.pi -xt = np.concatenate((np.sin(anglet), np.cos(anglet)), - axis=1) + sigma * rng.randn(n, 2) -xt[:n // 2, 1] += 2 +xt = np.concatenate((np.sin(anglet), np.cos(anglet)), axis=1) + sigma * rng.randn(n, 2) +xt[: n // 2, 1] += 2 -A = np.array([[1.5, .7], [.7, 1.5]]) +A = np.array([[1.5, 0.7], [0.7, 1.5]]) b = np.array([[4, 2]]) xt = xt.dot(A) + b @@ -55,10 +53,10 @@ # --------- plt.figure(1, (5, 5)) -plt.plot(xs[:, 0], xs[:, 1], '+') -plt.plot(xt[:, 0], xt[:, 1], 'o') -plt.legend(('Source', 'Target')) -plt.title('Source and target distributions') +plt.plot(xs[:, 0], xs[:, 1], "+") +plt.plot(xt[:, 0], xt[:, 1], "o") +plt.legend(("Source", "Target")) +plt.title("Source and target distributions") plt.show() ############################################################################## @@ -83,17 +81,17 @@ plt.figure(2, (10, 5)) plt.clf() plt.subplot(1, 2, 1) -plt.plot(xs[:, 0], xs[:, 1], '+') -plt.plot(xt[:, 0], xt[:, 1], 'o') -plt.plot(xst[:, 0], xst[:, 1], '+') -plt.legend(('Source', 'Target', 'Transp. Monge'), loc=0) -plt.title('Transported samples with Monge') +plt.plot(xs[:, 0], xs[:, 1], "+") +plt.plot(xt[:, 0], xt[:, 1], "o") +plt.plot(xst[:, 0], xst[:, 1], "+") +plt.legend(("Source", "Target", "Transp. Monge"), loc=0) +plt.title("Transported samples with Monge") plt.subplot(1, 2, 2) -plt.plot(xs[:, 0], xs[:, 1], '+') -plt.plot(xt[:, 0], xt[:, 1], 'o') -plt.plot(xstgw[:, 0], xstgw[:, 1], '+') -plt.legend(('Source', 'Target', 'Transp. GW'), loc=0) -plt.title('Transported samples with Gaussian GW') +plt.plot(xs[:, 0], xs[:, 1], "+") +plt.plot(xt[:, 0], xt[:, 1], "o") +plt.plot(xstgw[:, 0], xstgw[:, 1], "+") +plt.legend(("Source", "Target", "Transp. GW"), loc=0) +plt.title("Transported samples with Gaussian GW") plt.show() ############################################################################## @@ -116,11 +114,11 @@ def minmax(img): # Loading images -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") -I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 -I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 +I1 = plt.imread(os.path.join(data_path, "ocean_day.jpg")).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, "ocean_sunset.jpg")).astype(np.float64) / 256 X1 = im2mat(I1) @@ -164,30 +162,30 @@ def minmax(img): plt.subplot(2, 3, 1) plt.imshow(I1) -plt.axis('off') -plt.title('Im. 1') +plt.axis("off") +plt.title("Im. 1") plt.subplot(2, 3, 4) plt.imshow(I2) -plt.axis('off') -plt.title('Im. 2') +plt.axis("off") +plt.title("Im. 2") plt.subplot(2, 3, 2) plt.imshow(I1t) -plt.axis('off') -plt.title('Monge mapping Im. 1') +plt.axis("off") +plt.title("Monge mapping Im. 1") plt.subplot(2, 3, 5) plt.imshow(I2t) -plt.axis('off') -plt.title('Inverse Monge mapping Im. 2') +plt.axis("off") +plt.title("Inverse Monge mapping Im. 2") plt.subplot(2, 3, 3) plt.imshow(I1tgw) -plt.axis('off') -plt.title('Gaussian GW mapping Im. 1') +plt.axis("off") +plt.title("Gaussian GW mapping Im. 1") plt.subplot(2, 3, 6) plt.imshow(I2tgw) -plt.axis('off') -plt.title('Inverse Gaussian GW mapping Im. 2') +plt.axis("off") +plt.title("Inverse Gaussian GW mapping Im. 2") diff --git a/examples/domain-adaptation/plot_otda_mapping.py b/examples/domain-adaptation/plot_otda_mapping.py index d21d3c94c..42a89a381 100644 --- a/examples/domain-adaptation/plot_otda_mapping.py +++ b/examples/domain-adaptation/plot_otda_mapping.py @@ -34,12 +34,11 @@ theta = 2 * np.pi / 20 noise_level = 0.1 -Xs, ys = ot.datasets.make_data_classif( - 'gaussrot', n_source_samples, nz=noise_level) -Xs_new, _ = ot.datasets.make_data_classif( - 'gaussrot', n_source_samples, nz=noise_level) +Xs, ys = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level) +Xs_new, _ = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level) Xt, yt = ot.datasets.make_data_classif( - 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + "gaussrot", n_target_samples, theta=theta, nz=noise_level +) # one of the target mode changes its variance (no linear mapping) Xt[yt == 2] *= 3 @@ -51,10 +50,10 @@ pl.figure(1, (10, 5)) pl.clf() -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") ############################################################################## @@ -63,8 +62,8 @@ # MappingTransport with linear kernel ot_mapping_linear = ot.da.MappingTransport( - kernel="linear", mu=1e0, eta=1e-8, bias=True, - max_iter=20, verbose=True) + kernel="linear", mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True +) ot_mapping_linear.fit(Xs=Xs, Xt=Xt) @@ -77,8 +76,8 @@ # MappingTransport with gaussian kernel ot_mapping_gaussian = ot.da.MappingTransport( - kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1, - max_iter=10, verbose=True) + kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1, max_iter=10, verbose=True +) ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt) # for original source samples, transform applies barycentric mapping @@ -95,32 +94,48 @@ pl.figure(2) pl.clf() pl.subplot(2, 2, 1) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=.2) -pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+', - label='Mapped source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2) +pl.scatter( + transp_Xs_linear[:, 0], + transp_Xs_linear[:, 1], + c=ys, + marker="+", + label="Mapped source samples", +) pl.title("Bary. mapping (linear)") pl.legend(loc=0) pl.subplot(2, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=.2) -pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1], - c=ys, marker='+', label='Learned mapping') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2) +pl.scatter( + transp_Xs_linear_new[:, 0], + transp_Xs_linear_new[:, 1], + c=ys, + marker="+", + label="Learned mapping", +) pl.title("Estim. mapping (linear)") pl.subplot(2, 2, 3) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=.2) -pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys, - marker='+', label='barycentric mapping') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2) +pl.scatter( + transp_Xs_gaussian[:, 0], + transp_Xs_gaussian[:, 1], + c=ys, + marker="+", + label="barycentric mapping", +) pl.title("Bary. mapping (kernel)") pl.subplot(2, 2, 4) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=.2) -pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys, - marker='+', label='Learned mapping') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.2) +pl.scatter( + transp_Xs_gaussian_new[:, 0], + transp_Xs_gaussian_new[:, 1], + c=ys, + marker="+", + label="Learned mapping", +) pl.title("Estim. mapping (kernel)") pl.tight_layout() diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index dbece7082..c52bdf121 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -50,11 +50,11 @@ def minmax(img): # ------------- # Loading images -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") -I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 -I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 +I1 = plt.imread(os.path.join(data_path, "ocean_day.jpg")).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, "ocean_sunset.jpg")).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) @@ -85,14 +85,16 @@ def minmax(img): Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape)) ot_mapping_linear = ot.da.MappingTransport( - mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True) + mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True +) ot_mapping_linear.fit(Xs=Xs, Xt=Xt) X1tl = ot_mapping_linear.transform(Xs=X1) Image_mapping_linear = minmax(mat2im(X1tl, I1.shape)) ot_mapping_gaussian = ot.da.MappingTransport( - mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True) + mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True +) ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt) X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping @@ -106,13 +108,13 @@ def minmax(img): plt.figure(1, figsize=(6.4, 3)) plt.subplot(1, 2, 1) plt.imshow(I1) -plt.axis('off') -plt.title('Image 1') +plt.axis("off") +plt.title("Image 1") plt.subplot(1, 2, 2) plt.imshow(I2) -plt.axis('off') -plt.title('Image 2') +plt.axis("off") +plt.title("Image 2") plt.tight_layout() @@ -125,16 +127,16 @@ def minmax(img): plt.subplot(1, 2, 1) plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) plt.axis([0, 1, 0, 1]) -plt.xlabel('Red') -plt.ylabel('Blue') -plt.title('Image 1') +plt.xlabel("Red") +plt.ylabel("Blue") +plt.title("Image 1") plt.subplot(1, 2, 2) plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) plt.axis([0, 1, 0, 1]) -plt.xlabel('Red') -plt.ylabel('Blue') -plt.title('Image 2') +plt.xlabel("Red") +plt.ylabel("Blue") +plt.title("Image 2") plt.tight_layout() @@ -146,33 +148,33 @@ def minmax(img): plt.subplot(2, 3, 1) plt.imshow(I1) -plt.axis('off') -plt.title('Im. 1') +plt.axis("off") +plt.title("Im. 1") plt.subplot(2, 3, 4) plt.imshow(I2) -plt.axis('off') -plt.title('Im. 2') +plt.axis("off") +plt.title("Im. 2") plt.subplot(2, 3, 2) plt.imshow(Image_emd) -plt.axis('off') -plt.title('EmdTransport') +plt.axis("off") +plt.title("EmdTransport") plt.subplot(2, 3, 5) plt.imshow(Image_sinkhorn) -plt.axis('off') -plt.title('SinkhornTransport') +plt.axis("off") +plt.title("SinkhornTransport") plt.subplot(2, 3, 3) plt.imshow(Image_mapping_linear) -plt.axis('off') -plt.title('MappingTransport (linear)') +plt.axis("off") +plt.title("MappingTransport (linear)") plt.subplot(2, 3, 6) plt.imshow(Image_mapping_gaussian) -plt.axis('off') -plt.title('MappingTransport (gaussian)') +plt.axis("off") +plt.title("MappingTransport (gaussian)") plt.tight_layout() plt.show() diff --git a/examples/domain-adaptation/plot_otda_semi_supervised.py b/examples/domain-adaptation/plot_otda_semi_supervised.py index 278c8dde0..454e67ec3 100644 --- a/examples/domain-adaptation/plot_otda_semi_supervised.py +++ b/examples/domain-adaptation/plot_otda_semi_supervised.py @@ -5,7 +5,7 @@ ============================================ This example introduces a semi supervised domain adaptation in a 2D setting. -It explicits the problem of semi supervised domain adaptation and introduces +It explicit the problem of semi supervised domain adaptation and introduces some optimal transport approaches to solve it. Quantities such as optimal couplings, greater coupling coefficients and @@ -31,8 +31,8 @@ n_samples_source = 150 n_samples_target = 150 -Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source) -Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target) +Xs, ys = ot.datasets.make_data_classif("3gauss", n_samples_source) +Xt, yt = ot.datasets.make_data_classif("3gauss2", n_samples_target) ############################################################################## @@ -69,30 +69,30 @@ pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Source samples') +pl.title("Source samples") pl.subplot(2, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) -pl.title('Target samples') +pl.title("Target samples") pl.subplot(2, 2, 3) -pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest') +pl.imshow(ot_sinkhorn_un.cost_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Cost matrix - unsupervised DA') +pl.title("Cost matrix - unsupervised DA") pl.subplot(2, 2, 4) -pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest') +pl.imshow(ot_sinkhorn_semi.cost_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Cost matrix - semi-supervised DA') +pl.title("Cost matrix - semi-supervised DA") pl.tight_layout() @@ -107,16 +107,16 @@ pl.figure(2, figsize=(8, 4)) pl.subplot(1, 2, 1) -pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest') +pl.imshow(ot_sinkhorn_un.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nUnsupervised DA') +pl.title("Optimal coupling\nUnsupervised DA") pl.subplot(1, 2, 2) -pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest') +pl.imshow(ot_sinkhorn_semi.coupling_, interpolation="nearest") pl.xticks([]) pl.yticks([]) -pl.title('Optimal coupling\nSemi-supervised DA') +pl.title("Optimal coupling\nSemi-supervised DA") pl.tight_layout() @@ -128,21 +128,31 @@ # display transported samples pl.figure(4, figsize=(8, 4)) pl.subplot(1, 2, 1) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.5) -pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Transported samples\nEmdTransport') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) +pl.scatter( + transp_Xs_sinkhorn_un[:, 0], + transp_Xs_sinkhorn_un[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) +pl.title("Transported samples\nEmdTransport") pl.legend(loc=0) pl.xticks([]) pl.yticks([]) pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.5) -pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Transported samples\nSinkhornTransport') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5) +pl.scatter( + transp_Xs_sinkhorn_semi[:, 0], + transp_Xs_sinkhorn_semi[:, 1], + c=ys, + marker="+", + label="Transp samples", + s=30, +) +pl.title("Transported samples\nSinkhornTransport") pl.xticks([]) pl.yticks([]) diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index a0cc1852d..865c1e71a 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -20,7 +20,7 @@ # # License: MIT License -#%% load libraries +# %% load libraries import numpy as np import matplotlib.pyplot as plt import networkx as nx @@ -29,11 +29,11 @@ import matplotlib.colors as mcol from matplotlib import cm from ot.gromov import fgw_barycenters -#%% Graph functions +# %% Graph functions def find_thresh(C, inf=0.5, sup=3, step=10): - """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + """Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected The threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjacency matrix and the original matrix. @@ -52,14 +52,14 @@ def find_thresh(C, inf=0.5, sup=3, step=10): search = np.linspace(inf, sup, step) for thresh in search: Cprime = sp_to_adjacency(C, 0, thresh) - SC = shortest_path(Cprime, method='D') - SC[SC == float('inf')] = 100 + SC = shortest_path(Cprime, method="D") + SC[SC == float("inf")] = 100 dist.append(np.linalg.norm(SC - C)) return search[np.argmin(dist)], dist def sp_to_adjacency(C, threshinf=0.2, threshsup=1.8): - """ Thresholds the structure matrix in order to compute an adjacency matrix. + """Thresholds the structure matrix in order to compute an adjacency matrix. All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 Parameters ---------- @@ -84,9 +84,10 @@ def sp_to_adjacency(C, threshinf=0.2, threshsup=1.8): return C -def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): - """ Create a noisy circular graph - """ +def build_noisy_circular_graph( + N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None +): + """Create a noisy circular graph""" g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): @@ -116,21 +117,22 @@ def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structur def graph_colors(nx_graph, vmin=0, vmax=7): cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) - cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') + cpick = cm.ScalarMappable(norm=cnorm, cmap="viridis") cpick.set_array([]) val_map = {} - for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): + for k, v in nx.get_node_attributes(nx_graph, "attr_name").items(): val_map[k] = cpick.to_rgba(v) colors = [] for node in nx_graph.nodes(): colors.append(val_map[node]) return colors + ############################################################################## # Generate data # ------------- -#%% circular dataset +# %% circular dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. @@ -138,32 +140,47 @@ def graph_colors(nx_graph, vmin=0, vmax=7): np.random.seed(30) X0 = [] for k in range(9): - X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) + X0.append( + build_noisy_circular_graph( + np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3 + ) + ) ############################################################################## # Plot data # --------- -#%% Plot graphs +# %% Plot graphs plt.figure(figsize=(8, 10)) for i in range(len(X0)): plt.subplot(3, 3, i + 1) g = X0[i] pos = nx.kamada_kawai_layout(g) - nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) -plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) + nx.draw( + g, + pos=pos, + node_color=graph_colors(g, vmin=-1, vmax=1), + with_labels=False, + node_size=100, + ) +plt.suptitle("Dataset of noisy graphs. Color indicates the label", fontsize=20) plt.show() ############################################################################## # Barycenter computation # ---------------------- -#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +# %% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances Cs = [shortest_path(nx.adjacency_matrix(x).todense()) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] -Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] +Ys = [ + np.array([v for (k, v) in nx.get_node_attributes(x, "attr_name").items()]).reshape( + -1, 1 + ) + for x in X0 +] lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() sizebary = 15 # we choose a barycenter with 15 nodes @@ -173,13 +190,17 @@ def graph_colors(nx_graph, vmin=0, vmax=7): # Plot Barycenter # ------------------------- -#%% Create the barycenter -bary = nx.from_numpy_array(sp_to_adjacency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) +# %% Create the barycenter +bary = nx.from_numpy_array( + sp_to_adjacency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]) +) for i, v in enumerate(A.ravel()): bary.add_node(i, attr_name=v) -#%% +# %% pos = nx.kamada_kawai_layout(bary) -nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) -plt.suptitle('Barycenter', fontsize=20) +nx.draw( + bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False +) +plt.suptitle("Barycenter", fontsize=20) plt.show() diff --git a/examples/gromov/plot_entropic_semirelaxed_fgw.py b/examples/gromov/plot_entropic_semirelaxed_fgw.py index 642baeae6..335fa7572 100644 --- a/examples/gromov/plot_entropic_semirelaxed_fgw.py +++ b/examples/gromov/plot_entropic_semirelaxed_fgw.py @@ -27,7 +27,12 @@ import numpy as np import matplotlib.pylab as pl -from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +from ot.gromov import ( + entropic_semirelaxed_gromov_wasserstein, + entropic_semirelaxed_fused_gromov_wasserstein, + gromov_wasserstein, + fused_gromov_wasserstein, +) import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -39,11 +44,8 @@ N2 = 20 # 2 communities N3 = 30 # 3 communities -p2 = [[1., 0.1], - [0.1, 0.9]] -p3 = [[1., 0.1, 0.], - [0.1, 0.95, 0.1], - [0., 0.1, 0.9]] +p2 = [[1.0, 0.1], [0.1, 0.9]] +p3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]] G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) @@ -57,11 +59,11 @@ # Add weights on the edges for visualization later on weight_intra_G2 = 5 weight_inter_G2 = 0.5 -weight_intra_G3 = 1. +weight_intra_G3 = 1.0 weight_inter_G3 = 1.5 weightedG2 = networkx.Graph() -part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G2 = [G2.nodes[i]["block"] for i in range(N2)] for node in G2.nodes(): weightedG2.add_node(node) @@ -72,7 +74,7 @@ weightedG2.add_edge(i, j, weight=weight_inter_G2) weightedG3 = networkx.Graph() -part_G3 = [G3.nodes[i]['block'] for i in range(N3)] +part_G3 = [G3.nodes[i]["block"] for i in range(N3)] for node in G3.nodes(): weightedG3.add_node(node) @@ -89,22 +91,24 @@ # 0) GW(C2, h2, C3, h3) for reference OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) -gw = log['gw_dist'] +gw = log["gw_dist"] # 1) srGW_e(C2, h2, C3) OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein( - C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True) -srgw_23 = log_23['srgw_dist'] + C2, C3, h2, symmetric=True, epsilon=1.0, G0=None, log=True +) +srgw_23 = log_23["srgw_dist"] # 2) srGW_e(C3, h3, C2) OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein( - C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True) -srgw_32 = log_32['srgw_dist'] + C3, C2, h3, symmetric=None, epsilon=1.0, G0=None, log=True +) +srgw_32 = log_32["srgw_dist"] -print('GW(C2, C3) = ', gw) -print('srGW_e(C2, h2, C3) = ', srgw_23) -print('srGW_e(C3, h3, C2) = ', srgw_32) +print("GW(C2, C3) = ", gw) +print("srGW_e(C2, h2, C3) = ", srgw_23) +print("srGW_e(C3, h3, C2) = ", srgw_32) ############################################################################# @@ -118,12 +122,19 @@ # sent, adding a minimal intensity of 0.1 if mass sent is not zero. -def draw_graph(G, C, nodes_color_part, Gweights=None, - pos=None, edge_color='black', node_size=None, - shiftx=0, seed=0): - - if (pos is None): - pos = networkx.spring_layout(G, scale=1., seed=seed) +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, + seed=0, +): + if pos is None: + pos = networkx.spring_layout(G, scale=1.0, seed=seed) if shiftx != 0: for k, v in pos.items(): @@ -132,7 +143,9 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, alpha_edge = 0.7 width_edge = 1.8 if Gweights is None: - networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + networkx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) else: # We make more visible connections between activated nodes n = len(Gweights) @@ -145,36 +158,69 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, elif C[i, j] > 0: edgelist_deactivated.append((i, j)) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, - width=width_edge, alpha=alpha_edge, - edge_color=edge_color) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, - width=width_edge, alpha=0.1, - edge_color=edge_color) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) if Gweights is None: for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=node_size, alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) else: scaled_Gweights = Gweights / (0.5 * Gweights.max()) nodes_size = node_size * scaled_Gweights for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=nodes_size[node], alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=nodes_size[node], + alpha=1, + node_color=node_color, + ) return pos -def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, - p1, p2, T, pos1=None, pos2=None, - shiftx=4, switchx=False, node_size=70, - seed_G1=0, seed_G2=0): +def draw_transp_colored_srGW( + G1, + C1, + G2, + C2, + part_G1, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + seed_G1=0, + seed_G2=0, +): starting_color = 0 # get graphs partition and their coloring part1 = part_G1.copy() - unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part1)] nodes_color_part1 = [] for cluster in part1: nodes_color_part1.append(unique_colors[cluster]) @@ -184,18 +230,38 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, for i in range(len(G2.nodes())): j = np.argmax(T[:, i]) nodes_color_part2.append(nodes_color_part1[j]) - pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, - pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) - pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, - node_size=node_size, shiftx=shiftx, seed=seed_G2) + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + seed=seed_G1, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + seed=seed_G2, + ) for k1, v1 in pos1.items(): max_Tk1 = np.max(T[k1, :]) for k2, v2 in pos2.items(): - if (T[k1, k2] > 0): - pl.plot([pos1[k1][0], pos2[k2][0]], - [pos1[k1][1], pos2[k2][1]], - '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), - color=nodes_color_part1[k1]) + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.6, + alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.0), + color=nodes_color_part1[k1], + ) return pos1, pos2 @@ -207,21 +273,51 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, pl.figure(1, figsize=(8, 2.5)) pl.clf() pl.subplot(121) -pl.axis('off') +pl.axis("off") pl.axis -pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) +pl.title( + r"$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$" % (np.round(srgw_23, 3)), + fontsize=fontsize, +) hbar2 = OT_23.sum(axis=0) pos1, pos2 = draw_transp_colored_srGW( - weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, - shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=None, + p2=hbar2, + T=OT_23, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G2, + seed_G2=seed_G3, +) pl.subplot(122) -pl.axis('off') +pl.axis("off") hbar3 = OT_32.sum(axis=0) -pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pl.title( + r"$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$" % (np.round(srgw_32, 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_srGW( - weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, - pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG3, + C3, + weightedG2, + C2, + part_G3, + p1=None, + p2=hbar3, + T=OT_32, + pos1=pos2, + pos2=pos1, + shiftx=3.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.tight_layout() pl.show() @@ -240,7 +336,7 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, F3 = np.zeros((N3, 1)) for i, c in enumerate(part_G3): - F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01) ############################################################################# # @@ -249,28 +345,31 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, alpha = 0.5 # Compute pairwise euclidean distance between node features -M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) +M = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T) # 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference OT, log = fused_gromov_wasserstein( - M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) -fgw = log['fgw_dist'] + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True +) +fgw = log["fgw_dist"] # 1) srFGW_e(C2, F2, h2, C3, F3) OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein( - M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None) -srfgw_23 = log_23['srfgw_dist'] + M, C2, C3, h2, symmetric=True, epsilon=1.0, alpha=0.5, log=True, G0=None +) +srfgw_23 = log_23["srfgw_dist"] # 2) srFGW(C3, F3, h3, C2, F2) OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein( - M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None) -srfgw_32 = log_32['srfgw_dist'] + M.T, C3, C2, h3, symmetric=None, epsilon=1.0, alpha=alpha, log=True, G0=None +) +srfgw_32 = log_32["srfgw_dist"] -print('FGW(C2, F2, C3, F3) = ', fgw) -print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23) -print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32) +print("FGW(C2, F2, C3, F3) = ", fgw) +print(r"$srGW_e$(C2, F2, h2, C3, F3) = ", srfgw_23) +print(r"$srGW_e$(C3, F3, h3, C2, F2) = ", srfgw_32) ############################################################################# # @@ -284,21 +383,53 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, pl.figure(2, figsize=(8, 2.5)) pl.clf() pl.subplot(121) -pl.axis('off') +pl.axis("off") pl.axis -pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) +pl.title( + r"$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$" + % (np.round(srfgw_23, 3)), + fontsize=fontsize, +) hbar2 = OT_23.sum(axis=0) pos1, pos2 = draw_transp_colored_srGW( - weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, - shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=None, + p2=hbar2, + T=OT_23, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G2, + seed_G2=seed_G3, +) pl.subplot(122) -pl.axis('off') +pl.axis("off") hbar3 = OT_32.sum(axis=0) -pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pl.title( + r"$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$" + % (np.round(srfgw_32, 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_srGW( - weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, - pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG3, + C3, + weightedG2, + C2, + part_G3, + p1=None, + p2=hbar3, + T=OT_32, + pos1=pos2, + pos2=pos1, + shiftx=3.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.tight_layout() pl.show() diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py index 68ecb13f6..1e4252ad5 100644 --- a/examples/gromov/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -41,11 +41,15 @@ phi = np.arange(n)[:, None] xs = phi + sig * np.random.randn(n, 1) -ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) +ys = np.vstack( + (np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1))) +) + sig2 * np.random.randn(n, 1) phi2 = np.arange(n2)[:, None] xt = phi2 + sig * np.random.randn(n2, 1) -yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) +yt = np.vstack( + (np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1))) +) + sig2 * np.random.randn(n2, 1) yt = yt[::-1, :] p = ot.unif(n) @@ -62,15 +66,15 @@ pl.subplot(2, 1, 1) pl.scatter(ys, xs, c=phi, s=70) -pl.ylabel('Feature value a', fontsize=20) -pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, y=1) +pl.ylabel("Feature value a", fontsize=20) +pl.title("$\mu=\sum_i \delta_{x_i,a_i}$", fontsize=25, y=1) pl.xticks(()) pl.yticks(()) pl.subplot(2, 1, 2) pl.scatter(yt, xt, c=phi2, s=70) -pl.xlabel('coordinates x/y', fontsize=25) -pl.ylabel('Feature value b', fontsize=20) -pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, y=1) +pl.xlabel("coordinates x/y", fontsize=25) +pl.ylabel("Feature value b", fontsize=20) +pl.title("$\\nu=\sum_j \delta_{y_j,b_j}$", fontsize=25, y=1) pl.yticks(()) pl.tight_layout() pl.show() @@ -91,7 +95,7 @@ # Plot matrices # ------------- -cmap = 'Reds' +cmap = "Reds" pl.figure(2, (5, 5)) fs = 15 @@ -101,7 +105,7 @@ ax1 = pl.subplot(gs[3:, :2]) -pl.imshow(C1, cmap=cmap, interpolation='nearest') +pl.imshow(C1, cmap=cmap, interpolation="nearest") pl.title("$C_1$", fontsize=fs) pl.xlabel("$k$", fontsize=fs) pl.ylabel("$i$", fontsize=fs) @@ -110,22 +114,22 @@ ax2 = pl.subplot(gs[:3, 2:]) -pl.imshow(C2, cmap=cmap, interpolation='nearest') +pl.imshow(C2, cmap=cmap, interpolation="nearest") pl.title("$C_2$", fontsize=fs) pl.ylabel("$l$", fontsize=fs) pl.xticks(()) pl.yticks(l_y) -ax2.set_aspect('auto') +ax2.set_aspect("auto") ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) -pl.imshow(M, cmap=cmap, interpolation='nearest') +pl.imshow(M, cmap=cmap, interpolation="nearest") pl.yticks(l_x) pl.xticks(l_y) pl.ylabel("$i$", fontsize=fs) pl.title("$M_{AB}$", fontsize=fs) pl.xlabel("$j$", fontsize=fs) pl.tight_layout() -ax3.set_aspect('auto') +ax3.set_aspect("auto") pl.show() ############################################################################## @@ -136,35 +140,39 @@ alpha = 1e-3 ot.tic() -Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) +Gwg, logw = fused_gromov_wasserstein( + M, C1, C2, p, q, loss_fun="square_loss", alpha=alpha, verbose=True, log=True +) ot.toc() # reload_ext WGW -Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) +Gg, log = gromov_wasserstein( + C1, C2, p, q, loss_fun="square_loss", verbose=True, log=True +) ############################################################################## # Visualize transport matrices # ---------------------------- # visu OT matrix -cmap = 'Blues' +cmap = "Blues" fs = 15 pl.figure(3, (13, 5)) pl.clf() pl.subplot(1, 3, 1) -pl.imshow(Got, cmap=cmap, interpolation='nearest') +pl.imshow(Got, cmap=cmap, interpolation="nearest") pl.ylabel("$i$", fontsize=fs) pl.xticks(()) -pl.title('Wasserstein ($M$ only)') +pl.title("Wasserstein ($M$ only)") pl.subplot(1, 3, 2) -pl.imshow(Gg, cmap=cmap, interpolation='nearest') -pl.title('Gromov ($C_1,C_2$ only)') +pl.imshow(Gg, cmap=cmap, interpolation="nearest") +pl.title("Gromov ($C_1,C_2$ only)") pl.xticks(()) pl.subplot(1, 3, 3) -pl.imshow(Gwg, cmap=cmap, interpolation='nearest') -pl.title('FGW ($M+C_1,C_2$)') +pl.imshow(Gwg, cmap=cmap, interpolation="nearest") +pl.title("FGW ($M+C_1,C_2$)") pl.xlabel("$j$", fontsize=fs) pl.ylabel("$i$", fontsize=fs) diff --git a/examples/gromov/plot_fgw_solvers.py b/examples/gromov/plot_fgw_solvers.py index 75c12cca0..ab1ccac88 100644 --- a/examples/gromov/plot_fgw_solvers.py +++ b/examples/gromov/plot_fgw_solvers.py @@ -44,9 +44,11 @@ import numpy as np import matplotlib.pylab as pl -from ot.gromov import (fused_gromov_wasserstein, - entropic_fused_gromov_wasserstein, - BAPG_fused_gromov_wasserstein) +from ot.gromov import ( + fused_gromov_wasserstein, + entropic_fused_gromov_wasserstein, + BAPG_fused_gromov_wasserstein, +) import networkx from networkx.generators.community import stochastic_block_model as sbm from time import time @@ -59,15 +61,12 @@ N2 = 20 # 2 communities N3 = 30 # 3 communities -p2 = [[1., 0.1], - [0.1, 0.9]] -p3 = [[1., 0.1, 0.], - [0.1, 0.95, 0.1], - [0., 0.1, 0.9]] +p2 = [[1.0, 0.1], [0.1, 0.9]] +p3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]] G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) -part_G2 = [G2.nodes[i]['block'] for i in range(N2)] -part_G3 = [G3.nodes[i]['block'] for i in range(N3)] +part_G2 = [G2.nodes[i]["block"] for i in range(N2)] +part_G3 = [G3.nodes[i]["block"] for i in range(N3)] C2 = networkx.to_numpy_array(G2) C3 = networkx.to_numpy_array(G3) @@ -82,10 +81,10 @@ F3 = np.zeros((N3, 1)) for i, c in enumerate(part_G3): - F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01) # Compute pairwise euclidean distance between node features -M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) +M = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T) h2 = np.ones(C2.shape[0]) / C2.shape[0] h3 = np.ones(C3.shape[0]) / C3.shape[0] @@ -99,51 +98,100 @@ # Conditional Gradient algorithm -print('Conditional Gradient \n') +print("Conditional Gradient \n") start_cg = time() T_cg, log_cg = fused_gromov_wasserstein( - M, C2, C3, h2, h3, 'square_loss', alpha=alpha, tol_rel=1e-9, - verbose=True, log=True) + M, C2, C3, h2, h3, "square_loss", alpha=alpha, tol_rel=1e-9, verbose=True, log=True +) end_cg = time() time_cg = 1000 * (end_cg - start_cg) # Proximal Point algorithm with Kullback-Leibler as proximal operator -print('Proximal Point Algorithm \n') +print("Proximal Point Algorithm \n") start_ppa = time() T_ppa, log_ppa = entropic_fused_gromov_wasserstein( - M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA', - tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10) + M, + C2, + C3, + h2, + h3, + "square_loss", + alpha=alpha, + epsilon=1.0, + solver="PPA", + tol=1e-9, + log=True, + verbose=True, + warmstart=False, + numItermax=10, +) end_ppa = time() time_ppa = 1000 * (end_ppa - start_ppa) # Projected Gradient algorithm with entropic regularization -print('Projected Gradient Descent \n') +print("Projected Gradient Descent \n") start_pgd = time() T_pgd, log_pgd = entropic_fused_gromov_wasserstein( - M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD', - tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10) + M, + C2, + C3, + h2, + h3, + "square_loss", + alpha=alpha, + epsilon=0.01, + solver="PGD", + tol=1e-9, + log=True, + verbose=True, + warmstart=False, + numItermax=10, +) end_pgd = time() time_pgd = 1000 * (end_pgd - start_pgd) # Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator -print('Bregman Alternated Projected Gradient \n') +print("Bregman Alternated Projected Gradient \n") start_bapg = time() T_bapg, log_bapg = BAPG_fused_gromov_wasserstein( - M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., - tol=1e-9, marginal_loss=True, verbose=True, log=True) + M, + C2, + C3, + h2, + h3, + "square_loss", + alpha=alpha, + epsilon=1.0, + tol=1e-9, + marginal_loss=True, + verbose=True, + log=True, +) end_bapg = time() time_bapg = 1000 * (end_bapg - start_bapg) -print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log_cg['fgw_dist'])) -print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log_ppa['fgw_dist'])) -print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_pgd['fgw_dist'])) -print('Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_bapg['fgw_dist'])) +print( + "Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: " + + str(log_cg["fgw_dist"]) +) +print( + "Fused Gromov-Wasserstein distance estimated with Proximal Point solver: " + + str(log_ppa["fgw_dist"]) +) +print( + "Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: " + + str(log_pgd["fgw_dist"]) +) +print( + "Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: " + + str(log_bapg["fgw_dist"]) +) # compute OT sparsity level -T_cg_sparsity = 100 * (T_cg == 0.).astype(np.float64).sum() / (N2 * N3) -T_ppa_sparsity = 100 * (T_ppa == 0.).astype(np.float64).sum() / (N2 * N3) -T_pgd_sparsity = 100 * (T_pgd == 0.).astype(np.float64).sum() / (N2 * N3) -T_bapg_sparsity = 100 * (T_bapg == 0.).astype(np.float64).sum() / (N2 * N3) +T_cg_sparsity = 100 * (T_cg == 0.0).astype(np.float64).sum() / (N2 * N3) +T_ppa_sparsity = 100 * (T_ppa == 0.0).astype(np.float64).sum() / (N2 * N3) +T_pgd_sparsity = 100 * (T_pgd == 0.0).astype(np.float64).sum() / (N2 * N3) +T_bapg_sparsity = 100 * (T_bapg == 0.0).astype(np.float64).sum() / (N2 * N3) # Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the # marginal constraints @@ -169,11 +217,11 @@ # Add weights on the edges for visualization later on weight_intra_G2 = 5 weight_inter_G2 = 0.5 -weight_intra_G3 = 1. +weight_intra_G3 = 1.0 weight_inter_G3 = 1.5 weightedG2 = networkx.Graph() -part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G2 = [G2.nodes[i]["block"] for i in range(N2)] for node in G2.nodes(): weightedG2.add_node(node) @@ -184,7 +232,7 @@ weightedG2.add_edge(i, j, weight=weight_inter_G2) weightedG3 = networkx.Graph() -part_G3 = [G3.nodes[i]['block'] for i in range(N3)] +part_G3 = [G3.nodes[i]["block"] for i in range(N3)] for node in G3.nodes(): weightedG3.add_node(node) @@ -195,12 +243,19 @@ weightedG3.add_edge(i, j, weight=weight_inter_G3) -def draw_graph(G, C, nodes_color_part, Gweights=None, - pos=None, edge_color='black', node_size=None, - shiftx=0, seed=0): - - if (pos is None): - pos = networkx.spring_layout(G, scale=1., seed=seed) +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, + seed=0, +): + if pos is None: + pos = networkx.spring_layout(G, scale=1.0, seed=seed) if shiftx != 0: for k, v in pos.items(): @@ -209,7 +264,9 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, alpha_edge = 0.7 width_edge = 1.8 if Gweights is None: - networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + networkx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) else: # We make more visible connections between activated nodes n = len(Gweights) @@ -222,35 +279,69 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, elif C[i, j] > 0: edgelist_deactivated.append((i, j)) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, - width=width_edge, alpha=alpha_edge, - edge_color=edge_color) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, - width=width_edge, alpha=0.1, - edge_color=edge_color) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) if Gweights is None: for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=node_size, alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) else: scaled_Gweights = Gweights / (0.5 * Gweights.max()) nodes_size = node_size * scaled_Gweights for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=nodes_size[node], alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=nodes_size[node], + alpha=1, + node_color=node_color, + ) return pos -def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T, - pos1=None, pos2=None, shiftx=4, switchx=False, - node_size=70, seed_G1=0, seed_G2=0): +def draw_transp_colored_GW( + G1, + C1, + G2, + C2, + part_G1, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + seed_G1=0, + seed_G2=0, +): starting_color = 0 # get graphs partition and their coloring part1 = part_G1.copy() - unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part1)] nodes_color_part1 = [] for cluster in part1: nodes_color_part1.append(unique_colors[cluster]) @@ -260,19 +351,39 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T, for i in range(len(G2.nodes())): j = np.argmax(T[:, i]) nodes_color_part2.append(nodes_color_part1[j]) - pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, - pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) - pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, - node_size=node_size, shiftx=shiftx, seed=seed_G2) + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + seed=seed_G1, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + seed=seed_G2, + ) for k1, v1 in pos1.items(): max_Tk1 = np.max(T[k1, :]) for k2, v2 in pos2.items(): - if (T[k1, k2] > 0): - pl.plot([pos1[k1][0], pos2[k2][0]], - [pos1[k1][1], pos2[k2][1]], - '-', lw=0.7, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), - color=nodes_color_part1[k1]) + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.7, + alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.0), + color=nodes_color_part1[k1], + ) return pos1, pos2 @@ -284,49 +395,127 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T, pl.figure(2, figsize=(15, 3.5)) pl.clf() pl.subplot(141) -pl.axis('off') - -pl.title('(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( - np.round(log_cg['fgw_dist'], 3), str(np.round(T_cg_sparsity, 2)) + ' %', - np.round(err_cg, 4), str(np.round(time_cg, 2)) + ' ms'), fontsize=fontsize) +pl.axis("off") + +pl.title( + "(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s" + % ( + np.round(log_cg["fgw_dist"], 3), + str(np.round(T_cg_sparsity, 2)) + " %", + np.round(err_cg, 4), + str(np.round(time_cg, 2)) + " ms", + ), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_GW( - weightedG2, C2, weightedG3, C3, part_G2, p1=T_cg.sum(1), p2=T_cg.sum(0), - T=T_cg, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=T_cg.sum(1), + p2=T_cg.sum(0), + T=T_cg, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G2, + seed_G2=seed_G3, +) pl.subplot(142) -pl.axis('off') - -pl.title('(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( - np.round(log_ppa['fgw_dist'], 3), str(np.round(T_ppa_sparsity, 2)) + ' %', - np.round(err_ppa, 4), str(np.round(time_ppa, 2)) + ' ms'), fontsize=fontsize) +pl.axis("off") + +pl.title( + "(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s" + % ( + np.round(log_ppa["fgw_dist"], 3), + str(np.round(T_ppa_sparsity, 2)) + " %", + np.round(err_ppa, 4), + str(np.round(time_ppa, 2)) + " ms", + ), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_GW( - weightedG2, C2, weightedG3, C3, part_G2, p1=T_ppa.sum(1), p2=T_ppa.sum(0), - T=T_ppa, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=T_ppa.sum(1), + p2=T_ppa.sum(0), + T=T_ppa, + pos1=pos1, + pos2=pos2, + shiftx=0.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.subplot(143) -pl.axis('off') - -pl.title('(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( - np.round(log_pgd['fgw_dist'], 3), str(np.round(T_pgd_sparsity, 2)) + ' %', - np.round(err_pgd, 4), str(np.round(time_pgd, 2)) + ' ms'), fontsize=fontsize) +pl.axis("off") + +pl.title( + "(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s" + % ( + np.round(log_pgd["fgw_dist"], 3), + str(np.round(T_pgd_sparsity, 2)) + " %", + np.round(err_pgd, 4), + str(np.round(time_pgd, 2)) + " ms", + ), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_GW( - weightedG2, C2, weightedG3, C3, part_G2, p1=T_pgd.sum(1), p2=T_pgd.sum(0), - T=T_pgd, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=T_pgd.sum(1), + p2=T_pgd.sum(0), + T=T_pgd, + pos1=pos1, + pos2=pos2, + shiftx=0.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.subplot(144) -pl.axis('off') - -pl.title('(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % ( - np.round(log_bapg['fgw_dist'], 3), str(np.round(T_bapg_sparsity, 2)) + ' %', - np.round(err_bapg, 4), str(np.round(time_bapg, 2)) + ' ms'), fontsize=fontsize) +pl.axis("off") + +pl.title( + "(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s" + % ( + np.round(log_bapg["fgw_dist"], 3), + str(np.round(T_bapg_sparsity, 2)) + " %", + np.round(err_bapg, 4), + str(np.round(time_bapg, 2)) + " ms", + ), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_GW( - weightedG2, C2, weightedG3, C3, part_G2, p1=T_bapg.sum(1), p2=T_bapg.sum(0), - T=T_bapg, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=T_bapg.sum(1), + p2=T_bapg.sum(0), + T=T_bapg, + pos1=pos1, + pos2=pos2, + shiftx=0.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.tight_layout() diff --git a/examples/gromov/plot_gnn_TFGW.py b/examples/gromov/plot_gnn_TFGW.py index de745031d..9ec27f47d 100644 --- a/examples/gromov/plot_gnn_TFGW.py +++ b/examples/gromov/plot_gnn_TFGW.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ============================== -Graph classification with Tempate Based Fused Gromov Wasserstein +Graph classification with Template Based Fused Gromov Wasserstein ============================== This example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] . @@ -17,7 +17,7 @@ # sphinx_gallery_thumbnail_number = 1 -#%% +# %% import matplotlib.pyplot as pl import torch @@ -47,15 +47,15 @@ n_nodes = 10 n_node_classes = 2 -#edge probabilities for the SBMs +# edge probabilities for the SBMs P1 = [[0.8]] P2 = [[0.9, 0.1], [0.1, 0.9]] -#block sizes +# block sizes block_sizes1 = [n_nodes] block_sizes2 = [n_nodes // 2, n_nodes // 2] -#node features +# node features x1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) x1 = one_hot(x1, num_classes=n_node_classes) x1 = torch.reshape(x1, (n_nodes, n_node_classes)) @@ -64,19 +64,25 @@ x2 = one_hot(x2, num_classes=n_node_classes) x2 = torch.reshape(x2, (n_nodes, n_node_classes)) -graphs1 = [GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) for i in range(n_graphs)] -graphs2 = [GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) for i in range(n_graphs)] +graphs1 = [ + GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0])) + for i in range(n_graphs) +] +graphs2 = [ + GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1])) + for i in range(n_graphs) +] graphs = graphs1 + graphs2 -#split the data into train and test sets +# split the data into train and test sets train_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs]) train_loader = DataLoader(train_graphs, batch_size=10, shuffle=True) test_loader = DataLoader(test_graphs, batch_size=10, shuffle=False) -#%% +# %% ############################################################################## # Plot data @@ -89,24 +95,28 @@ pl.figure(0, figsize=(8, 2.5)) pl.clf() pl.subplot(121) -pl.axis('off') -pl.title('Graph of class 1', fontsize=fontsize) +pl.axis("off") +pl.title("Graph of class 1", fontsize=fontsize) G = to_networkx(graphs1[0], to_undirected=True) pos = nx.spring_layout(G, seed=0) nx.draw_networkx(G, pos, with_labels=False, node_color="tab:blue") pl.subplot(122) -pl.axis('off') -pl.title('Graph of class 2', fontsize=fontsize) +pl.axis("off") +pl.title("Graph of class 2", fontsize=fontsize) G = to_networkx(graphs2[0], to_undirected=True) pos = nx.spring_layout(G, seed=0) -nx.draw_networkx(G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue") -nx.draw_networkx(G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red") +nx.draw_networkx( + G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color="tab:blue" +) +nx.draw_networkx( + G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color="tab:red" +) pl.tight_layout() pl.show() -#%% +# %% ############################################################################## # Pooling architecture using the TFGW layer @@ -118,7 +128,16 @@ class pooling_TFGW(nn.Module): Pooling architecture using the TFGW layer. """ - def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidden_layers, feature_init_mean=0., feature_init_std=1.): + def __init__( + self, + n_features, + n_templates, + n_template_nodes, + n_classes, + n_hidden_layers, + feature_init_mean=0.0, + feature_init_std=1.0, + ): """ Pooling architecture using the TFGW layer. """ @@ -131,7 +150,13 @@ def __init__(self, n_features, n_templates, n_template_nodes, n_classes, n_hidde self.conv = GCNConv(self.n_features, self.n_hidden_layers) - self.TFGW = TFGWPooling(self.n_hidden_layers, self.n_templates, self.n_template_nodes, feature_init_mean=feature_init_mean, feature_init_std=feature_init_std) + self.TFGW = TFGWPooling( + self.n_hidden_layers, + self.n_templates, + self.n_template_nodes, + feature_init_mean=feature_init_mean, + feature_init_std=feature_init_std, + ) self.linear = Linear(self.n_templates, n_classes) @@ -154,11 +179,19 @@ def forward(self, x, edge_index, batch=None): n_epochs = 25 -#store latent embeddings and classes for TSNE visualization +# store latent embeddings and classes for TSNE visualization embeddings_for_TSNE = [] classes = [] -model = pooling_TFGW(n_features=2, n_templates=2, n_template_nodes=2, n_classes=2, n_hidden_layers=2, feature_init_mean=0.5, feature_init_std=0.5) +model = pooling_TFGW( + n_features=2, + n_templates=2, + n_template_nodes=2, + n_classes=2, + n_hidden_layers=2, + feature_init_mean=0.5, + feature_init_std=0.5, +) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005) criterion = torch.nn.CrossEntropyLoss() @@ -167,7 +200,6 @@ def forward(self, x, edge_index, batch=None): all_loss = [] for epoch in range(n_epochs): - losses = [] accs = [] @@ -184,12 +216,14 @@ def forward(self, x, edge_index, batch=None): accs.append(train_acc) losses.append(loss.item()) - #store last classes and embeddings for TSNE visualization + # store last classes and embeddings for TSNE visualization if epoch == n_epochs - 1: embeddings_for_TSNE.append(latent_embedding) classes.append(data.y) - print(f'Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}') + print( + f"Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}" + ) all_accuracy.append(torch.mean(torch.tensor(accs))) all_loss.append(torch.mean(torch.tensor(losses))) @@ -199,18 +233,18 @@ def forward(self, x, edge_index, batch=None): pl.clf() pl.subplot(121) pl.plot(all_loss) -pl.xlabel('epochs') -pl.title('Loss') +pl.xlabel("epochs") +pl.title("Loss") pl.subplot(122) pl.plot(all_accuracy) -pl.xlabel('epochs') -pl.title('Accuracy') +pl.xlabel("epochs") +pl.title("Accuracy") pl.tight_layout() pl.show() -#Test +# Test test_accs = [] @@ -225,17 +259,21 @@ def forward(self, x, edge_index, batch=None): classes = torch.hstack(classes) -print(f'Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}') +print(f"Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}") -#%% +# %% ############################################################################## # TSNE visualization of graph classification # --------- -indices = torch.randint(2 * n_graphs, (60,)) # select a subset of embeddings for TSNE visualization +indices = torch.randint( + 2 * n_graphs, (60,) +) # select a subset of embeddings for TSNE visualization latent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :] -TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(latent_embeddings) +TSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform( + latent_embeddings +) class_0 = classes[indices] == 0 class_1 = classes[indices] == 1 @@ -244,12 +282,22 @@ def forward(self, x, edge_index, batch=None): TSNE_embeddings_1 = TSNE_embeddings[class_1, :] pl.figure(2, figsize=(6, 2.5)) -pl.scatter(TSNE_embeddings_0[:, 0], TSNE_embeddings_0[:, 1], - alpha=0.5, marker='o', label='class 1') -pl.scatter(TSNE_embeddings_1[:, 0], TSNE_embeddings_1[:, 1], - alpha=0.5, marker='o', label='class 2') +pl.scatter( + TSNE_embeddings_0[:, 0], + TSNE_embeddings_0[:, 1], + alpha=0.5, + marker="o", + label="class 1", +) +pl.scatter( + TSNE_embeddings_1[:, 0], + TSNE_embeddings_1[:, 1], + alpha=0.5, + marker="o", + label="class 2", +) pl.legend() -pl.title('TSNE in the latent space after training') +pl.title("TSNE in the latent space after training") pl.show() diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index 252267fd7..b376af642 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -74,9 +74,9 @@ fig = pl.figure(1) ax1 = fig.add_subplot(121) -ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -ax2 = fig.add_subplot(122, projection='3d') -ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r') +ax1.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +ax2 = fig.add_subplot(122, projection="3d") +ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color="r") pl.show() ############################################################################# @@ -94,11 +94,11 @@ pl.figure(2) pl.subplot(121) pl.imshow(C1) -pl.title('C1') +pl.title("C1") pl.subplot(122) pl.imshow(C2) -pl.title('C2') +pl.title("C2") pl.show() @@ -112,26 +112,36 @@ # Conditional Gradient algorithm gw0, log0 = ot.gromov.gromov_wasserstein( - C1, C2, p, q, 'square_loss', verbose=True, log=True) + C1, C2, p, q, "square_loss", verbose=True, log=True +) # Proximal Point algorithm with Kullback-Leibler as proximal operator gw, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA', - log=True, verbose=True) + C1, C2, p, q, "square_loss", epsilon=5e-4, solver="PPA", log=True, verbose=True +) # Projected Gradient algorithm with entropic regularization gwe, loge = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD', - log=True, verbose=True) - -print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist'])) -print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist'])) -print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist'])) + C1, C2, p, q, "square_loss", epsilon=5e-4, solver="PGD", log=True, verbose=True +) + +print( + "Gromov-Wasserstein distance estimated with Conditional Gradient solver: " + + str(log0["gw_dist"]) +) +print( + "Gromov-Wasserstein distance estimated with Proximal Point solver: " + + str(log["gw_dist"]) +) +print( + "Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: " + + str(loge["gw_dist"]) +) # compute OT sparsity level -gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2) -gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2) -gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2) +gw0_sparsity = 100 * (gw0 == 0.0).astype(np.float64).sum() / (n_samples**2) +gw_sparsity = 100 * (gw == 0.0).astype(np.float64).sum() / (n_samples**2) +gwe_sparsity = 100 * (gwe == 0.0).astype(np.float64).sum() / (n_samples**2) # Methods using Sinkhorn projections tend to produce feasibility errors on the # marginal constraints @@ -141,25 +151,43 @@ erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q) pl.figure(3, (10, 6)) -cmap = 'Blues' +cmap = "Blues" fontsize = 12 pl.subplot(131) pl.imshow(gw0, cmap=cmap) -pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % ( - np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))), - fontsize=fontsize) +pl.title( + "(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s" + % ( + np.round(log0["gw_dist"], 4), + str(np.round(gw0_sparsity, 2)) + " %", + np.round(np.round(err0, 4)), + ), + fontsize=fontsize, +) pl.subplot(132) pl.imshow(gw, cmap=cmap) -pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( - np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)), - fontsize=fontsize) +pl.title( + "(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s" + % ( + np.round(log["gw_dist"], 4), + str(np.round(gw_sparsity, 2)) + " %", + np.round(err, 4), + ), + fontsize=fontsize, +) pl.subplot(133) pl.imshow(gwe, cmap=cmap) -pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( - np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)), - fontsize=fontsize) +pl.title( + "Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s" + % ( + np.round(loge["gw_dist"], 4), + str(np.round(gwe_sparsity, 2)) + " %", + np.round(erre, 4), + ), + fontsize=fontsize, +) pl.tight_layout() pl.show() @@ -174,26 +202,30 @@ def loss(x, y): return np.abs(x - y) -pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, - log=True) +pgw, plog = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True +) -sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, - log=True) +sgw, slog = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, epsilon=0.1, max_iter=100, log=True +) -print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) -print('Variance estimated: ' + str(plog['gw_dist_std'])) -print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) -print('Variance estimated: ' + str(slog['gw_dist_std'])) +print( + "Pointwise Gromov-Wasserstein distance estimated: " + str(plog["gw_dist_estimated"]) +) +print("Variance estimated: " + str(plog["gw_dist_std"])) +print("Sampled Gromov-Wasserstein distance: " + str(slog["gw_dist_estimated"])) +print("Variance estimated: " + str(slog["gw_dist_std"])) pl.figure(4, (10, 5)) pl.subplot(121) pl.imshow(pgw.toarray(), cmap=cmap) -pl.title('Pointwise Gromov Wasserstein') +pl.title("Pointwise Gromov Wasserstein") pl.subplot(122) pl.imshow(sgw, cmap=cmap) -pl.title('Sampled Gromov Wasserstein') +pl.title("Sampled Gromov Wasserstein") pl.show() diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index 1b9abbf7b..a06a40837 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -60,11 +60,8 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9): rng = np.random.RandomState(seed=3) mds = manifold.MDS( - dim, - max_iter=max_iter, - eps=1e-9, - dissimilarity='precomputed', - n_init=1) + dim, max_iter=max_iter, eps=1e-9, dissimilarity="precomputed", n_init=1 + ) pos = mds.fit(C).embedding_ nmds = manifold.MDS( @@ -73,7 +70,8 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9): eps=1e-9, dissimilarity="precomputed", random_state=rng, - n_init=1) + n_init=1, + ) npos = nmds.fit_transform(C, init=pos) return npos @@ -91,13 +89,15 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) -this_file = os.path.realpath('__file__') -data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +this_file = os.path.realpath("__file__") +data_path = os.path.join(Path(this_file).parent.parent.parent, "data") -square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2] -cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2] -triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2] -star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2] +square = plt.imread(os.path.join(data_path, "square.png")).astype(np.float64)[:, :, 2] +cross = plt.imread(os.path.join(data_path, "cross.png")).astype(np.float64)[:, :, 2] +triangle = plt.imread(os.path.join(data_path, "triangle.png")).astype(np.float64)[ + :, :, 2 +] +star = plt.imread(os.path.join(data_path, "star.png")).astype(np.float64)[:, :, 2] shapes = [square, cross, triangle, star] @@ -132,31 +132,55 @@ def im2mat(img): Ct01 = [0 for i in range(2)] for i in range(2): - Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], - [ps[0], ps[1] - ], p, lambdast[i], 'square_loss', # 5e-4, - max_iter=100, tol=1e-3) + Ct01[i] = ot.gromov.gromov_barycenters( + n_samples, + [Cs[0], Cs[1]], + [ps[0], ps[1]], + p, + lambdast[i], + "square_loss", # 5e-4, + max_iter=100, + tol=1e-3, + ) Ct02 = [0 for i in range(2)] for i in range(2): - Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], - [ps[0], ps[2] - ], p, lambdast[i], 'square_loss', # 5e-4, - max_iter=100, tol=1e-3) + Ct02[i] = ot.gromov.gromov_barycenters( + n_samples, + [Cs[0], Cs[2]], + [ps[0], ps[2]], + p, + lambdast[i], + "square_loss", # 5e-4, + max_iter=100, + tol=1e-3, + ) Ct13 = [0 for i in range(2)] for i in range(2): - Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], - [ps[1], ps[3] - ], p, lambdast[i], 'square_loss', # 5e-4, - max_iter=100, tol=1e-3) + Ct13[i] = ot.gromov.gromov_barycenters( + n_samples, + [Cs[1], Cs[3]], + [ps[1], ps[3]], + p, + lambdast[i], + "square_loss", # 5e-4, + max_iter=100, + tol=1e-3, + ) Ct23 = [0 for i in range(2)] for i in range(2): - Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], - [ps[2], ps[3] - ], p, lambdast[i], 'square_loss', # 5e-4, - max_iter=100, tol=1e-3) + Ct23[i] = ot.gromov.gromov_barycenters( + n_samples, + [Cs[2], Cs[3]], + [ps[2], ps[3]], + p, + lambdast[i], + "square_loss", # 5e-4, + max_iter=100, + tol=1e-3, + ) ############################################################################## @@ -192,59 +216,59 @@ def im2mat(img): ax1 = plt.subplot2grid((4, 4), (0, 0)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') +ax1.scatter(npos[0][:, 0], npos[0][:, 1], color="r") ax2 = plt.subplot2grid((4, 4), (0, 1)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') +ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color="b") ax3 = plt.subplot2grid((4, 4), (0, 2)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') +ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color="b") ax4 = plt.subplot2grid((4, 4), (0, 3)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') +ax4.scatter(npos[1][:, 0], npos[1][:, 1], color="r") ax5 = plt.subplot2grid((4, 4), (1, 0)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') +ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color="b") ax6 = plt.subplot2grid((4, 4), (1, 3)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') +ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color="b") ax7 = plt.subplot2grid((4, 4), (2, 0)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') +ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color="b") ax8 = plt.subplot2grid((4, 4), (2, 3)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') +ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color="b") ax9 = plt.subplot2grid((4, 4), (3, 0)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') +ax9.scatter(npos[2][:, 0], npos[2][:, 1], color="r") ax10 = plt.subplot2grid((4, 4), (3, 1)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') +ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color="b") ax11 = plt.subplot2grid((4, 4), (3, 2)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') +ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color="b") ax12 = plt.subplot2grid((4, 4), (3, 3)) plt.xlim((-1, 1)) plt.ylim((-1, 1)) -ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r') +ax12.scatter(npos[3][:, 0], npos[3][:, 1], color="r") diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py index 8cccf8825..4b94c9d40 100755 --- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -41,7 +41,12 @@ import numpy as np import matplotlib.pylab as pl from sklearn.manifold import MDS -from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning +from ot.gromov import ( + gromov_wasserstein_linear_unmixing, + gromov_wasserstein_dictionary_learning, + fused_gromov_wasserstein_linear_unmixing, + fused_gromov_wasserstein_dictionary_learning, +) import ot import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -81,16 +86,23 @@ # Visualize samples -def plot_graph(x, C, binary=True, color='C0', s=None): + +def plot_graph(x, C, binary=True, color="C0", s=None): for j in range(C.shape[0]): for i in range(j): if binary: if C[i, j] > 0: - pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.plot( + [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k" + ) else: # connection intensity proportional to C[i,j] - pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + pl.plot( + [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color="k" + ) - pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + pl.scatter( + x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k", cmap="tab10", vmax=9 + ) pl.figure(1, (12, 8)) @@ -98,14 +110,14 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters # get 2d position for nodes - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) - pl.title('(graph) sample from label ' + str(c), fontsize=14) - plot_graph(x, C, binary=True, color='C0', s=50.) + pl.title("(graph) sample from label " + str(c), fontsize=14) + plot_graph(x, C, binary=True, color="C0", s=50.0) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) - pl.title('(matrix) sample from label %s \n' % c, fontsize=14) - pl.imshow(C, interpolation='nearest') + pl.title("(matrix) sample from label %s \n" % c, fontsize=14) + pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() @@ -123,21 +135,34 @@ def plot_graph(x, C, binary=True, color='C0', s=None): nt = 6 # of 6 nodes each q = ot.unif(nt) -reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} +reg = 0.0 # regularization coefficient to promote sparsity of unmixings {w_s} Cdict_GW, log = gromov_wasserstein_dictionary_learning( - Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, - learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, - use_log=True, use_adam_optimizer=True, verbose=True + Cs=dataset, + D=D, + nt=nt, + ps=ps, + q=q, + epochs=10, + batch_size=16, + learning_rate=0.1, + reg=reg, + projection="nonnegative_symmetric", + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=30, + max_iter_inner=300, + use_log=True, + use_adam_optimizer=True, + verbose=True, ) # visualize loss evolution over epochs pl.figure(2, (4, 3)) pl.clf() -pl.title('loss evolution by epoch', fontsize=14) -pl.plot(log['loss_epochs']) -pl.xlabel('epochs', fontsize=12) -pl.ylabel('loss', fontsize=12) +pl.title("loss evolution by epoch", fontsize=14) +pl.plot(log["loss_epochs"]) +pl.xlabel("epochs", fontsize=12) +pl.ylabel("loss", fontsize=12) pl.tight_layout() pl.show() @@ -153,14 +178,14 @@ def plot_graph(x, C, binary=True, color='C0', s=None): pl.clf() for idx_atom, atom in enumerate(Cdict_GW): scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - scaled_atom) pl.subplot(2, D, idx_atom + 1) - pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) - plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) + pl.title("(graph) atom " + str(idx_atom + 1), fontsize=14) + plot_graph(x, atom / atom.max(), binary=False, color="C0", s=100.0) pl.axis("off") pl.subplot(2, D, D + idx_atom + 1) - pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) - pl.imshow(scaled_atom, interpolation='nearest') + pl.title("(matrix) atom %s \n" % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation="nearest") pl.colorbar() pl.axis("off") pl.tight_layout() @@ -176,40 +201,62 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for C in dataset: p = ot.unif(C.shape[0]) unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( - C, Cdict_GW, p=p, q=q, reg=reg, - tol_outer=10**(-5), tol_inner=10**(-5), - max_iter_outer=30, max_iter_inner=300 + C, + Cdict_GW, + p=p, + q=q, + reg=reg, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=30, + max_iter_inner=300, ) unmixings.append(unmixing) reconstruction_errors.append(reconstruction_error) unmixings = np.array(unmixings) -print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) +print("cumulated reconstruction error:", np.array(reconstruction_errors).sum()) # Compute the 2D representation of the unmixing living in the 2-simplex of probability unmixings2D = np.zeros(shape=(N, 2)) for i, w in enumerate(unmixings): - unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. - unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. -x = [0., 0.] -y = [1., 0.] -z = [0.5, np.sqrt(3) / 2.] + unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 + unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 +x = [0.0, 0.0] +y = [1.0, 0.0] +z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(4, (4, 4)) pl.clf() -pl.title('Embedding space', fontsize=14) +pl.title("Embedding space", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=40.0, + label="1 cluster", + ) else: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) -pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') -pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) -pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) -pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) -pl.axis('off') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=40.0, + label="%s clusters" % (cluster + 1), + ) +pl.scatter( + extremities[:, 0], extremities[:, 1], c="black", marker="x", s=80.0, label="atoms" +) +pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) +pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) +pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) +pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() @@ -228,11 +275,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None): n = dataset[i].shape[0] F = np.zeros((n, 3)) if i < Nc: # graph with 1 cluster - F[:, 0] = 1. + F[:, 0] = 1.0 elif i < 2 * Nc: # graph with 2 clusters - F[:, 1] = 1. + F[:, 1] = 1.0 else: # graph with 3 clusters - F[:, 2] = 1. + F[:, 2] = 1.0 dataset_features.append(F) pl.figure(5, (12, 8)) @@ -240,16 +287,16 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters F = dataset_features[(c - 1) * Nc] - colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + colors = ["C" + str(np.argmax(F[i])) for i in range(F.shape[0])] # get 2d position for nodes - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) - pl.title('(graph) sample from label ' + str(c), fontsize=14) + pl.title("(graph) sample from label " + str(c), fontsize=14) plot_graph(x, C, binary=True, color=colors, s=50) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) - pl.title('(matrix) sample from label %s \n' % c, fontsize=14) - pl.imshow(C, interpolation='nearest') + pl.title("(matrix) sample from label %s \n" % c, fontsize=14) + pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() @@ -268,18 +315,34 @@ def plot_graph(x, C, binary=True, color='C0', s=None): Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( - Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, - epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, - projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True + Cs=dataset, + Ys=dataset_features, + D=D, + nt=nt, + ps=ps, + q=q, + alpha=alpha, + epochs=10, + batch_size=16, + learning_rate_C=0.1, + learning_rate_Y=0.1, + reg=reg, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=30, + max_iter_inner=300, + projection="nonnegative_symmetric", + use_log=True, + use_adam_optimizer=True, + verbose=True, ) # visualize loss evolution pl.figure(6, (4, 3)) pl.clf() -pl.title('loss evolution by epoch', fontsize=14) -pl.plot(log['loss_epochs']) -pl.xlabel('epochs', fontsize=12) -pl.ylabel('loss', fontsize=12) +pl.title("loss evolution by epoch", fontsize=14) +pl.plot(log["loss_epochs"]) +pl.xlabel("epochs", fontsize=12) +pl.ylabel("loss", fontsize=12) pl.tight_layout() pl.show() @@ -295,16 +358,16 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) - #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) - colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + # scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) + colors = ["C%s" % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - scaled_atom) pl.subplot(2, D, idx_atom + 1) - pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) + pl.title("(attributed graph) atom " + str(idx_atom + 1), fontsize=14) plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) pl.axis("off") pl.subplot(2, D, D + idx_atom + 1) - pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) - pl.imshow(scaled_atom, interpolation='nearest') + pl.title("(matrix) atom %s \n" % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation="nearest") pl.colorbar() pl.axis("off") pl.tight_layout() @@ -321,40 +384,68 @@ def plot_graph(x, C, binary=True, color='C0', s=None): C = dataset[i] Y = dataset_features[i] p = ot.unif(C.shape[0]) - unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( - C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, - reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 + unmixing, Cembedded, Yembedded, OT, reconstruction_error = ( + fused_gromov_wasserstein_linear_unmixing( + C, + Y, + Cdict_FGW, + Ydict_FGW, + p=p, + q=q, + alpha=alpha, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=30, + max_iter_inner=300, + ) ) unmixings.append(unmixing) reconstruction_errors.append(reconstruction_error) unmixings = np.array(unmixings) -print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) +print("cumulated reconstruction error:", np.array(reconstruction_errors).sum()) # Visualize unmixings in the 2-simplex of probability unmixings2D = np.zeros(shape=(N, 2)) for i, w in enumerate(unmixings): - unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. - unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. -x = [0., 0.] -y = [1., 0.] -z = [0.5, np.sqrt(3) / 2.] + unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 + unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 +x = [0.0, 0.0] +y = [1.0, 0.0] +z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(8, (4, 4)) pl.clf() -pl.title('Embedding space', fontsize=14) +pl.title("Embedding space", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=40.0, + label="1 cluster", + ) else: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) - -pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') -pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) -pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) -pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) -pl.axis('off') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=40.0, + label="%s clusters" % (cluster + 1), + ) + +pl.scatter( + extremities[:, 0], extremities[:, 1], c="black", marker="x", s=80.0, label="atoms" +) +pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) +pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) +pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) +pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() diff --git a/examples/gromov/plot_quantized_gromov_wasserstein.py b/examples/gromov/plot_quantized_gromov_wasserstein.py index 02d777c71..cdfbb3cd5 100644 --- a/examples/gromov/plot_quantized_gromov_wasserstein.py +++ b/examples/gromov/plot_quantized_gromov_wasserstein.py @@ -53,10 +53,14 @@ from scipy.sparse.csgraph import shortest_path from ot.gromov import ( - quantized_fused_gromov_wasserstein_partitioned, quantized_fused_gromov_wasserstein, - get_graph_partition, get_graph_representants, format_partitioned_graph, + quantized_fused_gromov_wasserstein_partitioned, + quantized_fused_gromov_wasserstein, + get_graph_partition, + get_graph_representants, + format_partitioned_graph, quantized_fused_gromov_wasserstein_samples, - get_partition_and_representants_samples) + get_partition_and_representants_samples, +) ############################################################################# # @@ -67,11 +71,8 @@ N1 = 30 # 2 communities N2 = 45 # 3 communities -p1 = [[0.8, 0.1], - [0.1, 0.7]] -p2 = [[0.8, 0.1, 0.], - [0.1, 0.75, 0.1], - [0., 0.1, 0.7]] +p1 = [[0.8, 0.1], [0.1, 0.7]] +p2 = [[0.8, 0.1, 0.0], [0.1, 0.75, 0.1], [0.0, 0.1, 0.7]] G1 = sbm(seed=0, sizes=[N1 // 2, N1 // 2], p=p1) G2 = sbm(seed=0, sizes=[N2 // 3, N2 // 3, N2 // 3], p=p2) @@ -88,11 +89,11 @@ # Add weights on the edges for visualization later on weight_intra_G1 = 5 weight_inter_G1 = 0.5 -weight_intra_G2 = 1. +weight_intra_G2 = 1.0 weight_inter_G2 = 1.5 weightedG1 = networkx.Graph() -part_G1 = [G1.nodes[i]['block'] for i in range(N1)] +part_G1 = [G1.nodes[i]["block"] for i in range(N1)] for node in G1.nodes(): weightedG1.add_node(node) @@ -103,7 +104,7 @@ weightedG1.add_edge(i, j, weight=weight_inter_G1) weightedG2 = networkx.Graph() -part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G2 = [G2.nodes[i]["block"] for i in range(N2)] for node in G2.nodes(): weightedG2.add_node(node) @@ -116,10 +117,10 @@ # setup for graph visualization -def node_coloring(part, starting_color=0): +def node_coloring(part, starting_color=0): # get graphs partition and their coloring - unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part)] + unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part)] nodes_color_part = [] for cluster in part: nodes_color_part.append(unique_colors[cluster]) @@ -127,12 +128,22 @@ def node_coloring(part, starting_color=0): return nodes_color_part -def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, - edge_color='black', alpha_edge=0.7, node_size=None, - shiftx=0, seed=0, highlight_rep=False): - - if (pos is None): - pos = networkx.spring_layout(G, scale=1., seed=seed) +def draw_graph( + G, + C, + nodes_color_part, + rep_indices, + node_alphas=None, + pos=None, + edge_color="black", + alpha_edge=0.7, + node_size=None, + shiftx=0, + seed=0, + highlight_rep=False, +): + if pos is None: + pos = networkx.spring_layout(G, scale=1.0, seed=seed) if shiftx != 0: for k, v in pos.items(): @@ -142,24 +153,35 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, if not highlight_rep: networkx.draw_networkx_edges( - G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) else: for edge in G.edges: if (edge[0] in rep_indices) and (edge[1] in rep_indices): networkx.draw_networkx_edges( - G, pos, edgelist=[edge], width=width_edge, alpha=alpha_edge, - edge_color=edge_color) + G, + pos, + edgelist=[edge], + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) else: networkx.draw_networkx_edges( - G, pos, edgelist=[edge], width=width_edge, alpha=0.2, - edge_color=edge_color) + G, + pos, + edgelist=[edge], + width=width_edge, + alpha=0.2, + edge_color=edge_color, + ) for node, node_color in enumerate(nodes_color_part): - local_node_shape, local_node_size = 'o', node_size + local_node_shape, local_node_size = "o", node_size if highlight_rep: if node in rep_indices: - local_node_shape, local_node_size = '*', 6 * node_size + local_node_shape, local_node_size = "*", 6 * node_size if node_alphas is None: alpha = 0.9 @@ -169,10 +191,15 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, else: alpha = node_alphas[node] - networkx.draw_networkx_nodes(G, pos, nodelist=[node], alpha=alpha, - node_shape=local_node_shape, - node_size=local_node_size, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + alpha=alpha, + node_shape=local_node_shape, + node_size=local_node_size, + node_color=node_color, + ) return pos @@ -188,16 +215,18 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, # 1-a) Partition C1 and C2 in 2 and 3 clusters respectively using Louvain # algorithm from Networkx. Then encode these partitions via vectors of assignments. -part_method = 'louvain' -rep_method = 'pagerank' +part_method = "louvain" +rep_method = "pagerank" npart_1 = 2 # 2 clusters used to describe C1 npart_2 = 3 # 3 clusters used to describe C2 part1 = get_graph_partition( - C1, npart=npart_1, part_method=part_method, F=None, alpha=1.) + C1, npart=npart_1, part_method=part_method, F=None, alpha=1.0 +) part2 = get_graph_partition( - C2, npart=npart_2, part_method=part_method, F=None, alpha=1.) + C2, npart=npart_2, part_method=part_method, F=None, alpha=1.0 +) # 1-b) Select representant in each partition using the Pagerank algorithm # implementation from networkx. @@ -205,22 +234,33 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, rep_indices1 = get_graph_representants(C1, part1, rep_method=rep_method) rep_indices2 = get_graph_representants(C2, part2, rep_method=rep_method) -# 1-c) Formate partitions such that: +# 1-c) Format partitions such that: # CR contains relations between representants in each space. # list_R contains relations between samples and representants within each partition. # list_h contains samples relative importance within each partition. CR1, list_R1, list_h1 = format_partitioned_graph( - spC1, h1, part1, rep_indices1, F=None, M=None, alpha=1.) + spC1, h1, part1, rep_indices1, F=None, M=None, alpha=1.0 +) CR2, list_R2, list_h2 = format_partitioned_graph( - spC2, h2, part2, rep_indices2, F=None, M=None, alpha=1.) + spC2, h2, part2, rep_indices2, F=None, M=None, alpha=1.0 +) # 1-d) call to partitioned quantized gromov-wasserstein solver OT_global_, OTs_local_, OT_, log_ = quantized_fused_gromov_wasserstein_partitioned( - CR1, CR2, list_R1, list_R2, list_h1, list_h2, MR=None, - alpha=1., build_OT=True, log=True) + CR1, + CR2, + list_R1, + list_R2, + list_h1, + list_h2, + MR=None, + alpha=1.0, + build_OT=True, + log=True, +) # Visualization of the graph pre-processing @@ -235,49 +275,69 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, nodes_color_part1 = node_coloring(part1_, starting_color=0) -nodes_color_part2 = node_coloring(part2_, starting_color=np.unique(nodes_color_part1).shape[0]) +nodes_color_part2 = node_coloring( + part2_, starting_color=np.unique(nodes_color_part1).shape[0] +) pl.figure(1, figsize=(6, 5)) pl.clf() -pl.axis('off') +pl.axis("off") pl.subplot(2, 3, 1) -pl.title(r'Input graph: $\mathbf{spC_1}$', fontsize=fontsize) +pl.title(r"Input graph: $\mathbf{spC_1}$", fontsize=fontsize) pos1 = draw_graph( - G1, C1, ['C0' for _ in part1_], rep_indices1, node_size=node_size, seed=seed_G1) + G1, C1, ["C0" for _ in part1_], rep_indices1, node_size=node_size, seed=seed_G1 +) pl.subplot(2, 3, 2) -pl.title('Partitioning', fontsize=fontsize) +pl.title("Partitioning", fontsize=fontsize) _ = draw_graph( - G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, seed=seed_G1) + G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, seed=seed_G1 +) pl.subplot(2, 3, 3) -pl.title('Representant selection', fontsize=fontsize) +pl.title("Representant selection", fontsize=fontsize) _ = draw_graph( - G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, - seed=seed_G1, highlight_rep=True) + G1, + C1, + nodes_color_part1, + rep_indices1, + pos=pos1, + node_size=node_size, + seed=seed_G1, + highlight_rep=True, +) pl.subplot(2, 3, 4) -pl.title(r'Input graph: $\mathbf{spC_2}$', fontsize=fontsize) +pl.title(r"Input graph: $\mathbf{spC_2}$", fontsize=fontsize) pos2 = draw_graph( - G2, C2, ['C0' for _ in part2_], rep_indices2, node_size=node_size, seed=seed_G2) + G2, C2, ["C0" for _ in part2_], rep_indices2, node_size=node_size, seed=seed_G2 +) pl.subplot(2, 3, 5) -pl.title(r'Partitioning', fontsize=fontsize) +pl.title(r"Partitioning", fontsize=fontsize) _ = draw_graph( - G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, seed=seed_G2) + G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, seed=seed_G2 +) pl.subplot(2, 3, 6) -pl.title(r'Representant selection', fontsize=fontsize) +pl.title(r"Representant selection", fontsize=fontsize) _ = draw_graph( - G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, - seed=seed_G2, highlight_rep=True) + G2, + C2, + nodes_color_part2, + rep_indices2, + pos=pos2, + node_size=node_size, + seed=seed_G2, + highlight_rep=True, +) pl.tight_layout() ############################################################################# @@ -295,10 +355,23 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, # no node features are considered on this synthetic dataset. Hence we simply # let F1, F2 = None and set alpha = 1. OT_global, OTs_local, OT, log = quantized_fused_gromov_wasserstein( - spC1, spC2, npart_1, npart_2, h1, h2, C1_aux=C1, C2_aux=C2, F1=None, F2=None, - alpha=1., part_method=part_method, rep_method=rep_method, log=True) - -qGW_dist = log['qFGW_dist'] + spC1, + spC2, + npart_1, + npart_2, + h1, + h2, + C1_aux=C1, + C2_aux=C2, + F1=None, + F2=None, + alpha=1.0, + part_method=part_method, + rep_method=rep_method, + log=True, +) + +qGW_dist = log["qFGW_dist"] ############################################################################# @@ -312,70 +385,139 @@ def draw_graph(G, C, nodes_color_part, rep_indices, node_alphas=None, pos=None, def draw_transp_colored_qGW( - G1, C1, G2, C2, part1, part2, rep_indices1, rep_indices2, T, - pos1=None, pos2=None, shiftx=4, switchx=False, node_size=70, - seed_G1=0, seed_G2=0, highlight_rep=False): + G1, + C1, + G2, + C2, + part1, + part2, + rep_indices1, + rep_indices2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + seed_G1=0, + seed_G2=0, + highlight_rep=False, +): starting_color = 0 # get graphs partition and their coloring - unique_colors1 = ['C%s' % (starting_color + i) for i in np.unique(part1)] + unique_colors1 = ["C%s" % (starting_color + i) for i in np.unique(part1)] nodes_color_part1 = [] for cluster in part1: nodes_color_part1.append(unique_colors1[cluster]) starting_color = len(unique_colors1) + 1 - unique_colors2 = ['C%s' % (starting_color + i) for i in np.unique(part2)] + unique_colors2 = ["C%s" % (starting_color + i) for i in np.unique(part2)] nodes_color_part2 = [] for cluster in part2: nodes_color_part2.append(unique_colors2[cluster]) pos1 = draw_graph( - G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, - shiftx=0, seed=seed_G1, highlight_rep=highlight_rep) + G1, + C1, + nodes_color_part1, + rep_indices1, + pos=pos1, + node_size=node_size, + shiftx=0, + seed=seed_G1, + highlight_rep=highlight_rep, + ) pos2 = draw_graph( - G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, - shiftx=shiftx, seed=seed_G1, highlight_rep=highlight_rep) + G2, + C2, + nodes_color_part2, + rep_indices2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + seed=seed_G1, + highlight_rep=highlight_rep, + ) if not highlight_rep: for k1, v1 in pos1.items(): max_Tk1 = np.max(T[k1, :]) for k2, v2 in pos2.items(): - if (T[k1, k2] > 0): - pl.plot([pos1[k1][0], pos2[k2][0]], - [pos1[k1][1], pos2[k2][1]], - '-', lw=0.7, alpha=T[k1, k2] / max_Tk1, - color=nodes_color_part1[k1]) + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.7, + alpha=T[k1, k2] / max_Tk1, + color=nodes_color_part1[k1], + ) else: # OT is only between representants for id1, node_id1 in enumerate(rep_indices1): max_Tk1 = np.max(T[id1, :]) for id2, node_id2 in enumerate(rep_indices2): - if (T[id1, id2] > 0): - pl.plot([pos1[node_id1][0], pos2[node_id2][0]], - [pos1[node_id1][1], pos2[node_id2][1]], - '-', lw=0.8, alpha=T[id1, id2] / max_Tk1, - color=nodes_color_part1[node_id1]) + if T[id1, id2] > 0: + pl.plot( + [pos1[node_id1][0], pos2[node_id2][0]], + [pos1[node_id1][1], pos2[node_id2][1]], + "-", + lw=0.8, + alpha=T[id1, id2] / max_Tk1, + color=nodes_color_part1[node_id1], + ) return pos1, pos2 pl.figure(2, figsize=(5, 2.5)) pl.clf() -pl.axis('off') +pl.axis("off") pl.subplot(1, 2, 1) -pl.title(r'qGW$(\mathbf{spC_1}, \mathbf{spC_1}) =%s$' % (np.round(qGW_dist, 3)), fontsize=fontsize) +pl.title( + r"qGW$(\mathbf{spC_1}, \mathbf{spC_1}) =%s$" % (np.round(qGW_dist, 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_qGW( - weightedG1, C1, weightedG2, C2, part1_, part2_, rep_indices1, rep_indices2, - T=OT_, shiftx=1.5, node_size=node_size, seed_G1=seed_G1, seed_G2=seed_G2) + weightedG1, + C1, + weightedG2, + C2, + part1_, + part2_, + rep_indices1, + rep_indices2, + T=OT_, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G1, + seed_G2=seed_G2, +) pl.tight_layout() pl.subplot(1, 2, 2) -pl.title(r' GW$(\mathbf{CR_1}, \mathbf{CR_2}) =%s$' % (np.round(log_['global dist'], 3)), fontsize=fontsize) +pl.title( + r" GW$(\mathbf{CR_1}, \mathbf{CR_2}) =%s$" % (np.round(log_["global dist"], 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_qGW( - weightedG1, C1, weightedG2, C2, part1_, part2_, rep_indices1, rep_indices2, - T=OT_global, shiftx=1.5, node_size=node_size, seed_G1=seed_G1, seed_G2=seed_G2, - highlight_rep=True) + weightedG1, + C1, + weightedG2, + C2, + part1_, + part2_, + rep_indices1, + rep_indices2, + T=OT_global, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G1, + seed_G2=seed_G2, + highlight_rep=True, +) pl.tight_layout() pl.show() @@ -406,7 +548,7 @@ def draw_transp_colored_qGW( # Further associated to color intensity features derived from z FX = z - z.min() / (z.max() - z.min()) -FX = np.clip(0.8 * FX + 0.2, a_min=0.2, a_max=1.) # for numerical issues +FX = np.clip(0.8 * FX + 0.2, a_min=0.2, a_max=1.0) # for numerical issues FY = FX @@ -418,10 +560,8 @@ def draw_transp_colored_qGW( # Compute the partitioning and representant selection further used within # qFGW wrapper, both provided by a K-means algorithm. Then visualize partitioned spaces. -part1, rep_indices1 = get_partition_and_representants_samples( - X, 4, 'kmeans', 0) -part2, rep_indices2 = get_partition_and_representants_samples( - Y, 4, 'kmeans', 0) +part1, rep_indices1 = get_partition_and_representants_samples(X, 4, "kmeans", 0) +part2, rep_indices2 = get_partition_and_representants_samples(Y, 4, "kmeans", 0) upart1 = np.unique(part1) upart2 = np.unique(part2) @@ -433,7 +573,7 @@ def draw_transp_colored_qGW( ax1 = fig.add_subplot(1, 3, 1) ax1.set_title("2D curve") ax1.scatter(X[:, 0], X[:, 1], color="C0", alpha=FX, s=s) -plt.axis('off') +plt.axis("off") ax2 = fig.add_subplot(1, 3, 2) @@ -441,7 +581,7 @@ def draw_transp_colored_qGW( for i, elem in enumerate(upart1): idx = np.argwhere(part1 == elem)[:, 0] ax2.scatter(X[idx, 0], X[idx, 1], color="C%s" % i, alpha=FX[idx], s=s) -plt.axis('off') +plt.axis("off") ax3 = fig.add_subplot(1, 3, 3) ax3.set_title("Representant selection") @@ -449,8 +589,10 @@ def draw_transp_colored_qGW( idx = np.argwhere(part1 == elem)[:, 0] ax3.scatter(X[idx, 0], X[idx, 1], color="C%s" % i, alpha=FX[idx], s=10) rep_idx = rep_indices1[i] - ax3.scatter([X[rep_idx, 0]], [X[rep_idx, 1]], color="C%s" % i, alpha=1, s=6 * s, marker='*') -plt.axis('off') + ax3.scatter( + [X[rep_idx, 0]], [X[rep_idx, 1]], color="C%s" % i, alpha=1, s=6 * s, marker="*" + ) +plt.axis("off") plt.tight_layout() plt.show() @@ -460,26 +602,34 @@ def draw_transp_colored_qGW( ax4 = fig.add_subplot(1, 3, 1, projection="3d") ax4.set_title("3D curve") -ax4.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c='C0', alpha=FY, s=s) -plt.axis('off') +ax4.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c="C0", alpha=FY, s=s) +plt.axis("off") ax5 = fig.add_subplot(1, 3, 2, projection="3d") ax5.set_title("Partitioning") for i, elem in enumerate(upart2): idx = np.argwhere(part2 == elem)[:, 0] - color = 'C%s' % (start_color + i) + color = "C%s" % (start_color + i) ax5.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s) -plt.axis('off') +plt.axis("off") ax6 = fig.add_subplot(1, 3, 3, projection="3d") ax6.set_title("Representant selection") for i, elem in enumerate(upart2): idx = np.argwhere(part2 == elem)[:, 0] - color = 'C%s' % (start_color + i) + color = "C%s" % (start_color + i) rep_idx = rep_indices2[i] ax6.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s) - ax6.scatter([Y[rep_idx, 0]], [Y[rep_idx, 1]], [Y[rep_idx, 2]], c=color, alpha=1, s=6 * s, marker='*') -plt.axis('off') + ax6.scatter( + [Y[rep_idx, 0]], + [Y[rep_idx, 1]], + [Y[rep_idx, 2]], + c=color, + alpha=1, + s=6 * s, + marker="*", + ) +plt.axis("off") plt.tight_layout() plt.show() @@ -494,21 +644,31 @@ def draw_transp_colored_qGW( # the K-means algorithm before computing qFGW. T_global, Ts_local, T, log = quantized_fused_gromov_wasserstein_samples( - X, Y, 4, 4, p=None, q=None, F1=FX[:, None], F2=FY[:, None], alpha=0.5, - method='kmeans', log=True) + X, + Y, + 4, + 4, + p=None, + q=None, + F1=FX[:, None], + F2=FY[:, None], + alpha=0.5, + method="kmeans", + log=True, +) # Plot low rank GW with different ranks pl.figure(5, figsize=(6, 3)) pl.subplot(1, 2, 1) -pl.title('OT between distributions') +pl.title("OT between distributions") pl.imshow(T, interpolation="nearest", aspect="auto") pl.colorbar() -pl.axis('off') +pl.axis("off") pl.subplot(1, 2, 2) -pl.title('OT between representants') +pl.title("OT between representants") pl.imshow(T_global, interpolation="nearest", aspect="auto") -pl.axis('off') +pl.axis("off") pl.colorbar() pl.tight_layout() diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py index 579f23d3b..22015e162 100644 --- a/examples/gromov/plot_semirelaxed_fgw.py +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -27,7 +27,12 @@ import numpy as np import matplotlib.pylab as pl -from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +from ot.gromov import ( + semirelaxed_gromov_wasserstein, + semirelaxed_fused_gromov_wasserstein, + gromov_wasserstein, + fused_gromov_wasserstein, +) import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -39,11 +44,8 @@ N2 = 20 # 2 communities N3 = 30 # 3 communities -p2 = [[1., 0.1], - [0.1, 0.9]] -p3 = [[1., 0.1, 0.], - [0.1, 0.95, 0.1], - [0., 0.1, 0.9]] +p2 = [[1.0, 0.1], [0.1, 0.9]] +p3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]] G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) @@ -57,11 +59,11 @@ # Add weights on the edges for visualization later on weight_intra_G2 = 5 weight_inter_G2 = 0.5 -weight_intra_G3 = 1. +weight_intra_G3 = 1.0 weight_inter_G3 = 1.5 weightedG2 = networkx.Graph() -part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G2 = [G2.nodes[i]["block"] for i in range(N2)] for node in G2.nodes(): weightedG2.add_node(node) @@ -72,7 +74,7 @@ weightedG2.add_edge(i, j, weight=weight_inter_G2) weightedG3 = networkx.Graph() -part_G3 = [G3.nodes[i]['block'] for i in range(N3)] +part_G3 = [G3.nodes[i]["block"] for i in range(N3)] for node in G3.nodes(): weightedG3.add_node(node) @@ -89,22 +91,24 @@ # 0) GW(C2, h2, C3, h3) for reference OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) -gw = log['gw_dist'] +gw = log["gw_dist"] # 1) srGW(C2, h2, C3) -OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True, - log=True, G0=None) -srgw_23 = log_23['srgw_dist'] +OT_23, log_23 = semirelaxed_gromov_wasserstein( + C2, C3, h2, symmetric=True, log=True, G0=None +) +srgw_23 = log_23["srgw_dist"] # 2) srGW(C3, h3, C2) -OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None, - log=True, G0=OT.T) -srgw_32 = log_32['srgw_dist'] +OT_32, log_32 = semirelaxed_gromov_wasserstein( + C3, C2, h3, symmetric=None, log=True, G0=OT.T +) +srgw_32 = log_32["srgw_dist"] -print('GW(C2, C3) = ', gw) -print('srGW(C2, h2, C3) = ', srgw_23) -print('srGW(C3, h3, C2) = ', srgw_32) +print("GW(C2, C3) = ", gw) +print("srGW(C2, h2, C3) = ", srgw_23) +print("srGW(C3, h3, C2) = ", srgw_32) ############################################################################# @@ -116,12 +120,19 @@ # based on the optimal transport plan from the srGW matching -def draw_graph(G, C, nodes_color_part, Gweights=None, - pos=None, edge_color='black', node_size=None, - shiftx=0, seed=0): - - if (pos is None): - pos = networkx.spring_layout(G, scale=1., seed=seed) +def draw_graph( + G, + C, + nodes_color_part, + Gweights=None, + pos=None, + edge_color="black", + node_size=None, + shiftx=0, + seed=0, +): + if pos is None: + pos = networkx.spring_layout(G, scale=1.0, seed=seed) if shiftx != 0: for k, v in pos.items(): @@ -130,7 +141,9 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, alpha_edge = 0.7 width_edge = 1.8 if Gweights is None: - networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + networkx.draw_networkx_edges( + G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color + ) else: # We make more visible connections between activated nodes n = len(Gweights) @@ -143,36 +156,69 @@ def draw_graph(G, C, nodes_color_part, Gweights=None, elif C[i, j] > 0: edgelist_deactivated.append((i, j)) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, - width=width_edge, alpha=alpha_edge, - edge_color=edge_color) - networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, - width=width_edge, alpha=0.1, - edge_color=edge_color) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_activated, + width=width_edge, + alpha=alpha_edge, + edge_color=edge_color, + ) + networkx.draw_networkx_edges( + G, + pos, + edgelist=edgelist_deactivated, + width=width_edge, + alpha=0.1, + edge_color=edge_color, + ) if Gweights is None: for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=node_size, alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=node_size, + alpha=1, + node_color=node_color, + ) else: scaled_Gweights = Gweights / (0.5 * Gweights.max()) nodes_size = node_size * scaled_Gweights for node, node_color in enumerate(nodes_color_part): - networkx.draw_networkx_nodes(G, pos, nodelist=[node], - node_size=nodes_size[node], alpha=1, - node_color=node_color) + networkx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=nodes_size[node], + alpha=1, + node_color=node_color, + ) return pos -def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, - p1, p2, T, pos1=None, pos2=None, - shiftx=4, switchx=False, node_size=70, - seed_G1=0, seed_G2=0): +def draw_transp_colored_srGW( + G1, + C1, + G2, + C2, + part_G1, + p1, + p2, + T, + pos1=None, + pos2=None, + shiftx=4, + switchx=False, + node_size=70, + seed_G1=0, + seed_G2=0, +): starting_color = 0 # get graphs partition and their coloring part1 = part_G1.copy() - unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + unique_colors = ["C%s" % (starting_color + i) for i in np.unique(part1)] nodes_color_part1 = [] for cluster in part1: nodes_color_part1.append(unique_colors[cluster]) @@ -182,17 +228,37 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, for i in range(len(G2.nodes())): j = np.argmax(T[:, i]) nodes_color_part2.append(nodes_color_part1[j]) - pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, - pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) - pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, - node_size=node_size, shiftx=shiftx, seed=seed_G2) + pos1 = draw_graph( + G1, + C1, + nodes_color_part1, + Gweights=p1, + pos=pos1, + node_size=node_size, + shiftx=0, + seed=seed_G1, + ) + pos2 = draw_graph( + G2, + C2, + nodes_color_part2, + Gweights=p2, + pos=pos2, + node_size=node_size, + shiftx=shiftx, + seed=seed_G2, + ) for k1, v1 in pos1.items(): for k2, v2 in pos2.items(): - if (T[k1, k2] > 0): - pl.plot([pos1[k1][0], pos2[k2][0]], - [pos1[k1][1], pos2[k2][1]], - '-', lw=0.8, alpha=0.5, - color=nodes_color_part1[k1]) + if T[k1, k2] > 0: + pl.plot( + [pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + "-", + lw=0.8, + alpha=0.5, + color=nodes_color_part1[k1], + ) return pos1, pos2 @@ -204,21 +270,51 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, pl.figure(1, figsize=(8, 2.5)) pl.clf() pl.subplot(121) -pl.axis('off') +pl.axis("off") pl.axis -pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) +pl.title( + r"srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$" % (np.round(srgw_23, 3)), + fontsize=fontsize, +) hbar2 = OT_23.sum(axis=0) pos1, pos2 = draw_transp_colored_srGW( - weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, - shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=None, + p2=hbar2, + T=OT_23, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G2, + seed_G2=seed_G3, +) pl.subplot(122) -pl.axis('off') +pl.axis("off") hbar3 = OT_32.sum(axis=0) -pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pl.title( + r"srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$" % (np.round(srgw_32, 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_srGW( - weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, - pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG3, + C3, + weightedG2, + C2, + part_G3, + p1=None, + p2=hbar3, + T=OT_32, + pos1=pos2, + pos2=pos1, + shiftx=3.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.tight_layout() pl.show() @@ -237,7 +333,7 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, F3 = np.zeros((N3, 1)) for i, c in enumerate(part_G3): - F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01) ############################################################################# # @@ -246,28 +342,31 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, alpha = 0.5 # Compute pairwise euclidean distance between node features -M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) +M = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T) # 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference OT, log = fused_gromov_wasserstein( - M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) -fgw = log['fgw_dist'] + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True +) +fgw = log["fgw_dist"] # 1) srFGW(C2, F2, h2, C3, F3) OT_23, log_23 = semirelaxed_fused_gromov_wasserstein( - M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None) -srfgw_23 = log_23['srfgw_dist'] + M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None +) +srfgw_23 = log_23["srfgw_dist"] # 2) srFGW(C3, F3, h3, C2, F2) OT_32, log_32 = semirelaxed_fused_gromov_wasserstein( - M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None) -srfgw_32 = log_32['srfgw_dist'] + M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None +) +srfgw_32 = log_32["srfgw_dist"] -print('FGW(C2, F2, C3, F3) = ', fgw) -print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) -print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) +print("FGW(C2, F2, C3, F3) = ", fgw) +print("srGW(C2, F2, h2, C3, F3) = ", srfgw_23) +print("srGW(C3, F3, h3, C2, F2) = ", srfgw_32) ############################################################################# # @@ -281,21 +380,53 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, pl.figure(2, figsize=(8, 2.5)) pl.clf() pl.subplot(121) -pl.axis('off') +pl.axis("off") pl.axis -pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) +pl.title( + r"srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$" + % (np.round(srfgw_23, 3)), + fontsize=fontsize, +) hbar2 = OT_23.sum(axis=0) pos1, pos2 = draw_transp_colored_srGW( - weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, - shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + weightedG2, + C2, + weightedG3, + C3, + part_G2, + p1=None, + p2=hbar2, + T=OT_23, + shiftx=1.5, + node_size=node_size, + seed_G1=seed_G2, + seed_G2=seed_G3, +) pl.subplot(122) -pl.axis('off') +pl.axis("off") hbar3 = OT_32.sum(axis=0) -pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pl.title( + r"srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$" + % (np.round(srfgw_32, 3)), + fontsize=fontsize, +) pos1, pos2 = draw_transp_colored_srGW( - weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, - pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) + weightedG3, + C3, + weightedG2, + C2, + part_G3, + p1=None, + p2=hbar3, + T=OT_32, + pos1=pos2, + pos2=pos1, + shiftx=3.0, + node_size=node_size, + seed_G1=0, + seed_G2=0, +) pl.tight_layout() pl.show() diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py index e555d1e70..232da0a56 100644 --- a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -43,8 +43,7 @@ import numpy as np import matplotlib.pylab as pl from sklearn.manifold import MDS -from ot.gromov import ( - semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) +from ot.gromov import semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters import ot import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -83,7 +82,7 @@ sizes = [n_nodes] G = sbm(sizes, P, seed=i, directed=False) - part = np.array([G.nodes[i]['block'] for i in range(np.sum(sizes))]) + part = np.array([G.nodes[i]["block"] for i in range(np.sum(sizes))]) C = networkx.to_numpy_array(G) dataset.append(C) node_labels.append(part) @@ -92,16 +91,23 @@ # Visualize samples -def plot_graph(x, C, binary=True, color='C0', s=None): + +def plot_graph(x, C, binary=True, color="C0", s=None): for j in range(C.shape[0]): for i in range(j): if binary: if C[i, j] > 0: - pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.plot( + [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k" + ) else: # connection intensity proportional to C[i,j] - pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + pl.plot( + [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color="k" + ) - pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + pl.scatter( + x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k", cmap="tab10", vmax=9 + ) pl.figure(1, (12, 8)) @@ -109,14 +115,14 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters # get 2d position for nodes - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) - pl.title('(graph) sample from label ' + str(c), fontsize=14) - plot_graph(x, C, binary=True, color='C0', s=50.) + pl.title("(graph) sample from label " + str(c), fontsize=14) + plot_graph(x, C, binary=True, color="C0", s=50.0) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) - pl.title('(matrix) sample from label %s \n' % c, fontsize=14) - pl.imshow(C, interpolation='nearest') + pl.title("(matrix) sample from label %s \n" % c, fontsize=14) + pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() @@ -129,7 +135,7 @@ def plot_graph(x, C, binary=True, color='C0', s=None): np.random.seed(0) ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes -lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter +lambdas = [1.0 / n_samples for _ in range(n_samples)] # uniform barycenter N = 3 # 3 nodes in the barycenter # Here we use the Fluid partitioning method to deduce initial transport plans @@ -137,41 +143,71 @@ def plot_graph(x, C, binary=True, color='C0', s=None): # initial transport plans. Then a warmstart strategy is used iteratively to # init each individual srGW problem within the BCD algorithm. -init_plan = 'fluid' # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan` +init_plan = "fluid" # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan` warmstartT = True C, log = semirelaxed_gromov_barycenters( - N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss', - tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True, - G0=init_plan, verbose=False) - -print('barycenter structure:', C) - -unmixings = log['p'] + N=N, + Cs=dataset, + ps=ps, + lambdas=lambdas, + loss_fun="square_loss", + tol=1e-6, + stop_criterion="loss", + warmstartT=warmstartT, + log=True, + G0=init_plan, + verbose=False, +) + +print("barycenter structure:", C) + +unmixings = log["p"] # Compute the 2D representation of the embeddings living in the 2-simplex of probability unmixings2D = np.zeros(shape=(n_samples, 2)) for i, w in enumerate(unmixings): - unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. - unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. -x = [0., 0.] -y = [1., 0.] -z = [0.5, np.sqrt(3) / 2.] + unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 + unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 +x = [0.0, 0.0] +y = [1.0, 0.0] +z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(2, (4, 4)) pl.clf() -pl.title('Embedding space', fontsize=14) +pl.title("Embedding space", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=80.0, + label="1 cluster", + ) else: - pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) -pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') -pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) -pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) -pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) -pl.axis('off') + pl.scatter( + unmixings2D[start:end, 0], + unmixings2D[start:end, 1], + c="C" + str(cluster), + marker="o", + s=80.0, + label="%s clusters" % (cluster + 1), + ) +pl.scatter( + extremities[:, 0], + extremities[:, 1], + c="black", + marker="x", + s=100.0, + label="bary. nodes", +) +pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) +pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) +pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) +pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() @@ -187,7 +223,7 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for i in range(len(dataset)): n = dataset[i].shape[0] F = np.zeros((n, 3)) - F[np.arange(n), node_labels[i]] = 1. + F[np.arange(n), node_labels[i]] = 1.0 dataset_features.append(F) pl.figure(3, (12, 8)) @@ -195,16 +231,16 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters F = dataset_features[(c - 1) * Nc] - colors = [f'C{labels[i]}' for i in range(F.shape[0])] + colors = [f"C{labels[i]}" for i in range(F.shape[0])] # get 2d position for nodes - x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) - pl.title('(graph) sample from label ' + str(c), fontsize=14) + pl.title("(graph) sample from label " + str(c), fontsize=14) plot_graph(x, C, binary=True, color=colors, s=50) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) - pl.title('(matrix) sample from label %s \n' % c, fontsize=14) - pl.imshow(C, interpolation='nearest') + pl.title("(matrix) sample from label %s \n" % c, fontsize=14) + pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() @@ -222,45 +258,76 @@ def plot_graph(x, C, binary=True, color='C0', s=None): list_unmixings2D = [] for ialpha, alpha in enumerate(list_alphas): - print('--- alpha:', alpha) + print("--- alpha:", alpha) C, F, log = semirelaxed_fgw_barycenters( - N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas, - alpha=alpha, loss_fun='square_loss', tol=1e-6, stop_criterion='loss', - warmstartT=warmstartT, log=True, G0=init_plan) - - print('barycenter structure:', C) - print('barycenter features:', F) - - unmixings = log['p'] + N=N, + Ys=dataset_features, + Cs=dataset, + ps=ps, + lambdas=lambdas, + alpha=alpha, + loss_fun="square_loss", + tol=1e-6, + stop_criterion="loss", + warmstartT=warmstartT, + log=True, + G0=init_plan, + ) + + print("barycenter structure:", C) + print("barycenter features:", F) + + unmixings = log["p"] # Compute the 2D representation of the embeddings living in the 2-simplex of probability unmixings2D = np.zeros(shape=(n_samples, 2)) for i, w in enumerate(unmixings): - unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. - unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. + unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 + unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 list_unmixings2D.append(unmixings2D.copy()) -x = [0., 0.] -y = [1., 0.] -z = [0.5, np.sqrt(3) / 2.] +x = [0.0, 0.0] +y = [1.0, 0.0] +z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(4, (12, 4)) pl.clf() -pl.suptitle('Embedding spaces', fontsize=14) +pl.suptitle("Embedding spaces", fontsize=14) for ialpha, alpha in enumerate(list_alphas): pl.subplot(1, len(list_alphas), ialpha + 1) - pl.title(f'alpha = {alpha}', fontsize=14) + pl.title(f"alpha = {alpha}", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: - pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + pl.scatter( + list_unmixings2D[ialpha][start:end, 0], + list_unmixings2D[ialpha][start:end, 1], + c="C" + str(cluster), + marker="o", + s=80.0, + label="1 cluster", + ) else: - pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) - pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') - pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) - pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) - pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) - pl.axis('off') + pl.scatter( + list_unmixings2D[ialpha][start:end, 0], + list_unmixings2D[ialpha][start:end, 1], + c="C" + str(cluster), + marker="o", + s=80.0, + label="%s clusters" % (cluster + 1), + ) + pl.scatter( + extremities[:, 0], + extremities[:, 1], + c="black", + marker="x", + s=100.0, + label="bary. nodes", + ) + pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) + pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) + pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) + pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py index 98c1ce146..57c963ab8 100644 --- a/examples/others/plot_COOT.py +++ b/examples/others/plot_COOT.py @@ -36,14 +36,14 @@ sigma = 0.2 X1 = ( - np.cos(np.arange(n1) * np.pi / n1)[:, None] + - np.cos(np.arange(d1) * np.pi / d1)[None, :] + - sigma * np.random.randn(n1, d1) + np.cos(np.arange(n1) * np.pi / n1)[:, None] + + np.cos(np.arange(d1) * np.pi / d1)[None, :] + + sigma * np.random.randn(n1, d1) ) X2 = ( - np.cos(np.arange(n2) * np.pi / n2)[:, None] + - np.cos(np.arange(d2) * np.pi / d2)[None, :] + - sigma * np.random.randn(n2, d2) + np.cos(np.arange(n2) * np.pi / n2)[:, None] + + np.cos(np.arange(d2) * np.pi / d2)[None, :] + + sigma * np.random.randn(n2, d2) ) # %% @@ -52,7 +52,7 @@ pl.figure(1, (8, 5)) pl.subplot(1, 2, 1) pl.imshow(X1) -pl.title('$X_1$') +pl.title("$X_1$") pl.subplot(1, 2, 2) pl.imshow(X2) @@ -65,14 +65,14 @@ pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True) coot_distance = coot2(X1, X2) -print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance)) +print("CO-Optimal Transport distance = {:.5f}".format(coot_distance)) fig = pl.figure(4, (9, 7)) pl.clf() ax1 = pl.subplot(2, 2, 3) pl.imshow(X1) -pl.xlabel('$X_1$') +pl.xlabel("$X_1$") ax2 = pl.subplot(2, 2, 2) ax2.yaxis.tick_right() @@ -82,16 +82,18 @@ for i in range(n1): j = np.argmax(pi_sample[i, :]) - xyA = (d1 - .5, i) - xyB = (j, d2 - .5) - con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, - coordsB=ax2.transData, color="black") + xyA = (d1 - 0.5, i) + xyB = (j, d2 - 0.5) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="black" + ) fig.add_artist(con) for i in range(d1): j = np.argmax(pi_feature[i, :]) - xyA = (i, -.5) - xyB = (-.5, j) + xyA = (i, -0.5) + xyB = (-0.5, j) con = ConnectionPatch( - xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue" + ) fig.add_artist(con) diff --git a/examples/others/plot_EWCA.py b/examples/others/plot_EWCA.py index fb9bd713f..af9192537 100644 --- a/examples/others/plot_EWCA.py +++ b/examples/others/plot_EWCA.py @@ -61,8 +61,8 @@ s=50, ) pl.scatter( - X[n_samples // 2:, 0], - X[n_samples // 2:, 1], + X[n_samples // 2 :, 0], + X[n_samples // 2 :, 1], color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)], alpha=0.4, label="Class 2", diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 8b6db31ba..7742d496e 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -38,51 +38,63 @@ eps = 0.1 m_s = np.array([[1], [2]]) m_t = np.array([[3], [4.2], [5]]) -C_s = np.array([[[.05]], [[.06]]]) -C_t = np.array([[[.03]], [[.07]], [[.04]]]) -w_s = np.array([.4, .6]) -w_t = np.array([.4, .2, .4]) +C_s = np.array([[[0.05]], [[0.06]]]) +C_t = np.array([[[0.03]], [[0.07]], [[0.04]]]) +w_s = np.array([0.4, 0.6]) +w_t = np.array([0.4, 0.2, 0.4]) n = 500 a_x, b_x = 0, 3 x = np.linspace(a_x, b_x, n) a_y, b_y = 2, 6 y = np.linspace(a_y, b_y, n) -plan_density = gmm_ot_plan_density(x[:, None], y[:, None], - m_s, m_t, C_s, C_t, w_s, w_t, - plan=None, atol=2e-2) +plan_density = gmm_ot_plan_density( + x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2 +) a = gmm_pdf(x[:, None], m_s, C_s, w_s) b = gmm_pdf(y[:, None], m_t, C_t, w_t) plt.figure(figsize=(8, 8)) -plot1D_mat(a, b, plan_density, title='GMM OT plan', plot_style='xy', - a_label='Source distribution', b_label='Target distribution') +plot1D_mat( + a, + b, + plan_density, + title="GMM OT plan", + plot_style="xy", + a_label="Source distribution", + b_label="Target distribution", +) ############################################################################## # Generate GMMOT maps and plot them over plan # ------------------------------------------- plt.figure(figsize=(8, 8)) -ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, plot_style='xy', - title='GMM OT plan with T_mean and T_rand maps', - a_label='Source distribution', - b_label='Target distribution') -T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, - w_s, w_t, method='bary')[:, 0] -x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, - a_y=a_y, b_y=b_y) +ax_s, ax_t, ax_M = plot1D_mat( + a, + b, + plan_density, + plot_style="xy", + title="GMM OT plan with T_mean and T_rand maps", + a_label="Source distribution", + b_label="Target distribution", +) +T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="bary")[:, 0] +x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, a_y=a_y, b_y=b_y) -ax_M.plot(x_rescaled, T_mean_rescaled, label='T_mean', alpha=.5, - linewidth=5, color='aqua') +ax_M.plot( + x_rescaled, T_mean_rescaled, label="T_mean", alpha=0.5, linewidth=5, color="aqua" +) -T_rand = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, - w_s, w_t, method='rand', seed=0)[:, 0] -x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, - a_y=a_y, b_y=b_y) +T_rand = gmm_ot_apply_map( + x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="rand", seed=0 +)[:, 0] +x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, a_y=a_y, b_y=b_y) -ax_M.scatter(x_rescaled, T_rand_rescaled, label='T_rand', alpha=.5, - s=20, color='orange') +ax_M.scatter( + x_rescaled, T_rand_rescaled, label="T_rand", alpha=0.5, s=20, color="orange" +) -ax_M.legend(loc='upper left', fontsize=13) +ax_M.legend(loc="upper left", fontsize=13) # %% diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 8cff1cc42..beb675755 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -54,26 +54,34 @@ w_t = torch.tensor(ot.unif(kt)) -def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5): - +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5): def eigsorted(cov): + if torch.is_tensor(cov): + cov = cov.detach().numpy() vals, vecs = np.linalg.eigh(cov) - order = vals.argsort()[::-1] + order = vals.argsort()[::-1].copy() return vals[order], vecs[:, order] vals, vecs = eigsorted(C) theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) w, h = 2 * nstd * np.sqrt(vals) - ell = Ellipse(xy=(mu[0], mu[1]), - width=w, height=h, alpha=alpha, - angle=theta, facecolor=color, edgecolor=color, label=label, fill=True) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=alpha, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) pl.gca().add_artist(ell) -def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1): for k in range(ms.shape[0]): - draw_cov(ms[k], Cs[k], color, None, nstd, - alpha * ws[k]) + draw_cov(ms[k], Cs[k], color, None, nstd, alpha * ws[k]) axis = [-3, 3, -3, 3] @@ -81,18 +89,16 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): pl.clf() pl.subplot(1, 2, 1) -pl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0') -draw_gmm(m_s.detach(), C_s.detach(), - torch.softmax(w_s, 0).detach().numpy(), - color='C0') +pl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color="C0") +draw_gmm(m_s.detach(), C_s.detach(), torch.softmax(w_s, 0).detach().numpy(), color="C0") pl.axis(axis) -pl.title('Source GMM') +pl.title("Source GMM") pl.subplot(1, 2, 2) -pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1') -draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1') +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color="C1") +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color="C1") pl.axis(axis) -pl.title('Target GMM') +pl.title("Target GMM") ############################################################################## # Gradient descent loop @@ -100,9 +106,13 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): n_gd_its = 100 lr = 3e-2 -opt = Adam([{'params': m_s, 'lr': 2 * lr}, - {'params': C_s, 'lr': lr}, - {'params': w_s, 'lr': lr}]) +opt = Adam( + [ + {"params": m_s, "lr": 2 * lr}, + {"params": C_s, "lr": lr}, + {"params": w_s, "lr": lr}, + ] +) m_list = [m_s.data.numpy().copy()] C_list = [C_s.data.numpy().copy()] w_list = [torch.softmax(w_s, 0).data.numpy().copy()] @@ -110,8 +120,7 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): for _ in range(n_gd_its): opt.zero_grad() - loss = gmm_ot_loss(m_s, m_t, C_s, C_t, - torch.softmax(w_s, 0), w_t) + loss = gmm_ot_loss(m_s, m_t, C_s, C_t, torch.softmax(w_s, 0), w_t) loss.backward() opt.step() with torch.no_grad(): @@ -124,9 +133,9 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): pl.figure(2) pl.clf() pl.plot(loss_list) -pl.title('Loss') -pl.xlabel('its') -pl.ylabel('loss') +pl.title("Loss") +pl.xlabel("its") +pl.ylabel("loss") ############################################################################## @@ -136,18 +145,18 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): axis = [-3, 3, -3, 3] pl.figure(3, (10, 10)) pl.clf() -pl.title('GMM flow, last step') -pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0', label='Source') -draw_gmm(m_list[0], C_list[0], w_list[0], color='C0') +pl.title("GMM flow, last step") +pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color="C0", label="Source") +draw_gmm(m_list[0], C_list[0], w_list[0], color="C0") pl.axis(axis) -pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1', label='Target') -draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1') +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color="C1", label="Target") +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color="C1") pl.axis(axis) k = -1 -pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2', alpha=1, label='Last step') -draw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1) +pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color="C2", alpha=1, label="Last step") +draw_gmm(m_list[k], C_list[k], w_list[0], color="C2", alpha=1) pl.axis(axis) pl.legend(fontsize=15) @@ -163,27 +172,32 @@ def index_to_color(i): n_steps_visu = 100 pl.figure(3, (10, 10)) pl.clf() -pl.title('GMM flow, all steps') +pl.title("GMM flow, all steps") its_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)] -cmp = cm['plasma'].resampled(index_to_color(n_steps_visu)) +cmp = cm["plasma"].resampled(index_to_color(n_steps_visu)) -pl.scatter(m_list[0][:, 0], m_list[0][:, 1], - color=cmp(index_to_color(0)), label='Source') -draw_gmm(m_list[0], C_list[0], w_list[0], - color=cmp(index_to_color(0))) +pl.scatter( + m_list[0][:, 0], m_list[0][:, 1], color=cmp(index_to_color(0)), label="Source" +) +draw_gmm(m_list[0], C_list[0], w_list[0], color=cmp(index_to_color(0))) -pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), - color=cmp(index_to_color(n_steps_visu - 1)), label='Target') -draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), - color=cmp(index_to_color(n_steps_visu - 1))) +pl.scatter( + m_t[:, 0].detach(), + m_t[:, 1].detach(), + color=cmp(index_to_color(n_steps_visu - 1)), + label="Target", +) +draw_gmm( + m_t.detach(), C_t.detach(), w_t.numpy(), color=cmp(index_to_color(n_steps_visu - 1)) +) for k in its_to_show: - pl.scatter(m_list[k][:, 0], m_list[k][:, 1], - color=cmp(index_to_color(k)), alpha=0.8) - draw_gmm(m_list[k], C_list[k], w_list[0], - color=cmp(index_to_color(k)), alpha=0.04) + pl.scatter( + m_list[k][:, 0], m_list[k][:, 1], color=cmp(index_to_color(k)), alpha=0.8 + ) + draw_gmm(m_list[k], C_list[k], w_list[0], color=cmp(index_to_color(k)), alpha=0.04) pl.axis(axis) pl.legend(fontsize=15) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index ccb241162..fbc343a8a 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -55,30 +55,38 @@ Xs_classes = (Xs[:, 0] < 0).astype(int) Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1) -plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0') -plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1') -plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') -plt.axis('equal') -plt.title('Splitting sphere dataset') -plt.legend(loc='upper right') +plt.scatter( + Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c="blue", label="source class 0" +) +plt.scatter( + Xs[Xs_classes == 1, 0], + Xs[Xs_classes == 1, 1], + c="dodgerblue", + label="source class 1", +) +plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target") +plt.axis("equal") +plt.title("Splitting sphere dataset") +plt.legend(loc="upper right") plt.show() # %% # Fitting the Nearest Brenier Potential L = 3 # need L > 2 to allow the 2*y term, default is 1.4 -phi, G = ot.mapping.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', - gradient_lipschitz_constant=L) +phi, G = ot.mapping.nearest_brenier_potential_fit( + Xs, Xt, Xs_classes, its=10, init_method="barycentric", gradient_lipschitz_constant=L +) # %% # Plotting the images of the source data plt.clf() -plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') -plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source") +plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target") for i in range(n_fitting_samples): - plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color='black', alpha=.5) -plt.title('Images of in-data source samples by the fitted SSNB') -plt.legend(loc='upper right') -plt.axis('equal') + plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color="black", alpha=0.5) +plt.title("Images of in-data source samples by the fitted SSNB") +plt.legend(loc="upper right") +plt.axis("equal") plt.show() # %% @@ -86,29 +94,34 @@ n_predict_samples = 50 Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2)) Ys_classes = (Ys[:, 0] < 0).astype(int) -phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, - gradient_lipschitz_constant=L) +phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds( + Xs, phi, G, Ys, Xs_classes, Ys_classes, gradient_lipschitz_constant=L +) # %% # Plot predictions for the gradient of the lower-bounding potential plt.clf() -plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') -plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source") +plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target") for i in range(n_predict_samples): - plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5) -plt.title('Images of new source samples by $\\nabla \\varphi_l$') -plt.legend(loc='upper right') -plt.axis('equal') + plt.plot( + [Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color="black", alpha=0.5 + ) +plt.title("Images of new source samples by $\\nabla \\varphi_l$") +plt.legend(loc="upper right") +plt.axis("equal") plt.show() # %% # Plot predictions for the gradient of the upper-bounding potential plt.clf() -plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') -plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +plt.scatter(Xs[:, 0], Xs[:, 1], c="dodgerblue", label="source") +plt.scatter(Xt[:, 0], Xt[:, 1], c="red", label="target") for i in range(n_predict_samples): - plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5) -plt.title('Images of new source samples by $\\nabla \\varphi_u$') -plt.legend(loc='upper right') -plt.axis('equal') + plt.plot( + [Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color="black", alpha=0.5 + ) +plt.title("Images of new source samples by $\\nabla \\varphi_u$") +plt.legend(loc="upper right") +plt.axis("equal") plt.show() diff --git a/examples/others/plot_WDA.py b/examples/others/plot_WDA.py index bdfa57dec..f1b9342fa 100644 --- a/examples/others/plot_WDA.py +++ b/examples/others/plot_WDA.py @@ -28,7 +28,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 1000 # nb samples in source and target datasets nz = 0.2 @@ -38,14 +38,12 @@ # generate circle dataset t = np.random.rand(n) * 2 * np.pi ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 -xs = np.concatenate( - (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xs = np.concatenate((np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) t = np.random.rand(n) * 2 * np.pi yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 -xt = np.concatenate( - (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xt = np.concatenate((np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) nbnoise = 8 @@ -57,25 +55,25 @@ # Plot data # --------- -#%% plot samples +# %% plot samples pl.figure(1, figsize=(6.4, 3.5)) pl.subplot(1, 2, 1) -pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker="+", label="Source samples") pl.legend(loc=0) -pl.title('Discriminant dimensions') +pl.title("Discriminant dimensions") pl.subplot(1, 2, 2) -pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') +pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker="+", label="Source samples") pl.legend(loc=0) -pl.title('Other dimensions') +pl.title("Other dimensions") pl.tight_layout() ############################################################################## # Compute Fisher Discriminant Analysis # ------------------------------------ -#%% Compute FDA +# %% Compute FDA p = 2 Pfda, projfda = fda(xs, ys, p) @@ -84,7 +82,7 @@ # Compute Wasserstein Discriminant Analysis # ----------------------------------------- -#%% Compute WDA +# %% Compute WDA p = 2 reg = 1e0 k = 10 @@ -101,7 +99,7 @@ # Plot 2D projections # ------------------- -#%% plot samples +# %% plot samples xsp = projfda(xs) xtp = projfda(xt) @@ -112,24 +110,24 @@ pl.figure(2) pl.subplot(2, 2, 1) -pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') +pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker="+", label="Projected samples") pl.legend(loc=0) -pl.title('Projected training samples FDA') +pl.title("Projected training samples FDA") pl.subplot(2, 2, 2) -pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') +pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker="+", label="Projected samples") pl.legend(loc=0) -pl.title('Projected test samples FDA') +pl.title("Projected test samples FDA") pl.subplot(2, 2, 3) -pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') +pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker="+", label="Projected samples") pl.legend(loc=0) -pl.title('Projected training samples WDA') +pl.title("Projected training samples WDA") pl.subplot(2, 2, 4) -pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') +pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker="+", label="Projected samples") pl.legend(loc=0) -pl.title('Projected test samples WDA') +pl.title("Projected test samples WDA") pl.tight_layout() pl.show() diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py index e3164bad7..16636f1ab 100644 --- a/examples/others/plot_WeakOT_VS_OT.py +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -24,7 +24,7 @@ # Generate data an plot it # ------------------------ -#%% parameters and data generation +# %% parameters and data generation n = 50 # nb samples @@ -32,7 +32,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) @@ -43,28 +43,28 @@ M = ot.dist(xs, xt) M /= M.max() -#%% plot samples +# %% plot samples pl.figure(1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") pl.figure(2) -pl.imshow(M, interpolation='nearest') -pl.title('Cost matrix M') +pl.imshow(M, interpolation="nearest") +pl.title("Cost matrix M") ############################################################################## # Compute Weak OT and exact OT solutions # -------------------------------------- -#%% EMD +# %% EMD G0 = ot.emd(a, b, M) -#%% Weak OT +# %% Weak OT Gweak = ot.weak_optimal_transport(xs, xt, a, b) @@ -76,23 +76,23 @@ pl.figure(3, (8, 5)) pl.subplot(1, 2, 1) -pl.imshow(G0, interpolation='nearest') -pl.title('OT matrix') +pl.imshow(G0, interpolation="nearest") +pl.title("OT matrix") pl.subplot(1, 2, 2) -pl.imshow(Gweak, interpolation='nearest') -pl.title('Weak OT matrix') +pl.imshow(Gweak, interpolation="nearest") +pl.title("Weak OT matrix") pl.figure(4, (8, 5)) pl.subplot(1, 2, 1) -ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.title('OT matrix with samples') +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.title("OT matrix with samples") pl.subplot(1, 2, 2) -ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.title('Weak OT matrix with samples') +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.title("Weak OT matrix with samples") diff --git a/examples/others/plot_dmmot.py b/examples/others/plot_dmmot.py index 1548ba470..a493f38fc 100644 --- a/examples/others/plot_dmmot.py +++ b/examples/others/plot_dmmot.py @@ -32,11 +32,11 @@ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) A = np.vstack((a1, a2)).T x = np.arange(n, dtype=np.float64) -M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') +M = ot.utils.dist(x.reshape((n, 1)), metric="minkowski") pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a1, 'b', label='Source distribution') -pl.plot(x, a2, 'r', label='Target distribution') +pl.plot(x, a1, "b", label="Source distribution") +pl.plot(x, a2, "r", label="Target distribution") pl.legend() # %% @@ -49,21 +49,23 @@ weights = np.ones(d) / d l2_bary = A.dot(weights) -print('LP Iterations:') +print("LP Iterations:") weights = np.ones(d) / d lp_bary, lp_log = ot.lp.barycenter( - A, M, weights, solver='interior-point', verbose=False, log=True) -print('Time\t: ', ot.toc('')) -print('Obj\t: ', lp_log['fun']) + A, M, weights, solver="interior-point", verbose=False, log=True +) +print("Time\t: ", ot.toc("")) +print("Obj\t: ", lp_log["fun"]) -print('') -print('Discrete MMOT Algorithm:') +print("") +print("Discrete MMOT Algorithm:") ot.tic() barys, log = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True) -dmmot_obj = log['primal objective'] -print('Time\t: ', ot.toc('')) -print('Obj\t: ', dmmot_obj) + A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True +) +dmmot_obj = log["primal objective"] +print("Time\t: ", ot.toc("")) +print("Obj\t: ", dmmot_obj) # %% # Compare Barycenters in both methods @@ -71,15 +73,15 @@ pl.figure(1, figsize=(6.4, 3)) for i in range(len(barys)): if i == 0: - pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') + pl.plot(x, barys[i], "g-*", label="Discrete MMOT") else: continue # pl.plot(x, barys[i], 'g-*') -pl.plot(x, lp_bary, label='LP Barycenter') -pl.plot(x, l2_bary, label='L2 Barycenter') -pl.plot(x, a1, 'b', label='Source distribution') -pl.plot(x, a2, 'r', label='Target distribution') -pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver') +pl.plot(x, lp_bary, label="LP Barycenter") +pl.plot(x, l2_bary, label="L2 Barycenter") +pl.plot(x, a1, "b", label="Source distribution") +pl.plot(x, a2, "r", label="Target distribution") +pl.title("Monge Cost: Barycenters from LP Solver and dmmot solver") pl.legend() @@ -98,14 +100,14 @@ data.append(a) x = np.arange(n, dtype=np.float64) -M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski') +M = ot.utils.dist(x.reshape((n, 1)), metric="minkowski") A = np.vstack(data).T pl.figure(1, figsize=(6.4, 3)) for i in range(len(data)): pl.plot(x, data[i]) -pl.title('Distributions') +pl.title("Distributions") pl.legend() # %% @@ -115,8 +117,7 @@ # values cannot be compared. # Perform gradient descent optimization using the d-MMOT method. -barys = ot.lp.dmmot_monge_1dgrid_optimize( - A, niters=3000, lr_init=1e-4, lr_decay=0.997) +barys = ot.lp.dmmot_monge_1dgrid_optimize(A, niters=3000, lr_init=1e-4, lr_decay=0.997) # after minimization, any distribution can be used as a estimate of barycenter. bary = barys[0] @@ -124,17 +125,18 @@ # Compute 1D Wasserstein barycenter using the L2/LP method weights = ot.unif(d) l2_bary = A.dot(weights) -lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point', - verbose=False, log=True) +lp_bary, bary_log = ot.lp.barycenter( + A, M, weights, solver="interior-point", verbose=False, log=True +) # %% # Compare Barycenters in both methods # --------- pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, bary, 'g-*', label='Discrete MMOT') -pl.plot(x, l2_bary, 'k', label='L2 Barycenter') -pl.plot(x, lp_bary, 'k-', label='LP Wasserstein') -pl.title('Barycenters') +pl.plot(x, bary, "g-*", label="Discrete MMOT") +pl.plot(x, l2_bary, "k", label="L2 Barycenter") +pl.plot(x, lp_bary, "k-", label="LP Wasserstein") +pl.title("Barycenters") pl.legend() # %% @@ -145,13 +147,13 @@ pl.plot(x, data[i]) for i in range(len(barys)): if i == 0: - pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') + pl.plot(x, barys[i], "g-*", label="Discrete MMOT") else: continue # pl.plot(x, barys[i], 'g') -pl.plot(x, l2_bary, 'k^', label='L2') -pl.plot(x, lp_bary, 'o', color='grey', label='LP') -pl.title('Barycenters') +pl.plot(x, l2_bary, "k^", label="L2") +pl.plot(x, lp_bary, "o", color="grey", label="LP") +pl.title("Barycenters") pl.legend() pl.show() diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py index 02074d70b..3aaf5fbf5 100644 --- a/examples/others/plot_factored_coupling.py +++ b/examples/others/plot_factored_coupling.py @@ -29,32 +29,32 @@ n = 100 # nb samples -xs = np.random.rand(n, 2) - .5 +xs = np.random.rand(n, 2) - 0.5 xs = xs + np.sign(xs) -xt = np.random.rand(n, 2) - .5 +xt = np.random.rand(n, 2) - 0.5 a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples -#%% plot samples +# %% plot samples pl.figure(1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") # %% # Compute Factored OT and exact OT solutions # ------------------------------------------ -#%% EMD +# %% EMD M = ot.dist(xs, xt) G0 = ot.emd(a, b, M) -#%% factored OT OT +# %% factored OT OT Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) @@ -66,21 +66,21 @@ pl.figure(2, (14, 4)) pl.subplot(1, 3, 1) -ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.title('Exact OT with samples') +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[0.2, 0.2, 0.2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.title("Exact OT with samples") pl.subplot(1, 3, 2) -ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) -ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') -pl.title('Factored OT with template samples') +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[0.6, 0.6, 0.9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[0.9, 0.6, 0.6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.plot(xb[:, 0], xb[:, 1], "og", label="Template samples") +pl.title("Factored OT with template samples") pl.subplot(1, 3, 3) -ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.title('Factored OT low rank OT plan') +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[0.2, 0.2, 0.2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.title("Factored OT low rank OT plan") diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py index b03280114..2710401a0 100644 --- a/examples/others/plot_logo.py +++ b/examples/others/plot_logo.py @@ -1,4 +1,3 @@ - # -*- coding: utf-8 -*- r""" ======================= @@ -29,12 +28,48 @@ # Letter P -p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) -p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) +p1 = np.array( + [ + [0, 6.0], + [0, 5], + [0, 4], + [0, 3], + [0, 2], + [0, 1], + ] +) +p2 = np.array( + [ + [1.5, 6], + [2, 4], + [2, 5], + [1.5, 3], + [0.5, 2], + [0.5, 1], + ] +) # Letter O -o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) -o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) +o1 = np.array( + [ + [0, 6.0], + [-1, 5], + [-1.5, 4], + [-1.5, 3], + [-1, 2], + [0, 1], + ] +) +o2 = np.array( + [ + [1, 6.0], + [2, 5], + [2.5, 4], + [2.5, 3], + [2, 2], + [1, 1], + ] +) # Scaling and translation for letter O o1[:, 0] += 6.4 @@ -43,8 +78,26 @@ o2[:, 0] *= 0.6 # Letter T -t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) -t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) +t1 = np.array( + [ + [-1, 6.0], + [-1, 5], + [0, 4], + [0, 3], + [0, 2], + [0, 1], + ] +) +t2 = np.array( + [ + [1.5, 6.0], + [1.5, 5], + [0.5, 4], + [0.5, 3], + [0.5, 2], + [0.5, 1], + ] +) # Translating the T t1[:, 0] += 7.1 @@ -56,7 +109,7 @@ # Horizontal and vertical scaling sx = 1.0 -sy = .5 +sy = 0.5 x1[:, 0] *= sx x1[:, 1] *= sy x2[:, 0] *= sx @@ -67,7 +120,7 @@ # -------------------------------- # Solve OT problem between the points -M = ot.dist(x1, x2, metric='euclidean') +M = ot.dist(x1, x2, metric="euclidean") T = ot.emd([], [], M) pl.figure(1, (3.5, 1.1)) @@ -76,14 +129,21 @@ for i in range(M.shape[0]): for j in range(M.shape[1]): if T[i, j] > 1e-8: - pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1) + pl.plot( + [x1[i, 0], x2[j, 0]], + [x1[i, 1], x2[j, 1]], + color="k", + alpha=0.6, + linewidth=3, + zorder=1, + ) # plot the samples -pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k') -pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k') +pl.plot(x1[:, 0], x1[:, 1], "o", markerfacecolor="C3", markeredgecolor="k") +pl.plot(x2[:, 0], x2[:, 1], "o", markerfacecolor="b", markeredgecolor="k") -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") # Save logo file # pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight') @@ -93,19 +153,26 @@ # Plot the logo (dark background) # -------------------------------- -pl.figure(2, (3.5, 1.1), facecolor='darkgray') +pl.figure(2, (3.5, 1.1), facecolor="darkgray") pl.clf() # plot the OT plan for i in range(M.shape[0]): for j in range(M.shape[1]): if T[i, j] > 1e-8: - pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1) + pl.plot( + [x1[i, 0], x2[j, 0]], + [x1[i, 1], x2[j, 1]], + color="w", + alpha=0.8, + linewidth=3, + zorder=1, + ) # plot the samples -pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') -pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') +pl.plot(x1[:, 0], x1[:, 1], "o", markerfacecolor="w", markeredgecolor="w") +pl.plot(x2[:, 0], x2[:, 1], "o", markerfacecolor="w", markeredgecolor="w") -pl.axis('equal') -pl.axis('off') +pl.axis("equal") +pl.axis("off") # Save logo file # pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight') diff --git a/examples/others/plot_lowrank_GW.py b/examples/others/plot_lowrank_GW.py index 02fef6ded..ff1929a68 100644 --- a/examples/others/plot_lowrank_GW.py +++ b/examples/others/plot_lowrank_GW.py @@ -4,7 +4,7 @@ Low rank Gromov-Wasterstein between samples ======================================== -Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] +Comparison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67] on two curves in 2D and 3D, both sampled with 200 points. The squared Euclidean distance is considered as the ground cost for both samples. @@ -20,7 +20,7 @@ # # sphinx_gallery_thumbnail_number = 3 -#%% +# %% import numpy as np import matplotlib.pylab as pl import ot.plot @@ -30,7 +30,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n_samples = 200 # Generate 2D and 3D curves @@ -49,20 +49,22 @@ # Plot data # ------------ -#%% +# %% # Plot the source and target samples fig = pl.figure(1, figsize=(10, 4)) ax = fig.add_subplot(121) ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6) -ax.tick_params(left=False, right=False, labelleft=False, - labelbottom=False, bottom=False) +ax.tick_params( + left=False, right=False, labelleft=False, labelbottom=False, bottom=False +) ax.set_title("2D curve (source)") ax2 = fig.add_subplot(122, projection="3d") -ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6) -ax2.tick_params(left=False, right=False, labelleft=False, - labelbottom=False, bottom=False) +ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c="red", linewidth=6) +ax2.tick_params( + left=False, right=False, labelleft=False, labelbottom=False, bottom=False +) ax2.view_init(15, -50) ax2.set_title("3D curve (target)") @@ -74,7 +76,7 @@ # Entropic Gromov-Wasserstein # ------------ -#%% +# %% # Compute cost matrices C1 = ot.dist(X, X, metric="sqeuclidean") @@ -93,13 +95,13 @@ start = time.time() gw, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, tol=1e-3, epsilon=reg, - log=True, verbose=False) + C1, C2, tol=1e-3, epsilon=reg, log=True, verbose=False +) end = time.time() time_entropic = end - start -entropic_gw_loss = np.round(log['gw_dist'], 3) +entropic_gw_loss = np.round(log["gw_dist"], 3) # Plot entropic gw pl.figure(2) @@ -137,8 +139,17 @@ start = time.time() Q, R, g, log = ot.lowrank_gromov_wasserstein_samples( - X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2), - cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6, + X, + Y, + reg=0, + rank=rank, + rescale_cost=False, + cost_factorized_Xs=(A1, A2), + cost_factorized_Xt=(B1, B2), + seed_init=49, + numItermax=1000, + log=True, + stopThr=1e-6, ) end = time.time() @@ -156,11 +167,11 @@ pl.subplot(1, 2, 1) pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto") -pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0])) +pl.title("Low rank GW (rank=10, loss={})".format(list_loss_GW[0])) pl.subplot(1, 2, 2) pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto") -pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1])) +pl.title("Low rank GW (rank=50, loss={})".format(list_loss_GW[1])) pl.tight_layout() pl.show() diff --git a/examples/others/plot_lowrank_sinkhorn.py b/examples/others/plot_lowrank_sinkhorn.py index 0664a829c..f48fc873a 100644 --- a/examples/others/plot_lowrank_sinkhorn.py +++ b/examples/others/plot_lowrank_sinkhorn.py @@ -25,16 +25,20 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 100 m = 120 # Gaussian distribution -a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(n, m=int(5 * n / 6), s=15 / np.sqrt(2)) +a = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss( + n, m=int(5 * n / 6), s=15 / np.sqrt(2) +) a = a / np.sum(a) -b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(m, m=int(m / 2), s=35 / np.sqrt(2)) +b = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss( + m, m=int(m / 2), s=35 / np.sqrt(2) +) b = b / np.sum(b) # Source and target distribution @@ -46,12 +50,23 @@ # Solve Low rank sinkhorn # ------------ -#%% +# %% # Solve low rank sinkhorn -Q, R, g, log = ot.lowrank_sinkhorn(X, Y, a, b, rank=10, init="random", gamma_init="rescale", rescale_cost=True, warn=False, log=True) +Q, R, g, log = ot.lowrank_sinkhorn( + X, + Y, + a, + b, + rank=10, + init="random", + gamma_init="rescale", + rescale_cost=True, + warn=False, + log=True, +) P = log["lazy_plan"][:] -ot.plot.plot1D_mat(a, b, P, 'OT matrix Low rank') +ot.plot.plot1D_mat(a, b, P, "OT matrix Low rank") ############################################################################## @@ -59,7 +74,7 @@ # ----------------------- # Compare Sinkhorn and Low rank sinkhorn with different regularizations and ranks. -#%% Sinkhorn +# %% Sinkhorn # Compute cost matrix for sinkhorn OT M = ot.dist(X, Y) @@ -73,52 +88,52 @@ P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan list_P_Sin.append(P) -#%% Low rank sinkhorn +# %% Low rank sinkhorn # Solve low rank sinkhorn with different ranks using ot.solve_sample list_rank = [3, 10, 50] list_P_LR = [] for rank in list_rank: - P = ot.solve_sample(X, Y, a, b, method='lowrank', rank=rank).plan + P = ot.solve_sample(X, Y, a, b, method="lowrank", rank=rank).plan P = P[:] list_P_LR.append(P) -#%% +# %% # Plot sinkhorn vs low rank sinkhorn pl.figure(1, figsize=(10, 8)) pl.subplot(2, 3, 1) -pl.imshow(list_P_Sin[0], interpolation='nearest') -pl.axis('off') -pl.title('Sinkhorn (reg=0.05)') +pl.imshow(list_P_Sin[0], interpolation="nearest") +pl.axis("off") +pl.title("Sinkhorn (reg=0.05)") pl.subplot(2, 3, 2) -pl.imshow(list_P_Sin[1], interpolation='nearest') -pl.axis('off') -pl.title('Sinkhorn (reg=0.005)') +pl.imshow(list_P_Sin[1], interpolation="nearest") +pl.axis("off") +pl.title("Sinkhorn (reg=0.005)") pl.subplot(2, 3, 3) -pl.imshow(list_P_Sin[2], interpolation='nearest') -pl.axis('off') -pl.title('Sinkhorn (reg=0.001)') +pl.imshow(list_P_Sin[2], interpolation="nearest") +pl.axis("off") +pl.title("Sinkhorn (reg=0.001)") pl.show() pl.subplot(2, 3, 4) -pl.imshow(list_P_LR[0], interpolation='nearest') -pl.axis('off') -pl.title('Low rank (rank=3)') +pl.imshow(list_P_LR[0], interpolation="nearest") +pl.axis("off") +pl.title("Low rank (rank=3)") pl.subplot(2, 3, 5) -pl.imshow(list_P_LR[1], interpolation='nearest') -pl.axis('off') -pl.title('Low rank (rank=10)') +pl.imshow(list_P_LR[1], interpolation="nearest") +pl.axis("off") +pl.title("Low rank (rank=10)") pl.subplot(2, 3, 6) -pl.imshow(list_P_LR[2], interpolation='nearest') -pl.axis('off') -pl.title('Low rank (rank=50)') +pl.imshow(list_P_LR[2], interpolation="nearest") +pl.axis("off") +pl.title("Low rank (rank=50)") pl.tight_layout() diff --git a/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py b/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py index e1f48f724..5273fdbb1 100644 --- a/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py +++ b/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py @@ -5,7 +5,7 @@ ====================================================================================================================================== In this example, we consider two point clouds living in different Euclidean spaces, where the outliers -are artifically injected into the target data. We illustrate two methods which allow to filter out +are artificially injected into the target data. We illustrate two methods which allow to filter out these outliers. The first method requires learning the sample marginal distribution which minimizes @@ -70,15 +70,15 @@ n = 15 X = ( - torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + - torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] + torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + + torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] ) # Generate clean target data mixed with outliers Y_noisy = torch.randn((n, d2)) * 10.0 Y_noisy[:n2, :] = ( - torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + - torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] + torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + + torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] ) Y = Y_noisy[:n2, :] @@ -86,13 +86,13 @@ fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) axes[0].imshow(X, vmin=-2, vmax=2) -axes[0].set_title('$X$') +axes[0].set_title("$X$") axes[1].imshow(Y, vmin=-2, vmax=2) -axes[1].set_title('Clean $Y$') +axes[1].set_title("Clean $Y$") axes[2].imshow(Y_noisy, vmin=-2, vmax=2) -axes[2].set_title('Noisy $Y$') +axes[2].set_title("Noisy $Y$") pl.tight_layout() @@ -107,7 +107,6 @@ b = torch.tensor(ot.unif(n), requires_grad=True) for i in range(niter): - loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) losses.append(float(loss)) @@ -121,7 +120,7 @@ # Estimated sample marginal distribution and training loss curve pl.plot(losses[10:]) -pl.title('CO-Optimal Transport distance') +pl.title("CO-Optimal Transport distance") print(f"Marginal distribution = {b.detach().numpy()}") @@ -141,7 +140,7 @@ ax1 = pl.subplot(2, 2, 3) pl.imshow(X, vmin=-2, vmax=2) -pl.xlabel('$X$') +pl.xlabel("$X$") ax2 = pl.subplot(2, 2, 2) ax2.yaxis.tick_right() @@ -151,18 +150,20 @@ for i in range(n1): j = np.argmax(pi_sample[i, :]) - xyA = (d1 - .5, i) - xyB = (j, d2 - .5) - con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, - coordsB=ax2.transData, color="black") + xyA = (d1 - 0.5, i) + xyB = (j, d2 - 0.5) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="black" + ) fig.add_artist(con) for i in range(d1): j = np.argmax(pi_feature[i, :]) - xyA = (i, -.5) - xyB = (-.5, j) + xyA = (i, -0.5) + xyB = (-0.5, j) con = ConnectionPatch( - xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue" + ) fig.add_artist(con) # %% @@ -171,9 +172,18 @@ # ----------------------------------------------------------------------------------------- pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl", - unbalanced_solver="mm", max_iter=1000, tol=1e-6, - max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False + X=X, + Y=Y_noisy, + reg_marginals=(10, 10), + epsilon=0, + divergence="kl", + unbalanced_solver="mm", + max_iter=1000, + tol=1e-6, + max_iter_ot=1000, + tol_ot=1e-6, + log=False, + verbose=False, ) # %% @@ -187,7 +197,7 @@ ax1 = pl.subplot(2, 2, 3) pl.imshow(X, vmin=-2, vmax=2) -pl.xlabel('$X$') +pl.xlabel("$X$") ax2 = pl.subplot(2, 2, 2) ax2.yaxis.tick_right() @@ -197,16 +207,18 @@ for i in range(n1): j = np.argmax(pi_sample[i, :]) - xyA = (d1 - .5, i) - xyB = (j, d2 - .5) - con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, - coordsB=ax2.transData, color="black") + xyA = (d1 - 0.5, i) + xyB = (j, d2 - 0.5) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="black" + ) fig.add_artist(con) for i in range(d1): j = np.argmax(pi_feature[i, :]) - xyA = (i, -.5) - xyB = (-.5, j) + xyA = (i, -0.5) + xyB = (-0.5, j) con = ConnectionPatch( - xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue" + ) fig.add_artist(con) diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py index 3640b8840..b182f02d7 100644 --- a/examples/others/plot_screenkhorn_1D.py +++ b/examples/others/plot_screenkhorn_1D.py @@ -25,7 +25,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -44,17 +44,17 @@ # Plot distributions and loss matrix # ---------------------------------- -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") pl.legend() # plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +ot.plot.plot1D_mat(a, b, M, "Cost matrix M") ############################################################################## # Solve Screenkhorn @@ -65,7 +65,9 @@ ns_budget = 30 # budget number of points to be kept in the source distribution nt_budget = 30 # budget number of points to be kept in the target distribution -G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +G_screen = screenkhorn( + a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True +) pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +ot.plot.plot1D_mat(a, b, G_screen, "OT matrix Screenkhorn") pl.show() diff --git a/examples/others/plot_stochastic.py b/examples/others/plot_stochastic.py index f3afb0b25..0fa559a78 100644 --- a/examples/others/plot_stochastic.py +++ b/examples/others/plot_stochastic.py @@ -53,8 +53,7 @@ # Call the "SAG" method to find the transportation matrix in the discrete case method = "SAG" -sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax) +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax) print(sag_pi) ############################################################################# @@ -84,9 +83,10 @@ # case. method = "ASGD" -asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, log=log) -print(log_asgd['alpha'], log_asgd['beta']) +asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic( + a, b, M, reg, method, numItermax, log=log +) +print(log_asgd["alpha"], log_asgd["beta"]) print(asgd_pi) ############################################################################# @@ -103,7 +103,7 @@ # For SAG pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') +ot.plot.plot1D_mat(a, b, sag_pi, "semi-dual : OT matrix SAG") pl.show() @@ -111,7 +111,7 @@ # For ASGD pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') +ot.plot.plot1D_mat(a, b, asgd_pi, "semi-dual : OT matrix ASGD") pl.show() @@ -119,7 +119,7 @@ # For Sinkhorn pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn") pl.show() @@ -154,10 +154,10 @@ # Call the "SGD" dual method to find the transportation matrix in the # semi-continuous case -sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, - batch_size, numItermax, - lr, log=log) -print(log_sgd['alpha'], log_sgd['beta']) +sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic( + a, b, M, reg, batch_size, numItermax, lr, log=log +) +print(log_sgd["alpha"], log_sgd["beta"]) print(sgd_dual_pi) ############################################################################# @@ -177,7 +177,7 @@ # For SGD pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') +ot.plot.plot1D_mat(a, b, sgd_dual_pi, "dual : OT matrix SGD") pl.show() @@ -185,5 +185,5 @@ # For Sinkhorn pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +ot.plot.plot1D_mat(a, b, sinkhorn_pi, "OT matrix Sinkhorn") pl.show() diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py index 1c5136051..8eae721f6 100644 --- a/examples/plot_Intro_OT.py +++ b/examples/plot_Intro_OT.py @@ -79,17 +79,17 @@ # # -data = np.load('../data/manhattan.npz') +data = np.load("../data/manhattan.npz") -bakery_pos = data['bakery_pos'] -bakery_prod = data['bakery_prod'] -cafe_pos = data['cafe_pos'] -cafe_prod = data['cafe_prod'] -Imap = data['Imap'] +bakery_pos = data["bakery_pos"] +bakery_prod = data["bakery_prod"] +cafe_pos = data["cafe_pos"] +cafe_prod = data["cafe_prod"] +Imap = data["Imap"] -print('Bakery production: {}'.format(bakery_prod)) -print('Cafe sale: {}'.format(cafe_prod)) -print('Total croissants : {}'.format(cafe_prod.sum())) +print("Bakery production: {}".format(bakery_prod)) +print("Cafe sale: {}".format(cafe_prod)) +print("Total croissants : {}".format(cafe_prod.sum())) ############################################################################## @@ -102,11 +102,13 @@ pl.figure(1, (7, 6)) pl.clf() -pl.imshow(Imap, interpolation='bilinear') # plot the map -pl.scatter(bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c='r', ec='k', label='Bakeries') -pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c='b', ec='k', label='Cafés') +pl.imshow(Imap, interpolation="bilinear") # plot the map +pl.scatter( + bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c="r", ec="k", label="Bakeries" +) +pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c="b", ec="k", label="Cafés") pl.legend() -pl.title('Manhattan Bakeries and Cafés') +pl.title("Manhattan Bakeries and Cafés") ############################################################################## @@ -127,23 +129,39 @@ f = pl.figure(2, (14, 7)) pl.clf() pl.subplot(121) -pl.imshow(Imap, interpolation='bilinear') # plot the map +pl.imshow(Imap, interpolation="bilinear") # plot the map for i in range(len(cafe_pos)): - pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', - fontsize=14, fontweight='bold', ha='center', va='center') + pl.text( + cafe_pos[i, 0], + cafe_pos[i, 1], + labels[i], + color="b", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) for i in range(len(bakery_pos)): - pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', - fontsize=14, fontweight='bold', ha='center', va='center') -pl.title('Manhattan Bakeries and Cafés') + pl.text( + bakery_pos[i, 0], + bakery_pos[i, 1], + labels[i], + color="r", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) +pl.title("Manhattan Bakeries and Cafés") ax = pl.subplot(122) im = pl.imshow(C, cmap="coolwarm") -pl.title('Cost matrix') +pl.title("Cost matrix") cbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True) cbar.ax.set_ylabel("cost", rotation=-90, va="bottom") -pl.xlabel('Cafés') -pl.ylabel('Bakeries') +pl.xlabel("Cafés") +pl.ylabel("Bakeries") pl.tight_layout() @@ -180,29 +198,50 @@ f = pl.figure(3, (14, 7)) pl.clf() pl.subplot(121) -pl.imshow(Imap, interpolation='bilinear') # plot the map +pl.imshow(Imap, interpolation="bilinear") # plot the map for i in range(len(bakery_pos)): for j in range(len(cafe_pos)): - pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], [bakery_pos[i, 1], cafe_pos[j, 1]], - '-k', lw=3. * ot_emd[i, j] / ot_emd.max()) + pl.plot( + [bakery_pos[i, 0], cafe_pos[j, 0]], + [bakery_pos[i, 1], cafe_pos[j, 1]], + "-k", + lw=3.0 * ot_emd[i, j] / ot_emd.max(), + ) for i in range(len(cafe_pos)): - pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', fontsize=14, - fontweight='bold', ha='center', va='center') + pl.text( + cafe_pos[i, 0], + cafe_pos[i, 1], + labels[i], + color="b", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) for i in range(len(bakery_pos)): - pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', fontsize=14, - fontweight='bold', ha='center', va='center') -pl.title('Manhattan Bakeries and Cafés') + pl.text( + bakery_pos[i, 0], + bakery_pos[i, 1], + labels[i], + color="r", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) +pl.title("Manhattan Bakeries and Cafés") ax = pl.subplot(122) im = pl.imshow(ot_emd) for i in range(len(bakery_prod)): for j in range(len(cafe_prod)): - text = ax.text(j, i, '{0:g}'.format(ot_emd[i, j]), - ha="center", va="center", color="w") -pl.title('Transport matrix') + text = ax.text( + j, i, "{0:g}".format(ot_emd[i, j]), ha="center", va="center", color="w" + ) +pl.title("Transport matrix") -pl.xlabel('Cafés') -pl.ylabel('Bakeries') +pl.xlabel("Cafés") +pl.ylabel("Bakeries") pl.tight_layout() ############################################################################## @@ -224,7 +263,7 @@ # W = np.sum(ot_emd * C) -print('Wasserstein loss (EMD) = {0:.2f}'.format(W)) +print("Wasserstein loss (EMD) = {0:.2f}".format(W)) ############################################################################## # Regularized OT with Sinkhorn @@ -255,51 +294,77 @@ reg = 0.1 K = np.exp(-C / C.max() / reg) nit = 100 -u = np.ones((len(bakery_prod), )) +u = np.ones((len(bakery_prod),)) for i in range(1, nit): v = cafe_prod / np.dot(K.T, u) u = bakery_prod / (np.dot(K, v)) -ot_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v))) +ot_sink_algo = np.atleast_2d(u).T * ( + K * v.T +) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v))) # Compute Sinkhorn transport matrix with POT ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max()) # Difference between the 2 -print('Difference between algo and ot.sinkhorn = {0:.2g}'.format(np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2)))) +print( + "Difference between algo and ot.sinkhorn = {0:.2g}".format( + np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2)) + ) +) ############################################################################## # Plot the matrix and the map # ``````````````````````````` -print('Min. of Sinkhorn\'s transport matrix = {0:.2g}'.format(np.min(ot_sinkhorn))) +print("Min. of Sinkhorn's transport matrix = {0:.2g}".format(np.min(ot_sinkhorn))) f = pl.figure(4, (13, 6)) pl.clf() pl.subplot(121) -pl.imshow(Imap, interpolation='bilinear') # plot the map +pl.imshow(Imap, interpolation="bilinear") # plot the map for i in range(len(bakery_pos)): for j in range(len(cafe_pos)): - pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], - [bakery_pos[i, 1], cafe_pos[j, 1]], - '-k', lw=3. * ot_sinkhorn[i, j] / ot_sinkhorn.max()) + pl.plot( + [bakery_pos[i, 0], cafe_pos[j, 0]], + [bakery_pos[i, 1], cafe_pos[j, 1]], + "-k", + lw=3.0 * ot_sinkhorn[i, j] / ot_sinkhorn.max(), + ) for i in range(len(cafe_pos)): - pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', - fontsize=14, fontweight='bold', ha='center', va='center') + pl.text( + cafe_pos[i, 0], + cafe_pos[i, 1], + labels[i], + color="b", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) for i in range(len(bakery_pos)): - pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', - fontsize=14, fontweight='bold', ha='center', va='center') -pl.title('Manhattan Bakeries and Cafés') + pl.text( + bakery_pos[i, 0], + bakery_pos[i, 1], + labels[i], + color="r", + fontsize=14, + fontweight="bold", + ha="center", + va="center", + ) +pl.title("Manhattan Bakeries and Cafés") ax = pl.subplot(122) im = pl.imshow(ot_sinkhorn) for i in range(len(bakery_prod)): for j in range(len(cafe_prod)): - text = ax.text(j, i, np.round(ot_sinkhorn[i, j], 1), - ha="center", va="center", color="w") -pl.title('Transport matrix') + text = ax.text( + j, i, np.round(ot_sinkhorn[i, j], 1), ha="center", va="center", color="w" + ) +pl.title("Transport matrix") -pl.xlabel('Cafés') -pl.ylabel('Bakeries') +pl.xlabel("Cafés") +pl.ylabel("Bakeries") pl.tight_layout() @@ -315,23 +380,25 @@ # reg_parameter = np.logspace(-3, 0, 20) -W_sinkhorn_reg = np.zeros((len(reg_parameter), )) -time_sinkhorn_reg = np.zeros((len(reg_parameter), )) +W_sinkhorn_reg = np.zeros((len(reg_parameter),)) +time_sinkhorn_reg = np.zeros((len(reg_parameter),)) f = pl.figure(5, (14, 5)) pl.clf() max_ot = 100 # plot matrices with the same colorbar for k in range(len(reg_parameter)): start = time.time() - ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max()) + ot_sinkhorn = ot.sinkhorn( + bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max() + ) time_sinkhorn_reg[k] = time.time() - start if k % 4 == 0 and k > 0: # we only plot a few ax = pl.subplot(1, 5, k // 4) im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot) - pl.title('reg={0:.2g}'.format(reg_parameter[k])) - pl.xlabel('Cafés') - pl.ylabel('Bakeries') + pl.title("reg={0:.2g}".format(reg_parameter[k])) + pl.xlabel("Cafés") + pl.ylabel("Bakeries") # Compute the Wasserstein loss for Sinkhorn, and compare with EMD W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C) @@ -355,9 +422,9 @@ pl.clf() pl.title("Comparison between Sinkhorn and EMD") -pl.plot(reg_parameter, W_sinkhorn_reg, 'o', label="Sinkhorn") +pl.plot(reg_parameter, W_sinkhorn_reg, "o", label="Sinkhorn") XLim = pl.xlim() -pl.plot(XLim, [W, W], '--k', label="EMD") +pl.plot(XLim, [W, W], "--k", label="EMD") pl.legend() pl.xlabel("reg") pl.ylabel("Wasserstein loss") diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 22ddd14e5..d6c8c561f 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -26,7 +26,7 @@ # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -46,24 +46,24 @@ # Plot distributions and loss matrix # ---------------------------------- -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") pl.legend() -#%% plot distributions and loss matrix +# %% plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +ot.plot.plot1D_mat(a, b, M, "Cost matrix M") ############################################################################## # Solve EMD # --------- -#%% EMD +# %% EMD # use fast 1D solver G0 = ot.emd_1d(x, x, a, b) @@ -72,19 +72,19 @@ # G0 = ot.emd(a, b, M) pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') +ot.plot.plot1D_mat(a, b, G0, "OT matrix G0") ############################################################################## # Solve Sinkhorn # -------------- -#%% Sinkhorn +# %% Sinkhorn lambd = 1e-3 Gs = ot.sinkhorn(a, b, M, lambd, verbose=True) pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn') +ot.plot.plot1D_mat(a, b, Gs, "OT matrix Sinkhorn") pl.show() diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 4f233fe86..f9d3c2d4b 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -27,7 +27,7 @@ # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -47,17 +47,17 @@ # Plot distributions and loss matrix # ---------------------------------- -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") pl.legend() -#%% plot distributions and loss matrix +# %% plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +ot.plot.plot1D_mat(a, b, M, "Cost matrix M") ############################################################################## @@ -65,35 +65,36 @@ # --------------- -#%% Smooth OT with KL regularization +# %% Smooth OT with KL regularization lambd = 2e-3 -Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl') +Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type="kl") pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.') +ot.plot.plot1D_mat(a, b, Gsm, "OT matrix Smooth OT KL reg.") pl.show() -#%% Smooth OT with squared l2 regularization +# %% Smooth OT with squared l2 regularization lambd = 1e-1 -Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') +Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type="l2") pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') +ot.plot.plot1D_mat(a, b, Gsm, "OT matrix Smooth OT l2 reg.") pl.show() -#%% Sparsity-constrained OT +# %% Sparsity-constrained OT lambd = 1e-1 max_nz = 2 # two non-zero entries are permitted per column of the OT plan Gsc = ot.smooth.smooth_ot_dual( - a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz) + a, b, M, lambd, reg_type="sparsity_constrained", max_nz=max_nz +) pl.figure(5, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity constrained OT matrix; k=2.') +ot.plot.plot1D_mat(a, b, Gsc, "Sparsity constrained OT matrix; k=2.") pl.show() diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index 4b9889255..e51ce1285 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -25,7 +25,7 @@ # Generate data # ------------- -#%% parameters and data generation +# %% parameters and data generation n = 50 # nb samples @@ -33,7 +33,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) @@ -47,43 +47,43 @@ # Plot data # --------- -#%% plot samples +# %% plot samples pl.figure(1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") pl.figure(2) -pl.imshow(M, interpolation='nearest') -pl.title('Cost matrix M') +pl.imshow(M, interpolation="nearest") +pl.title("Cost matrix M") ############################################################################## # Compute EMD # ----------- -#%% EMD +# %% EMD G0 = ot.emd(a, b, M) pl.figure(3) -pl.imshow(G0, interpolation='nearest') -pl.title('OT matrix G0') +pl.imshow(G0, interpolation="nearest") +pl.title("OT matrix G0") pl.figure(4) -ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('OT matrix with samples') +pl.title("OT matrix with samples") ############################################################################## # Compute Sinkhorn # ---------------- -#%% sinkhorn +# %% sinkhorn # reg term lambd = 1e-1 @@ -91,15 +91,15 @@ Gs = ot.sinkhorn(a, b, M, lambd) pl.figure(5) -pl.imshow(Gs, interpolation='nearest') -pl.title('OT matrix sinkhorn') +pl.imshow(Gs, interpolation="nearest") +pl.title("OT matrix sinkhorn") pl.figure(6) -ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('OT matrix Sinkhorn with samples') +pl.title("OT matrix Sinkhorn with samples") pl.show() @@ -108,7 +108,7 @@ # Empirical Sinkhorn # ------------------- -#%% sinkhorn +# %% sinkhorn # reg term lambd = 1e-1 @@ -116,14 +116,14 @@ Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) pl.figure(7) -pl.imshow(Ges, interpolation='nearest') -pl.title('OT matrix empirical sinkhorn') +pl.imshow(Ges, interpolation="nearest") +pl.title("OT matrix empirical sinkhorn") pl.figure(8) -ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('OT matrix Sinkhorn from samples') +pl.title("OT matrix Sinkhorn from samples") pl.show() diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index e1d102cd9..ffcd93889 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -38,40 +38,40 @@ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples # loss matrix -M1 = ot.dist(xs, xt, metric='euclidean') +M1 = ot.dist(xs, xt, metric="euclidean") M1 /= M1.max() # loss matrix -M2 = ot.dist(xs, xt, metric='sqeuclidean') +M2 = ot.dist(xs, xt, metric="sqeuclidean") M2 /= M2.max() # loss matrix -Mp = ot.dist(xs, xt, metric='cityblock') +Mp = ot.dist(xs, xt, metric="cityblock") Mp /= Mp.max() # Data pl.figure(1, figsize=(7, 3)) pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') -pl.title('Source and target distributions') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") +pl.title("Source and target distributions") # Cost matrices pl.figure(2, figsize=(7, 3)) pl.subplot(1, 3, 1) -pl.imshow(M1, interpolation='nearest') -pl.title('Euclidean cost') +pl.imshow(M1, interpolation="nearest") +pl.title("Euclidean cost") pl.subplot(1, 3, 2) -pl.imshow(M2, interpolation='nearest') -pl.title('Squared Euclidean cost') +pl.imshow(M2, interpolation="nearest") +pl.title("Squared Euclidean cost") pl.subplot(1, 3, 3) -pl.imshow(Mp, interpolation='nearest') -pl.title('L1 (cityblock cost') +pl.imshow(Mp, interpolation="nearest") +pl.title("L1 (cityblock cost") pl.tight_layout() ############################################################################## @@ -79,7 +79,7 @@ # ---------------------------- -#%% EMD +# %% EMD G1 = ot.emd(a, b, M1) G2 = ot.emd(a, b, M2) Gp = ot.emd(a, b, Mp) @@ -88,28 +88,28 @@ pl.figure(3, figsize=(7, 3)) pl.subplot(1, 3, 1) -ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, G1, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT Euclidean') +pl.title("OT Euclidean") pl.subplot(1, 3, 2) -ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, G2, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT squared Euclidean') +pl.title("OT squared Euclidean") pl.subplot(1, 3, 3) -ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT L1 (cityblock)') +pl.title("OT L1 (cityblock)") pl.tight_layout() pl.show() @@ -121,10 +121,8 @@ n = 20 # nb samples xtot = np.zeros((n + 1, 2)) -xtot[:, 0] = np.cos( - (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) -xtot[:, 1] = np.sin( - (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) +xtot[:, 0] = np.cos((np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) +xtot[:, 1] = np.sin((np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xs = xtot[:n, :] xt = xtot[1:, :] @@ -132,41 +130,41 @@ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples # loss matrix -M1 = ot.dist(xs, xt, metric='euclidean') +M1 = ot.dist(xs, xt, metric="euclidean") M1 /= M1.max() # loss matrix -M2 = ot.dist(xs, xt, metric='sqeuclidean') +M2 = ot.dist(xs, xt, metric="sqeuclidean") M2 /= M2.max() # loss matrix -Mp = ot.dist(xs, xt, metric='cityblock') +Mp = ot.dist(xs, xt, metric="cityblock") Mp /= Mp.max() # Data pl.figure(4, figsize=(7, 3)) pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') -pl.title('Source and target distributions') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") +pl.title("Source and target distributions") # Cost matrices pl.figure(5, figsize=(7, 3)) pl.subplot(1, 3, 1) -pl.imshow(M1, interpolation='nearest') -pl.title('Euclidean cost') +pl.imshow(M1, interpolation="nearest") +pl.title("Euclidean cost") pl.subplot(1, 3, 2) -pl.imshow(M2, interpolation='nearest') -pl.title('Squared Euclidean cost') +pl.imshow(M2, interpolation="nearest") +pl.title("Squared Euclidean cost") pl.subplot(1, 3, 3) -pl.imshow(Mp, interpolation='nearest') -pl.title('L1 (cityblock) cost') +pl.imshow(Mp, interpolation="nearest") +pl.title("L1 (cityblock) cost") pl.tight_layout() ############################################################################## @@ -174,7 +172,7 @@ # ----------------------------- # -#%% EMD +# %% EMD G1 = ot.emd(a, b, M1) G2 = ot.emd(a, b, M2) Gp = ot.emd(a, b, Mp) @@ -183,28 +181,28 @@ pl.figure(6, figsize=(7, 3)) pl.subplot(1, 3, 1) -ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, G1, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT Euclidean') +pl.title("OT Euclidean") pl.subplot(1, 3, 2) -ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, G2, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT squared Euclidean') +pl.title("OT squared Euclidean") pl.subplot(1, 3, 3) -ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1]) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') -pl.axis('equal') +ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[0.5, 0.5, 1]) +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") +pl.axis("equal") # pl.legend(loc=0) -pl.title('OT L1 (cityblock)') +pl.title("OT L1 (cityblock)") pl.tight_layout() pl.show() diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 32d63e802..6f18f7e00 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -26,7 +26,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 100 # nb bins n_target = 20 # nb target distributions @@ -46,26 +46,26 @@ B[:, i] = gauss(n, m=m, s=5) # loss matrix and normalization -M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), "euclidean") M /= M.max() * 0.1 -M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') +M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), "sqeuclidean") M2 /= M2.max() * 0.1 ############################################################################## # Plot data # --------- -#%% plot the distributions +# %% plot the distributions pl.figure(1) pl.subplot(2, 1, 1) -pl.plot(x, a, 'r', label='Source distribution') -pl.title('Source distribution') +pl.plot(x, a, "r", label="Source distribution") +pl.title("Source distribution") pl.subplot(2, 1, 2) for i in range(n_target): - pl.plot(x, B[:, i], 'b', alpha=i / n_target) -pl.plot(x, B[:, -1], 'b', label='Target distributions') -pl.title('Target distributions') + pl.plot(x, B[:, i], "b", alpha=i / n_target) +pl.plot(x, B[:, -1], "b", label="Target distributions") +pl.title("Target distributions") pl.tight_layout() @@ -73,7 +73,7 @@ # Compute EMD for the different losses # ------------------------------------ -#%% Compute and plot distributions and loss matrix +# %% Compute and plot distributions and loss matrix d_emd = ot.emd2(a, B, M) # direct computation of OT loss d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metric M2 @@ -81,28 +81,28 @@ pl.figure(2) pl.subplot(2, 1, 1) -pl.plot(x, a, 'r', label='Source distribution') -pl.title('Distributions') +pl.plot(x, a, "r", label="Source distribution") +pl.title("Distributions") for i in range(n_target): - pl.plot(x, B[:, i], 'b', alpha=i / n_target) -pl.plot(x, B[:, -1], 'b', label='Target distributions') -pl.ylim((-.01, 0.13)) + pl.plot(x, B[:, i], "b", alpha=i / n_target) +pl.plot(x, B[:, -1], "b", label="Target distributions") +pl.ylim((-0.01, 0.13)) pl.xticks(()) pl.legend() pl.subplot(2, 1, 2) -pl.plot(d_emd, label='Euclidean OT') -pl.plot(d_emd2, label='Squared Euclidean OT') -pl.plot(d_tv, label='Total Variation (TV)') -#pl.xlim((-7,23)) -pl.xlabel('Displacement') -pl.title('Divergences') +pl.plot(d_emd, label="Euclidean OT") +pl.plot(d_emd2, label="Squared Euclidean OT") +pl.plot(d_tv, label="Total Variation (TV)") +# pl.xlim((-7,23)) +pl.xlabel("Displacement") +pl.title("Divergences") pl.legend() ############################################################################## # Compute Sinkhorn for the different losses # ----------------------------------------- -#%% +# %% reg = 1e-1 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) @@ -111,22 +111,22 @@ pl.clf() pl.subplot(2, 1, 1) -pl.plot(x, a, 'r', label='Source distribution') -pl.title('Distributions') +pl.plot(x, a, "r", label="Source distribution") +pl.title("Distributions") for i in range(n_target): - pl.plot(x, B[:, i], 'b', alpha=i / n_target) -pl.plot(x, B[:, -1], 'b', label='Target distributions') -pl.ylim((-.01, 0.13)) + pl.plot(x, B[:, i], "b", alpha=i / n_target) +pl.plot(x, B[:, -1], "b", label="Target distributions") +pl.ylim((-0.01, 0.13)) pl.xticks(()) pl.legend() pl.subplot(2, 1, 2) -pl.plot(d_emd, label='Euclidean OT') -pl.plot(d_emd2, label='Squared Euclidean OT') -pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') -pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') -pl.plot(d_tv, label='Total Variation (TV)') -#pl.xlim((-7,23)) -pl.xlabel('Displacement') -pl.title('Divergences') +pl.plot(d_emd, label="Euclidean OT") +pl.plot(d_emd2, label="Squared Euclidean OT") +pl.plot(d_sinkhorn, "+", label="Euclidean Sinkhorn") +pl.plot(d_sinkhorn2, "+", label="Squared Euclidean Sinkhorn") +pl.plot(d_tv, label="Total Variation (TV)") +# pl.xlim((-7,23)) +pl.xlabel("Displacement") +pl.title("Divergences") pl.legend() pl.show() diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py index 3ede96f3c..0335fcbe7 100644 --- a/examples/plot_compute_wasserstein_circle.py +++ b/examples/plot_compute_wasserstein_circle.py @@ -25,7 +25,7 @@ # Plot data # --------- -#%% plot the distributions +# %% plot the distributions def pdf_von_Mises(theta, mu, kappa): @@ -51,7 +51,7 @@ def pdf_von_Mises(theta, mu, kappa): label = "Source distributions" else: label = None - pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + pl.plot(t / (2 * np.pi), pdf_t, c="b", label=label) pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") pl.legend() @@ -82,7 +82,7 @@ def pdf_von_Mises(theta, mu, kappa): # and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the # Euclidean version. -#%% Compute and plot distributions +# %% Compute and plot distributions mu_targets = np.linspace(0, 2 * np.pi, 200) xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi @@ -118,10 +118,24 @@ def pdf_von_Mises(theta, mu, kappa): pl.figure(1) pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") -pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.fill_between( + mu_targets / (2 * np.pi), + m_w2_circle - 2 * std_w2_circle, + m_w2_circle + 2 * std_w2_circle, + alpha=0.5, +) pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") -pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) -pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.fill_between( + mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5 +) +pl.vlines( + x=[mu1 / (2 * np.pi)], + ymin=0, + ymax=np.max(w2), + linestyle="--", + color="k", + label=r"$\mu_{\mathrm{target}}$", +) pl.legend() pl.xlabel(r"$\mu_{\mathrm{source}}$") pl.show() @@ -132,7 +146,7 @@ def pdf_von_Mises(theta, mu, kappa): # ---------------------------------------------------------------------- # When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. -#%% Compute Wasserstein between Von Mises and uniform +# %% Compute Wasserstein between Von Mises and uniform kappas = np.logspace(-5, 2, 100) n_try = 20 diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 7b021d22b..f08fd286c 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -35,7 +35,7 @@ # Generate data # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -54,18 +54,18 @@ # Solve EMD # --------- -#%% EMD +# %% EMD G0 = ot.emd(a, b, M) pl.figure(1, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') +ot.plot.plot1D_mat(a, b, G0, "OT matrix G0") ############################################################################## # Solve EMD with Frobenius norm regularization # -------------------------------------------- -#%% Example with Frobenius norm regularization +# %% Example with Frobenius norm regularization def f(G): @@ -81,13 +81,13 @@ def df(G): Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True) pl.figure(2) -ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg') +ot.plot.plot1D_mat(a, b, Gl2, "OT matrix Frob. reg") ############################################################################## # Solve EMD with entropic regularization # -------------------------------------- -#%% Example with entropic regularization +# %% Example with entropic regularization def f(G): @@ -95,7 +95,7 @@ def f(G): def df(G): - return np.log(G) + 1. + return np.log(G) + 1.0 reg = 1e-3 @@ -103,13 +103,13 @@ def df(G): Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True) pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg') +ot.plot.plot1D_mat(a, b, Ge, "OT matrix Entrop. reg") ############################################################################## # Solve EMD with Frobenius norm + entropic regularization # ------------------------------------------------------- -#%% Example with Frobenius norm + entropic regularization with gcg +# %% Example with Frobenius norm + entropic regularization with gcg def f(G): @@ -126,7 +126,7 @@ def df(G): Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True) pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg') +ot.plot.plot1D_mat(a, b, Gel2, "OT entropic + matrix Frob. reg") pl.show() @@ -139,20 +139,20 @@ def df(G): pl.subplot(2, 2, 1) pl.imshow(G0[:nvisu, :]) -pl.axis('off') -pl.title('Exact OT') +pl.axis("off") +pl.title("Exact OT") pl.subplot(2, 2, 2) pl.imshow(Gl2[:nvisu, :]) -pl.axis('off') -pl.title('Frobenius reg.') +pl.axis("off") +pl.title("Frobenius reg.") pl.subplot(2, 2, 3) pl.imshow(Ge[:nvisu, :]) -pl.axis('off') -pl.title('Entropic reg.') +pl.axis("off") +pl.title("Entropic reg.") pl.subplot(2, 2, 4) pl.imshow(Gel2[:nvisu, :]) -pl.axis('off') -pl.title('Entropic + Frobenius reg.') +pl.axis("off") +pl.title("Entropic + Frobenius reg.") diff --git a/examples/plot_solve_variants.py b/examples/plot_solve_variants.py index 82f892a52..aca59ffcb 100644 --- a/examples/plot_solve_variants.py +++ b/examples/plot_solve_variants.py @@ -4,7 +4,7 @@ Optimal Transport solvers comparison ====================================== -This example illustrates the solutions returns for diffrent variants of exact, +This example illustrates the solutions returns for different variants of exact, regularized and unbalanced OT solvers. """ @@ -13,7 +13,7 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 3 -#%% +# %% import numpy as np import matplotlib.pylab as pl @@ -26,7 +26,7 @@ # ------------- -#%% parameters +# %% parameters n = 50 # nb bins @@ -46,17 +46,17 @@ # Plot distributions and loss matrix # ---------------------------------- -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") pl.legend() -#%% plot distributions and loss matrix +# %% plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +ot.plot.plot1D_mat(a, b, M, "Cost matrix M") ############################################################################## # Define Group lasso regularization and gradient @@ -65,16 +65,16 @@ def reg_gl(G): # group lasso + small l2 reg - G1 = G[:n // 2, :]**2 - G2 = G[n // 2:, :]**2 + G1 = G[: n // 2, :] ** 2 + G2 = G[n // 2 :, :] ** 2 gl1 = np.sum(np.sqrt(np.sum(G1, 0))) gl2 = np.sum(np.sqrt(np.sum(G2, 0))) return gl1 + gl2 + 0.1 * np.sum(G**2) def grad_gl(G): # gradient of group lasso + small l2 reg - G1 = G[:n // 2, :] - G2 = G[n // 2:, :] + G1 = G[: n // 2, :] + G2 = G[n // 2 :, :] gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8) gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8) return np.concatenate((gl1, gl2), axis=0) + 0.2 * G @@ -87,38 +87,65 @@ def grad_gl(G): # gradient of group lasso + small l2 reg # --------------------------------------- lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] -lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] +lst_unbalanced = [ + "Balanced", + "Unbalanced KL", + "Unbalanced L2", + "Unb. TV (Partial)", +] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] lst_solvers = [ # name, param for ot.solve function # balanced OT - ('Exact OT', dict()), - ('Entropic Reg. OT', dict(reg=0.005)), - ('L2 Reg OT', dict(reg=1, reg_type='l2')), - ('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)), - - + ("Exact OT", dict()), + ("Entropic Reg. OT", dict(reg=0.005)), + ("L2 Reg OT", dict(reg=1, reg_type="l2")), + ("Group Lasso Reg. OT", dict(reg=0.1, reg_type=reg_type_gl)), # unbalanced OT KL - ('Unbalanced KL No Reg.', dict(unbalanced=0.005)), - ('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')), - ('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')), - ('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')), - + ("Unbalanced KL No Reg.", dict(unbalanced=0.005)), + ( + "Unbalanced KL with KL Reg.", + dict(reg=0.0005, unbalanced=0.005, unbalanced_type="kl", reg_type="kl"), + ), + ( + "Unbalanced KL with L2 Reg.", + dict(reg=0.5, reg_type="l2", unbalanced=0.005, unbalanced_type="kl"), + ), + ( + "Unbalanced KL with Group Lasso Reg.", + dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type="kl"), + ), # unbalanced OT L2 - ('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')), - ('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')), - ('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')), - ('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')), - + ("Unbalanced L2 No Reg.", dict(unbalanced=0.5, unbalanced_type="l2")), + ( + "Unbalanced L2 with KL Reg.", + dict(reg=0.001, unbalanced=0.2, unbalanced_type="l2"), + ), + ( + "Unbalanced L2 with L2 Reg.", + dict(reg=0.1, reg_type="l2", unbalanced=0.2, unbalanced_type="l2"), + ), + ( + "Unbalanced L2 with Group Lasso Reg.", + dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type="l2"), + ), # unbalanced OT TV - ('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')), - ('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')), - ('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')), - ('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')), - + ("Unbalanced TV No Reg.", dict(unbalanced=0.1, unbalanced_type="tv")), + ( + "Unbalanced TV with KL Reg.", + dict(reg=0.001, unbalanced=0.01, unbalanced_type="tv"), + ), + ( + "Unbalanced TV with L2 Reg.", + dict(reg=0.1, reg_type="l2", unbalanced=0.01, unbalanced_type="tv"), + ), + ( + "Unbalanced TV with Group Lasso Reg.", + dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type="tv"), + ), ] lst_plans = [] -for (name, param) in lst_solvers: +for name, param in lst_solvers: G = ot.solve(M, a, b, **param).plan lst_plans.append(G) @@ -136,14 +163,15 @@ def grad_gl(G): # gradient of group lasso + small l2 reg m2 = plan.sum(0) m1 = plan.sum(1) m1, m2 = m1 / a.max(), m2 / b.max() - pl.imshow(plan, cmap='Greys') - pl.plot(x, m2 * 10, 'r') - pl.plot(m1 * 10, x, 'b') - pl.plot(x, b / b.max() * 10, 'r', alpha=0.3) - pl.plot(a / a.max() * 10, x, 'b', alpha=0.3) - #pl.axis('off') - pl.tick_params(left=False, right=False, labelleft=False, - labelbottom=False, bottom=False) + pl.imshow(plan, cmap="Greys") + pl.plot(x, m2 * 10, "r") + pl.plot(m1 * 10, x, "b") + pl.plot(x, b / b.max() * 10, "r", alpha=0.3) + pl.plot(a / a.max() * 10, x, "b", alpha=0.3) + # pl.axis('off') + pl.tick_params( + left=False, right=False, labelleft=False, labelbottom=False, bottom=False + ) if i == 0: pl.title(rname) if j == 0: diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 77df2f5ca..58256fd44 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -35,7 +35,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) @@ -49,10 +49,10 @@ # %% plot samples pl.figure(1) -pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +pl.plot(xt[:, 0], xt[:, 1], "xr", label="Target samples") pl.legend(loc=0) -pl.title('Source and target distributions') +pl.title("Source and target distributions") ############################################################################### # Sliced Wasserstein distance for different seeds and number of projections @@ -65,7 +65,9 @@ # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) + res[seed, i] = ot.sliced_wasserstein_distance( + xs, xt, a, b, n_projections, seed=seed + ) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) @@ -76,13 +78,15 @@ pl.figure(2) pl.plot(n_projections_arr, res_mean, label="SWD") -pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) +pl.fill_between( + n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5 +) pl.legend() -pl.xscale('log') +pl.xscale("log") pl.xlabel("Number of projections") pl.ylabel("Distance") -pl.title('Sliced Wasserstein Distance with 95% confidence interval') +pl.title("Sliced Wasserstein Distance with 95% confidence interval") pl.show() diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py index 246b2a8ae..ad31e9cf1 100644 --- a/examples/sliced-wasserstein/plot_variance_ssw.py +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -45,10 +45,10 @@ # %% plot samples fig = pl.figure(figsize=(10, 10)) -ax = pl.axes(projection='3d') +ax = pl.axes(projection="3d") ax.grid(False) -u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +u, v = np.mgrid[0 : 2 * np.pi : 30j, 0 : np.pi : 30j] x = np.cos(u) * np.sin(v) y = np.sin(u) * np.sin(v) z = np.cos(v) @@ -60,9 +60,9 @@ fs = 10 # Labels -ax.set_xlabel('x', fontsize=fs) -ax.set_ylabel('y', fontsize=fs) -ax.set_zlabel('z', fontsize=fs) +ax.set_xlabel("x", fontsize=fs) +ax.set_ylabel("y", fontsize=fs) +ax.set_zlabel("z", fontsize=fs) ax.view_init(20, 120) ax.set_xlim(-1.5, 1.5) @@ -88,7 +88,9 @@ # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + res[seed, i] = ot.sliced_wasserstein_sphere( + xs, xt, a, b, n_projections, seed=seed, p=1 + ) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) @@ -99,13 +101,15 @@ pl.figure(2) pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") -pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) +pl.fill_between( + n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5 +) pl.legend() -pl.xscale('log') +pl.xscale("log") pl.xlabel("Number of projections") pl.ylabel("Distance") -pl.title('Spherical Sliced Wasserstein Distance with 95% confidence interval') +pl.title("Spherical Sliced Wasserstein Distance with 95% confidence interval") pl.show() diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 06dd02d93..ade4bbb0c 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -25,7 +25,7 @@ # ------------- -#%% parameters +# %% parameters n = 100 # nb bins @@ -37,7 +37,7 @@ b = gauss(n, m=60, s=10) # make distributions unbalanced -b *= 5. +b *= 5.0 # loss matrix M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) @@ -48,17 +48,17 @@ # Plot distributions and loss matrix # ---------------------------------- -#%% plot the distributions +# %% plot the distributions pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") pl.legend() # plot distributions and loss matrix pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +ot.plot.plot1D_mat(a, b, M, "Cost matrix M") ############################################################################## @@ -68,11 +68,11 @@ # Sinkhorn epsilon = 0.1 # entropy parameter -alpha = 1. # Unbalanced KL relaxation parameter +alpha = 1.0 # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') +ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn") pl.show() @@ -82,9 +82,9 @@ # ------------------------- pl.figure(4, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') -pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') -pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') -pl.legend(loc='upper right') -pl.title('Distributions and transported mass for UOT') +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source") +pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py index de1a3b3d5..89015456e 100644 --- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -21,6 +21,7 @@ import numpy as np import matplotlib.pylab as pl import ot + # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa from matplotlib.collections import PolyCollection @@ -41,7 +42,7 @@ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) # make unbalanced dists -a2 *= 3. +a2 *= 3.0 # creating matrix A containing all distributions A = np.vstack((a1, a2)).T @@ -60,7 +61,7 @@ pl.figure(1, figsize=(6.4, 3)) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.tight_layout() ############################################################################## @@ -77,7 +78,7 @@ # wasserstein reg = 1e-3 -alpha = 1. +alpha = 1.0 bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) @@ -86,13 +87,13 @@ pl.subplot(2, 1, 1) for i in range(n_distributions): pl.plot(x, A[:, i]) -pl.title('Distributions') +pl.title("Distributions") pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.plot(x, bary_l2, "r", label="l2") +pl.plot(x, bary_wass, "g", label="Wasserstein") pl.legend() -pl.title('Barycenters') +pl.title("Barycenters") pl.tight_layout() ############################################################################## @@ -113,54 +114,56 @@ weight = weight_list[i] weights = np.array([1 - weight, weight]) B_l2[:, i] = A.dot(weights) - B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced( + A, M, reg, alpha, weights=weights + ) # plot interpolation pl.figure(3) -cmap = pl.get_cmap('viridis') +cmap = pl.get_cmap("viridis") verts = [] zs = weight_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().add_subplot(projection='3d') +ax = pl.gcf().add_subplot(projection="3d") poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') -ax.set_xlabel('x') +ax.add_collection3d(poly, zs=zs, zdir="y") +ax.set_xlabel("x") ax.set_xlim3d(0, n) -ax.set_ylabel(r'$\alpha$') +ax.set_ylabel(r"$\alpha$") ax.set_ylim3d(0, 1) -ax.set_zlabel('') +ax.set_zlabel("") ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with l2') +pl.title("Barycenter interpolation with l2") pl.tight_layout() pl.figure(4) -cmap = pl.get_cmap('viridis') +cmap = pl.get_cmap("viridis") verts = [] zs = weight_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().add_subplot(projection='3d') +ax = pl.gcf().add_subplot(projection="3d") poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') -ax.set_xlabel('x') +ax.add_collection3d(poly, zs=zs, zdir="y") +ax.set_xlabel("x") ax.set_xlim3d(0, n) -ax.set_ylabel(r'$\alpha$') +ax.set_ylabel(r"$\alpha$") ax.set_ylim3d(0, 1) -ax.set_zlabel('') +ax.set_zlabel("") ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with Wasserstein') +pl.title("Barycenter interpolation with Wasserstein") pl.tight_layout() pl.show() diff --git a/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py index 3d49c313c..4e194c4bf 100644 --- a/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py +++ b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py @@ -40,7 +40,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) ############################################################################## @@ -56,8 +56,8 @@ xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) - xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) - xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0) + xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0) n = n + n_noise @@ -67,10 +67,29 @@ M = ot.dist(xs, xt) M /= M.max() - entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0) - entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", - method="sinkhorn_translation_invariant", log=True, - numItermax=num_iter_max, stopThr=0) + entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced( + a, + b, + M, + reg, + reg_m_kl, + reg_type="kl", + log=True, + numItermax=num_iter_max, + stopThr=0, + ) + entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced( + a, + b, + M, + reg, + reg_m_kl, + reg_type="kl", + method="sinkhorn_translation_invariant", + log=True, + numItermax=num_iter_max, + stopThr=0, + ) err_sinkhorn_uot[seed] = log_uot["err"] err_sinkhorn_uot_ti[seed] = log_uot_ti["err"] @@ -91,7 +110,9 @@ pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5) pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn") -pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5) +pl.fill_between( + absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5 +) pl.yscale("log") pl.legend() diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 5c85a5a22..5ccc197d6 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -45,9 +45,9 @@ fig = pl.figure() ax1 = fig.add_subplot(131) -ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +ax1.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") ax2 = fig.add_subplot(132) -ax2.scatter(xt[:, 0], xt[:, 1], color='r') +ax2.scatter(xt[:, 0], xt[:, 1], color="r") ax3 = fig.add_subplot(133) ax3.imshow(M) pl.show() @@ -61,20 +61,18 @@ q = ot.unif(n_samples + n_noise) w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True) -w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5, - log=True) +w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5, log=True) -print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist'])) -print('Entropic partial Wasserstein distance (m = 0.5): ' + - str(log['partial_w_dist'])) +print("Partial Wasserstein distance (m = 0.5): " + str(log0["partial_w_dist"])) +print("Entropic partial Wasserstein distance (m = 0.5): " + str(log["partial_w_dist"])) pl.figure(1, (10, 5)) pl.subplot(1, 2, 1) -pl.imshow(w0, cmap='jet') -pl.title('Partial Wasserstein') +pl.imshow(w0, cmap="jet") +pl.title("Partial Wasserstein") pl.subplot(1, 2, 2) -pl.imshow(w, cmap='jet') -pl.title('Entropic partial Wasserstein') +pl.imshow(w, cmap="jet") +pl.title("Entropic partial Wasserstein") pl.show() @@ -108,9 +106,9 @@ fig = pl.figure() ax1 = fig.add_subplot(121) -ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') -ax2 = fig.add_subplot(122, projection='3d') -ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r') +ax1.plot(xs[:, 0], xs[:, 1], "+b", label="Source samples") +ax2 = fig.add_subplot(122, projection="3d") +ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color="r") pl.show() @@ -123,46 +121,45 @@ C2 = sp.spatial.distance.cdist(xt, xt) # transport 100% of the mass -print('------m = 1') +print("------m = 1") m = 1 res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) -res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True, - verbose=True) +res, log = ot.gromov.entropic_partial_gromov_wasserstein( + C1, C2, p, q, 10, m=m, log=True, verbose=True +) -print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) -print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) +print("Wasserstein distance (m = 1): " + str(log0["partial_gw_dist"])) +print("Entropic Wasserstein distance (m = 1): " + str(log["partial_gw_dist"])) pl.figure(1, (10, 5)) pl.title("mass to be transported m = 1") pl.subplot(1, 2, 1) -pl.imshow(res0, cmap='jet') -pl.title('Gromov-Wasserstein') +pl.imshow(res0, cmap="jet") +pl.title("Gromov-Wasserstein") pl.subplot(1, 2, 2) -pl.imshow(res, cmap='jet') -pl.title('Entropic Gromov-Wasserstein') +pl.imshow(res, cmap="jet") +pl.title("Entropic Gromov-Wasserstein") pl.show() # transport 2/3 of the mass -print('------m = 2/3') +print("------m = 2/3") m = 2 / 3 -res0, log0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True, - verbose=True) -res, log = ot.gromov.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True, - verbose=True) +res0, log0 = ot.gromov.partial_gromov_wasserstein( + C1, C2, p, q, m=m, log=True, verbose=True +) +res, log = ot.gromov.entropic_partial_gromov_wasserstein( + C1, C2, p, q, 10, m=m, log=True, verbose=True +) -print('Partial Wasserstein distance (m = 2/3): ' + - str(log0['partial_gw_dist'])) -print('Entropic partial Wasserstein distance (m = 2/3): ' + - str(log['partial_gw_dist'])) +print("Partial Wasserstein distance (m = 2/3): " + str(log0["partial_gw_dist"])) +print("Entropic partial Wasserstein distance (m = 2/3): " + str(log["partial_gw_dist"])) pl.figure(1, (10, 5)) pl.title("mass to be transported m = 2/3") pl.subplot(1, 2, 1) -pl.imshow(res0, cmap='jet') -pl.title('Partial Gromov-Wasserstein') +pl.imshow(res0, cmap="jet") +pl.title("Partial Gromov-Wasserstein") pl.subplot(1, 2, 2) -pl.imshow(res, cmap='jet') -pl.title('Entropic partial Gromov-Wasserstein') +pl.imshow(res, cmap="jet") +pl.title("Entropic partial Gromov-Wasserstein") pl.show() diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index ffedc6e8b..7f3dab6f7 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -25,7 +25,7 @@ # Generate data # ------------- -#%% parameters and data generation +# %% parameters and data generation n = 20 # nb samples @@ -33,7 +33,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) np.random.seed(0) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) @@ -49,25 +49,27 @@ # Plot data # --------- -#%% plot 2 distribution samples +# %% plot 2 distribution samples pl.figure(1) -pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source') -pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target') +pl.scatter(xs[:, 0], xs[:, 1], c="C0", label="Source") +pl.scatter(xt[:, 0], xt[:, 1], c="C1", label="Target") pl.legend(loc=2) -pl.title('Source and target distributions') +pl.title("Source and target distributions") pl.show() ############################################################################## # Compute semi-relaxed and fully relaxed regularization paths # ----------------------------------------------------------- -#%% +# %% final_gamma = 1e-6 -t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma, - semi_relaxed=False) -t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, - semi_relaxed=True) +t, t_list, g_list = ot.regpath.regularization_path( + a, b, M, reg=final_gamma, semi_relaxed=False +) +t2, t_list2, g_list2 = ot.regpath.regularization_path( + a, b, M, reg=final_gamma, semi_relaxed=True +) ############################################################################## @@ -77,13 +79,12 @@ # The OT plan is plotted as a function of $\gamma$ that is the inverse of the # weight on the marginal relaxations. -#%% fully relaxed l2-penalized UOT +# %% fully relaxed l2-penalized UOT pl.figure(2) selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3] for p in range(4): - tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list, - t_list) + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list, t_list) P = tp.reshape((n, n)) pl.subplot(2, 2, p + 1) if P.sum() > 0: @@ -91,17 +92,32 @@ for i in range(n): for j in range(n): if P[i, j] > 0: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', - alpha=P[i, j] * 0.3) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2, - label='Re-weighted source', alpha=1) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2, - label='Re-weighted target', alpha=1) - pl.plot([], [], color='C2', alpha=0.8, label='OT plan') - pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]), - fontsize=11) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + color="C2", + alpha=P[i, j] * 0.3, + ) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2) + pl.scatter( + xs[:, 0], + xs[:, 1], + c="C0", + s=P.sum(1).ravel() * (1 + p) * 2, + label="Re-weighted source", + alpha=1, + ) + pl.scatter( + xt[:, 0], + xt[:, 1], + c="C1", + s=P.sum(0).ravel() * (1 + p) * 2, + label="Re-weighted target", + alpha=1, + ) + pl.plot([], [], color="C2", alpha=0.8, label="OT plan") + pl.title(r"$\ell_2$ UOT $\gamma$={}".format(selected_gamma[p]), fontsize=11) if p < 2: pl.xticks(()) pl.show() @@ -112,52 +128,67 @@ # ----------------------------------- nv = 50 -g_list_v = np.logspace(-.5, -2.5, nv) +g_list_v = np.logspace(-0.5, -2.5, nv) pl.figure(3) def _update_plot(iv): pl.clf() - tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, - t_list) + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, t_list) P = tp.reshape((n, n)) if P.sum() > 0: P = P / P.max() for i in range(n): for j in range(n): if P[i, j] > 0: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', - alpha=P[i, j] * 0.5) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, - label='Re-weighted source', alpha=1) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, - label='Re-weighted target', alpha=1) - pl.plot([], [], color='C2', alpha=0.8, label='OT plan') - pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), - fontsize=11) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + color="C2", + alpha=P[i, j] * 0.5, + ) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2) + pl.scatter( + xs[:, 0], + xs[:, 1], + c="C0", + s=P.sum(1).ravel() * (1 + p) * 4, + label="Re-weighted source", + alpha=1, + ) + pl.scatter( + xt[:, 0], + xt[:, 1], + c="C1", + s=P.sum(0).ravel() * (1 + p) * 4, + label="Re-weighted target", + alpha=1, + ) + pl.plot([], [], color="C2", alpha=0.8, label="OT plan") + pl.title(r"$\ell_2$ UOT $\gamma$={:1.3f}".format(g_list_v[iv]), fontsize=11) return 1 i = 0 _update_plot(i) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000 +) ############################################################################## # Plot the semi-relaxed regularization path # ----------------------------------------- -#%% semi-relaxed l2-penalized UOT +# %% semi-relaxed l2-penalized UOT pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): - tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, - t_list2) + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, t_list2) P = tp.reshape((n, n)) pl.subplot(2, 2, p + 1) if P.sum() > 0: @@ -165,15 +196,26 @@ def _update_plot(iv): for i in range(n): for j in range(n): if P[i, j] > 0: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', - alpha=P[i, j] * 0.3) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal') - pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p), - label='Source marginal', alpha=1) - pl.plot([], [], color='C2', alpha=0.8, label='OT plan') - pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]), - fontsize=11) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + color="C2", + alpha=P[i, j] * 0.3, + ) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=1, label="Target marginal") + pl.scatter( + xs[:, 0], + xs[:, 1], + c="C0", + s=P.sum(1).ravel() * 2 * (1 + p), + label="Source marginal", + alpha=1, + ) + pl.plot([], [], color="C2", alpha=0.8, label="OT plan") + pl.title( + r"Semi-relaxed $l_2$ UOT $\gamma$={}".format(selected_gamma[p]), fontsize=11 + ) if p < 2: pl.xticks(()) pl.show() @@ -191,29 +233,47 @@ def _update_plot(iv): def _update_plot(iv): pl.clf() - tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, - t_list2) + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, t_list2) P = tp.reshape((n, n)) if P.sum() > 0: P = P / P.max() for i in range(n): for j in range(n): if P[i, j] > 0: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', - alpha=P[i, j] * 0.5) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, - label='Re-weighted source', alpha=1) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, - label='Re-weighted target', alpha=1) - pl.plot([], [], color='C2', alpha=0.8, label='OT plan') - pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), - fontsize=11) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + color="C2", + alpha=P[i, j] * 0.5, + ) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2) + pl.scatter( + xs[:, 0], + xs[:, 1], + c="C0", + s=P.sum(1).ravel() * (1 + p) * 4, + label="Re-weighted source", + alpha=1, + ) + pl.scatter( + xt[:, 0], + xt[:, 1], + c="C1", + s=P.sum(0).ravel() * (1 + p) * 4, + label="Re-weighted target", + alpha=1, + ) + pl.plot([], [], color="C2", alpha=0.8, label="OT plan") + pl.title( + r"Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}".format(g_list_v[iv]), fontsize=11 + ) return 1 i = 0 _update_plot(i) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000) +ani = animation.FuncAnimation( + pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000 +) diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py index 03487e7e2..8351eafd6 100644 --- a/examples/unbalanced-partial/plot_unbalanced_OT.py +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -44,7 +44,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) -cov_t = np.array([[1, -.8], [-.8, 1]]) +cov_t = np.array([[1, -0.8], [-0.8, 1]]) np.random.seed(0) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) @@ -52,8 +52,8 @@ n_noise = 10 -xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) -xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) +xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0) +xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0) n = n + n_noise @@ -74,8 +74,8 @@ mass = 0.7 entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) -kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') -l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div="kl") +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div="l2") partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) ############################################################################## @@ -84,9 +84,12 @@ pl.figure(2) transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] -title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + - str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), - "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] +title = [ + "partial OT \n m=" + str(mass), + "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + str(reg_m_l2), + "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), +] for p in range(4): pl.subplot(2, 4, p + 1) @@ -96,19 +99,23 @@ for i in range(n): for j in range(n): if P[i, j] > 0: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', - alpha=P[i, j] * 0.3) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) - pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) - pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + color="C2", + alpha=P[i, j] * 0.3, + ) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c="C0", s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c="C1", s=P.sum(0).ravel() * (1 + p) * 2) pl.title(title[p]) pl.yticks(()) pl.xticks(()) if p < 1: pl.ylabel("mappings") pl.subplot(2, 4, p + 5) - pl.imshow(P, cmap='jet') + pl.imshow(P, cmap="jet") pl.yticks(()) pl.xticks(()) if p < 1: diff --git a/ignore-words.txt b/ignore-words.txt new file mode 100644 index 000000000..00c1f5edb --- /dev/null +++ b/ignore-words.txt @@ -0,0 +1,9 @@ +embedd +ot +OT +coo +wass +ccompiler +ist +lik +ges \ No newline at end of file diff --git a/ot/__init__.py b/ot/__init__.py index 3334dc229..fa007c78f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -16,7 +16,6 @@ # # License: MIT License - # All submodules and packages from . import lp from . import bregman @@ -40,18 +39,33 @@ # OT functions -from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein_circle, - semidiscrete_wasserstein2_unif_circle) +from .lp import ( + emd, + emd2, + emd_1d, + emd2_1d, + wasserstein_1d, + binary_search_circle, + wasserstein_circle, + semidiscrete_wasserstein2_unif_circle, +) from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, - sinkhorn_unbalanced2) +from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm -from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, - sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) -from .gromov import (gromov_wasserstein, gromov_wasserstein2, - gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2, - lowrank_gromov_wasserstein_samples) +from .sliced import ( + sliced_wasserstein_distance, + max_sliced_wasserstein_distance, + sliced_wasserstein_sphere, + sliced_wasserstein_sphere_unif, +) +from .gromov import ( + gromov_wasserstein, + gromov_wasserstein2, + gromov_barycenters, + fused_gromov_wasserstein, + fused_gromov_wasserstein2, + lowrank_gromov_wasserstein_samples, +) from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve, solve_gromov, solve_sample @@ -62,16 +76,60 @@ __version__ = "0.9.5dev0" -__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', - 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', - 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', - 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', - 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', - 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', - 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn', - 'lowrank_gromov_wasserstein_samples'] +__all__ = [ + "emd", + "emd2", + "emd_1d", + "sinkhorn", + "sinkhorn2", + "utils", + "datasets", + "bregman", + "lp", + "tic", + "toc", + "toq", + "gromov", + "emd2_1d", + "wasserstein_1d", + "backend", + "gaussian", + "dist", + "unif", + "barycenter", + "sinkhorn_lpl1_mm", + "da", + "optim", + "sinkhorn_unbalanced", + "barycenter_unbalanced", + "sinkhorn_unbalanced2", + "sliced_wasserstein_distance", + "sliced_wasserstein_sphere", + "gromov_wasserstein", + "gromov_wasserstein2", + "gromov_barycenters", + "fused_gromov_wasserstein", + "fused_gromov_wasserstein2", + "max_sliced_wasserstein_distance", + "weak_optimal_transport", + "factored_optimal_transport", + "solve", + "solve_gromov", + "solve_sample", + "smooth", + "stochastic", + "unbalanced", + "partial", + "regpath", + "solvers", + "weak", + "factored", + "lowrank", + "gmm", + "binary_search_circle", + "wasserstein_circle", + "semidiscrete_wasserstein2_unif_circle", + "sliced_wasserstein_sphere_unif", + "lowrank_sinkhorn", + "lowrank_gromov_wasserstein_samples", +] diff --git a/ot/backend.py b/ot/backend.py index 427cf3018..a99639445 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -96,15 +96,16 @@ import scipy.special as special from scipy.sparse import coo_matrix, csr_matrix, issparse -DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH' -DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX' -DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY' -DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW' +DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH" +DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX" +DISABLE_CUPY_KEY = "POT_BACKEND_DISABLE_CUPY" +DISABLE_TF_KEY = "POT_BACKEND_DISABLE_TENSORFLOW" if not os.environ.get(DISABLE_TORCH_KEY, False): try: import torch + torch_type = torch.Tensor except ImportError: torch = False @@ -119,8 +120,9 @@ import jax.numpy as jnp import jax.scipy.special as jspecial from jax.lib import xla_bridge + jax_type = jax.numpy.ndarray - jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24 + jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24 except ImportError: jax = False jax_type = float @@ -132,6 +134,7 @@ try: import cupy as cp import cupyx + cp_type = cp.ndarray except ImportError: cp = False @@ -144,6 +147,7 @@ try: import tensorflow as tf import tensorflow.experimental.numpy as tnp + tf_type = tf.Tensor except ImportError: tf = False @@ -153,7 +157,9 @@ tf_type = float -str_type_error = "All array should be from the same type/backend. Current types are : {}" +str_type_error = ( + "All array should be from the same type/backend. Current types are : {}" +) # Mapping between argument types and the existing backend @@ -194,8 +200,7 @@ def get_backend_list(): """ return [ _get_backend_instance(backend_impl) - for backend_impl - in get_available_backend_implementations() + for backend_impl in get_available_backend_implementations() ] @@ -207,9 +212,9 @@ def get_available_backend_implementations(): def get_backend(*args): """Returns the proper backend for a list of input arrays - Accepts None entries in the arguments, and ignores them + Accepts None entries in the arguments, and ignores them - Also raises TypeError if all arrays are not from the same backend + Also raises TypeError if all arrays are not from the same backend """ args = [arg for arg in args if arg is not None] # exclude None entries @@ -233,7 +238,7 @@ def to_numpy(*args): return [get_backend(a).to_numpy(a) for a in args] -class Backend(): +class Backend: """ Backend abstract class. Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`, @@ -279,7 +284,7 @@ def _from_numpy(self, a, type_as=None): raise NotImplementedError() def set_gradients(self, val, inputs, grads): - """Define the gradients for the value val wrt the inputs """ + """Define the gradients for the value val wrt the inputs""" raise NotImplementedError() def detach(self, *arrays): @@ -406,7 +411,7 @@ def minimum(self, a, b): raise NotImplementedError() def sign(self, a): - r""" Returns an element-wise indication of the sign of a number. + r"""Returns an element-wise indication of the sign of a number. This function follows the api from :any:`numpy.sign` @@ -544,7 +549,7 @@ def argsort(self, a, axis=None): """ raise NotImplementedError() - def searchsorted(self, a, v, side='left'): + def searchsorted(self, a, v, side="left"): r""" Finds indices where elements should be inserted to maintain order in given tensor. @@ -804,7 +809,7 @@ def tocsr(self, a): """ raise NotImplementedError() - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): r""" Removes entries smaller than the given threshold from the sparse tensor. @@ -1077,10 +1082,9 @@ class NumpyBackend(Backend): - `__type__` is np.ndarray """ - __name__ = 'numpy' + __name__ = "numpy" __type__ = np.ndarray - __type_list__ = [np.array(1, dtype=np.float32), - np.array(1, dtype=np.float64)] + __type_list__ = [np.array(1, dtype=np.float32), np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -1190,7 +1194,7 @@ def sort(self, a, axis=-1): def argsort(self, a, axis=-1): return np.argsort(a, axis) - def searchsorted(self, a, v, side='left'): + def searchsorted(self, a, v, side="left"): if a.ndim == 1: return np.searchsorted(a, v, side) else: @@ -1286,7 +1290,7 @@ def tocsr(self, a): else: return csr_matrix(a) - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): if threshold > 0: if self.issparse(a): a.data[self.abs(a.data) <= threshold] = 0 @@ -1360,9 +1364,9 @@ def sqrtm(self, a): L, V = np.linalg.eigh(a) L = np.sqrt(L) # Q[...] = V[...] @ diag(L[...]) - Q = np.einsum('...jk,...k->...jk', V, L) + Q = np.einsum("...jk,...k->...jk", V, L) # R[...] = Q[...] @ V[...].T - return np.einsum('...jk,...kl->...jl', Q, np.swapaxes(V, -1, -2)) + return np.einsum("...jk,...kl->...jl", Q, np.swapaxes(V, -1, -2)) def eigh(self, a): return np.linalg.eigh(a) @@ -1441,7 +1445,7 @@ class JaxBackend(Backend): - `__type__` is jax.numpy.ndarray """ - __name__ = 'jax' + __name__ = "jax" __type__ = jax_type __type_list__ = None @@ -1458,7 +1462,7 @@ def __init__(self): for d in available_devices: self.__type_list__ += [ jax.device_put(jnp.array(1, dtype=jnp.float32), d), - jax.device_put(jnp.array(1, dtype=jnp.float64), d) + jax.device_put(jnp.array(1, dtype=jnp.float64), d), ] self.jax_new_version = jax_new_version @@ -1485,7 +1489,8 @@ def _from_numpy(self, a, type_as=None): def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree - val, = jax.lax.stop_gradient((val,)) + + (val,) = jax.lax.stop_gradient((val,)) ravelled_inputs, _ = ravel_pytree(inputs) ravelled_grads, _ = ravel_pytree(grads) @@ -1493,7 +1498,7 @@ def set_gradients(self, val, inputs, grads): aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 aux = aux - jax.lax.stop_gradient(aux) - val, = jax.tree_map(lambda z: z + aux, (val,)) + (val,) = jax.tree_map(lambda z: z + aux, (val,)) return val def _detach(self, a): @@ -1518,7 +1523,9 @@ def full(self, shape, fill_value, type_as=None): if type_as is None: return jnp.full(shape, fill_value) else: - return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) + return self._change_device( + jnp.full(shape, fill_value, dtype=type_as.dtype), type_as + ) def eye(self, N, M=None, type_as=None): if type_as is None: @@ -1586,7 +1593,7 @@ def sort(self, a, axis=-1): def argsort(self, a, axis=-1): return jnp.argsort(a, axis) - def searchsorted(self, a, v, side='left'): + def searchsorted(self, a, v, side="left"): if a.ndim == 1: return jnp.searchsorted(a, v, side) else: @@ -1632,7 +1639,9 @@ def linspace(self, start, stop, num, type_as=None): if type_as is None: return jnp.linspace(start, stop, num) else: - return self._change_device(jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as) + return self._change_device( + jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as + ) def meshgrid(self, a, b): return jnp.meshgrid(a, b) @@ -1688,14 +1697,10 @@ def tocsr(self, a): # Currently, JAX does not support sparse matrices return a - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): # Currently, JAX does not support sparse matrices if threshold > 0: - return self.where( - self.abs(a) <= threshold, - self.zeros((1,), type_as=a), - a - ) + return self.where(self.abs(a) <= threshold, self.zeros((1,), type_as=a), a) return a def todense(self, a): @@ -1726,7 +1731,9 @@ def assert_same_dtype_device(self, a, b): b_dtype, b_device = self.dtype_device(b) assert a_dtype == b_dtype, "Dtype discrepancy" - assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + assert ( + a_device == b_device + ), f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def squeeze(self, a, axis=None): return jnp.squeeze(a, axis=axis) @@ -1767,9 +1774,9 @@ def sqrtm(self, a): L, V = jnp.linalg.eigh(a) L = jnp.sqrt(L) # Q[...] = V[...] @ diag(L[...]) - Q = jnp.einsum('...jk,...k->...jk', V, L) + Q = jnp.einsum("...jk,...k->...jk", V, L) # R[...] = Q[...] @ V[...].T - return jnp.einsum('...jk,...kl->...jl', Q, jnp.swapaxes(V, -1, -2)) + return jnp.einsum("...jk,...kl->...jl", Q, jnp.swapaxes(V, -1, -2)) def eigh(self, a): return jnp.linalg.eigh(a) @@ -1833,25 +1840,30 @@ class TorchBackend(Backend): - `__type__` is torch.Tensor """ - __name__ = 'torch' + __name__ = "torch" __type__ = torch_type __type_list__ = None rng_ = None def __init__(self): - self.rng_ = torch.Generator("cpu") self.rng_.seed() - self.__type_list__ = [torch.tensor(1, dtype=torch.float32), - torch.tensor(1, dtype=torch.float64)] + self.__type_list__ = [ + torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64), + ] if torch.cuda.is_available(): self.rng_cuda_ = torch.Generator("cuda") self.rng_cuda_.seed() - self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) - self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + self.__type_list__.append( + torch.tensor(1, dtype=torch.float32, device="cuda") + ) + self.__type_list__.append( + torch.tensor(1, dtype=torch.float64, device="cuda") + ) else: self.rng_cuda_ = torch.Generator("cpu") @@ -1860,7 +1872,6 @@ def __init__(self): # define a function that takes inputs val and grads # ad returns a val tensor with proper gradients class ValFunction(Function): - @staticmethod def forward(ctx, val, grads, *inputs): ctx.grads = grads @@ -1879,7 +1890,12 @@ def _to_numpy(self, a): return a.cpu().detach().numpy() def _from_numpy(self, a, type_as=None): - if isinstance(a, float) or isinstance(a, int): + if ( + isinstance(a, float) + or isinstance(a, int) + or isinstance(a, np.float32) + or isinstance(a, np.float64) + ): a = np.array(a) if type_as is None: return torch.from_numpy(a) @@ -1887,7 +1903,6 @@ def _from_numpy(self, a, type_as=None): return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) def set_gradients(self, val, inputs, grads): - Func = self.ValFunction res = Func.apply(val, grads, *inputs) @@ -1925,7 +1940,9 @@ def full(self, shape, fill_value, type_as=None): if type_as is None: return torch.full(shape, fill_value) else: - return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + return torch.full( + shape, fill_value, dtype=type_as.dtype, device=type_as.device + ) def eye(self, N, M=None, type_as=None): if M is None: @@ -2023,8 +2040,8 @@ def argsort(self, a, axis=-1): sorted, indices = torch.sort(a, dim=axis) return indices - def searchsorted(self, a, v, side='left'): - right = (side != 'left') + def searchsorted(self, a, v, side="left"): + right = side != "left" return torch.searchsorted(a, v, right=right) def flip(self, a, axis=None): @@ -2082,8 +2099,10 @@ def median(self, a, axis=None): return torch.quantile(a, 0.5, interpolation="midpoint") # Else, use numpy - warnings.warn("The median is being computed using numpy and the array has been detached " - "in the Pytorch backend.") + warnings.warn( + "The median is being computed using numpy and the array has been detached " + "in the Pytorch backend." + ) a_ = self.to_numpy(a) a_median = np.median(a_, axis=axis) return self.from_numpy(a_median, type_as=a) @@ -2098,7 +2117,9 @@ def linspace(self, start, stop, num, type_as=None): if type_as is None: return torch.linspace(start, stop, num) else: - return torch.linspace(start, stop, num, dtype=type_as.dtype, device=type_as.device) + return torch.linspace( + start, stop, num, dtype=type_as.dtype, device=type_as.device + ) def meshgrid(self, a, b): try: @@ -2139,15 +2160,29 @@ def seed(self, seed=None): def rand(self, *size, type_as=None): if type_as is not None: - generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ - return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) + generator = ( + self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + ) + return torch.rand( + size=size, + generator=generator, + dtype=type_as.dtype, + device=type_as.device, + ) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ - return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) + generator = ( + self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + ) + return torch.randn( + size=size, + dtype=type_as.dtype, + generator=generator, + device=type_as.device, + ) else: return torch.randn(size=size, generator=self.rng_) @@ -2156,8 +2191,11 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) else: return torch.sparse_coo_tensor( - torch.stack([rows, cols]), data, size=shape, - dtype=type_as.dtype, device=type_as.device + torch.stack([rows, cols]), + data, + size=shape, + dtype=type_as.dtype, + device=type_as.device, ) def issparse(self, a): @@ -2167,7 +2205,7 @@ def tocsr(self, a): # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support return self.todense(a) - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): if self.issparse(a): if threshold > 0: mask = self.abs(a) <= threshold @@ -2209,7 +2247,9 @@ def assert_same_dtype_device(self, a, b): b_dtype, b_device = self.dtype_device(b) assert a_dtype == b_dtype, "Dtype discrepancy" - assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + assert ( + a_device == b_device + ), f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def squeeze(self, a, axis=None): if axis is None: @@ -2241,7 +2281,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): if self.device_type(type_as) == "GPU": # pragma: no cover end.record() torch.cuda.synchronize() - duration = start.elapsed_time(end) / 1000. + duration = start.elapsed_time(end) / 1000.0 else: end = time.perf_counter() duration = end - start @@ -2264,10 +2304,9 @@ def sqrtm(self, a): L, V = torch.linalg.eigh(a) L = torch.sqrt(L) # Q[...] = V[...] @ diag(L[...]) - Q = torch.einsum('...jk,...k->...jk', V, L) + Q = torch.einsum("...jk,...k->...jk", V, L) # R[...] = Q[...] @ V[...].T - return torch.einsum('...jk,...kl->...jl', Q, - torch.transpose(V, -1, -2)) + return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2)) def eigh(self, a): return torch.linalg.eigh(a) @@ -2334,7 +2373,7 @@ class CupyBackend(Backend): # pragma: no cover - `__type__` is cp.ndarray """ - __name__ = 'cupy' + __name__ = "cupy" __type__ = cp_type __type_list__ = None @@ -2345,7 +2384,7 @@ def __init__(self): self.__type_list__ = [ cp.array(1, dtype=cp.float32), - cp.array(1, dtype=cp.float64) + cp.array(1, dtype=cp.float64), ] def _to_numpy(self, a): @@ -2464,7 +2503,7 @@ def sort(self, a, axis=-1): def argsort(self, a, axis=-1): return cp.argsort(a, axis) - def searchsorted(self, a, v, side='left'): + def searchsorted(self, a, v, side="left"): if a.ndim == 1: return cp.searchsorted(a, v, side) else: @@ -2573,9 +2612,7 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): rows = self.from_numpy(rows) cols = self.from_numpy(cols) if type_as is None: - return cupyx.scipy.sparse.coo_matrix( - (data, (rows, cols)), shape=shape - ) + return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape) else: with cp.cuda.Device(type_as.device): return cupyx.scipy.sparse.coo_matrix( @@ -2591,7 +2628,7 @@ def tocsr(self, a): else: return cupyx.scipy.sparse.csr_matrix(a) - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): if threshold > 0: if self.issparse(a): a.data[self.abs(a.data) <= threshold] = 0 @@ -2628,7 +2665,9 @@ def assert_same_dtype_device(self, a, b): # cupy has implicit type conversion so # we automatically validate the test for type - assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + assert ( + a_device == b_device + ), f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def squeeze(self, a, axis=None): return cp.squeeze(a, axis=axis) @@ -2657,7 +2696,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): end_gpu.record() end_gpu.synchronize() key = ("Cupy", self.device_type(type_as), self.bitsize(type_as)) - t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000. + t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000.0 results[key] = t_gpu / n_runs mempool.free_all_blocks() pinned_mempool.free_all_blocks() @@ -2676,10 +2715,9 @@ def sqrtm(self, a): L, V = cp.linalg.eigh(a) L = cp.sqrt(L) # Q[...] = V[...] @ diag(L[...]) - Q = cp.einsum('...jk,...k->...jk', V, L) + Q = cp.einsum("...jk,...k->...jk", V, L) # R[...] = Q[...] @ V[...].T - return cp.einsum('...jk,...kl->...jl', Q, - cp.swapaxes(V, -1, -2)) + return cp.einsum("...jk,...kl->...jl", Q, cp.swapaxes(V, -1, -2)) def eigh(self, a): return cp.linalg.eigh(a) @@ -2736,7 +2774,6 @@ def det(self, x): class TensorflowBackend(Backend): - __name__ = "tf" __type__ = tf_type __type_list__ = None @@ -2748,7 +2785,7 @@ def __init__(self): self.__type_list__ = [ tf.convert_to_tensor([1], dtype=tf.float32), - tf.convert_to_tensor([1], dtype=tf.float64) + tf.convert_to_tensor([1], dtype=tf.float64), ] tmp = self.randn(15, 10) @@ -2760,7 +2797,7 @@ def __init__(self): "numpy API. You can activate it by running: \n" "from tensorflow.python.ops.numpy_ops import np_config\n" "np_config.enable_numpy_behavior()", - stacklevel=2 + stacklevel=2, ) def _to_numpy(self, a): @@ -2787,7 +2824,9 @@ def set_gradients(self, val, inputs, grads): def tmp(input): def grad(upstream): return grads + return val, grad + return tmp(inputs) def _detach(self, a): @@ -2891,7 +2930,7 @@ def sort(self, a, axis=-1): def argsort(self, a, axis=-1): return tnp.argsort(a, axis) - def searchsorted(self, a, v, side='left'): + def searchsorted(self, a, v, side="left"): return tf.searchsorted(a, v, side=side) def flip(self, a, axis=None): @@ -2925,8 +2964,10 @@ def mean(self, a, axis=None): return tnp.mean(a, axis=axis) def median(self, a, axis=None): - warnings.warn("The median is being computed using numpy and the array has been detached " - "in the Tensorflow backend.") + warnings.warn( + "The median is being computed using numpy and the array has been detached " + "in the Tensorflow backend." + ) a_ = self.to_numpy(a) a_median = np.median(a_, axis=axis) return self.from_numpy(a_median, type_as=a) @@ -2977,11 +3018,9 @@ def seed(self, seed=None): def rand(self, *size, type_as=None): if type_as is None: - return self.rng_.uniform(size, minval=0., maxval=1.) + return self.rng_.uniform(size, minval=0.0, maxval=1.0) else: - return self.rng_.uniform( - size, minval=0., maxval=1., dtype=type_as.dtype - ) + return self.rng_.uniform(size, minval=0.0, maxval=1.0, dtype=type_as.dtype) def randn(self, *size, type_as=None): if type_as is None: @@ -2999,15 +3038,13 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if shape is None: shape = ( self._convert_to_index_for_coo(rows), - self._convert_to_index_for_coo(cols) + self._convert_to_index_for_coo(cols), ) if type_as is not None: data = self.from_numpy(data, type_as=type_as) sparse_tensor = tf.sparse.SparseTensor( - indices=tnp.stack([rows, cols]).T, - values=data, - dense_shape=shape + indices=tnp.stack([rows, cols]).T, values=data, dense_shape=shape ) # if type_as is not None: # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as) @@ -3020,7 +3057,7 @@ def issparse(self, a): def tocsr(self, a): return a - def eliminate_zeros(self, a, threshold=0.): + def eliminate_zeros(self, a, threshold=0.0): if self.issparse(a): values = a.values if threshold > 0: @@ -3030,7 +3067,7 @@ def eliminate_zeros(self, a, threshold=0.): return tf.sparse.retain(a, ~mask) else: if threshold > 0: - a = tnp.where(self.abs(a) > threshold, a, 0.) + a = tnp.where(self.abs(a) > threshold, a, 0.0) return a def todense(self, a): @@ -3049,9 +3086,7 @@ def copy(self, a): return tf.identity(a) def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - return tnp.allclose( - a, b, rtol=rtol, atol=atol, equal_nan=equal_nan - ) + return tnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def dtype_device(self, a): return a.dtype, a.device.split("device:")[1] @@ -3061,7 +3096,9 @@ def assert_same_dtype_device(self, a, b): b_dtype, b_device = self.dtype_device(b) assert a_dtype == b_dtype, "Dtype discrepancy" - assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + assert ( + a_device == b_device + ), f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" def squeeze(self, a, axis=None): return tnp.squeeze(a, axis=axis) @@ -3075,7 +3112,7 @@ def device_type(self, type_as): def _bench(self, callable, *args, n_runs=1, warmup_runs=1): results = dict() device_contexts = [tf.device("/CPU:0")] - if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover + if len(tf.config.list_physical_devices("GPU")) > 0: # pragma: no cover device_contexts.append(tf.device("/GPU:0")) for device_context in device_contexts: @@ -3092,7 +3129,7 @@ def _bench(self, callable, *args, n_runs=1, warmup_runs=1): key = ( "Tensorflow", self.device_type(inputs[0]), - self.bitsize(type_as) + self.bitsize(type_as), ) results[key] = (t1 - t0) / n_runs @@ -3111,10 +3148,11 @@ def sqrtm(self, a): L, V = tf.linalg.eigh(a) L = tf.sqrt(L) # Q[...] = V[...] @ diag(L[...]) - Q = tf.einsum('...jk,...k->...jk', V, L) + Q = tf.einsum("...jk,...k->...jk", V, L) # R[...] = Q[...] @ V[...].T - return tf.einsum('...jk,...kl->...jl', Q, - tf.linalg.matrix_transpose(V, (0, 2, 1))) + return tf.einsum( + "...jk,...kl->...jl", Q, tf.linalg.matrix_transpose(V, (0, 2, 1)) + ) def eigh(self, a): return tf.linalg.eigh(a) diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py index 0bcb4214d..54d7eca27 100644 --- a/ot/bregman/__init__.py +++ b/ot/bregman/__init__.py @@ -8,48 +8,70 @@ # # License: MIT License -from ._utils import (geometricBar, - geometricMean, - projR, - projC) - -from ._sinkhorn import (sinkhorn, - sinkhorn2, - sinkhorn_knopp, - sinkhorn_log, - greenkhorn, - sinkhorn_stabilized, - sinkhorn_epsilon_scaling) - -from ._barycenter import (barycenter, - barycenter_sinkhorn, - free_support_sinkhorn_barycenter, - barycenter_stabilized, - barycenter_debiased, - jcpot_barycenter) - -from ._convolutional import (convolutional_barycenter2d, - convolutional_barycenter2d_debiased) - -from ._empirical import (empirical_sinkhorn, - empirical_sinkhorn2, - empirical_sinkhorn_divergence) - -from ._screenkhorn import (screenkhorn) - -from ._dictionary import (unmix) - -from ._geomloss import (empirical_sinkhorn2_geomloss, geomloss) - - -__all__ = ['geometricBar', 'geometricMean', 'projR', 'projC', - 'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log', - 'greenkhorn', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', - 'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter', - 'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter', - 'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased', - 'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn2_geomloss' - 'empirical_sinkhorn_divergence', 'geomloss', - 'screenkhorn', - 'unmix' - ] +from ._utils import geometricBar, geometricMean, projR, projC + +from ._sinkhorn import ( + sinkhorn, + sinkhorn2, + sinkhorn_knopp, + sinkhorn_log, + greenkhorn, + sinkhorn_stabilized, + sinkhorn_epsilon_scaling, +) + +from ._barycenter import ( + barycenter, + barycenter_sinkhorn, + free_support_sinkhorn_barycenter, + barycenter_stabilized, + barycenter_debiased, + jcpot_barycenter, +) + +from ._convolutional import ( + convolutional_barycenter2d, + convolutional_barycenter2d_debiased, +) + +from ._empirical import ( + empirical_sinkhorn, + empirical_sinkhorn2, + empirical_sinkhorn_divergence, +) + +from ._screenkhorn import screenkhorn + +from ._dictionary import unmix + +from ._geomloss import empirical_sinkhorn2_geomloss, geomloss + + +__all__ = [ + "geometricBar", + "geometricMean", + "projR", + "projC", + "sinkhorn", + "sinkhorn2", + "sinkhorn_knopp", + "sinkhorn_log", + "greenkhorn", + "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "barycenter", + "barycenter_sinkhorn", + "free_support_sinkhorn_barycenter", + "barycenter_stabilized", + "barycenter_debiased", + "jcpot_barycenter", + "convolutional_barycenter2d", + "convolutional_barycenter2d_debiased", + "empirical_sinkhorn", + "empirical_sinkhorn2", + "empirical_sinkhorn2_geomloss", + "empirical_sinkhorn_divergence", + "geomloss", + "screenkhorn", + "unmix", +] diff --git a/ot/bregman/_barycenter.py b/ot/bregman/_barycenter.py index 5d90782fb..77f20f87d 100644 --- a/ot/bregman/_barycenter.py +++ b/ot/bregman/_barycenter.py @@ -20,8 +20,19 @@ from ._sinkhorn import sinkhorn -def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): +def barycenter( + A, + M, + reg, + weights=None, + method="sinkhorn", + numItermax=10000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, + **kwargs, +): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -53,7 +64,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -84,28 +95,60 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, """ - if method.lower() == 'sinkhorn': - return barycenter_sinkhorn(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return barycenter_stabilized(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) - elif method.lower() == 'sinkhorn_log': - return _barycenter_sinkhorn_log(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + if method.lower() == "sinkhorn": + return barycenter_sinkhorn( + A, + M, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_stabilized": + return barycenter_stabilized( + A, + M, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return _barycenter_sinkhorn_log( + A, + M, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): +def barycenter_sinkhorn( + A, + M, + reg, + weights=None, + numItermax=1000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, +): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -134,7 +177,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, reg : float Regularization term > 0 weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -172,10 +215,10 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} K = nx.exp(-M / reg) @@ -186,7 +229,6 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, u = (geometricMean(UKv) / UKv.T).T for ii in range(numItermax): - UKv = u * nx.dot(K.T, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv @@ -195,30 +237,42 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) -def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, - numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, - **kwargs): +def free_support_sinkhorn_barycenter( + measures_locations, + measures_weights, + X_init, + reg, + b=None, + weights=None, + numItermax=100, + numInnerItermax=1000, + stopThr=1e-7, + verbose=False, + log=None, + **kwargs, +): r""" Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: @@ -307,18 +361,21 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini log_dict = {} displacement_square_norms = [] - displacement_square_norm = stopThr + 1. - - while (displacement_square_norm > stopThr and iter_count < numItermax): + displacement_square_norm = stopThr + 1.0 + while displacement_square_norm > stopThr and iter_count < numItermax: T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for measure_locations_i, measure_weights_i, weight_i in zip( + measures_locations, measures_weights, weights + ): M_i = dist(X, measure_locations_i) - T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, - numItermax=numInnerItermax, **kwargs) - T_sum = T_sum + weight_i * 1. / \ - b[:, None] * nx.dot(T_i, measure_locations_i) + T_i = sinkhorn( + b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs + ) + T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( + T_i, measure_locations_i + ) displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: @@ -327,22 +384,33 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini X = T_sum if verbose: - print('iteration %d, displacement_square_norm=%f\n', - iter_count, displacement_square_norm) + print( + "iteration %d, displacement_square_norm=%f\n", + iter_count, + displacement_square_norm, + ) iter_count += 1 if log: - log_dict['displacement_square_norms'] = displacement_square_norms + log_dict["displacement_square_norms"] = displacement_square_norms return X, log_dict else: return X -def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the entropic wasserstein barycenter in log-domain - """ +def _barycenter_sinkhorn_log( + A, + M, + reg, + weights=None, + numItermax=1000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, +): + r"""Compute the entropic wasserstein barycenter in log-domain""" A, M = list_to_array(A, M) dim, n_hists = A.shape @@ -358,12 +426,12 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} - M = - M / reg + M = -M / reg logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) err = 1 @@ -379,32 +447,43 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) G = log_bar[:, None] - log_KU else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar) -def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): +def barycenter_stabilized( + A, + M, + reg, + tau=1e10, + weights=None, + numItermax=1000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, +): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -436,7 +515,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -475,17 +554,17 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if weights is None: weights = nx.ones((n_hists,), type_as=M) / n_hists else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} u = nx.ones((dim, n_hists), type_as=M) / dim v = nx.ones((dim, n_hists), type_as=M) / dim K = nx.exp(-M / reg) - err = 1. + err = 1.0 alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim @@ -505,12 +584,16 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) v = nx.ones(tuple(v.shape), type_as=v) Kv = nx.dot(K, v) - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % ii) + warnings.warn("Numerical errors at iteration %s" % ii) q = qprev break if (ii % 10 == 0 and not absorbing) or ii == 0: @@ -518,31 +601,43 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr: break if verbose: if ii % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) else: if warn: - warnings.warn("Stabilized Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + warnings.warn( + "Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`." + ) if log: - log['niter'] = ii - log['logu'] = nx.log(u + 1e-16) - log['logv'] = nx.log(v + 1e-16) + log["niter"] = ii + log["logu"] = nx.log(u + 1e-16) + log["logv"] = nx.log(v + 1e-16) return q, log else: return q -def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): +def barycenter_debiased( + A, + M, + reg, + weights=None, + method="sinkhorn", + numItermax=10000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, + **kwargs, +): r"""Compute the debiased Sinkhorn barycenter of distributions A The function solves the following optimization problem: @@ -573,7 +668,7 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_log' weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional @@ -601,24 +696,48 @@ def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=1 Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ - if method.lower() == 'sinkhorn': - return _barycenter_debiased(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) - elif method.lower() == 'sinkhorn_log': - return _barycenter_debiased_log(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) + if method.lower() == "sinkhorn": + return _barycenter_debiased( + A, + M, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return _barycenter_debiased_log( + A, + M, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the debiased sinkhorn barycenter of distributions A. - """ +def _barycenter_debiased( + A, + M, + reg, + weights=None, + numItermax=1000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, +): + r"""Compute the debiased sinkhorn barycenter of distributions A.""" A, M = list_to_array(A, M) @@ -627,10 +746,10 @@ def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} K = nx.exp(-M / reg) @@ -650,11 +769,11 @@ def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, c = (c * bar / nx.dot(K, c)) ** 0.5 if ii % 10 == 9: - err = abs(bar - bold).max() / max(bar.max(), 1.) + err = abs(bar - bold).max() / max(bar.max(), 1.0) # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping @@ -662,26 +781,34 @@ def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, break if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return bar, log else: return bar -def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, - warn=True): - r"""Compute the debiased sinkhorn barycenter in log domain. - """ +def _barycenter_debiased_log( + A, + M, + reg, + weights=None, + numItermax=1000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, +): + r"""Compute the debiased sinkhorn barycenter in log domain.""" A, M = list_to_array(A, M) dim, n_hists = A.shape @@ -696,12 +823,12 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} - M = - M / reg + M = -M / reg logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) c = nx.zeros(dim, type_as=A) @@ -718,15 +845,14 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr and ii > 20: break if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) G = log_bar[:, None] - log_KU for _ in range(10): @@ -734,19 +860,32 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar) -def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as +def jcpot_barycenter( + Xs, + Ys, + Xt, + reg, + metric="sqeuclidean", + numItermax=100, + stopThr=1e-6, + verbose=False, + log=False, + warn=True, + **kwargs, +): + r"""Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -817,7 +956,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - ''' + """ Xs = list_to_array(*Xs) Ys = list_to_array(*Ys) @@ -830,7 +969,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # log dictionary if log: - log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} + log = {"niter": 0, "err": [], "M": [], "D1": [], "D2": [], "gamma": []} K = [] M = [] @@ -841,7 +980,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, for d in range(nbdomains): dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain - dom['nbelem'] = nsk + dom["nbelem"] = nsk classes = nx.unique(Ys[d]) # get number of classes for this domain # format classes to start from 0 for convenience @@ -856,8 +995,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, for c in classes: nbelemperclass = float(nx.sum(Ys[d] == c)) if nbelemperclass != 0: - Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. - Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) + Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.0 + Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1.0 / (nbelemperclass) D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) @@ -875,7 +1014,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, old_bary = nx.ones((nbclasses,), type_as=Xs[0]) for ii in range(numItermax): - bary = nx.zeros((nbclasses,), type_as=Xs[0]) # update coupling matrices for marginal constraints w.r.t. uniform target distribution @@ -896,27 +1034,29 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, old_bary = bary if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) else: if warn: - warnings.warn("Algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) bary = bary / nx.sum(bary) if log: - log['niter'] = ii - log['M'] = M - log['D1'] = D1 - log['D2'] = D2 - log['gamma'] = K + log["niter"] = ii + log["M"] = M + log["D1"] = D1 + log["D2"] = D2 + log["gamma"] = K return bary, log else: return bary diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py index baea2aec3..0e6548710 100644 --- a/ot/bregman/_convolutional.py +++ b/ot/bregman/_convolutional.py @@ -14,9 +14,18 @@ from ..backend import get_backend -def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, - warn=True, **kwargs): +def convolutional_barycenter2d( + A, + reg, + weights=None, + method="sinkhorn", + numItermax=10000, + stopThr=1e-4, + verbose=False, + log=False, + warn=True, + **kwargs, +): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` where :math:`\mathbf{A}` is a collection of 2D images. @@ -43,7 +52,7 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm reg : float Regularization term >0 weights : array-like, shape (n_hists,) - Weights of each image on the simplex (barycentric coodinates) + Weights of each image on the simplex (barycentric coordinates) method : string, optional method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional @@ -80,25 +89,45 @@ def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numIterm International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ - if method.lower() == 'sinkhorn': - return _convolutional_barycenter2d(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return _convolutional_barycenter2d_log(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) + if method.lower() == "sinkhorn": + return _convolutional_barycenter2d( + A, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return _convolutional_barycenter2d_log( + A, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False, warn=True): +def _convolutional_barycenter2d( + A, + reg, + weights=None, + numItermax=10000, + stopThr=1e-9, + stabThr=1e-30, + verbose=False, + log=False, + warn=True, +): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. """ @@ -110,10 +139,10 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if weights is None: weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] else: - assert (len(weights) == A.shape[0]) + assert len(weights) == A.shape[0] if log: - log = {'err': []} + log = {"err": []} bar = nx.ones(A.shape[1:], type_as=A) bar /= nx.sum(bar) @@ -125,11 +154,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, A.shape[1], type_as=A) [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-((X - Y) ** 2) / reg) t = nx.linspace(0, 1, A.shape[2], type_as=A) [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-(X - Y) ** 2 / reg) + K2 = nx.exp(-((X - Y) ** 2) / reg) def convol_imgs(imgs): kx = nx.einsum("...ij,kjl->kil", K1, imgs) @@ -142,39 +171,46 @@ def convol_imgs(imgs): KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = nx.exp( - nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) - ) + bar = nx.exp(nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)) if ii % 10 == 9: err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err < stopThr: break else: if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + warnings.warn( + "Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`." + ) if log: - log['niter'] = ii - log['U'] = U + log["niter"] = ii + log["U"] = U return bar, log else: return bar -def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-30, verbose=False, - log=False, warn=True): +def _convolutional_barycenter2d_log( + A, + reg, + weights=None, + numItermax=10000, + stopThr=1e-4, + stabThr=1e-30, + verbose=False, + log=False, + warn=True, +): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images in log-domain. """ @@ -193,21 +229,21 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert (len(weights) == n_hists) + assert len(weights) == n_hists if log: - log = {'err': []} + log = {"err": []} err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) - M1 = - (X - Y) ** 2 / reg + M1 = -((X - Y) ** 2) / reg t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) - M2 = - (X - Y) ** 2 / reg + M2 = -((X - Y) ** 2) / reg def convol_img(log_img): log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) @@ -228,33 +264,42 @@ def convol_img(log_img): err = nx.exp(G + log_KU).std(axis=0).sum() # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err < stopThr: break G = log_bar[None, :, :] - log_KU else: if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + warnings.warn( + "Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar) -def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", - numItermax=10000, stopThr=1e-3, - verbose=False, log=False, warn=True, - **kwargs): +def convolutional_barycenter2d_debiased( + A, + reg, + weights=None, + method="sinkhorn", + numItermax=10000, + stopThr=1e-3, + verbose=False, + log=False, + warn=True, + **kwargs, +): r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` where :math:`\mathbf{A}` is a collection of 2D images. @@ -281,7 +326,7 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", reg : float Regularization term >0 weights : array-like, shape (n_hists,) - Weights of each image on the simplex (barycentric coodinates) + Weights of each image on the simplex (barycentric coordinates) method : string, optional method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional @@ -314,27 +359,46 @@ def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ - if method.lower() == 'sinkhorn': - return _convolutional_barycenter2d_debiased(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) + if method.lower() == "sinkhorn": + return _convolutional_barycenter2d_debiased( + A, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return _convolutional_barycenter2d_debiased_log( + A, + reg, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, - stopThr=1e-3, stabThr=1e-15, verbose=False, - log=False, warn=True): - r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. - """ +def _convolutional_barycenter2d_debiased( + A, + reg, + weights=None, + numItermax=10000, + stopThr=1e-3, + stabThr=1e-15, + verbose=False, + log=False, + warn=True, +): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.""" A = list_to_array(A) n_hists, width, height = A.shape @@ -344,10 +408,10 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert (len(weights) == n_hists) + assert len(weights) == n_hists if log: - log = {'err': []} + log = {"err": []} bar = nx.ones((width, height), type_as=A) bar /= width * height @@ -360,11 +424,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-((X - Y) ** 2) / reg) t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-(X - Y) ** 2 / reg) + K2 = nx.exp(-((X - Y) ** 2) / reg) def convol_imgs(imgs): kx = nx.einsum("...ij,kjl->kil", K1, imgs) @@ -377,9 +441,7 @@ def convol_imgs(imgs): KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = c * nx.exp( - nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) - ) + bar = c * nx.exp(nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)) for _ in range(10): c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 @@ -388,13 +450,12 @@ def convol_imgs(imgs): err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping @@ -402,22 +463,31 @@ def convol_imgs(imgs): break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii - log['U'] = U + log["niter"] = ii + log["U"] = U return bar, log else: return bar -def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, - stopThr=1e-3, stabThr=1e-30, verbose=False, - log=False, warn=True): - r"""Compute the debiased barycenter of 2D images in log-domain. - """ +def _convolutional_barycenter2d_debiased_log( + A, + reg, + weights=None, + numItermax=10000, + stopThr=1e-3, + stabThr=1e-30, + verbose=False, + log=False, + warn=True, +): + r"""Compute the debiased barycenter of 2D images in log-domain.""" A = list_to_array(A) n_hists, width, height = A.shape @@ -430,21 +500,21 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: - assert (len(weights) == A.shape[0]) + assert len(weights) == A.shape[0] if log: - log = {'err': []} + log = {"err": []} err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, width, type_as=A) [Y, X] = nx.meshgrid(t, t) - M1 = - (X - Y) ** 2 / reg + M1 = -((X - Y) ** 2) / reg t = nx.linspace(0, 1, height, type_as=A) [Y, X] = nx.meshgrid(t, t) - M2 = - (X - Y) ** 2 / reg + M2 = -((X - Y) ** 2) / reg def convol_img(log_img): log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) @@ -469,24 +539,25 @@ def convol_img(log_img): err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) # log and verbose print if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err < stopThr and ii > 20: break G = log_bar[None, :, :] - log_KU else: if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") + warnings.warn( + "Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar) diff --git a/ot/bregman/_dictionary.py b/ot/bregman/_dictionary.py index 80f00f762..bb7047b3c 100644 --- a/ot/bregman/_dictionary.py +++ b/ot/bregman/_dictionary.py @@ -17,8 +17,21 @@ from ._utils import projC, projR -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False, warn=True): +def unmix( + a, + D, + M, + M0, + h0, + reg, + reg0, + alpha, + numItermax=1000, + stopThr=1e-3, + verbose=False, + log=False, + warn=True, +): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -111,13 +124,13 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, err = 1 # log = {'niter':0, 'all_err':[]} if log: - log = {'err': []} + log = {"err": []} for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) new = nx.sum(K0, axis=1) - # we recombine the current selection from dictionnary + # we recombine the current selection from dictionary inv_new = nx.dot(D, new) other = nx.sum(K, axis=1) # geometric interpolation @@ -127,21 +140,23 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err < stopThr: break else: if warn: - warnings.warn("Unmixing algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii + log["niter"] = ii return nx.sum(K0, axis=1), log else: return nx.sum(K0, axis=1) diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index b84c3b389..055a07ef3 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -17,8 +17,8 @@ from ._sinkhorn import sinkhorn, sinkhorn2 -def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None): - r""" Get a LazyTensor of Sinkhorn solution from the dual potentials +def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric="sqeuclidean", reg=1e-1, nx=None): + r"""Get a LazyTensor of Sinkhorn solution from the dual potentials The returned LazyTensor is :math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`. @@ -61,10 +61,24 @@ def func(i, j, X_a, X_b, f, g, metric, reg): return T -def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): - r''' +def empirical_sinkhorn( + X_s, + X_t, + reg, + a=None, + b=None, + metric="sqeuclidean", + numIterMax=10000, + stopThr=1e-9, + isLazy=False, + batchSize=100, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): + r""" Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -154,7 +168,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - ''' + """ X_s, X_t = list_to_array(X_s, X_t) @@ -181,8 +195,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', elif isinstance(batchSize, tuple) and len(batchSize) == 2: bs, bt = batchSize[0], batchSize[1] else: - raise ValueError( - "Batch size must be in integer or a tuple of two integers") + raise ValueError("Batch size must be in integer or a tuple of two integers") range_s, range_t = range(0, ns, bs), range(0, nt, bt) @@ -193,35 +206,31 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', X_t_np = nx.to_numpy(X_t) for i_ot in range(numIterMax): - lse_f_cols = [] for i in range_s: - M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = dist(X_s_np[i : i + bs, :], X_t_np, metric=metric) M = nx.from_numpy(M, type_as=a) - lse_f_cols.append( - nx.logsumexp(g[None, :] - M / reg, axis=1) - ) + lse_f_cols.append(nx.logsumexp(g[None, :] - M / reg, axis=1)) lse_f = nx.concatenate(lse_f_cols, axis=0) f = log_a - lse_f lse_g_cols = [] for j in range_t: - M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) + M = dist(X_s_np, X_t_np[j : j + bt, :], metric=metric) M = nx.from_numpy(M, type_as=a) - lse_g_cols.append( - nx.logsumexp(f[:, None] - M / reg, axis=0) - ) + lse_g_cols.append(nx.logsumexp(f[:, None] - M / reg, axis=0)) lse_g = nx.concatenate(lse_g_cols, axis=0) g = log_b - lse_g if (i_ot + 1) % 10 == 0: m1_cols = [] for i in range_s: - M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = dist(X_s_np[i : i + bs, :], X_t_np, metric=metric) M = nx.from_numpy(M, type_as=a) m1_cols.append( - nx.sum(nx.exp(f[i:i + bs, None] + - g[None, :] - M / reg), axis=1) + nx.sum( + nx.exp(f[i : i + bs, None] + g[None, :] - M / reg), axis=1 + ) ) m1 = nx.concatenate(m1_cols, axis=0) err = nx.sum(nx.abs(m1 - a)) @@ -229,16 +238,19 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', dict_log["err"].append(err) if verbose and (i_ot + 1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format( - i_ot + 1, err)) + print( + "Error in marginal at iteration {} = {}".format(i_ot + 1, err) + ) if err <= stopThr: break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: dict_log["u"] = f dict_log["v"] = g @@ -251,19 +263,53 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=True, warmstart=warmstart, **kwargs) + pi, log = sinkhorn( + a, + b, + M, + reg, + numItermax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=True, + warmstart=warmstart, + **kwargs, + ) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=False, warmstart=warmstart, **kwargs) + pi = sinkhorn( + a, + b, + M, + reg, + numItermax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=False, + warmstart=warmstart, + **kwargs, + ) return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): - r''' +def empirical_sinkhorn2( + X_s, + X_t, + reg, + a=None, + b=None, + metric="sqeuclidean", + numIterMax=10000, + stopThr=1e-9, + isLazy=False, + batchSize=100, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): + r""" Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -359,7 +405,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - ''' + """ X_s, X_t = list_to_array(X_s, X_t) @@ -373,22 +419,39 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if isLazy: if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - stopThr=stopThr, - isLazy=isLazy, - batchSize=batchSize, - verbose=verbose, log=log, - warn=warn, - warmstart=warmstart) + f, g, dict_log = empirical_sinkhorn( + X_s, + X_t, + reg, + a, + b, + metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + ) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, - verbose=verbose, log=log, - warn=warn, - warmstart=warmstart) + f, g = empirical_sinkhorn( + X_s, + X_t, + reg, + a, + b, + metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + ) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -399,9 +462,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', X_t_np = nx.to_numpy(X_t) for i in range_s: - M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M_block = dist(X_s_np[i : i + bs, :], X_t_np, metric=metric) M_block = nx.from_numpy(M_block, type_as=a) - pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + pi_block = nx.exp(f[i : i + bs, None] + g[None, :] - M_block / reg) loss += nx.sum(M_block * pi_block) if log: @@ -413,21 +476,53 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, **kwargs) + sinkhorn_loss, log = sinkhorn2( + a, + b, + M, + reg, + numItermax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, **kwargs) + sinkhorn_loss = sinkhorn2( + a, + b, + M, + reg, + numItermax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): - r''' +def empirical_sinkhorn_divergence( + X_s, + X_t, + reg, + a=None, + b=None, + metric="sqeuclidean", + numIterMax=10000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): + r""" Compute the sinkhorn divergence loss from empirical data The function solves the following optimization problems and return the @@ -533,7 +628,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics, (AISTATS) 21, 2018 - ''' + """ X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) @@ -545,50 +640,114 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli warmstart_b = (v, v) if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart, **kwargs) - - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_a, **kwargs) - - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_b, **kwargs) - - sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ - (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2( + X_s, + X_t, + reg, + a, + b, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2( + X_s, + X_s, + reg, + a, + a, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart_a, + **kwargs, + ) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2( + X_t, + X_t, + reg, + b, + b, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart_b, + **kwargs, + ) + + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) log = {} - log['sinkhorn_loss_ab'] = sinkhorn_loss_ab - log['sinkhorn_loss_a'] = sinkhorn_loss_a - log['sinkhorn_loss_b'] = sinkhorn_loss_b - log['log_sinkhorn_ab'] = log_ab - log['log_sinkhorn_a'] = log_a - log['log_sinkhorn_b'] = log_b + log["sinkhorn_loss_ab"] = sinkhorn_loss_ab + log["sinkhorn_loss_a"] = sinkhorn_loss_a + log["sinkhorn_loss_b"] = sinkhorn_loss_b + log["log_sinkhorn_ab"] = log_ab + log["log_sinkhorn_a"] = log_a + log["log_sinkhorn_b"] = log_b return nx.maximum(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart, **kwargs) - - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_a, **kwargs) - - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_b, **kwargs) - - sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ - (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_loss_ab = empirical_sinkhorn2( + X_s, + X_t, + reg, + a, + b, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + + sinkhorn_loss_a = empirical_sinkhorn2( + X_s, + X_s, + reg, + a, + a, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart_a, + **kwargs, + ) + + sinkhorn_loss_b = empirical_sinkhorn2( + X_t, + X_t, + reg, + b, + b, + metric=metric, + numIterMax=numIterMax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart_b, + **kwargs, + ) + + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return nx.maximum(0, sinkhorn_div) diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py index 0b1abd1b8..1df423db2 100644 --- a/ot/bregman/_geomloss.py +++ b/ot/bregman/_geomloss.py @@ -8,6 +8,7 @@ # License: MIT License import numpy as np + try: import geomloss from geomloss import SamplesLoss @@ -18,8 +19,10 @@ geomloss = False -def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', blur=0.1, nx=None): - """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) +def get_sinkhorn_geomloss_lazytensor( + X_a, X_b, f, g, a, b, metric="sqeuclidean", blur=0.1, nx=None +): + """Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) Parameters ---------- @@ -51,20 +54,35 @@ def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', shape = (X_a.shape[0], X_b.shape[0]) def func(i, j, X_a, X_b, f, g, a, b, metric, blur): - if metric == 'sqeuclidean': + if metric == "sqeuclidean": C = dist(X_a[i], X_b[j], metric=metric) / 2 else: C = dist(X_a[i], X_b[j], metric=metric) - return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j]) + return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * ( + a[i, None] * b[None, j] + ) - T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur) + T = LazyTensor( + shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur + ) return T -def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', scaling=0.95, - verbose=False, debias=False, log=False, backend='auto'): - r""" Solve the entropic regularization optimal transport problem with geomloss +def empirical_sinkhorn2_geomloss( + X_s, + X_t, + reg, + a=None, + b=None, + metric="sqeuclidean", + scaling=0.95, + verbose=False, + debias=False, + log=False, + backend="auto", +): + r"""Solve the entropic regularization optimal transport problem with geomloss The function solves the following optimization problem: @@ -103,7 +121,7 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid b : array-like, shape (n_samples_b,), default=None samples weights in the target domain metric : str, default='sqeuclidean' - Metric used for the cost matrix computation Only acepted values are + Metric used for the cost matrix computation Only accepted values are 'sqeuclidean' and 'euclidean'. scaling : float, default=0.95 Scaling parameter used for epsilon scaling. Value close to one promote @@ -142,18 +160,17 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid """ if geomloss: - nx = get_backend(X_s, X_t, a, b) - if nx.__name__ not in ['torch', 'numpy']: - raise ValueError('geomloss only support torch or numpy backend') + if nx.__name__ not in ["torch", "numpy"]: + raise ValueError("geomloss only support torch or numpy backend") if a is None: a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0] if b is None: b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0] - if nx.__name__ == 'numpy': + if nx.__name__ == "numpy": X_s_torch = torch.tensor(X_s) X_t_torch = torch.tensor(X_t) @@ -170,42 +187,54 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid # after that we are all in torch # set blur value and p - if metric == 'sqeuclidean': + if metric == "sqeuclidean": p = 2 blur = np.sqrt(reg / 2) # because geomloss divides cost by two - elif metric == 'euclidean': + elif metric == "euclidean": p = 1 blur = np.sqrt(reg) else: - raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') + raise ValueError("geomloss only supports sqeuclidean and euclidean metrics") # force gradients for computing dual a_torch.requires_grad = True b_torch.requires_grad = True - loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend=backend, debias=debias, scaling=scaling, verbose=verbose) + loss = SamplesLoss( + loss="sinkhorn", + p=p, + blur=blur, + backend=backend, + debias=debias, + scaling=scaling, + verbose=verbose, + ) # compute value - value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg? + value = loss( + a_torch, X_s_torch, b_torch, X_t_torch + ) # linear + entropic/KL reg? # get dual potentials f, g = grad(value, [a_torch, b_torch]) - if metric == 'sqeuclidean': + if metric == "sqeuclidean": value *= 2 # because geomloss divides cost by two - if nx.__name__ == 'numpy': + if nx.__name__ == "numpy": f = f.cpu().detach().numpy() g = g.cpu().detach().numpy() value = value.cpu().detach().numpy() if log: log = {} - log['f'] = f - log['g'] = g - log['value'] = value + log["f"] = f + log["g"] = g + log["value"] = value - log['lazy_plan'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx) + log["lazy_plan"] = get_sinkhorn_geomloss_lazytensor( + X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx + ) return value, log @@ -213,4 +242,4 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid return value else: - raise ImportError('geomloss not installed') + raise ImportError("geomloss not installed") diff --git a/ot/bregman/_screenkhorn.py b/ot/bregman/_screenkhorn.py index 8c53f73ae..ea00a03cc 100644 --- a/ot/bregman/_screenkhorn.py +++ b/ot/bregman/_screenkhorn.py @@ -17,9 +17,21 @@ from ..backend import get_backend -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, - restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, - verbose=False, log=False): +def screenkhorn( + a, + b, + M, + reg, + ns_budget=None, + nt_budget=None, + uniform=False, + restricted=True, + maxiter=10000, + maxfun=10000, + pgtol=1e-09, + verbose=False, + log=False, +): r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport @@ -79,7 +91,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, pgtol: `float`, default=1e-09 Final objective function accuracy in LBFGS solver verbose: `bool`, default=False - If `True`, display informations about the cardinals of the active sets + If `True`, display information about the cardinals of the active sets and the parameters kappa and epsilon @@ -119,15 +131,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, except ImportError: warnings.warn( "Bottleneck module is not installed. Install it from" - " https://pypi.org/project/Bottleneck/ for better performance.") + " https://pypi.org/project/Bottleneck/ for better performance." + ) bottleneck = np a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) if nx.__name__ in ("jax", "tf"): - raise TypeError("JAX or TF arrays have been received but screenkhorn is not " - "compatible with neither JAX nor TF.") + raise TypeError( + "JAX or TF arrays have been received but screenkhorn is not " + "compatible with neither JAX nor TF." + ) ns, nt = M.shape @@ -155,8 +170,8 @@ def projection(u, epsilon): epsilon = 0.0 kappa = 1.0 - cst_u = 0. - cst_v = 0. + cst_u = 0.0 + cst_v = 0.0 bounds_u = [(0.0, np.inf)] * ns bounds_v = [(0.0, np.inf)] * nt @@ -181,9 +196,10 @@ def projection(u, epsilon): epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy( - K_sum_cols), ns_budget - 1)[ns_budget - 1], - type_as=M + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ + ns_budget - 1 + ], + type_as=M, ) epsilon_u_square = a[0] / aK_sort @@ -192,9 +208,10 @@ def projection(u, epsilon): epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy( - K_sum_rows), nt_budget - 1)[nt_budget - 1], - type_as=M + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[ + nt_budget - 1 + ], + type_as=M, ) epsilon_v_square = b[0] / bK_sort else: @@ -215,7 +232,7 @@ def projection(u, epsilon): if uniform: aK = a / K_sum_cols aK_sort = nx.flip(nx.sort(aK), axis=0) - epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) + epsilon_u_square = nx.mean(aK_sort[ns_budget - 1 : ns_budget + 1]) Isel = a >= epsilon_u_square * K_sum_cols ns_budget = nx.sum(Isel) @@ -223,7 +240,7 @@ def projection(u, epsilon): if uniform: bK = b / K_sum_rows bK_sort = nx.flip(nx.sort(bK), axis=0) - epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) + epsilon_v_square = nx.mean(bK_sort[nt_budget - 1 : nt_budget + 1]) Jsel = b >= epsilon_v_square * K_sum_rows nt_budget = nx.sum(Jsel) @@ -233,8 +250,10 @@ def projection(u, epsilon): if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' - % (sum(Isel), sum(Jsel))) + print( + "Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n" + % (sum(Isel), sum(Jsel)) + ) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -263,26 +282,47 @@ def projection(u, epsilon): b_J_min = b_J[0] # box constraints in L-BFGS-B (see Proposition 1 in [26]) - bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget - - bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + bounds_u = [ + ( + max( + a_I_min + / ( + (nt - nt_budget) * epsilon + + nt_budget * (b_J_max / (ns * epsilon * kappa * K_min)) + ), + epsilon / kappa, + ), + a_I_max / (nt * epsilon * K_min), + ) + ] * ns_budget + + bounds_v = [ + ( + max( + b_J_min + / ( + (ns - ns_budget) * epsilon + + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min)) + ), + epsilon * kappa, + ), + b_J_max / (ns * epsilon * K_min), + ) + ] * nt_budget # pre-calculated constants for the objective - vec_eps_IJc = epsilon * kappa * nx.sum( - K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], - axis=1 + vec_eps_IJc = ( + epsilon + * kappa + * nx.sum(K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], axis=1) ) vec_eps_IcJ = (epsilon / kappa) * nx.sum( - nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, - axis=0 + nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, axis=0 ) # initialisation - u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) - v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) + u0 = nx.full((ns_budget,), 1.0 / ns_budget + epsilon / kappa, type_as=M) + v0 = nx.full((nt_budget,), 1.0 / nt_budget + epsilon * kappa, type_as=M) # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) if restricted: @@ -322,7 +362,7 @@ def screened_obj(usc, vsc): part_IJ = ( nx.dot(nx.dot(usc, K_IJ), vsc) - kappa * nx.dot(a_I, nx.log(usc)) - - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) + - (1.0 / kappa) * nx.dot(b_J, nx.log(vsc)) ) part_IJc = nx.dot(usc, vec_eps_IJc) part_IcJ = nx.dot(vec_eps_IcJ, vsc) @@ -332,7 +372,7 @@ def screened_obj(usc, vsc): def screened_grad(usc, vsc): # gradients of Psi_(kappa,epsilon) w.r.t u and v grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc - grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1.0 / kappa) * b_J / vsc return grad_u, grad_v def bfgspost(theta): @@ -357,12 +397,9 @@ def bfgspost(theta): def obj(theta): return bfgspost(nx.from_numpy(theta, type_as=M)) - theta, _, _ = fmin_l_bfgs_b(func=obj, - x0=theta0, - bounds=bounds, - maxfun=maxfun, - pgtol=pgtol, - maxiter=maxiter) + theta, _, _ = fmin_l_bfgs_b( + func=obj, x0=theta0, bounds=bounds, maxfun=maxfun, pgtol=pgtol, maxiter=maxiter + ) theta = nx.from_numpy(theta, type_as=M) usc = theta[:ns_budget] @@ -375,10 +412,10 @@ def obj(theta): if log: log = {} - log['u'] = usc_full - log['v'] = vsc_full - log['Isel'] = Isel - log['Jsel'] = Jsel + log["u"] = usc_full + log["v"] = vsc_full + log["Isel"] = Isel + log["Jsel"] = Jsel gamma = usc_full[:, None] * K * vsc_full[None, :] gamma = gamma / nx.sum(gamma) diff --git a/ot/bregman/_sinkhorn.py b/ot/bregman/_sinkhorn.py index beb101fe8..a288830d2 100644 --- a/ot/bregman/_sinkhorn.py +++ b/ot/bregman/_sinkhorn.py @@ -19,8 +19,20 @@ from ..backend import get_backend -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): +def sinkhorn( + a, + b, + M, + reg, + method="sinkhorn", + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -150,36 +162,93 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, """ - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'greenkhorn': - return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) + if method.lower() == "sinkhorn": + return sinkhorn_knopp( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return sinkhorn_log( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "greenkhorn": + return greenkhorn( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + ) + elif method.lower() == "sinkhorn_stabilized": + return sinkhorn_stabilized( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + warmstart=warmstart, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) + elif method.lower() == "sinkhorn_epsilon_scaling": + return sinkhorn_epsilon_scaling( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + warmstart=warmstart, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): +def sinkhorn2( + a, + b, + M, + reg, + method="sinkhorn", + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=False, + warmstart=None, + **kwargs, +): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -320,21 +389,48 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, nx = get_backend(M, a, b) if len(b.shape) < 2: - if method.lower() == 'sinkhorn': - res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) + if method.lower() == "sinkhorn": + res = sinkhorn_knopp( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + res = sinkhorn_log( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "sinkhorn_stabilized": + res = sinkhorn_stabilized( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + warmstart=warmstart, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) if log: @@ -343,28 +439,65 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, return nx.sum(M * res) else: - - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) + if method.lower() == "sinkhorn": + return sinkhorn_knopp( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "sinkhorn_log": + return sinkhorn_log( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warn=warn, + warmstart=warmstart, + **kwargs, + ) + elif method.lower() == "sinkhorn_stabilized": + return sinkhorn_stabilized( + a, + b, + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + warmstart=warmstart, + verbose=verbose, + log=log, + warn=warn, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): +def sinkhorn_knopp( + a, + b, + M, + reg, + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -472,7 +605,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, n_hists = 0 if log: - log = {'err': []} + log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -496,14 +629,18 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, vprev = v KtransposeU = nx.dot(K.T, u) v = b / KtransposeU - u = 1. / nx.dot(Kp, v) - - if (nx.any(KtransposeU == 0) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + u = 1.0 / nx.dot(Kp, v) + + if ( + nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Warning: numerical errors at iteration %d' % ii) + warnings.warn("Warning: numerical errors at iteration %d" % ii) u = uprev v = vprev break @@ -511,48 +648,59 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) + tmp2 = nx.einsum("ik,ij,jk->jk", u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = nx.einsum('i,ij,j->j', u, K, v) + tmp2 = nx.einsum("i,ij,j->j", u, K, v) err = nx.norm(tmp2 - b) # violation of marginal if log: - log['err'].append(err) + log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii - log['u'] = u - log['v'] = v + log["niter"] = ii + log["u"] = u + log["v"] = v if n_hists: # return only loss - res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M) if log: return res, log else: return res else: # return OT matrix - if log: return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log else: return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): +def sinkhorn_log( + a, + b, + M, + reg, + numItermax=1000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, + **kwargs, +): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -665,42 +813,52 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, else: n_hists = 0 - # in case of multiple historgrams + # in case of multiple histograms if n_hists > 1 and warmstart is None: warmstart = [None] * n_hists if n_hists: # we do not want to use tensors sor we do a loop - lst_loss = [] lst_u = [] lst_v = [] for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, - verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) + res = sinkhorn_log( + a, + b[:, k], + M, + reg, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + warmstart=warmstart[k], + **kwargs, + ) if log: lst_loss.append(nx.sum(M * res[0])) - lst_u.append(res[1]['log_u']) - lst_v.append(res[1]['log_v']) + lst_u.append(res[1]["log_u"]) + lst_v.append(res[1]["log_v"]) else: lst_loss.append(nx.sum(M * res)) res = nx.stack(lst_loss) if log: - log = {'log_u': nx.stack(lst_u, 1), - 'log_v': nx.stack(lst_v, 1), } - log['u'] = nx.exp(log['log_u']) - log['v'] = nx.exp(log['log_v']) + log = { + "log_u": nx.stack(lst_u, 1), + "log_v": nx.stack(lst_v, 1), + } + log["u"] = nx.exp(log["log_u"]) + log["v"] = nx.exp(log["log_v"]) return res, log else: return res else: - if log: - log = {'err': []} + log = {"err": []} - Mr = - M / reg + Mr = -M / reg # we assume that no distances are null except those of the diagonal of # distances @@ -721,7 +879,6 @@ def get_logT(u, v): err = 1 for ii in range(numItermax): - v = logb - nx.logsumexp(Mr + u[:, None], 0) u = loga - nx.logsumexp(Mr + v[None, :], 1) @@ -733,27 +890,28 @@ def get_logT(u, v): tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) err = nx.norm(tmp2 - b) # violation of marginal if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err < stopThr: break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['niter'] = ii - log['log_u'] = u - log['log_v'] = v - log['u'] = nx.exp(u) - log['v'] = nx.exp(v) + log["niter"] = ii + log["log_u"] = u + log["log_v"] = v + log["u"] = nx.exp(u) + log["v"] = nx.exp(v) return nx.exp(get_logT(u, v)), log @@ -761,8 +919,18 @@ def get_logT(u, v): return nx.exp(get_logT(u, v)) -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None): +def greenkhorn( + a, + b, + M, + reg, + numItermax=10000, + stopThr=1e-9, + verbose=False, + log=False, + warn=True, + warmstart=None, +): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -859,8 +1027,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, nx = get_backend(M, a, b) if nx.__name__ in ("jax", "tf"): - raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " - "compatible with neither JAX nor TF") + raise TypeError( + "JAX or TF arrays have been received. Greenkhorn is not " + "compatible with neither JAX nor TF" + ) if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -873,8 +1043,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, K = nx.exp(-M / reg) if warmstart is None: - u = nx.full((dim_a,), 1. / dim_a, type_as=K) - v = nx.full((dim_b,), 1. / dim_b, type_as=K) + u = nx.full((dim_a,), 1.0 / dim_a, type_as=K) + v = nx.full((dim_b,), 1.0 / dim_b, type_as=K) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) G = u[:, None] * K * v[None, :] @@ -884,8 +1054,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, stopThr_val = 1 if log: log = dict() - log['u'] = u - log['v'] = v + log["u"] = u + log["v"] = v for ii in range(numItermax): i_1 = nx.argmax(nx.abs(viol)) @@ -900,7 +1070,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, G[i_1, :] = new_u * K[i_1, :] * v viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] - viol_2 += (K[i_1, :].T * (new_u - old_u) * v) + viol_2 += K[i_1, :].T * (new_u - old_u) * v u[i_1] = new_u else: old_v = v[i_2] @@ -916,14 +1086,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: log["n_iter"] = ii - log['u'] = u - log['v'] = v + log["u"] = u + log["v"] = v if log: return G, log @@ -931,9 +1103,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, return G -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, - warmstart=None, verbose=False, print_period=20, - log=False, warn=True, **kwargs): +def sinkhorn_stabilized( + a, + b, + M, + reg, + numItermax=1000, + tau=1e3, + stopThr=1e-9, + warmstart=None, + verbose=False, + print_period=20, + log=False, + warn=True, + **kwargs, +): r""" Solve the entropic regularization OT problem with log stabilization @@ -1055,7 +1239,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_b = len(b) if log: - log = {'err': []} + log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -1074,19 +1258,20 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, def get_K(alpha, beta): """log space computation""" - return nx.exp(-(M - alpha.reshape((dim_a, 1)) - - beta.reshape((1, dim_b))) / reg) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) - / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) + return nx.exp( + -(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg + + nx.log(u.reshape((dim_a, 1))) + + nx.log(v.reshape((1, dim_b))) + ) K = get_K(alpha, beta) transp = K err = 1 for ii in range(numItermax): - uprev = u vprev = v @@ -1097,8 +1282,10 @@ def get_Gamma(alpha, beta, u, v): # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * \ - nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) + alpha, beta = ( + alpha + reg * nx.max(nx.log(u), 1), + beta + reg * nx.max(nx.log(v)), + ) else: alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: @@ -1122,13 +1309,12 @@ def get_Gamma(alpha, beta, u, v): transp = get_Gamma(alpha, beta, u, v) err = nx.norm(nx.sum(transp, axis=0) - b) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % (print_period * 20) == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err <= stopThr: break @@ -1136,15 +1322,17 @@ def get_Gamma(alpha, beta, u, v): if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %d' % ii) + warnings.warn("Numerical errors at iteration %d" % ii) u = uprev v = vprev break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: if n_hists: alpha = alpha[:, None] @@ -1152,35 +1340,52 @@ def get_Gamma(alpha, beta, u, v): logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) log["n_iter"] = ii - log['logu'] = logu - log['logv'] = logv - log['alpha'] = alpha + reg * nx.log(u) - log['beta'] = beta + reg * nx.log(v) - log['warmstart'] = (log['alpha'], log['beta']) + log["logu"] = logu + log["logv"] = logv + log["alpha"] = alpha + reg * nx.log(u) + log["beta"] = beta + reg * nx.log(v) + log["warmstart"] = (log["alpha"], log["beta"]) if n_hists: - res = nx.stack([ - nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) - for i in range(n_hists) - ]) + res = nx.stack( + [ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ] + ) return res, log else: return get_Gamma(alpha, beta, u, v), log else: if n_hists: - res = nx.stack([ - nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) - for i in range(n_hists) - ]) + res = nx.stack( + [ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ] + ) return res else: return get_Gamma(alpha, beta, u, v) -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, - numInnerItermax=100, tau=1e3, stopThr=1e-9, - warmstart=None, verbose=False, print_period=10, - log=False, warn=True, **kwargs): +def sinkhorn_epsilon_scaling( + a, + b, + M, + reg, + numItermax=100, + epsilon0=1e4, + numInnerItermax=100, + tau=1e3, + stopThr=1e-9, + warmstart=None, + verbose=False, + print_period=10, + log=False, + warn=True, + **kwargs, +): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -1292,7 +1497,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, ii = 0 if log: - log = {'err': []} + log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -1307,44 +1512,55 @@ def get_reg(n): # exponential decreasing err = 1 for ii in range(numItermax): - regi = get_reg(ii) - G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=stopThr, - warmstart=(alpha, beta), verbose=False, - print_period=20, tau=tau, log=True) - - alpha = logi['alpha'] - beta = logi['beta'] + G, logi = sinkhorn_stabilized( + a, + b, + M, + regi, + numItermax=numInnerItermax, + stopThr=stopThr, + warmstart=(alpha, beta), + verbose=False, + print_period=20, + tau=tau, + log=True, + ) + + alpha = logi["alpha"] + beta = logi["beta"] if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ - nx.norm(nx.sum(transp, axis=1) - a) ** 2 + err = ( + nx.norm(nx.sum(transp, axis=0) - b) ** 2 + + nx.norm(nx.sum(transp, axis=1) - a) ** 2 + ) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if ii % (print_period * 10) == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(ii, err)) if err <= stopThr and ii > numItermin: break else: if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + warnings.warn( + "Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`." + ) if log: - log['alpha'] = alpha - log['beta'] = beta - log['warmstart'] = (log['alpha'], log['beta']) - log['niter'] = ii + log["alpha"] = alpha + log["beta"] = beta + log["warmstart"] = (log["alpha"], log["beta"]) + log["niter"] = ii return G, log else: return G diff --git a/ot/bregman/_utils.py b/ot/bregman/_utils.py index 9535cd1c0..3b7a0af36 100644 --- a/ot/bregman/_utils.py +++ b/ot/bregman/_utils.py @@ -16,7 +16,7 @@ def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" weights, alldistribT = list_to_array(weights, alldistribT) nx = get_backend(weights, alldistribT) - assert (len(weights) == alldistribT.shape[1]) + assert len(weights) == alldistribT.shape[1] return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) @@ -28,14 +28,14 @@ def geometricMean(alldistribT): def projR(gamma, p): - """return the KL projection on the row constraints """ + """return the KL projection on the row constraints""" gamma, p = list_to_array(gamma, p) nx = get_backend(gamma, p) return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T def projC(gamma, q): - """return the KL projection on the column constraints """ + """return the KL projection on the column constraints""" gamma, q = list_to_array(gamma, q) nx = get_backend(gamma, q) return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10) diff --git a/ot/coot.py b/ot/coot.py index 4134e594c..8ea08a2f9 100644 --- a/ot/coot.py +++ b/ot/coot.py @@ -14,11 +14,28 @@ from .bregman import sinkhorn -def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - epsilon=0, alpha=0, M_samp=None, M_feat=None, - warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1, - nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn", - early_stopping_tol=1e-6, log=False, verbose=False): +def co_optimal_transport( + X, + Y, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + epsilon=0, + alpha=0, + M_samp=None, + M_feat=None, + warmstart=None, + nits_bcd=100, + tol_bcd=1e-7, + eval_bcd=1, + nits_ot=500, + tol_sinkhorn=1e-7, + method_sinkhorn="sinkhorn", + early_stopping_tol=1e-6, + log=False, + verbose=False, +): r"""Compute the CO-Optimal Transport between two matrices. Return the sample and feature transport plans between @@ -143,7 +160,10 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: raise ValueError( - "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn)) + "Method {} is not supported in CO-Optimal Transport.".format( + method_sinkhorn + ) + ) X, Y = list_to_array(X, Y) nx = get_backend(X, Y) @@ -152,7 +172,9 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat eps_samp, eps_feat = epsilon, epsilon else: if len(epsilon) != 2: - raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.") + raise ValueError( + "Epsilon must be either a scalar or an indexable object of length 2." + ) else: eps_samp, eps_feat = epsilon[0], epsilon[1] @@ -160,7 +182,9 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat alpha_samp, alpha_feat = alpha, alpha else: if len(alpha) != 2: - raise ValueError("Alpha must be either a scalar or an indexable object of length 2.") + raise ValueError( + "Alpha must be either a scalar or an indexable object of length 2." + ) else: alpha_samp, alpha_feat = alpha[0], alpha[1] @@ -187,18 +211,27 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat wxy_feat = wx_feat[:, None] * wy_feat[None, :] # pre-calculate cost constants - XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @ - wy_feat)[None, :] + alpha_samp * M_samp - XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T) - ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat + XY_sqr = (X**2 @ wx_feat)[:, None] + (Y**2 @ wy_feat)[None, :] + alpha_samp * M_samp + XY_sqr_T = ( + ((X.T) ** 2 @ wx_samp)[:, None] + + ((Y.T) ** 2 @ wy_samp)[None, :] + + alpha_feat * M_feat + ) # initialize coupling and dual vectors if warmstart is None: - pi_samp, pi_feat = wxy_samp, wxy_feat # shape nx_samp x ny_samp and nx_feat x ny_feat - duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros( - ny_samp, type_as=Y)) # shape nx_samp, ny_samp - duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros( - ny_feat, type_as=Y)) # shape nx_feat, ny_feat + pi_samp, pi_feat = ( + wxy_samp, + wxy_feat, + ) # shape nx_samp x ny_samp and nx_feat x ny_feat + duals_samp = ( + nx.zeros(nx_samp, type_as=X), + nx.zeros(ny_samp, type_as=Y), + ) # shape nx_samp, ny_samp + duals_feat = ( + nx.zeros(nx_feat, type_as=X), + nx.zeros(ny_feat, type_as=Y), + ) # shape nx_feat, ny_feat else: pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"] duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"] @@ -213,22 +246,42 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat # update sample coupling ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T # size nx_samp x ny_samp if eps_samp > 0: - pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn, - numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp) + pi_samp, dict_log = sinkhorn( + a=wx_samp, + b=wy_samp, + M=ot_cost, + reg=eps_samp, + method=method_sinkhorn, + numItermax=nits_ot, + stopThr=tol_sinkhorn, + log=True, + warmstart=duals_samp, + ) duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) elif eps_samp == 0: pi_samp, dict_log = emd( - a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True) + a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True + ) duals_samp = (dict_log["u"], dict_log["v"]) # update feature coupling ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y # size nx_feat x ny_feat if eps_feat > 0: - pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn, - numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat) + pi_feat, dict_log = sinkhorn( + a=wx_feat, + b=wy_feat, + M=ot_cost, + reg=eps_feat, + method=method_sinkhorn, + numItermax=nits_ot, + stopThr=tol_sinkhorn, + log=True, + warmstart=duals_feat, + ) duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) elif eps_feat == 0: pi_feat, dict_log = emd( - a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True) + a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True + ) duals_feat = (dict_log["u"], dict_log["v"]) if idx % eval_bcd == 0: @@ -251,16 +304,21 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat if verbose: print( - "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot)) + "CO-Optimal Transport cost at iteration {}: {}".format( + idx + 1, coot + ) + ) # sanity check if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0: warnings.warn("There is NaN in coupling.") if log: - dict_log = {"duals_sample": duals_samp, - "duals_feature": duals_feat, - "distances": list_coot[1:]} + dict_log = { + "duals_sample": duals_samp, + "duals_feature": duals_feat, + "distances": list_coot[1:], + } return pi_samp, pi_feat, dict_log @@ -268,12 +326,28 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat return pi_samp, pi_feat -def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - epsilon=0, alpha=0, M_samp=None, M_feat=None, - warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6, - nits_bcd=100, tol_bcd=1e-7, eval_bcd=1, - nits_ot=500, tol_sinkhorn=1e-7, - method_sinkhorn="sinkhorn"): +def co_optimal_transport2( + X, + Y, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + epsilon=0, + alpha=0, + M_samp=None, + M_feat=None, + warmstart=None, + log=False, + verbose=False, + early_stopping_tol=1e-6, + nits_bcd=100, + tol_bcd=1e-7, + eval_bcd=1, + nits_ot=500, + tol_sinkhorn=1e-7, + method_sinkhorn="sinkhorn", +): r"""Compute the CO-Optimal Transport distance between two measures. Returns the CO-Optimal Transport distance between @@ -386,13 +460,28 @@ def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_fea Advances in Neural Information Processing ny_sampstems, 33 (2020). """ - pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp, - wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp, - M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd, - tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot, - tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn, - early_stopping_tol=early_stopping_tol, - log=True, verbose=verbose) + pi_samp, pi_feat, dict_log = co_optimal_transport( + X=X, + Y=Y, + wx_samp=wx_samp, + wx_feat=wx_feat, + wy_samp=wy_samp, + wy_feat=wy_feat, + epsilon=epsilon, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + warmstart=warmstart, + nits_bcd=nits_bcd, + tol_bcd=tol_bcd, + eval_bcd=eval_bcd, + nits_ot=nits_ot, + tol_sinkhorn=tol_sinkhorn, + method_sinkhorn=method_sinkhorn, + early_stopping_tol=early_stopping_tol, + log=True, + verbose=verbose, + ) X, Y = list_to_array(X, Y) nx = get_backend(X, Y) @@ -413,14 +502,19 @@ def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_fea vx_samp, vy_samp = dict_log["duals_sample"] vx_feat, vy_feat = dict_log["duals_feature"] - gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \ - 2 * pi_samp @ Y @ pi_feat.T # shape (nx_samp, nx_feat) - gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \ - 2 * pi_samp.T @ X @ pi_feat # shape (ny_samp, ny_feat) + gradX = ( + 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - 2 * pi_samp @ Y @ pi_feat.T + ) # shape (nx_samp, nx_feat) + gradY = ( + 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - 2 * pi_samp.T @ X @ pi_feat + ) # shape (ny_samp, ny_feat) coot = dict_log["distances"][-1] - coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y), - (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY)) + coot = nx.set_gradients( + coot, + (wx_samp, wx_feat, wy_samp, wy_feat, X, Y), + (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY), + ) if log: return coot, dict_log diff --git a/ot/da.py b/ot/da.py index b51b08b3a..7fce1b7eb 100644 --- a/ot/da.py +++ b/ot/da.py @@ -18,19 +18,50 @@ from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array +from .utils import ( + unif, + dist, + kernel, + cost_normalization, + label_normalization, + laplacian, + dots, +) +from .utils import ( + BaseEstimator, + check_params, + deprecated, + labels_to_masks, + list_to_array, +) from .unbalanced import sinkhorn_unbalanced -from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping +from .gaussian import ( + empirical_bures_wasserstein_mapping, + empirical_gaussian_gromov_wasserstein_mapping, +) from .optim import cg from .optim import gcg -from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds, joint_OT_mapping_linear, \ - joint_OT_mapping_kernel - - -def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False): +from .mapping import ( + nearest_brenier_potential_fit, + nearest_brenier_potential_predict_bounds, + joint_OT_mapping_linear, + joint_OT_mapping_kernel, +) + + +def sinkhorn_lpl1_mm( + a, + labels_a, + b, + M, + reg, + eta=0.1, + numItermax=10, + numInnerItermax=200, + stopInnerThr=1e-9, + verbose=False, + log=False, +): r""" Solve the entropic regularization optimal transport problem with non-convex group lasso regularization @@ -130,14 +161,25 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, for _ in range(numItermax): Mreg = M + eta * W if log: - transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr, log=True) + transp, log = sinkhorn( + a, + b, + Mreg, + reg, + numItermax=numInnerItermax, + stopThr=stopInnerThr, + log=True, + ) else: - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr) + transp = sinkhorn( + a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr + ) # the transport has been computed # check if classes are really separated - W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = ( + nx.repeat(transp.T[:, :, None], n_labels, axis=2) + * unroll_labels_idx[None, :, :] + ) W = nx.sum(W, axis=1) W = nx.dot(W, unroll_labels_idx.T) W = p * ((W.T + epsilon) ** (p - 1)) @@ -148,9 +190,20 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, return transp -def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, eps=1e-12, - verbose=False, log=False): +def sinkhorn_l1l2_gl( + a, + labels_a, + b, + M, + reg, + eta=0.1, + numItermax=10, + numInnerItermax=200, + stopInnerThr=1e-9, + eps=1e-12, + verbose=False, + log=False, +): r""" Solve the entropic regularization optimal transport problem with group lasso regularization @@ -252,17 +305,44 @@ def df(G): G_norm = G_split / nx.clip(W, eps, None) return nx.sum(G_norm, axis=2).T - return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, - numInnerItermax=numInnerItermax, stopThr=stopInnerThr, - verbose=verbose, log=log) + return gcg( + a, + b, + M, + reg, + eta, + f, + df, + G0=None, + numItermax=numItermax, + numInnerItermax=numInnerItermax, + stopThr=stopInnerThr, + verbose=verbose, + log=log, + ) OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) -def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, - numItermax=100, stopThr=1e-9, numInnerItermax=100000, - stopInnerThr=1e-9, log=False, verbose=False): +def emd_laplace( + a, + b, + xs, + xt, + M, + sim="knn", + sim_param=None, + reg="pos", + eta=1, + alpha=0.5, + numItermax=100, + stopThr=1e-9, + numInnerItermax=100000, + stopInnerThr=1e-9, + log=False, + verbose=False, +): r"""Solve the optimal transport problem (OT) with Laplacian regularization .. math:: @@ -359,62 +439,82 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al """ if not isinstance(sim_param, (int, float, type(None))): raise ValueError( - 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + "Similarity parameter should be an int or a float. Got {type} instead.".format( + type=type(sim_param).__name__ + ) + ) a, b, xs, xt, M = list_to_array(a, b, xs, xt, M) nx = get_backend(a, b, xs, xt, M) - if sim == 'gauss': + if sim == "gauss": if sim_param is None: - sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sim_param = 1 / (2 * (nx.mean(dist(xs, xs, "sqeuclidean")) ** 2)) sS = kernel(xs, xs, method=sim, sigma=sim_param) sT = kernel(xt, xt, method=sim, sigma=sim_param) - elif sim == 'knn': + elif sim == "knn": if sim_param is None: sim_param = 3 try: from sklearn.neighbors import kneighbors_graph except ImportError: - raise ValueError('scikit-learn must be installed to use knn similarity. Install with `$pip install scikit-learn`.') + raise ValueError( + "scikit-learn must be installed to use knn similarity. Install with `$pip install scikit-learn`." + ) - sS = nx.from_numpy(kneighbors_graph( - X=nx.to_numpy(xs), n_neighbors=int(sim_param) - ).toarray(), type_as=xs) + sS = nx.from_numpy( + kneighbors_graph(X=nx.to_numpy(xs), n_neighbors=int(sim_param)).toarray(), + type_as=xs, + ) sS = (sS + sS.T) / 2 - sT = nx.from_numpy(kneighbors_graph( - X=nx.to_numpy(xt), n_neighbors=int(sim_param) - ).toarray(), type_as=xt) + sT = nx.from_numpy( + kneighbors_graph(X=nx.to_numpy(xt), n_neighbors=int(sim_param)).toarray(), + type_as=xt, + ) sT = (sT + sT.T) / 2 else: - raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) + raise ValueError( + 'Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format( + sim=sim + ) + ) lS = laplacian(sS) lT = laplacian(sT) def f(G): - return ( - alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) - + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs)) + return alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) + (1 - alpha) * nx.trace( + dots(xs.T, G, lT, G.T, xs) ) ls2 = lS + lS.T lt2 = lT + lT.T xt2 = nx.dot(xt, xt.T) - if reg == 'disp': + if reg == "disp": Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) Ct = -eta * (1 - alpha) / xt.shape[0] * dots(xs, xt.T, lt2) M = M + Cs + Ct def df(G): - return ( - alpha * dots(ls2, G, xt2) - + (1 - alpha) * dots(xs, xs.T, G, lt2) - ) - - return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, - stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) + return alpha * dots(ls2, G, xt2) + (1 - alpha) * dots(xs, xs.T, G, lt2) + + return cg( + a, + b, + M, + reg=eta, + f=f, + df=df, + G0=None, + numItermax=numItermax, + numItermaxEmd=numInnerItermax, + stopThr=stopThr, + stopThr2=stopInnerThr, + verbose=verbose, + log=log, + ) def distribution_estimation_uniform(X): @@ -435,7 +535,6 @@ def distribution_estimation_uniform(X): class BaseTransport(BaseEstimator): - """Base class for OTDA objects .. note:: @@ -489,13 +588,13 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): - # pairwise distance self.cost_ = dist(Xs, Xt, metric=self.metric) - self.cost_, self.norm_cost_ = cost_normalization(self.cost_, self.norm, return_value=True) + self.cost_, self.norm_cost_ = cost_normalization( + self.cost_, self.norm, return_value=True + ) if (ys is not None) and (yt is not None): - if self.limit_max != np.inf: self.limit_max = self.limit_max * nx.max(self.cost_) @@ -514,7 +613,7 @@ class label # we suppress potential RuntimeWarning caused by Inf multiplication # (as we explicitly cover potential NANs later) with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=RuntimeWarning) + warnings.simplefilter("ignore", category=RuntimeWarning) cost_correction = label_match * missing_labels * self.limit_max # this operation is necessary because 0 * Inf = NAN # thus is irrelevant when limit_max is finite @@ -588,7 +687,6 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs): - if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] @@ -602,8 +700,9 @@ class label # perform out of sample mapping indices = nx.arange(Xs.shape[0]) batch_ind = [ - indices[i:i + batch_size] - for i in range(0, len(indices), batch_size)] + indices[i : i + batch_size] + for i in range(0, len(indices), batch_size) + ] transp_Xs = [] for bi in batch_ind: @@ -665,8 +764,7 @@ def transform_labels(self, ys=None): return transp_ys.T - def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, - batch_size=128): + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters @@ -695,7 +793,6 @@ class label # check the necessary inputs parameters are here if check_params(Xt=Xt): - if nx.array_equal(self.xt_, Xt): # perform standard barycentric mapping transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] @@ -709,8 +806,9 @@ class label # perform out of sample mapping indices = nx.arange(Xt.shape[0]) batch_ind = [ - indices[i:i + batch_size] - for i in range(0, len(indices), batch_size)] + indices[i : i + batch_size] + for i in range(0, len(indices), batch_size) + ] transp_Xt = [] for bi in batch_ind: @@ -762,7 +860,7 @@ def inverse_transform_labels(self, yt=None): class LinearTransport(BaseTransport): - r""" OT linear operator between empirical distributions + r"""OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed @@ -806,8 +904,13 @@ class LinearTransport(BaseTransport): """ - def __init__(self, reg=1e-8, bias=True, log=False, - distribution_estimation=distribution_estimation_uniform): + def __init__( + self, + reg=1e-8, + bias=True, + log=False, + distribution_estimation=distribution_estimation_uniform, + ): self.bias = bias self.log = log self.reg = reg @@ -844,16 +947,24 @@ class label self.mu_t = self.distribution_estimation(Xt) # coupling estimation - returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg, - ws=nx.reshape(self.mu_s, (-1, 1)), - wt=nx.reshape(self.mu_t, (-1, 1)), - bias=self.bias, log=self.log) + returned_ = empirical_bures_wasserstein_mapping( + Xs, + Xt, + reg=self.reg, + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), + bias=self.bias, + log=self.log, + ) # deal with the value of log if self.log: self.A_, self.B_, self.log_ = returned_ else: - self.A_, self.B_, = returned_ + ( + self.A_, + self.B_, + ) = returned_ self.log_ = dict() # re compute inverse mapping @@ -895,8 +1006,7 @@ class label return transp_Xs - def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, - batch_size=128): + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters @@ -931,7 +1041,7 @@ class label class LinearGWTransport(LinearTransport): - r""" OT Gaussian Gromov-Wasserstein linear operator between empirical distributions + r"""OT Gaussian Gromov-Wasserstein linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two empirical distributions optimally wrt the Gromov-Wasserstein distance. This is equivalent to estimating the closed @@ -968,8 +1078,12 @@ class LinearGWTransport(LinearTransport): """ - def __init__(self, log=False, sign_eigs=None, - distribution_estimation=distribution_estimation_uniform): + def __init__( + self, + log=False, + sign_eigs=None, + distribution_estimation=distribution_estimation_uniform, + ): self.sign_eigs = sign_eigs self.log = log self.distribution_estimation = distribution_estimation @@ -1005,36 +1119,47 @@ class label self.mu_t = self.distribution_estimation(Xt) # coupling estimation - returned_ = empirical_gaussian_gromov_wasserstein_mapping(Xs, Xt, - ws=self.mu_s[:, None], - wt=self.mu_t[:, None], - sign_eigs=self.sign_eigs, - log=self.log) + returned_ = empirical_gaussian_gromov_wasserstein_mapping( + Xs, + Xt, + ws=self.mu_s[:, None], + wt=self.mu_t[:, None], + sign_eigs=self.sign_eigs, + log=self.log, + ) # deal with the value of log if self.log: self.A_, self.B_, self.log_ = returned_ else: - self.A_, self.B_, = returned_ + ( + self.A_, + self.B_, + ) = returned_ self.log_ = dict() # re compute inverse mapping - returned_1_ = empirical_gaussian_gromov_wasserstein_mapping(Xt, Xs, - ws=self.mu_t[:, None], - wt=self.mu_s[:, None], - sign_eigs=self.sign_eigs, - log=self.log) + returned_1_ = empirical_gaussian_gromov_wasserstein_mapping( + Xt, + Xs, + ws=self.mu_t[:, None], + wt=self.mu_s[:, None], + sign_eigs=self.sign_eigs, + log=self.log, + ) if self.log: self.A1_, self.B1_, self.log_1_ = returned_1_ else: - self.A1_, self.B1_, = returned_1_ + ( + self.A1_, + self.B1_, + ) = returned_1_ self.log_ = dict() return self class SinkhornTransport(BaseTransport): - """Domain Adaptation OT method based on Sinkhorn Algorithm Parameters @@ -1104,14 +1229,22 @@ class SinkhornTransport(BaseTransport): """ - def __init__(self, reg_e=1., method="sinkhorn_log", max_iter=1000, - tol=10e-9, verbose=False, log=False, - metric="sqeuclidean", norm=None, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='continuous', limit_max=np.inf): - - if out_of_sample_map not in ['ferradans', 'continuous']: - raise ValueError('Unknown out_of_sample_map method') + def __init__( + self, + reg_e=1.0, + method="sinkhorn_log", + max_iter=1000, + tol=10e-9, + verbose=False, + log=False, + metric="sqeuclidean", + norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="continuous", + limit_max=np.inf, + ): + if out_of_sample_map not in ["ferradans", "continuous"]: + raise ValueError("Unknown out_of_sample_map method") self.reg_e = reg_e self.method = method @@ -1152,17 +1285,26 @@ class label super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) - if self.out_of_sample_map == 'continuous': + if self.out_of_sample_map == "continuous": self.log = True - if not self.method == 'sinkhorn_log': - self.method = 'sinkhorn_log' - warnings.warn("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'") + if not self.method == "sinkhorn_log": + self.method = "sinkhorn_log" + warnings.warn( + "The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'" + ) # coupling estimation returned_ = sinkhorn( - a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, - method=self.method, numItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=self.log) + a=self.mu_s, + b=self.mu_t, + M=self.cost_, + reg=self.reg_e, + method=self.method, + numItermax=self.max_iter, + stopThr=self.tol, + verbose=self.verbose, + log=self.log, + ) # deal with the value of log if self.log: @@ -1200,18 +1342,17 @@ class label """ nx = self.nx - if self.out_of_sample_map == 'ferradans': + if self.out_of_sample_map == "ferradans": return super(SinkhornTransport, self).transform(Xs, ys, Xt, yt, batch_size) else: # self.out_of_sample_map == 'continuous': - # check the necessary inputs parameters are here - g = self.log_['log_v'] + g = self.log_["log_v"] indices = nx.arange(Xs.shape[0]) batch_ind = [ - indices[i:i + batch_size] - for i in range(0, len(indices), batch_size)] + indices[i : i + batch_size] for i in range(0, len(indices), batch_size) + ] transp_Xs = [] for bi in batch_ind: @@ -1258,22 +1399,21 @@ class label nx = self.nx - if self.out_of_sample_map == 'ferradans': - return super(SinkhornTransport, self).inverse_transform(Xs, ys, Xt, yt, batch_size) + if self.out_of_sample_map == "ferradans": + return super(SinkhornTransport, self).inverse_transform( + Xs, ys, Xt, yt, batch_size + ) else: # self.out_of_sample_map == 'continuous': - - f = self.log_['log_u'] + f = self.log_["log_u"] indices = nx.arange(Xt.shape[0]) batch_ind = [ - indices[i:i + batch_size] - for i in range(0, len(indices), batch_size - )] + indices[i : i + batch_size] for i in range(0, len(indices), batch_size) + ] transp_Xt = [] for bi in batch_ind: - M = dist(Xt[bi], self.xs_, metric=self.metric) M = cost_normalization(M, self.norm, value=self.norm_cost_) @@ -1289,7 +1429,6 @@ class label class EMDTransport(BaseTransport): - """Domain Adaptation OT method based on Earth Mover's Distance Parameters @@ -1332,10 +1471,16 @@ class EMDTransport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, metric="sqeuclidean", norm=None, log=False, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=10, - max_iter=100000): + def __init__( + self, + metric="sqeuclidean", + norm=None, + log=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="ferradans", + limit_max=10, + max_iter=100000, + ): self.metric = metric self.norm = norm self.log = log @@ -1372,8 +1517,12 @@ class label super(EMDTransport, self).fit(Xs, ys, Xt, yt) returned_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter, - log=self.log) + a=self.mu_s, + b=self.mu_t, + M=self.cost_, + numItermax=self.max_iter, + log=self.log, + ) # coupling estimation if self.log: @@ -1444,12 +1593,21 @@ class SinkhornLpl1Transport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_e=1., reg_cl=0.1, - max_iter=10, max_inner_iter=200, log=False, - tol=10e-9, verbose=False, - metric="sqeuclidean", norm=None, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=np.inf): + def __init__( + self, + reg_e=1.0, + reg_cl=0.1, + max_iter=10, + max_inner_iter=200, + log=False, + tol=10e-9, + verbose=False, + metric="sqeuclidean", + norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="ferradans", + limit_max=np.inf, + ): self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1493,10 +1651,18 @@ class label super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) returned_ = sinkhorn_lpl1_mm( - a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, - reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, - numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose, log=self.log) + a=self.mu_s, + labels_a=ys, + b=self.mu_t, + M=self.cost_, + reg=self.reg_e, + eta=self.reg_cl, + numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, + stopInnerThr=self.tol, + verbose=self.verbose, + log=self.log, + ) # deal with the value of log if self.log: @@ -1508,7 +1674,6 @@ class label class EMDLaplaceTransport(BaseTransport): - """Domain Adaptation OT method based on Earth Mover's Distance with Laplacian regularization Parameters @@ -1570,11 +1735,24 @@ class EMDLaplaceTransport(BaseTransport): Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", - norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9, - max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): + def __init__( + self, + reg_type="pos", + reg_lap=1.0, + reg_src=1.0, + metric="sqeuclidean", + norm=None, + similarity="knn", + similarity_param=None, + max_iter=100, + tol=1e-9, + max_inner_iter=100000, + inner_tol=1e-9, + log=False, + verbose=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="ferradans", + ): self.reg = reg_type self.reg_lap = reg_lap self.reg_src = reg_src @@ -1618,10 +1796,24 @@ class label super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) - returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, - alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) + returned_ = emd_laplace( + a=self.mu_s, + b=self.mu_t, + xs=self.xs_, + xt=self.xt_, + M=self.cost_, + sim=self.similarity, + sim_param=self.sim_param, + reg=self.reg, + eta=self.reg_lap, + alpha=self.reg_src, + numItermax=self.max_iter, + stopThr=self.tol, + numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, + log=self.log, + verbose=self.verbose, + ) # coupling estimation if self.log: @@ -1633,7 +1825,6 @@ class label class SinkhornL1l2Transport(BaseTransport): - """Domain Adaptation OT method based on sinkhorn algorithm + L1L2 class regularization. @@ -1695,12 +1886,21 @@ class SinkhornL1l2Transport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_e=1., reg_cl=0.1, - max_iter=10, max_inner_iter=200, - tol=10e-9, verbose=False, log=False, - metric="sqeuclidean", norm=None, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=10): + def __init__( + self, + reg_e=1.0, + reg_cl=0.1, + max_iter=10, + max_inner_iter=200, + tol=10e-9, + verbose=False, + log=False, + metric="sqeuclidean", + norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="ferradans", + limit_max=10, + ): self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1741,14 +1941,21 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt) returned_ = sinkhorn_l1l2_gl( - a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, - reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, - numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose, log=self.log) + a=self.mu_s, + labels_a=ys, + b=self.mu_t, + M=self.cost_, + reg=self.reg_e, + eta=self.reg_cl, + numItermax=self.max_iter, + numInnerItermax=self.max_inner_iter, + stopInnerThr=self.tol, + verbose=self.verbose, + log=self.log, + ) # deal with the value of log if self.log: @@ -1761,7 +1968,6 @@ class label class MappingTransport(BaseEstimator): - """MappingTransport: DA methods that aims at jointly estimating a optimal transport coupling and the associated mapping @@ -1821,10 +2027,23 @@ class MappingTransport(BaseEstimator): """ - def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean", - norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5, - max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False, - verbose2=False): + def __init__( + self, + mu=1, + eta=0.001, + bias=False, + metric="sqeuclidean", + norm=None, + kernel="linear", + sigma=1, + max_iter=100, + tol=1e-5, + max_inner_iter=10, + inner_tol=1e-6, + log=False, + verbose=False, + verbose2=False, + ): self.metric = metric self.norm = norm self.mu = mu @@ -1869,26 +2088,41 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): - self.xs_ = Xs self.xt_ = Xt if self.kernel == "linear": returned_ = joint_OT_mapping_linear( - Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, - verbose=self.verbose, verbose2=self.verbose2, + Xs, + Xt, + mu=self.mu, + eta=self.eta, + bias=self.bias, + verbose=self.verbose, + verbose2=self.verbose2, numItermax=self.max_iter, - numInnerItermax=self.max_inner_iter, stopThr=self.tol, - stopInnerThr=self.inner_tol, log=self.log) + numInnerItermax=self.max_inner_iter, + stopThr=self.tol, + stopInnerThr=self.inner_tol, + log=self.log, + ) elif self.kernel == "gaussian": returned_ = joint_OT_mapping_kernel( - Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, - sigma=self.sigma, verbose=self.verbose, - verbose2=self.verbose, numItermax=self.max_iter, + Xs, + Xt, + mu=self.mu, + eta=self.eta, + bias=self.bias, + sigma=self.sigma, + verbose=self.verbose, + verbose2=self.verbose, + numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, stopThr=self.tol, - log=self.log) + stopInnerThr=self.inner_tol, + stopThr=self.tol, + log=self.log, + ) # deal with the value of log if self.log: @@ -1916,7 +2150,6 @@ def transform(self, Xs): # check the necessary inputs parameters are here if check_params(Xs=Xs): - if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] @@ -1928,8 +2161,7 @@ def transform(self, Xs): transp_Xs = nx.dot(transp, self.xt_) else: if self.kernel == "gaussian": - K = kernel(Xs, self.xs_, method=self.kernel, - sigma=self.sigma) + K = kernel(Xs, self.xs_, method=self.kernel, sigma=self.sigma) elif self.kernel == "linear": K = Xs if self.bias: @@ -1942,7 +2174,6 @@ def transform(self, Xs): class UnbalancedSinkhornTransport(BaseTransport): - """Domain Adaptation unbalanced OT method based on sinkhorn algorithm Parameters @@ -1999,11 +2230,21 @@ class UnbalancedSinkhornTransport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', - max_iter=10, tol=1e-9, verbose=False, log=False, - metric="sqeuclidean", norm=None, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans', limit_max=10): + def __init__( + self, + reg_e=1.0, + reg_m=0.1, + method="sinkhorn", + max_iter=10, + tol=1e-9, + verbose=False, + log=False, + metric="sqeuclidean", + norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map="ferradans", + limit_max=10, + ): self.reg_e = reg_e self.reg_m = reg_m self.method = method @@ -2044,14 +2285,20 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): - super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt) returned_ = sinkhorn_unbalanced( - a=self.mu_s, b=self.mu_t, M=self.cost_, - reg=self.reg_e, reg_m=self.reg_m, method=self.method, - numItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=self.log) + a=self.mu_s, + b=self.mu_t, + M=self.cost_, + reg=self.reg_e, + reg_m=self.reg_m, + method=self.method, + numItermax=self.max_iter, + stopThr=self.tol, + verbose=self.verbose, + log=self.log, + ) # deal with the value of log if self.log: @@ -2064,7 +2311,6 @@ class label class JCPOTTransport(BaseTransport): - """Domain Adaptation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters @@ -2117,10 +2363,16 @@ class JCPOTTransport(BaseTransport): """ - def __init__(self, reg_e=.1, max_iter=10, - tol=10e-9, verbose=False, log=False, - metric="sqeuclidean", - out_of_sample_map='ferradans'): + def __init__( + self, + reg_e=0.1, + max_iter=10, + tol=10e-9, + verbose=False, + log=False, + metric="sqeuclidean", + out_of_sample_map="ferradans", + ): self.reg_e = reg_e self.max_iter = max_iter self.tol = tol @@ -2157,15 +2409,22 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - self.xs_ = Xs self.xt_ = Xt - returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, - metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=True) + returned_ = jcpot_barycenter( + Xs=Xs, + Ys=ys, + Xt=Xt, + reg=self.reg_e, + metric=self.metric, + distrinumItermax=self.max_iter, + stopThr=self.tol, + verbose=self.verbose, + log=True, + ) - self.coupling_ = returned_[1]['gamma'] + self.coupling_ = returned_[1]["gamma"] # deal with the value of log if self.log: @@ -2202,9 +2461,7 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs): - if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]): - # perform standard barycentric mapping for each source domain for coupling in self.coupling_: @@ -2216,12 +2473,12 @@ class label # compute transported samples transp_Xs.append(nx.dot(transp, self.xt_)) else: - # perform out of sample mapping indices = nx.arange(Xs.shape[0]) batch_ind = [ - indices[i:i + batch_size] - for i in range(0, len(indices), batch_size)] + indices[i : i + batch_size] + for i in range(0, len(indices), batch_size) + ] transp_Xs = [] @@ -2275,8 +2532,7 @@ def transform_labels(self, ys=None): # check the necessary inputs parameters are here if check_params(ys=ys): yt = nx.zeros( - (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), - type_as=ys[0] + (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), type_as=ys[0] ) for i in range(len(ys)): ysTemp = label_normalization(ys[i]) @@ -2291,7 +2547,7 @@ def transform_labels(self, ys=None): transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0) if self.log: - D1 = self.log_['D1'][i] + D1 = self.log_["D1"][i] else: D1 = nx.zeros((n, ns), type_as=transp) @@ -2331,7 +2587,6 @@ def inverse_transform_labels(self, yt=None): D1[int(c), ytTemp == c] = 1 for i in range(len(self.xs_)): - # perform label propagation transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] @@ -2403,7 +2658,14 @@ class NearestBrenierPotential(BaseTransport): ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data """ - def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None): + def __init__( + self, + strongly_convex_constant=0.6, + gradient_lipschitz_constant=1.4, + log=False, + its=100, + seed=None, + ): self.strongly_convex_constant = strongly_convex_constant self.gradient_lipschitz_constant = gradient_lipschitz_constant self.log = log @@ -2452,10 +2714,15 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None): """ self.fit_Xs, self.fit_ys, self.fit_Xt = Xs, ys, Xt - returned = nearest_brenier_potential_fit(Xs, Xt, X_classes=ys, - strongly_convex_constant=self.strongly_convex_constant, - gradient_lipschitz_constant=self.gradient_lipschitz_constant, - its=self.its, log=self.log) + returned = nearest_brenier_potential_fit( + Xs, + Xt, + X_classes=ys, + strongly_convex_constant=self.strongly_convex_constant, + gradient_lipschitz_constant=self.gradient_lipschitz_constant, + its=self.its, + log=self.log, + ) if self.log: self.phi, self.G, self.fit_log = returned @@ -2504,9 +2771,16 @@ def transform(self, Xs, ys=None): """ returned = nearest_brenier_potential_predict_bounds( - self.fit_Xs, self.phi, self.G, Xs, X_classes=self.fit_ys, Y_classes=ys, + self.fit_Xs, + self.phi, + self.G, + Xs, + X_classes=self.fit_ys, + Y_classes=ys, strongly_convex_constant=self.strongly_convex_constant, - gradient_lipschitz_constant=self.gradient_lipschitz_constant, log=self.log) + gradient_lipschitz_constant=self.gradient_lipschitz_constant, + log=self.log, + ) if self.log: _, G_lu, self.predict_log = returned else: diff --git a/ot/datasets.py b/ot/datasets.py index 35e781618..6e3be518a 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -6,7 +6,6 @@ # # License: MIT License - import numpy as np import scipy as sp from .utils import check_random_state, deprecated @@ -30,13 +29,13 @@ def make_1D_gauss(n, m, s): 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) - h = np.exp(-(x - m) ** 2 / (2 * s ** 2)) + h = np.exp(-((x - m) ** 2) / (2 * s**2)) return h / h.sum() @deprecated() def get_1D_gauss(n, m, sigma): - """ Deprecated see make_1D_gauss """ + """Deprecated see make_1D_gauss""" return make_1D_gauss(n, m, sigma) @@ -65,7 +64,11 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None): generator = check_random_state(random_state) if np.isscalar(sigma): - sigma = np.array([sigma, ]) + sigma = np.array( + [ + sigma, + ] + ) if len(sigma) > 1: P = sp.linalg.sqrtm(sigma) res = generator.randn(n, 2).dot(P) + m @@ -76,11 +79,11 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None): @deprecated() def get_2D_samples_gauss(n, m, sigma, random_state=None): - """ Deprecated see make_2D_samples_gauss """ + """Deprecated see make_2D_samples_gauss""" return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters @@ -108,38 +111,39 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa """ generator = check_random_state(random_state) - if dataset.lower() == '3gauss': + if dataset.lower() == "3gauss": y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) # class 1 - x[y == 1, 0] = -1. - x[y == 1, 1] = -1. - x[y == 2, 0] = -1. - x[y == 2, 1] = 1. - x[y == 3, 0] = 1. + x[y == 1, 0] = -1.0 + x[y == 1, 1] = -1.0 + x[y == 2, 0] = -1.0 + x[y == 2, 1] = 1.0 + x[y == 3, 0] = 1.0 x[y == 3, 1] = 0 x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2) x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) - elif dataset.lower() == '3gauss2': + elif dataset.lower() == "3gauss2": y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 x = np.zeros((n, 2)) y[y == 4] = 3 # class 1 - x[y == 1, 0] = -2. - x[y == 1, 1] = -2. - x[y == 2, 0] = -2. - x[y == 2, 1] = 2. - x[y == 3, 0] = 2. + x[y == 1, 0] = -2.0 + x[y == 1, 1] = -2.0 + x[y == 2, 0] = -2.0 + x[y == 2, 1] = 2.0 + x[y == 3, 0] = 2.0 x[y == 3, 1] = 0 x[y != 3, :] += nz * generator.randn(sum(y != 3), 2) x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2) - elif dataset.lower() == 'gaussrot': + elif dataset.lower() == "gaussrot": rot = np.array( - [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]]) + [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]] + ) m1 = np.array([-1, 1]) m2 = np.array([1, -1]) y = np.floor((np.arange(n) * 1.0 / n * 2)) + 1 @@ -152,16 +156,17 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa x = x.dot(rot) - elif dataset.lower() == '2gauss_prop': - + elif dataset.lower() == "2gauss_prop": y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) - x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * generator.randn(len(y), 2) + x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * generator.randn( + len(y), 2 + ) - if ('bias' not in kwargs) and ('b' not in kwargs): - kwargs['bias'] = np.array([0, 2]) + if ("bias" not in kwargs) and ("b" not in kwargs): + kwargs["bias"] = np.array([0, 2]) - x[:, 0] += kwargs['bias'][0] - x[:, 1] += kwargs['bias'][1] + x[:, 0] += kwargs["bias"][0] + x[:, 1] += kwargs["bias"][1] else: x = np.array(0) @@ -172,6 +177,6 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa @deprecated() -def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): - """ Deprecated see make_data_classif """ - return make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs) +def get_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs): + """Deprecated see make_data_classif""" + return make_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs) diff --git a/ot/dr.py b/ot/dr.py index e410ff837..c374c440c 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -17,6 +17,7 @@ # License: MIT License from scipy import linalg + try: import autograd.numpy as np from sklearn.decomposition import PCA @@ -25,23 +26,23 @@ import pymanopt.manifolds import pymanopt.optimizers except ImportError: - raise ImportError("Missing dependency for ot.dr. Requires autograd, pymanopt, scikit-learn. You can install with install with 'pip install POT[dr]', or 'conda install autograd pymanopt scikit-learn'") + raise ImportError( + "Missing dependency for ot.dr. Requires autograd, pymanopt, scikit-learn. You can install with install with 'pip install POT[dr]', or 'conda install autograd pymanopt scikit-learn'" + ) from .bregman import sinkhorn as sinkhorn_bregman from .utils import dist as dist_utils, check_random_state def dist(x1, x2): - r""" Compute squared euclidean distance between samples (autograd) - """ + r"""Compute squared euclidean distance between samples (autograd)""" x1p2 = np.sum(np.square(x1), 1) x2p2 = np.sum(np.square(x2), 1) return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T) def sinkhorn(w1, w2, M, reg, k): - r"""Sinkhorn algorithm with fixed number of iteration (autograd) - """ + r"""Sinkhorn algorithm with fixed number of iteration (autograd)""" K = np.exp(-M / reg) ui = np.ones((M.shape[0],)) vi = np.ones((M.shape[1],)) @@ -53,15 +54,13 @@ def sinkhorn(w1, w2, M, reg, k): def logsumexp(M, axis): - r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) - """ + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)""" amax = np.amax(M, axis=axis, keepdims=True) return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) def sinkhorn_log(w1, w2, M, reg, k): - r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) - """ + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)""" Mr = -M / reg ui = np.zeros((M.shape[0],)) vi = np.zeros((M.shape[1],)) @@ -75,8 +74,7 @@ def sinkhorn_log(w1, w2, M, reg, k): def split_classes(X, y): - r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` - """ + r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`""" lstsclass = np.unique(y) return [X[y == i, :].astype(np.float32) for i in lstsclass] @@ -126,8 +124,7 @@ def fda(X, y, p=2, reg=1e-16): mx0 = np.mean(mxc, 1) Cb = 0 for i in range(nc): - Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \ - (mxc[:, i] - mx0).reshape((1, -1)) + Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * (mxc[:, i] - mx0).reshape((1, -1)) w, V = linalg.eig(Cb, Cw + reg * np.eye(d)) @@ -141,7 +138,19 @@ def proj(X): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): +def wda( + X, + y, + p=2, + reg=1, + k=10, + solver=None, + sinkhorn_method="sinkhorn", + maxiter=100, + verbose=0, + P0=None, + normalize=False, +): r""" Wasserstein Discriminant Analysis :ref:`[11] ` @@ -202,9 +211,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa - if sinkhorn_method.lower() == 'sinkhorn': + if sinkhorn_method.lower() == "sinkhorn": sinkhorn_solver = sinkhorn - elif sinkhorn_method.lower() == 'sinkhorn_log': + elif sinkhorn_method.lower() == "sinkhorn_log": sinkhorn_solver = sinkhorn_log else: raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) @@ -258,9 +267,13 @@ def cost(P): # declare solver and solve if solver is None: - solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose) - elif solver in ['tr', 'TrustRegions']: - solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose) + solver = pymanopt.optimizers.SteepestDescent( + max_iterations=maxiter, log_verbosity=verbose + ) + elif solver in ["tr", "TrustRegions"]: + solver = pymanopt.optimizers.TrustRegions( + max_iterations=maxiter, log_verbosity=verbose + ) Popt = solver.run(problem, initial_point=P0) @@ -270,7 +283,20 @@ def proj(X): return Popt.point, proj -def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0, random_state=None): +def projection_robust_wasserstein( + X, + Y, + a, + b, + tau, + U0=None, + reg=0.1, + k=2, + stopThr=1e-3, + maxiter=100, + verbose=0, + random_state=None, +): r""" Projection Robust Wasserstein Distance :ref:`[32] ` @@ -322,7 +348,7 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh References ---------- .. [32] Huang, M. , Ma S. & Lai L. (2021). - A Riemannian Block Coordinate Descent Method for Computing + A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance, ICML. """ # noqa @@ -347,17 +373,24 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh def Vpi(X, Y, a, b, pi): # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }. A = X.T.dot(pi).dot(Y) - return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T + return ( + X.T.dot(np.diag(a)).dot(X) + + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) + - A + - A.T + ) err = 1 iter = 0 while err > stopThr and iter < maxiter: - # Projected cost matrix UUT = U.dot(U.T) - M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot( - np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T)) + M = ( + np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + + ones.dot(np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) + - 2 * X.dot(UUT.dot(Y.T)) + ) A = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=A) @@ -367,7 +400,7 @@ def Vpi(X, Y, a, b, pi): Ap = (1 / a).reshape(-1, 1) * A AtransposeU = np.dot(A.T, u) v = np.divide(b, AtransposeU) - u = 1. / np.dot(Ap, v) + u = 1.0 / np.dot(Ap, v) pi = u.reshape((-1, 1)) * A * v.reshape((1, -1)) V = Vpi(X, Y, a, b, pi) @@ -383,14 +416,26 @@ def Vpi(X, Y, a, b, pi): f_val = np.trace(U.T.dot(V.dot(U))) if verbose: - print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val) + print("RBCD Iteration: ", iter, " error", err, "\t fval: ", f_val) iter = iter + 1 return pi, U -def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopThr=1e-6, maxiter=100, maxiter_sink=1000, maxiter_MM=10, verbose=0): +def ewca( + X, + U0=None, + reg=1, + k=2, + method="BCD", + sinkhorn_method="sinkhorn", + stopThr=1e-6, + maxiter=100, + maxiter_sink=1000, + maxiter_MM=10, + verbose=0, +): r""" Entropic Wasserstein Component Analysis :ref:`[52] `. @@ -452,17 +497,21 @@ def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopT if U0 is None: pca_fitted = PCA(n_components=k).fit(X) U = pca_fitted.components_.T - if method == 'MM': + if method == "MM": lambda_scm = pca_fitted.explained_variance_[0] else: U = U0 # marginals - u0 = (1. / n) * np.ones(n) + u0 = (1.0 / n) * np.ones(n) # print iterations if verbose > 0: - print('{:4s}|{:13s}|{:12s}|{:12s}'.format('It.', 'Loss', 'Crit.', 'Thres.') + '\n' + '-' * 40) + print( + "{:4s}|{:13s}|{:12s}|{:12s}".format("It.", "Loss", "Crit.", "Thres.") + + "\n" + + "-" * 40 + ) def compute_loss(M, pi, reg): return np.sum(M * pi) + reg * np.sum(pi * (np.log(pi) - 1)) @@ -485,12 +534,17 @@ def grassmann_distance(U1, U2): # Solve transport M = dist_utils(X, (X @ U) @ U.T) pi, log_sinkhorn = sinkhorn_bregman( - u0, u0, M, reg, + u0, + u0, + M, + reg, numItermax=maxiter_sink, - method=sinkhorn_method, warmstart=sinkhorn_warmstart, - warn=False, log=True + method=sinkhorn_method, + warmstart=sinkhorn_warmstart, + warn=False, + log=True, ) - key_warmstart = 'warmstart' + key_warmstart = "warmstart" if key_warmstart in log_sinkhorn: sinkhorn_warmstart = log_sinkhorn[key_warmstart] if (pi >= 1e-300).all(): @@ -501,13 +555,13 @@ def grassmann_distance(U1, U2): # Solve PCA pi_sym = (pi + pi.T) / 2 - if method == 'BCD': + if method == "BCD": # block coordinate descent - S = X.T @ (2 * pi_sym - (1. / n) * np.eye(n)) @ X + S = X.T @ (2 * pi_sym - (1.0 / n) * np.eye(n)) @ X _, U = np.linalg.eigh(S) U = U[:, ::-1][:, :k] - elif method == 'MM': + elif method == "MM": # majorization-minimization eig, _ = np.linalg.eigh(pi_sym) lambda_pi = eig[0] @@ -535,6 +589,6 @@ def grassmann_distance(U1, U2): # print if verbose > 0: - print('{:4d}|{:8e}|{:8e}|{:8e}'.format(it, loss, crit, stopThr)) + print("{:4d}|{:8e}|{:8e}|{:8e}".format(it, loss, crit, stopThr)) return pi, U diff --git a/ot/factored.py b/ot/factored.py index 65613d328..f1b9f28c4 100644 --- a/ot/factored.py +++ b/ot/factored.py @@ -11,10 +11,23 @@ from .lp import emd from .bregman import sinkhorn -__all__ = ['factored_optimal_transport'] - - -def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): +__all__ = ["factored_optimal_transport"] + + +def factored_optimal_transport( + Xa, + Xb, + a=None, + b=None, + reg=0.0, + r=100, + X0=None, + stopThr=1e-7, + numItermax=100, + verbose=False, + log=False, + **kwargs, +): r"""Solves factored OT problem and return OT plans and intermediate distribution This function solve the following OT problem [40]_ @@ -107,7 +120,7 @@ def solve_ot(X1, X2, w1, w2): M = dist(X1, X2) if reg > 0: G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) - log['cost'] = nx.sum(G * M) + log["cost"] = nx.sum(G * M) return G, log else: return emd(w1, w2, M, log=True, **kwargs) @@ -116,7 +129,6 @@ def solve_ot(X1, X2, w1, w2): # solve the barycenter for i in range(numItermax): - old_X = X # solve OT with template @@ -132,15 +144,16 @@ def solve_ot(X1, X2, w1, w2): norm_delta.append(delta) if log: - log_dic = {'delta_iter': norm_delta, - 'ua': loga['u'], - 'va': loga['v'], - 'ub': logb['u'], - 'vb': logb['v'], - 'costa': loga['cost'], - 'costb': logb['cost'], - 'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx), - } + log_dic = { + "delta_iter": norm_delta, + "ua": loga["u"], + "va": loga["v"], + "ub": logb["u"], + "vb": logb["v"], + "costa": loga["cost"], + "costb": logb["cost"], + "lazy_plan": get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx), + } return Ga, Gb, X, log_dic return Ga, Gb, X diff --git a/ot/gaussian.py b/ot/gaussian.py index 832d193da..4645d5fa4 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -85,15 +85,16 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): if log: log = {} - log['Cs12'] = Cs12 - log['Cs12inv'] = Cs12inv + log["Cs12"] = Cs12 + log["Cs12inv"] = Cs12inv return A, b, log else: return A, b -def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, - wt=None, bias=True, log=False): +def empirical_bures_wasserstein_mapping( + xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False +): r"""Return OT linear operator between samples. The function estimates the optimal linear operator that aligns the two @@ -187,11 +188,12 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, if is_input_finite and not is_all_finite(A, b): warnings.warn( "Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping. " - "Consider increasing the regularization parameter `reg`.") + "Consider increasing the regularization parameter `reg`." + ) if log: - log['Cs'] = Cs - log['Ct'] = Ct + log["Cs"] = Cs + log["Ct"] = Ct return A, b, log else: return A, b @@ -249,18 +251,19 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.maximum(nx.norm(ms - mt)**2 + B, 0)) + W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) if log: log = {} - log['Cs12'] = Cs12 + log["Cs12"] = Cs12 return W, log else: return W -def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, - wt=None, bias=True, log=False): +def empirical_bures_wasserstein_distance( + xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False +): r"""Return Bures Wasserstein distance from mean and covariance of distribution. The function estimates the Bures-Wasserstein distance between two @@ -336,15 +339,17 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, if log: W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) - log['Cs'] = Cs - log['Ct'] = Ct + log["Cs"] = Cs + log["Ct"] = Ct return W, log else: W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) return W -def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False): +def bures_wasserstein_barycenter( + m, C, weights=None, num_iter=1000, eps=1e-7, log=False +): r"""Return OT linear operator between samples. The function estimates the optimal barycenter of the @@ -397,7 +402,10 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ - nx = get_backend(*C, *m,) + nx = get_backend( + *C, + *m, + ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] @@ -430,16 +438,15 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo if log: log = {} - log['num_iter'] = it - log['final_diff'] = diff + log["num_iter"] = it + log["final_diff"] = diff return mb, Cb, log else: return mb, Cb def empirical_bures_wasserstein_barycenter( - X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7, - w=None, bias=True, log=False + X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7, w=None, bias=True, log=False ): r"""Return OT linear operator between samples. @@ -504,7 +511,9 @@ def empirical_bures_wasserstein_barycenter( d = [X[i].shape[1] for i in range(k)] if w is None: - w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)] + w = [ + nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k) + ] if bias: m = [nx.dot(w[i].T, X[i]) / nx.sum(w[i]) for i in range(k)] @@ -519,15 +528,19 @@ def empirical_bures_wasserstein_barycenter( m = nx.stack(m, axis=0) C = nx.stack(C, axis=0) if log: - mb, Cb, log = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log) + mb, Cb, log = bures_wasserstein_barycenter( + m, C, weights=weights, num_iter=num_iter, eps=eps, log=log + ) return mb, Cb, log else: - mb, Cb = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log) + mb, Cb = bures_wasserstein_barycenter( + m, C, weights=weights, num_iter=num_iter, eps=eps, log=log + ) return mb, Cb def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): - r""" Return the Gaussian Gromov-Wasserstein value from [57]. + r"""Return the Gaussian Gromov-Wasserstein value from [57]. This function return the closed form value of the Gaussian Gromov-Wasserstein distance between two Gaussian distributions @@ -571,18 +584,21 @@ def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): d_t = nx.flip(nx.sort(nx.eigh(Cov_t)[0])) # compute the gaussien Gromov-Wasserstein distance - res = 4 * (nx.sum(d_s) - nx.sum(d_t))**2 + 8 * nx.sum((d_s[:n] - d_t)**2) + 8 * nx.sum((d_s[n:])**2) + res = ( + 4 * (nx.sum(d_s) - nx.sum(d_t)) ** 2 + + 8 * nx.sum((d_s[:n] - d_t) ** 2) + + 8 * nx.sum((d_s[n:]) ** 2) + ) if log: log = {} - log['d_s'] = d_s - log['d_t'] = d_t + log["d_s"] = d_s + log["d_t"] = d_t return nx.sqrt(res), log else: return nx.sqrt(res) -def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, - wt=None, log=False): +def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, wt=None, log=False): r"""Return Gaussian Gromov-Wasserstein distance between samples. The function estimates the Gaussian Gromov-Wasserstein distance between two @@ -637,16 +653,18 @@ def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, if log: G, log = gaussian_gromov_wasserstein_distance(Cs, Ct, log=log) - log['Cov_s'] = Cs - log['Cov_t'] = Ct + log["Cov_s"] = Cs + log["Cov_t"] = Ct return G, log else: G = gaussian_gromov_wasserstein_distance(Cs, Ct) return G -def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None, log=False): - r""" Return the Gaussian Gromov-Wasserstein mapping from [57]. +def gaussian_gromov_wasserstein_mapping( + mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None, log=False +): + r"""Return the Gaussian Gromov-Wasserstein mapping from [57]. This function return the closed form value of the Gaussian Gromov-Wasserstein mapping between two Gaussian distributions @@ -702,9 +720,21 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None sign_eigs = nx.ones(min(m, n), type_as=mu_s) if m >= n: - A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T + A = nx.concatenate( + ( + nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), + nx.zeros((n, m - n), type_as=mu_s), + ), + axis=1, + ).T else: - A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + A = nx.concatenate( + ( + nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), + nx.zeros((n - m, m), type_as=mu_s), + ), + axis=0, + ).T A = nx.dot(nx.dot(U_s, A), U_t.T) @@ -713,17 +743,18 @@ def gaussian_gromov_wasserstein_mapping(mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None if log: log = {} - log['d_s'] = d_s - log['d_t'] = d_t - log['U_s'] = U_s - log['U_t'] = U_t + log["d_s"] = d_s + log["d_t"] = d_t + log["U_s"] = U_s + log["U_t"] = U_t return A, b, log else: return A, b -def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, - wt=None, sign_eigs=None, log=False): +def empirical_gaussian_gromov_wasserstein_mapping( + xs, xt, ws=None, wt=None, sign_eigs=None, log=False +): r"""Return Gaussian Gromov-Wasserstein mapping between samples. The function estimates the Gaussian Gromov-Wasserstein mapping between two @@ -800,16 +831,28 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, # select the sign of the eigenvalues if sign_eigs is None: sign_eigs = nx.ones(min(m, n), type_as=mu_s) - elif sign_eigs == 'skewness': + elif sign_eigs == "skewness": size = min(m, n) - skew_s = nx.sum((nx.dot(xs, U_s[:, :size]))**3 * ws, axis=0) - skew_t = nx.sum((nx.dot(xt, U_t[:, :size]))**3 * wt, axis=0) + skew_s = nx.sum((nx.dot(xs, U_s[:, :size])) ** 3 * ws, axis=0) + skew_t = nx.sum((nx.dot(xt, U_t[:, :size])) ** 3 * wt, axis=0) sign_eigs = nx.sign(skew_t * skew_s) if m >= n: - A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s)), axis=1).T + A = nx.concatenate( + ( + nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), + nx.zeros((n, m - n), type_as=mu_s), + ), + axis=1, + ).T else: - A = nx.concatenate((nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s)), axis=0).T + A = nx.concatenate( + ( + nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), + nx.zeros((n - m, m), type_as=mu_s), + ), + axis=0, + ).T A = nx.dot(nx.dot(U_s, A), U_t.T) @@ -818,12 +861,12 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, if log: log = {} - log['d_s'] = d_s - log['d_t'] = d_t - log['U_s'] = U_s - log['U_t'] = U_t - log['Cov_s'] = Cov_s - log['Cov_t'] = Cov_t + log["d_s"] = d_s + log["d_t"] = d_t + log["U_s"] = U_s + log["U_t"] = U_t + log["Cov_s"] = Cov_s + log["Cov_t"] = Cov_t return A, b, log else: return A, b diff --git a/ot/gmm.py b/ot/gmm.py index 06caf90c7..0a597a70d 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -36,8 +36,9 @@ def gaussian_pdf(x, m, C): The probability density function evaluated at each sample. """ - assert x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2], \ - "Dimension mismatch" + assert ( + x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2] + ), "Dimension mismatch" nx = get_backend(x, m, C) d = x.shape[-1] z = (2 * np.pi) ** (-d / 2) * nx.det(C) ** (-0.5) @@ -67,8 +68,9 @@ def gmm_pdf(x, m, C, w): The PDF values at the given points. """ - assert m.shape[0] == C.shape[0] == w.shape[0], \ - "All GMM parameters must have the same amount of components" + assert ( + m.shape[0] == C.shape[0] == w.shape[0] + ), "All GMM parameters must have the same amount of components" nx = get_backend(x, m, C, w) out = nx.zeros((x.shape[:-1])) for k in range(m.shape[0]): @@ -106,27 +108,26 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): """ nx = get_backend(m_s, C_s, m_t, C_t) - assert m_s.shape[0] == C_s.shape[0], \ - "Source GMM has different amount of components" + assert m_s.shape[0] == C_s.shape[0], "Source GMM has different amount of components" - assert m_t.shape[0] == C_t.shape[0], \ - "Target GMM has different amount of components" + assert m_t.shape[0] == C_t.shape[0], "Target GMM has different amount of components" - assert m_s.shape[-1] == m_t.shape[-1] == C_s.shape[-1] == C_t.shape[-1], \ - "All GMMs must have the same dimension" + assert ( + m_s.shape[-1] == m_t.shape[-1] == C_s.shape[-1] == C_t.shape[-1] + ), "All GMMs must have the same dimension" - D_means = dist(m_s, m_t, metric='sqeuclidean') + D_means = dist(m_s, m_t, metric="sqeuclidean") # C2[i, j] = Cs12[i] @ C_t[j] @ Cs12[i], shape (k_s, k_t, d, d) Cs12 = nx.sqrtm(C_s) # broadcasts matrix sqrt over (k_s,) - C2 = nx.einsum('ikl,jlm,imn->ijkn', Cs12, C_t, Cs12) + C2 = nx.einsum("ikl,jlm,imn->ijkn", Cs12, C_t, Cs12) C = nx.sqrtm(C2) # broadcasts matrix sqrt over (k_s, k_t) # D_covs[i,j] = trace(C_s[i] + C_t[j] - 2C[i,j]) - trace_C_s = nx.einsum('ikk->i', C_s)[:, None] # (k_s, 1) - trace_C_t = nx.einsum('ikk->i', C_t)[None, :] # (1, k_t) + trace_C_s = nx.einsum("ikk->i", C_s)[:, None] # (k_s, 1) + trace_C_t = nx.einsum("ikk->i", C_t)[None, :] # (1, k_t) D_covs = trace_C_s + trace_C_t # broadcasts to (k_s, k_t) - D_covs -= 2 * nx.einsum('ijkk->ij', C) + D_covs -= 2 * nx.einsum("ijkk->ij", C) return nx.maximum(D_means + D_covs, 0) @@ -169,11 +170,9 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t, log=False): """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) - assert m_s.shape[0] == w_s.shape[0], \ - "Source GMM has different amount of components" + assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" - assert m_t.shape[0] == w_t.shape[0], \ - "Target GMM has different amount of components" + assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd2(w_s, w_t, D, log=log) @@ -217,18 +216,17 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t, log=False): """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) - assert m_s.shape[0] == w_s.shape[0], \ - "Source GMM has different amount of components" + assert m_s.shape[0] == w_s.shape[0], "Source GMM has different amount of components" - assert m_t.shape[0] == w_t.shape[0], \ - "Target GMM has different amount of components" + assert m_t.shape[0] == w_t.shape[0], "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd(w_s, w_t, D, log=log) -def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, - method='bary', seed=None): +def gmm_ot_apply_map( + x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, method="bary", seed=None +): r""" Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. The 'barycentric' mapping corresponds to the barycentric projection @@ -282,13 +280,13 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, d = m_s.shape[1] n_samples = x.shape[0] - if method == 'bary': + if method == "bary": normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) - print('where plan > 0', nx.where(plan > 0)) + print("where plan > 0", nx.where(plan > 0)) # only need to compute for non-zero plan entries - for (i, j) in zip(*nx.where(plan > 0)): + for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) g = gaussian_pdf(x, m_s[i], C_s[i])[:, None] @@ -312,7 +310,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, b = nx.zeros((k_s, k_t, d)) # only need to compute for non-zero plan entries - for (i, j) in zip(*nx.where(plan > 0)): + for i, j in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) @@ -321,8 +319,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, b[i, j] = m_t[j] - A[i, j] @ m_s[i] normalization = gmm_pdf(x, m_s, C_s, w_s) # (n_samples,) - gs = np.stack( - [gaussian_pdf(x, m_s[i], C_s[i]) for i in range(k_s)], axis=-1) + gs = np.stack([gaussian_pdf(x, m_s[i], C_s[i]) for i in range(k_s)], axis=-1) # (n_samples, k_s) out = nx.zeros(x.shape) @@ -338,8 +335,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, return out -def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, - plan=None, atol=1e-2): +def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): """ Compute the density of the Gaussian Mixture Model - Optimal Transport coupling between GMMS at given points, as introduced in [69]. @@ -381,8 +377,9 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ - assert x.shape[-1] == y.shape[-1], \ - "x (n, d) and y (m, d) must have the same dimension d" + assert ( + x.shape[-1] == y.shape[-1] + ), "x (n, d) and y (m, d) must have the same dimension d" n, m = x.shape[0], y.shape[0] nx = get_backend(x, y, m_s, m_t, C_s, C_t, w_s, w_t) @@ -399,12 +396,13 @@ def Tk0k1(k0, k1): g = gaussian_pdf(xx, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = nx.norm(Tx - yy, axis=-1) - out = out * ((norms < atol) * 1.) + out = out * ((norms < atol) * 1.0) return out mat = nx.stack( [ nx.stack([Tk0k1(k0, k1) for k1 in range(m_t.shape[0])]) for k0 in range(m_s.shape[0]) - ]) + ] + ) return nx.sum(mat, axis=(0, 1)) diff --git a/ot/gnn/__init__.py b/ot/gnn/__init__.py index af39db6d2..5f3a93fed 100644 --- a/ot/gnn/__init__.py +++ b/ot/gnn/__init__.py @@ -16,9 +16,13 @@ # All submodules and packages +from ._utils import FGW_distance_to_templates, wasserstein_distance_to_templates -from ._utils import (FGW_distance_to_templates, wasserstein_distance_to_templates) +from ._layers import TFGWPooling, TWPooling -from ._layers import (TFGWPooling, TWPooling) - -__all__ = ['FGW_distance_to_templates', 'wasserstein_distance_to_templates', 'TFGWPooling', 'TWPooling'] +__all__ = [ + "FGW_distance_to_templates", + "wasserstein_distance_to_templates", + "TFGWPooling", + "TWPooling", +] diff --git a/ot/gnn/_layers.py b/ot/gnn/_layers.py index 6fddc8254..b3adad489 100644 --- a/ot/gnn/_layers.py +++ b/ot/gnn/_layers.py @@ -10,7 +10,11 @@ import torch import torch.nn as nn -from ._utils import TFGW_template_initialization, FGW_distance_to_templates, wasserstein_distance_to_templates +from ._utils import ( + TFGW_template_initialization, + FGW_distance_to_templates, + wasserstein_distance_to_templates, +) class TFGWPooling(nn.Module): @@ -58,7 +62,17 @@ class TFGWPooling(nn.Module): "Template based graph neural network with optimal transport distances" """ - def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, alpha=None, train_node_weights=True, multi_alpha=False, feature_init_mean=0., feature_init_std=1.): + def __init__( + self, + n_features, + n_tplt=2, + n_tplt_nodes=2, + alpha=None, + train_node_weights=True, + multi_alpha=False, + feature_init_mean=0.0, + feature_init_std=1.0, + ): r""" Template Fused Gromov-Wasserstein (TFGW) layer. This layer is a pooling layer for graph neural networks. Computes the fused Gromov-Wasserstein distances between the graph and a set of templates. @@ -101,7 +115,7 @@ def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, alpha=None, train_node_ .. [53] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. "Template based graph neural network with optimal transport distances" - """ + """ super().__init__() self.n_tplt = n_tplt @@ -111,7 +125,13 @@ def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, alpha=None, train_node_ self.feature_init_mean = feature_init_mean self.feature_init_std = feature_init_std - tplt_adjacencies, tplt_features, self.q0 = TFGW_template_initialization(self.n_tplt, self.n_tplt_nodes, self.n_features, self.feature_init_mean, self.feature_init_std) + tplt_adjacencies, tplt_features, self.q0 = TFGW_template_initialization( + self.n_tplt, + self.n_tplt_nodes, + self.n_features, + self.feature_init_mean, + self.feature_init_std, + ) self.tplt_adjacencies = nn.Parameter(tplt_adjacencies) self.tplt_features = nn.Parameter(tplt_features) @@ -146,13 +166,22 @@ def forward(self, x, edge_index, batch=None): """ alpha = torch.sigmoid(self.alpha0) q = self.softmax(self.q0) - x = FGW_distance_to_templates(edge_index, self.tplt_adjacencies, x, self.tplt_features, q, alpha, self.multi_alpha, batch) + x = FGW_distance_to_templates( + edge_index, + self.tplt_adjacencies, + x, + self.tplt_features, + q, + alpha, + self.multi_alpha, + batch, + ) return x class TWPooling(nn.Module): r""" - Template Wasserstein (TW) layer, also kown as OT-GNN layer. This layer is a pooling layer for graph neural networks. + Template Wasserstein (TW) layer, also known as OT-GNN layer. This layer is a pooling layer for graph neural networks. Computes the Wasserstein distances between the features of the graph features and a set of templates. .. math:: @@ -160,7 +189,7 @@ class TWPooling(nn.Module): where : - - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \{1,...,K \}} \}` is the set of :math:`K` templates charactersised by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \{1,...,K \}} \}` is the set of :math:`K` templates characterized by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. - :math:`F` and :math:`h` are respectively the feature matrix and the node weights of the graph. Parameters @@ -185,9 +214,17 @@ class TWPooling(nn.Module): """ - def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, train_node_weights=True, feature_init_mean=0., feature_init_std=1.): + def __init__( + self, + n_features, + n_tplt=2, + n_tplt_nodes=2, + train_node_weights=True, + feature_init_mean=0.0, + feature_init_std=1.0, + ): r""" - Template Wasserstein (TW) layer, also kown as OT-GNN layer. This layer is a pooling layer for graph neural networks. + Template Wasserstein (TW) layer, also known as OT-GNN layer. This layer is a pooling layer for graph neural networks. Computes the Wasserstein distances between the features of the graph features and a set of templates. .. math:: @@ -195,7 +232,7 @@ def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, train_node_weights=True where : - - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \llbracket 1;K \rrbracket }` is the set of :math:`K` templates charactersised by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. + - :math:`\mathcal{G}=\{(\overline{F}_k,\overline{h}_k) \}_{k \in \llbracket 1;K \rrbracket }` is the set of :math:`K` templates characterized by their feature matrices :math:`\overline{F}_k` and their node weights :math:`\overline{h}_k`. - :math:`F` and :math:`h` are respectively the feature matrix and the node weights of the graph. Parameters @@ -226,7 +263,13 @@ def __init__(self, n_features, n_tplt=2, n_tplt_nodes=2, train_node_weights=True self.feature_init_mean = feature_init_mean self.feature_init_std = feature_init_std - _, tplt_features, self.q0 = TFGW_template_initialization(self.n_tplt, self.n_tplt_nodes, self.n_features, self.feature_init_mean, self.feature_init_std) + _, tplt_features, self.q0 = TFGW_template_initialization( + self.n_tplt, + self.n_tplt_nodes, + self.n_features, + self.feature_init_mean, + self.feature_init_std, + ) self.tplt_features = nn.Parameter(tplt_features) self.softmax = nn.Softmax(dim=1) diff --git a/ot/gnn/_utils.py b/ot/gnn/_utils.py index 18e32f627..16487c210 100644 --- a/ot/gnn/_utils.py +++ b/ot/gnn/_utils.py @@ -15,12 +15,14 @@ from torch_geometric.utils import subgraph -def TFGW_template_initialization(n_tplt, n_tplt_nodes, n_features, feature_init_mean=0., feature_init_std=1.): +def TFGW_template_initialization( + n_tplt, n_tplt_nodes, n_features, feature_init_mean=0.0, feature_init_std=1.0 +): """ Initializes templates for the Template Fused Gromov Wasserstein layer. Returns the adjacency matrices and the features of the nodes of the templates. - Adjacency matrices are intialised uniformly with values in :math:`[0,1]`. - Node features are intialized following a normal distribution. + Adjacency matrices are initialized uniformly with values in :math:`[0,1]`. + Node features are initialized following a normal distribution. Parameters ---------- @@ -39,7 +41,7 @@ def TFGW_template_initialization(n_tplt, n_tplt_nodes, n_features, feature_init_ Returns ---------- tplt_adjacencies: torch.Tensor, shape (n_templates, n_template_nodes, n_template_nodes) - Adjancency matrices for the templates. + Adjacency matrices for the templates. tplt_features: torch.Tensor, shape (n_templates, n_template_nodes, n_features) Node features for each template. q: torch.Tensor, shape (n_templates, n_template_nodes) @@ -53,12 +55,23 @@ def TFGW_template_initialization(n_tplt, n_tplt_nodes, n_features, feature_init_ q = torch.zeros(n_tplt, n_tplt_nodes) - tplt_adjacencies = 0.5 * (tplt_adjacencies + torch.transpose(tplt_adjacencies, 1, 2)) + tplt_adjacencies = 0.5 * ( + tplt_adjacencies + torch.transpose(tplt_adjacencies, 1, 2) + ) return tplt_adjacencies, tplt_features, q -def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_features, tplt_weights, alpha=0.5, multi_alpha=False, batch=None): +def FGW_distance_to_templates( + G_edges, + tplt_adjacencies, + G_features, + tplt_features, + tplt_weights, + alpha=0.5, + multi_alpha=False, + batch=None, +): """ Computes the FGW distances between a graph and templates. @@ -89,47 +102,67 @@ def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_featur """ if batch is None: - n, n_feat = G_features.shape n_T, _, n_feat_T = tplt_features.shape weights_G = torch.ones(n) / n - C = torch.sparse_coo_tensor(G_edges, torch.ones(len(G_edges[0])), size=(n, n)).type(torch.float) + C = torch.sparse_coo_tensor( + G_edges, torch.ones(len(G_edges[0])), size=(n, n) + ).type(torch.float) C = C.to_dense() if not n_feat == n_feat_T: - raise ValueError('The templates and the graphs must have the same feature dimension.') + raise ValueError( + "The templates and the graphs must have the same feature dimension." + ) distances = torch.zeros(n_T) for j in range(n_T): - - template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + template_features = tplt_features[j].reshape( + len(tplt_features[j]), n_feat_T + ) M = dist(G_features, template_features).type(torch.float) - #if alpha is zero the emd distance is used + # if alpha is zero the emd distance is used if multi_alpha and torch.any(alpha > 0): - embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) + embedding = fused_gromov_wasserstein2( + M, + C, + tplt_adjacencies[j], + weights_G, + tplt_weights[j], + alpha=alpha[j], + symmetric=True, + max_iter=50, + ) elif not multi_alpha and torch.all(alpha == 0): embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) elif not multi_alpha and alpha > 0: - embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) + embedding = fused_gromov_wasserstein2( + M, + C, + tplt_adjacencies[j], + weights_G, + tplt_weights[j], + alpha=alpha, + symmetric=True, + max_iter=50, + ) else: embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) distances[j] = embedding else: - n_T, _, n_feat_T = tplt_features.shape num_graphs = torch.max(batch) + 1 distances = torch.zeros(num_graphs, n_T) - #iterate over the graphs in the batch + # iterate over the graphs in the batch for i in range(num_graphs): - nodes = torch.where(batch == i)[0] G_edges_i, _ = subgraph(nodes, edge_index=G_edges, relabel_nodes=True) @@ -141,24 +174,47 @@ def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_featur n_edges = len(G_edges_i[0]) - C = torch.sparse_coo_tensor(G_edges_i, torch.ones(n_edges), size=(n, n)).type(torch.float) + C = torch.sparse_coo_tensor( + G_edges_i, torch.ones(n_edges), size=(n, n) + ).type(torch.float) C = C.to_dense() if not n_feat == n_feat_T: - raise ValueError('The templates and the graphs must have the same feature dimension.') + raise ValueError( + "The templates and the graphs must have the same feature dimension." + ) for j in range(n_T): - - template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + template_features = tplt_features[j].reshape( + len(tplt_features[j]), n_feat_T + ) M = dist(G_features_i, template_features).type(torch.float) - #if alpha is zero the emd distance is used + # if alpha is zero the emd distance is used if multi_alpha and torch.any(alpha > 0): - embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) + embedding = fused_gromov_wasserstein2( + M, + C, + tplt_adjacencies[j], + weights_G, + tplt_weights[j], + alpha=alpha[j], + symmetric=True, + max_iter=50, + ) elif not multi_alpha and torch.all(alpha == 0): embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) elif not multi_alpha and alpha > 0: - embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) + embedding = fused_gromov_wasserstein2( + M, + C, + tplt_adjacencies[j], + weights_G, + tplt_weights[j], + alpha=alpha, + symmetric=True, + max_iter=50, + ) else: embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) @@ -167,7 +223,9 @@ def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_featur return distances -def wasserstein_distance_to_templates(G_features, tplt_features, tplt_weights, batch=None): +def wasserstein_distance_to_templates( + G_features, tplt_features, tplt_weights, batch=None +): """ Computes the Wasserstein distances between a graph and graph templates. @@ -189,34 +247,34 @@ def wasserstein_distance_to_templates(G_features, tplt_features, tplt_weights, b """ if batch is None: - n, n_feat = G_features.shape n_T, _, n_feat_T = tplt_features.shape weights_G = torch.ones(n) / n if not n_feat == n_feat_T: - raise ValueError('The templates and the graphs must have the same feature dimension.') + raise ValueError( + "The templates and the graphs must have the same feature dimension." + ) distances = torch.zeros(n_T) for j in range(n_T): - - template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + template_features = tplt_features[j].reshape( + len(tplt_features[j]), n_feat_T + ) M = dist(G_features, template_features).type(torch.float) distances[j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) else: - n_T, _, n_feat_T = tplt_features.shape num_graphs = torch.max(batch) + 1 distances = torch.zeros(num_graphs, n_T) - #iterate over the graphs in the batch + # iterate over the graphs in the batch for i in range(num_graphs): - nodes = torch.where(batch == i)[0] G_features_i = G_features[nodes] @@ -226,11 +284,14 @@ def wasserstein_distance_to_templates(G_features, tplt_features, tplt_weights, b weights_G = torch.ones(n) / n if not n_feat == n_feat_T: - raise ValueError('The templates and the graphs must have the same feature dimension.') + raise ValueError( + "The templates and the graphs must have the same feature dimension." + ) for j in range(n_T): - - template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) + template_features = tplt_features[j].reshape( + len(tplt_features[j]), n_feat_T + ) M = dist(G_features_i, template_features).type(torch.float) distances[i, j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 2efd69ccd..f552cb914 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -11,116 +11,169 @@ # License: MIT License # All submodules and packages -from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, - init_matrix_semirelaxed, semirelaxed_init_plan, - update_barycenter_structure, update_barycenter_feature, - div_between_product, div_to_product, fused_unbalanced_across_spaces_cost, - uot_cost_matrix, uot_parameters_and_measures) +from ._utils import ( + init_matrix, + tensor_product, + gwloss, + gwggrad, + init_matrix_semirelaxed, + semirelaxed_init_plan, + update_barycenter_structure, + update_barycenter_feature, + div_between_product, + div_to_product, + fused_unbalanced_across_spaces_cost, + uot_cost_matrix, + uot_parameters_and_measures, +) -from ._gw import (gromov_wasserstein, gromov_wasserstein2, - fused_gromov_wasserstein, fused_gromov_wasserstein2, - solve_gromov_linesearch, gromov_barycenters, fgw_barycenters) +from ._gw import ( + gromov_wasserstein, + gromov_wasserstein2, + fused_gromov_wasserstein, + fused_gromov_wasserstein2, + solve_gromov_linesearch, + gromov_barycenters, + fgw_barycenters, +) -from ._bregman import (entropic_gromov_wasserstein, - entropic_gromov_wasserstein2, - BAPG_gromov_wasserstein, - BAPG_gromov_wasserstein2, - entropic_gromov_barycenters, - entropic_fused_gromov_wasserstein, - entropic_fused_gromov_wasserstein2, - BAPG_fused_gromov_wasserstein, - BAPG_fused_gromov_wasserstein2, - entropic_fused_gromov_barycenters) +from ._bregman import ( + entropic_gromov_wasserstein, + entropic_gromov_wasserstein2, + BAPG_gromov_wasserstein, + BAPG_gromov_wasserstein2, + entropic_gromov_barycenters, + entropic_fused_gromov_wasserstein, + entropic_fused_gromov_wasserstein2, + BAPG_fused_gromov_wasserstein, + BAPG_fused_gromov_wasserstein2, + entropic_fused_gromov_barycenters, +) -from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein, - sampled_gromov_wasserstein) +from ._estimators import ( + GW_distance_estimation, + pointwise_gromov_wasserstein, + sampled_gromov_wasserstein, +) -from ._semirelaxed import (semirelaxed_gromov_wasserstein, - semirelaxed_gromov_wasserstein2, - semirelaxed_fused_gromov_wasserstein, - semirelaxed_fused_gromov_wasserstein2, - solve_semirelaxed_gromov_linesearch, - entropic_semirelaxed_gromov_wasserstein, - entropic_semirelaxed_gromov_wasserstein2, - entropic_semirelaxed_fused_gromov_wasserstein, - entropic_semirelaxed_fused_gromov_wasserstein2, - semirelaxed_gromov_barycenters, - semirelaxed_fgw_barycenters) +from ._semirelaxed import ( + semirelaxed_gromov_wasserstein, + semirelaxed_gromov_wasserstein2, + semirelaxed_fused_gromov_wasserstein, + semirelaxed_fused_gromov_wasserstein2, + solve_semirelaxed_gromov_linesearch, + entropic_semirelaxed_gromov_wasserstein, + entropic_semirelaxed_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein, + entropic_semirelaxed_fused_gromov_wasserstein2, + semirelaxed_gromov_barycenters, + semirelaxed_fgw_barycenters, +) -from ._dictionary import (gromov_wasserstein_dictionary_learning, - gromov_wasserstein_linear_unmixing, - fused_gromov_wasserstein_dictionary_learning, - fused_gromov_wasserstein_linear_unmixing) +from ._dictionary import ( + gromov_wasserstein_dictionary_learning, + gromov_wasserstein_linear_unmixing, + fused_gromov_wasserstein_dictionary_learning, + fused_gromov_wasserstein_linear_unmixing, +) -from ._lowrank import (_flat_product_operator, - lowrank_gromov_wasserstein_samples) +from ._lowrank import lowrank_gromov_wasserstein_samples -from ._quantized import (quantized_fused_gromov_wasserstein_partitioned, - get_graph_partition, - get_graph_representants, - format_partitioned_graph, - quantized_fused_gromov_wasserstein, - get_partition_and_representants_samples, - format_partitioned_samples, - quantized_fused_gromov_wasserstein_samples - ) +from ._quantized import ( + quantized_fused_gromov_wasserstein_partitioned, + get_graph_partition, + get_graph_representants, + format_partitioned_graph, + quantized_fused_gromov_wasserstein, + get_partition_and_representants_samples, + format_partitioned_samples, + quantized_fused_gromov_wasserstein_samples, +) -from ._unbalanced import (fused_unbalanced_gromov_wasserstein, - fused_unbalanced_gromov_wasserstein2, - unbalanced_co_optimal_transport, - unbalanced_co_optimal_transport2, - fused_unbalanced_across_spaces_divergence) +from ._unbalanced import ( + fused_unbalanced_gromov_wasserstein, + fused_unbalanced_gromov_wasserstein2, + unbalanced_co_optimal_transport, + unbalanced_co_optimal_transport2, + fused_unbalanced_across_spaces_divergence, +) -from ._partial import (partial_gromov_wasserstein, - partial_gromov_wasserstein2, - solve_partial_gromov_linesearch, - entropic_partial_gromov_wasserstein, - entropic_partial_gromov_wasserstein2) +from ._partial import ( + partial_gromov_wasserstein, + partial_gromov_wasserstein2, + solve_partial_gromov_linesearch, + entropic_partial_gromov_wasserstein, + entropic_partial_gromov_wasserstein2, +) -__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', - 'init_matrix_semirelaxed', 'semirelaxed_init_plan', - 'update_barycenter_structure', 'update_barycenter_feature', - 'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost', - 'uot_cost_matrix', 'uot_parameters_and_measures', - 'gromov_wasserstein', 'gromov_wasserstein2', - 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'solve_gromov_linesearch', 'gromov_barycenters', - 'fgw_barycenters', 'entropic_gromov_wasserstein', - 'entropic_gromov_wasserstein2', 'BAPG_gromov_wasserstein', - 'BAPG_gromov_wasserstein2', 'entropic_gromov_barycenters', - 'entropic_fused_gromov_wasserstein', - 'entropic_fused_gromov_wasserstein2', 'BAPG_fused_gromov_wasserstein', - 'BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters', - 'GW_distance_estimation', 'pointwise_gromov_wasserstein', - 'sampled_gromov_wasserstein', - 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2', - 'semirelaxed_fused_gromov_wasserstein', - 'semirelaxed_fused_gromov_wasserstein2', - 'solve_semirelaxed_gromov_linesearch', - 'entropic_semirelaxed_gromov_wasserstein', - 'entropic_semirelaxed_gromov_wasserstein2', - 'entropic_semirelaxed_fused_gromov_wasserstein', - 'entropic_semirelaxed_fused_gromov_wasserstein2', - 'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters', - 'gromov_wasserstein_dictionary_learning', - 'gromov_wasserstein_linear_unmixing', - 'fused_gromov_wasserstein_dictionary_learning', - 'fused_gromov_wasserstein_linear_unmixing', - 'lowrank_gromov_wasserstein_samples', - 'quantized_fused_gromov_wasserstein_partitioned', - 'get_graph_partition', 'get_graph_representants', - 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', - 'get_partition_and_representants_samples', 'format_partitioned_samples', - 'quantized_fused_gromov_wasserstein_samples', - 'fused_unbalanced_gromov_wasserstein', - 'fused_unbalanced_gromov_wasserstein2', - 'unbalanced_co_optimal_transport', - 'unbalanced_co_optimal_transport2', - 'fused_unbalanced_across_spaces_divergence', - 'partial_gromov_wasserstein', 'partial_gromov_wasserstein2', - 'solve_partial_gromov_linesearch', - 'entropic_partial_gromov_wasserstein', - 'entropic_partial_gromov_wasserstein2' - ] +__all__ = [ + "init_matrix", + "tensor_product", + "gwloss", + "gwggrad", + "init_matrix_semirelaxed", + "semirelaxed_init_plan", + "update_barycenter_structure", + "update_barycenter_feature", + "div_between_product", + "div_to_product", + "fused_unbalanced_across_spaces_cost", + "uot_cost_matrix", + "uot_parameters_and_measures", + "gromov_wasserstein", + "gromov_wasserstein2", + "fused_gromov_wasserstein", + "fused_gromov_wasserstein2", + "solve_gromov_linesearch", + "gromov_barycenters", + "fgw_barycenters", + "entropic_gromov_wasserstein", + "entropic_gromov_wasserstein2", + "BAPG_gromov_wasserstein", + "BAPG_gromov_wasserstein2", + "entropic_gromov_barycenters", + "entropic_fused_gromov_wasserstein", + "entropic_fused_gromov_wasserstein2", + "BAPG_fused_gromov_wasserstein", + "BAPG_fused_gromov_wasserstein2", + "entropic_fused_gromov_barycenters", + "GW_distance_estimation", + "pointwise_gromov_wasserstein", + "sampled_gromov_wasserstein", + "semirelaxed_gromov_wasserstein", + "semirelaxed_gromov_wasserstein2", + "semirelaxed_fused_gromov_wasserstein", + "semirelaxed_fused_gromov_wasserstein2", + "solve_semirelaxed_gromov_linesearch", + "entropic_semirelaxed_gromov_wasserstein", + "entropic_semirelaxed_gromov_wasserstein2", + "entropic_semirelaxed_fused_gromov_wasserstein", + "entropic_semirelaxed_fused_gromov_wasserstein2", + "semirelaxed_fgw_barycenters", + "semirelaxed_gromov_barycenters", + "gromov_wasserstein_dictionary_learning", + "gromov_wasserstein_linear_unmixing", + "fused_gromov_wasserstein_dictionary_learning", + "fused_gromov_wasserstein_linear_unmixing", + "lowrank_gromov_wasserstein_samples", + "quantized_fused_gromov_wasserstein_partitioned", + "get_graph_partition", + "get_graph_representants", + "format_partitioned_graph", + "quantized_fused_gromov_wasserstein", + "get_partition_and_representants_samples", + "format_partitioned_samples", + "quantized_fused_gromov_wasserstein_samples", + "fused_unbalanced_gromov_wasserstein", + "fused_unbalanced_gromov_wasserstein2", + "unbalanced_co_optimal_transport", + "unbalanced_co_optimal_transport2", + "fused_unbalanced_across_spaces_divergence", + "partial_gromov_wasserstein", + "partial_gromov_wasserstein2", + "solve_partial_gromov_linesearch", + "entropic_partial_gromov_wasserstein", + "entropic_partial_gromov_wasserstein2", +] diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index c60e786f7..fbc8d4897 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -23,8 +23,22 @@ def entropic_gromov_wasserstein( - C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, - tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1000, + tol=1e-9, + solver="PGD", + warmstart=False, + verbose=False, + log=False, + **kwargs, +): r""" Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` estimated using Sinkhorn projections. @@ -132,11 +146,13 @@ def entropic_gromov_wasserstein( learning for graph matching and node embedding. In International Conference on Machine Learning (ICML), 2019. """ - if solver not in ['PGD', 'PPA']: + if solver not in ["PGD", "PPA"]: raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) C1, C2 = list_to_array(C1, C2) arr = [C1, C2] @@ -161,7 +177,9 @@ def entropic_gromov_wasserstein( constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) @@ -175,28 +193,38 @@ def entropic_gromov_wasserstein( nu = nx.zeros(N2, type_as=C2) - np.log(N2) if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Tprev = T # compute the gradient if symmetric: tens = gwggrad(constC, hC1, hC2, T, nx) else: - tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + tens = 0.5 * ( + gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx) + ) - if solver == 'PPA': + if solver == "PPA": tens = tens - epsilon * nx.log(T) if warmstart: - T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs) - mu = epsilon * nx.log(loginn['u']) - nu = epsilon * nx.log(loginn['v']) + T, loginn = sinkhorn( + p, + q, + tens, + epsilon, + method="sinkhorn", + log=True, + warmstart=(mu, nu), + **kwargs, + ) + mu = epsilon * nx.log(loginn["u"]) + nu = epsilon * nx.log(loginn["v"]) else: - T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs) + T = sinkhorn(p, q, tens, epsilon, method="sinkhorn", **kwargs) if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -204,29 +232,44 @@ def entropic_gromov_wasserstein( err = nx.norm(T - Tprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if abs(nx.sum(T) - 1) > 1e-5: - warnings.warn("Solver failed to produce a transport plan. You might " - "want to increase the regularization parameter `epsilon`.") + warnings.warn( + "Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`." + ) if log: - log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx) + log["gw_dist"] = gwloss(constC, hC1, hC2, T, nx) return T, log else: return T def entropic_gromov_wasserstein2( - C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, - tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1000, + tol=1e-9, + solver="PGD", + warmstart=False, + verbose=False, + log=False, + **kwargs, +): r""" Returns the Gromov-Wasserstein loss :math:`\mathbf{GW}` between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` estimated using Sinkhorn projections. To recover the Gromov-Wasserstein distance as defined in [13] compute :math:`d_{GW} = \frac{1}{2} \sqrt{\mathbf{GW}}`. @@ -294,8 +337,8 @@ def entropic_gromov_wasserstein2( G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not - required to satisfy marginal constraints but we strongly recommand it - to correcly estimate the GW distance. + required to satisfy marginal constraints but we strongly recommend it + to correctly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional @@ -332,21 +375,46 @@ def entropic_gromov_wasserstein2( """ T, logv = entropic_gromov_wasserstein( - C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, - tol, solver, warmstart, verbose, log=True, **kwargs) - - logv['T'] = T + C1, + C2, + p, + q, + loss_fun, + epsilon, + symmetric, + G0, + max_iter, + tol, + solver, + warmstart, + verbose, + log=True, + **kwargs, + ) + + logv["T"] = T if log: - return logv['gw_dist'], logv + return logv["gw_dist"], logv else: - return logv['gw_dist'] + return logv["gw_dist"] def BAPG_gromov_wasserstein( - C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, - symmetric=None, G0=None, max_iter=1000, tol=1e-9, marginal_loss=False, - verbose=False, log=False): + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1000, + tol=1e-9, + marginal_loss=False, + verbose=False, + log=False, +): r""" Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` estimated using Bregman Alternated Projected Gradient method. @@ -438,8 +506,10 @@ def BAPG_gromov_wasserstein( in Graph Data". International Conference on Learning Representations (ICLR), 2022. """ - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) C1, C2 = list_to_array(C1, C2) arr = [C1, C2] @@ -464,46 +534,54 @@ def BAPG_gromov_wasserstein( constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) if marginal_loss: if symmetric: + def df(T): return gwggrad(constC, hC1, hC2, T, nx) else: + def df(T): - return 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + return 0.5 * ( + gwggrad(constC, hC1, hC2, T, nx) + + gwggrad(constCt, hC1t, hC2t, T, nx) + ) else: if symmetric: + def df(T): - A = - nx.dot(nx.dot(hC1, T), hC2.T) + A = -nx.dot(nx.dot(hC1, T), hC2.T) return 2 * A else: + def df(T): - A = - nx.dot(nx.dot(hC1, T), hC2t) - At = - nx.dot(nx.dot(hC1t, T), hC2) + A = -nx.dot(nx.dot(hC1, T), hC2t) + At = -nx.dot(nx.dot(hC1t, T), hC2) return A + At cpt = 0 err = 1e15 if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Tprev = T # rows update - T = T * nx.exp(- df(T) / epsilon) + T = T * nx.exp(-df(T) / epsilon) row_scaling = p / nx.sum(T, 1) T = nx.reshape(row_scaling, (-1, 1)) * T # columns update - T = T * nx.exp(- df(T) / epsilon) + T = T * nx.exp(-df(T) / epsilon) column_scaling = q / nx.sum(T, 0) T = nx.reshape(column_scaling, (1, -1)) * T @@ -513,25 +591,26 @@ def df(T): err = nx.norm(T - Tprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if nx.any(nx.isnan(T)): - warnings.warn("Solver failed to produce a transport plan. You might " - "want to increase the regularization parameter `epsilon`.", - UserWarning) + warnings.warn( + "Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`.", + UserWarning, + ) if log: - log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx) + log["gw_dist"] = gwloss(constC, hC1, hC2, T, nx) if not marginal_loss: - log['loss'] = log['gw_dist'] - nx.sum(constC * T) + log["loss"] = log["gw_dist"] - nx.sum(constC * T) return T, log else: @@ -539,8 +618,20 @@ def df(T): def BAPG_gromov_wasserstein2( - C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, - tol=1e-9, marginal_loss=False, verbose=False, log=False): + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1000, + tol=1e-9, + marginal_loss=False, + verbose=False, + log=False, +): r""" Returns the Gromov-Wasserstein loss :math:`\mathbf{GW}` between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` estimated using Bregman Alternated Projected Gradient method. @@ -610,8 +701,8 @@ def BAPG_gromov_wasserstein2( G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. Otherwise G0 will be used as initial transport of the solver. G0 is not - required to satisfy marginal constraints but we strongly recommand it - to correcly estimate the GW distance. + required to satisfy marginal constraints but we strongly recommend it + to correctly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional @@ -637,22 +728,48 @@ def BAPG_gromov_wasserstein2( """ T, logv = BAPG_gromov_wasserstein( - C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, - tol, marginal_loss, verbose, log=True) - - logv['T'] = T + C1, + C2, + p, + q, + loss_fun, + epsilon, + symmetric, + G0, + max_iter, + tol, + marginal_loss, + verbose, + log=True, + ) + + logv["T"] = T if log: - return logv['gw_dist'], logv + return logv["gw_dist"], logv else: - return logv['gw_dist'] + return logv["gw_dist"] def entropic_gromov_barycenters( - N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', - epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, - stop_criterion='barycenter', warmstartT=False, verbose=False, - log=False, init_C=None, random_state=None, **kwargs): + N, + Cs, + ps=None, + p=None, + lambdas=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=True, + max_iter=1000, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + random_state=None, + **kwargs, +): r""" Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` estimated using Gromov-Wasserstein transports from Sinkhorn projections. @@ -729,19 +846,27 @@ def entropic_gromov_barycenters( "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. """ - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) if isinstance(Cs[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr = [*Cs] if ps is not None: if isinstance(ps[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr += [*ps] else: @@ -755,7 +880,7 @@ def entropic_gromov_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = [1.0 / S] * S # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: @@ -773,7 +898,7 @@ def entropic_gromov_barycenters( if warmstartT: T = [None] * S - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": inner_log = False else: inner_log = True @@ -781,58 +906,88 @@ def entropic_gromov_barycenters( if log: log_ = {} - log_['err'] = [] - if stop_criterion == 'loss': - log_['loss'] = [] + log_["err"] = [] + if stop_criterion == "loss": + log_["loss"] = [] while (err > tol) and (cpt < max_iter): - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C else: prev_loss = curr_loss # get transport plans if warmstartT: - res = [entropic_gromov_wasserstein( - C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s], - max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)] + res = [ + entropic_gromov_wasserstein( + C, + Cs[s], + p, + ps[s], + loss_fun, + epsilon, + symmetric, + T[s], + max_iter, + 1e-4, + verbose=verbose, + log=inner_log, + **kwargs, + ) + for s in range(S) + ] else: - res = [entropic_gromov_wasserstein( - C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None, - max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)] - if stop_criterion == 'barycenter': + res = [ + entropic_gromov_wasserstein( + C, + Cs[s], + p, + ps[s], + loss_fun, + epsilon, + symmetric, + None, + max_iter, + 1e-4, + verbose=verbose, + log=inner_log, + **kwargs, + ) + for s in range(S) + ] + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['gw_dist'] for output in res]) + curr_loss = np.sum([output[1]["gw_dist"] for output in res]) # update barycenters C = update_barycenter_structure( - T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx + ) # update convergence criterion - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": err = nx.norm(C - Cprev) if log: - log_['err'].append(err) + log_["err"].append(err) else: - err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan if log: - log_['loss'].append(curr_loss) - log_['err'].append(err) + log_["loss"].append(curr_loss) + log_["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if log: - log_['T'] = T - log_['p'] = p + log_["T"] = T + log_["p"] = p return C, log_ else: @@ -840,9 +995,24 @@ def entropic_gromov_barycenters( def entropic_fused_gromov_wasserstein( - M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, - symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, - solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + alpha=0.5, + G0=None, + max_iter=1000, + tol=1e-9, + solver="PGD", + warmstart=False, + verbose=False, + log=False, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, @@ -964,11 +1134,13 @@ def entropic_fused_gromov_wasserstein( application on graphs", International Conference on Machine Learning (ICML). 2019. """ - if solver not in ['PGD', 'PPA']: + if solver not in ["PGD", "PPA"]: raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) M, C1, C2 = list_to_array(M, C1, C2) arr = [M, C1, C2] @@ -992,7 +1164,9 @@ def entropic_fused_gromov_wasserstein( T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) cpt = 0 @@ -1005,28 +1179,38 @@ def entropic_fused_gromov_wasserstein( nu = nx.zeros(N2, type_as=C2) - np.log(N2) if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Tprev = T # compute the gradient if symmetric: tens = alpha * gwggrad(constC, hC1, hC2, T, nx) + (1 - alpha) * M else: - tens = (alpha * 0.5) * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + (1 - alpha) * M + tens = (alpha * 0.5) * ( + gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx) + ) + (1 - alpha) * M - if solver == 'PPA': + if solver == "PPA": tens = tens - epsilon * nx.log(T) if warmstart: - T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs) - mu = epsilon * nx.log(loginn['u']) - nu = epsilon * nx.log(loginn['v']) + T, loginn = sinkhorn( + p, + q, + tens, + epsilon, + method="sinkhorn", + log=True, + warmstart=(mu, nu), + **kwargs, + ) + mu = epsilon * nx.log(loginn["u"]) + nu = epsilon * nx.log(loginn["v"]) else: - T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs) + T = sinkhorn(p, q, tens, epsilon, method="sinkhorn", **kwargs) if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -1034,30 +1218,48 @@ def entropic_fused_gromov_wasserstein( err = nx.norm(T - Tprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if abs(nx.sum(T) - 1) > 1e-5: - warnings.warn("Solver failed to produce a transport plan. You might " - "want to increase the regularization parameter `epsilon`.") + warnings.warn( + "Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`." + ) if log: - log['fgw_dist'] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss(constC, hC1, hC2, T, nx) + log["fgw_dist"] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss( + constC, hC1, hC2, T, nx + ) return T, log else: return T def entropic_fused_gromov_wasserstein2( - M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, - symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, - solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + alpha=0.5, + G0=None, + max_iter=1000, + tol=1e-9, + solver="PGD", + warmstart=False, + verbose=False, + log=False, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein distance between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, @@ -1128,7 +1330,7 @@ def entropic_fused_gromov_wasserstein2( symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like, shape (ns,nt), optional @@ -1170,25 +1372,54 @@ def entropic_fused_gromov_wasserstein2( nx = get_backend(M, C1, C2) T, logv = entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter, - tol, solver, warmstart, verbose, log=True, **kwargs) - - logv['T'] = T + M, + C1, + C2, + p, + q, + loss_fun, + epsilon, + symmetric, + alpha, + G0, + max_iter, + tol, + solver, + warmstart, + verbose, + log=True, + **kwargs, + ) + + logv["T"] = T lin_term = nx.sum(T * M) - logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term) - logv['lin_loss'] = lin_term * (1 - alpha) + logv["quad_loss"] = logv["fgw_dist"] - (1 - alpha) * lin_term + logv["lin_loss"] = lin_term * (1 - alpha) if log: - return logv['fgw_dist'], logv + return logv["fgw_dist"], logv else: - return logv['fgw_dist'] + return logv["fgw_dist"] def BAPG_fused_gromov_wasserstein( - M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, - symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, - marginal_loss=False, verbose=False, log=False): + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + alpha=0.5, + G0=None, + max_iter=1000, + tol=1e-9, + marginal_loss=False, + verbose=False, + log=False, +): r""" Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, @@ -1292,8 +1523,10 @@ def BAPG_fused_gromov_wasserstein( "Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications". In Thirty-seventh Conference on Neural Information Processing Systems. """ - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) M, C1, C2 = list_to_array(M, C1, C2) arr = [M, C1, C2] @@ -1317,46 +1550,55 @@ def BAPG_fused_gromov_wasserstein( T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) # Define gradients if marginal_loss: if symmetric: + def df(T): return alpha * gwggrad(constC, hC1, hC2, T, nx) + (1 - alpha) * M else: + def df(T): - return (alpha * 0.5) * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + (1 - alpha) * M + return (alpha * 0.5) * ( + gwggrad(constC, hC1, hC2, T, nx) + + gwggrad(constCt, hC1t, hC2t, T, nx) + ) + (1 - alpha) * M else: if symmetric: + def df(T): - A = - nx.dot(nx.dot(hC1, T), hC2.T) + A = -nx.dot(nx.dot(hC1, T), hC2.T) return 2 * alpha * A + (1 - alpha) * M else: + def df(T): - A = - nx.dot(nx.dot(hC1, T), hC2t) - At = - nx.dot(nx.dot(hC1t, T), hC2) + A = -nx.dot(nx.dot(hC1, T), hC2t) + At = -nx.dot(nx.dot(hC1t, T), hC2) return alpha * (A + At) + (1 - alpha) * M + cpt = 0 err = 1e15 if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Tprev = T # rows update - T = T * nx.exp(- df(T) / epsilon) + T = T * nx.exp(-df(T) / epsilon) row_scaling = p / nx.sum(T, 1) T = nx.reshape(row_scaling, (-1, 1)) * T # columns update - T = T * nx.exp(- df(T) / epsilon) + T = T * nx.exp(-df(T) / epsilon) column_scaling = q / nx.sum(T, 0) T = nx.reshape(column_scaling, (1, -1)) * T @@ -1366,25 +1608,28 @@ def df(T): err = nx.norm(T - Tprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if nx.any(nx.isnan(T)): - warnings.warn("Solver failed to produce a transport plan. You might " - "want to increase the regularization parameter `epsilon`.", - UserWarning) + warnings.warn( + "Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`.", + UserWarning, + ) if log: - log['fgw_dist'] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss(constC, hC1, hC2, T, nx) + log["fgw_dist"] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss( + constC, hC1, hC2, T, nx + ) if not marginal_loss: - log['loss'] = log['fgw_dist'] - alpha * nx.sum(constC * T) + log["loss"] = log["fgw_dist"] - alpha * nx.sum(constC * T) return T, log else: @@ -1392,9 +1637,22 @@ def df(T): def BAPG_fused_gromov_wasserstein2( - M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, - symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, - marginal_loss=False, verbose=False, log=False): + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + alpha=0.5, + G0=None, + max_iter=1000, + tol=1e-9, + marginal_loss=False, + verbose=False, + log=False, +): r""" Returns the Fused Gromov-Wasserstein loss between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, @@ -1500,27 +1758,59 @@ def BAPG_fused_gromov_wasserstein2( nx = get_backend(M, C1, C2) T, logv = BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter, - tol, marginal_loss, verbose, log=True) - - logv['T'] = T + M, + C1, + C2, + p, + q, + loss_fun, + epsilon, + symmetric, + alpha, + G0, + max_iter, + tol, + marginal_loss, + verbose, + log=True, + ) + + logv["T"] = T lin_term = nx.sum(T * M) - logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term) - logv['lin_loss'] = lin_term * (1 - alpha) + logv["quad_loss"] = logv["fgw_dist"] - (1 - alpha) * lin_term + logv["lin_loss"] = lin_term * (1 - alpha) if log: - return logv['fgw_dist'], logv + return logv["fgw_dist"], logv else: - return logv['fgw_dist'] + return logv["fgw_dist"] def entropic_fused_gromov_barycenters( - N, Ys, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', - epsilon=0.1, symmetric=True, alpha=0.5, max_iter=1000, tol=1e-9, - stop_criterion='barycenter', warmstartT=False, verbose=False, - log=False, init_C=None, init_Y=None, fixed_structure=False, - fixed_features=False, random_state=None, **kwargs): + N, + Ys, + Cs, + ps=None, + p=None, + lambdas=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=True, + alpha=0.5, + max_iter=1000, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + init_Y=None, + fixed_structure=False, + fixed_features=False, + random_state=None, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections. @@ -1617,19 +1907,27 @@ def entropic_fused_gromov_barycenters( "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - if loss_fun not in ('square_loss', 'kl_loss'): - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + if loss_fun not in ("square_loss", "kl_loss"): + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) if isinstance(Cs[0], list) or isinstance(Ys[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr = [*Cs, *Ys] if ps is not None: if isinstance(ps[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr += [*ps] else: @@ -1642,15 +1940,14 @@ def entropic_fused_gromov_barycenters( nx = get_backend(*arr) S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = [1.0 / S] * S d = Ys[0].shape[1] # dimension on the node features # Initialization of C : random euclidean distance matrix (if not provided by user) if fixed_structure: if init_C is None: - raise UndefinedParameter( - 'If C is fixed it must be provided in init_C') + raise UndefinedParameter("If C is fixed it must be provided in init_C") else: C = init_C else: @@ -1665,8 +1962,7 @@ def entropic_fused_gromov_barycenters( # Initialization of Y if fixed_features: if init_Y is None: - raise UndefinedParameter( - 'If Y is fixed it must be provided in init_Y') + raise UndefinedParameter("If Y is fixed it must be provided in init_Y") else: Y = init_Y else: @@ -1681,7 +1977,7 @@ def entropic_fused_gromov_barycenters( if warmstartT: T = [None] * S - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": inner_log = False else: @@ -1690,17 +1986,16 @@ def entropic_fused_gromov_barycenters( if log: log_ = {} - if stop_criterion == 'barycenter': - log_['err_feature'] = [] - log_['err_structure'] = [] - log_['Ts_iter'] = [] + if stop_criterion == "barycenter": + log_["err_feature"] = [] + log_["err_structure"] = [] + log_["Ts_iter"] = [] else: - log_['loss'] = [] - log_['err_rel_loss'] = [] + log_["loss"] = [] + log_["err_rel_loss"] = [] for cpt in range(max_iter): # break if specified errors are below tol. - - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C Yprev = Y else: @@ -1708,72 +2003,108 @@ def entropic_fused_gromov_barycenters( # get transport plans if warmstartT: - res = [entropic_fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha, - T[s], max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)] + res = [ + entropic_fused_gromov_wasserstein( + Ms[s], + C, + Cs[s], + p, + ps[s], + loss_fun, + epsilon, + symmetric, + alpha, + T[s], + max_iter, + 1e-4, + verbose=verbose, + log=inner_log, + **kwargs, + ) + for s in range(S) + ] else: - res = [entropic_fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha, - None, max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)] - - if stop_criterion == 'barycenter': + res = [ + entropic_fused_gromov_wasserstein( + Ms[s], + C, + Cs[s], + p, + ps[s], + loss_fun, + epsilon, + symmetric, + alpha, + None, + max_iter, + 1e-4, + verbose=verbose, + log=inner_log, + **kwargs, + ) + for s in range(S) + ] + + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['fgw_dist'] for output in res]) + curr_loss = np.sum([output[1]["fgw_dist"] for output in res]) # update barycenters if not fixed_features: X = update_barycenter_feature( - T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx) + T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx + ) Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: C = update_barycenter_structure( - T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx + ) # update convergence criterion - if stop_criterion == 'barycenter': - err_feature, err_structure = 0., 0. + if stop_criterion == "barycenter": + err_feature, err_structure = 0.0, 0.0 if not fixed_features: err_feature = nx.norm(Y - Yprev) if not fixed_structure: err_structure = nx.norm(C - Cprev) if log: - log_['err_feature'].append(err_feature) - log_['err_structure'].append(err_structure) - log_['Ts_iter'].append(T) + log_["err_feature"].append(err_feature) + log_["err_structure"].append(err_structure) + log_["Ts_iter"].append(T) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_structure)) - print('{:5d}|{:8e}|'.format(cpt, err_feature)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_structure)) + print("{:5d}|{:8e}|".format(cpt, err_feature)) if (err_feature <= tol) or (err_structure <= tol): break else: - err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err_rel_loss = ( + abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan + ) if log: - log_['loss'].append(curr_loss) - log_['err_rel_loss'].append(err_rel_loss) + log_["loss"].append(curr_loss) + log_["err_rel_loss"].append(err_rel_loss) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_rel_loss)) if err_rel_loss <= tol: break if log: - log_['T'] = T - log_['p'] = p - log_['Ms'] = Ms + log_["T"] = T + log_["p"] = p + log_["Ms"] = Ms return Y, C, log_ else: diff --git a/ot/gromov/_dictionary.py b/ot/gromov/_dictionary.py index fbecb706a..6a2b30764 100644 --- a/ot/gromov/_dictionary.py +++ b/ot/gromov/_dictionary.py @@ -16,8 +16,28 @@ from ._gw import gromov_wasserstein, fused_gromov_wasserstein -def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, random_state=None, **kwargs): +def gromov_wasserstein_dictionary_learning( + Cs, + D, + nt, + reg=0.0, + ps=None, + q=None, + epochs=20, + batch_size=32, + learning_rate=1.0, + Cdict_init=None, + projection="nonnegative_symmetric", + use_log=True, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=20, + max_iter_inner=200, + use_adam_optimizer=True, + verbose=False, + random_state=None, + **kwargs, +): r""" Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` @@ -50,7 +70,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e reg : float, optional Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + Distribution in each source space C of Cs. Default is None and corresponds to uniform distributions. q : array-like, shape (nt,), optional Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. epochs: int, optional @@ -118,22 +138,24 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e if Cdict_init is None: # Initialize randomly structures of dictionary atoms based on samples dataset_means = [C.mean() for C in Cs] - Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + Cdict = rng.normal( + loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt) + ) else: Cdict = nx.to_numpy(Cdict_init).copy() assert Cdict.shape == (D, nt, nt) - if 'symmetric' in projection: + if "symmetric" in projection: Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) symmetric = True else: symmetric = False - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0 + if "nonnegative" in projection: + Cdict[Cdict < 0.0] = 0 if use_adam_optimizer: adam_moments = _initialize_adam_optimizer(Cdict) - log = {'loss_batches': [], 'loss_epochs': []} + log = {"loss_batches": [], "loss_epochs": []} const_q = q[:, None] * q[None, :] Cdict_best_state = Cdict.copy() loss_best_state = np.inf @@ -142,77 +164,115 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) for epoch in range(epochs): - cumulated_loss_over_epoch = 0. + cumulated_loss_over_epoch = 0.0 for _ in range(iter_by_epoch): # batch sampling batch = rng.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. + cumulated_loss_over_batch = 0.0 unmixings = np.zeros((batch_size, D)) Cs_embedded = np.zeros((batch_size, nt, nt)) Ts = [None] * batch_size for batch_idx, C_idx in enumerate(batch): # BCD solver for Gromov-Wasserstein linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( - Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, - max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ( + unmixings[batch_idx], + Cs_embedded[batch_idx], + Ts[batch_idx], + current_loss, + ) = gromov_wasserstein_linear_unmixing( + Cs[C_idx], + Cdict, + reg=reg, + p=ps[C_idx], + q=q, + tol_outer=tol_outer, + tol_inner=tol_inner, + max_iter_outer=max_iter_outer, + max_iter_inner=max_iter_inner, + symmetric=symmetric, + **kwargs, ) cumulated_loss_over_batch += current_loss cumulated_loss_over_epoch += cumulated_loss_over_batch if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) + log["loss_batches"].append(cumulated_loss_over_batch) # Stochastic projected gradient step over dictionary atoms grad_Cdict = np.zeros_like(Cdict) for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + shared_term_structures = Cs_embedded[batch_idx] * const_q - ( + Cs[C_idx].dot(Ts[batch_idx]) + ).T.dot(Ts[batch_idx]) + grad_Cdict += ( + unmixings[batch_idx][:, None, None] + * shared_term_structures[None, :, :] + ) grad_Cdict *= 2 / batch_size if use_adam_optimizer: - Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + Cdict, adam_moments = _adam_stochastic_updates( + Cdict, grad_Cdict, learning_rate, adam_moments + ) else: Cdict -= learning_rate * grad_Cdict - if 'symmetric' in projection: + if "symmetric" in projection: Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. + if "nonnegative" in projection: + Cdict[Cdict < 0.0] = 0.0 if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) + log["loss_epochs"].append(cumulated_loss_over_epoch) if loss_best_state > cumulated_loss_over_epoch: loss_best_state = cumulated_loss_over_epoch Cdict_best_state = Cdict.copy() if verbose: - print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + print( + "--- epoch =", + epoch, + " cumulated reconstruction error: ", + cumulated_loss_over_epoch, + ) return nx.from_numpy(Cdict_best_state), log def _initialize_adam_optimizer(variable): - # Initialization for our numpy implementation of adam optimizer atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor atoms_adam_count = 1 - return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + return {"mean": atoms_adam_m, "var": atoms_adam_v, "count": atoms_adam_count} -def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): - - adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad - adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) - unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) - unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) +def _adam_stochastic_updates( + variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09 +): + adam_moments["mean"] = beta_1 * adam_moments["mean"] + (1 - beta_1) * grad + adam_moments["var"] = beta_2 * adam_moments["var"] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments["mean"] / (1 - beta_1 ** adam_moments["count"]) + unbiased_v = adam_moments["var"] / (1 - beta_2 ** adam_moments["count"]) variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) - adam_moments['count'] += 1 + adam_moments["count"] += 1 return variable, adam_moments -def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=None, **kwargs): +def gromov_wasserstein_linear_unmixing( + C, + Cdict, + reg=0.0, + p=None, + q=None, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=20, + max_iter_inner=200, + symmetric=None, + **kwargs, +): r""" Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. @@ -300,28 +360,68 @@ def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_out previous_loss = current_loss # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w T, log = gromov_wasserstein( - C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, - max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., log=True, armijo=False, symmetric=symmetric, **kwargs) - current_loss = log['gw_dist'] + C1=C, + C2=Cembedded, + p=p, + q=q, + loss_fun="square_loss", + G0=T, + max_iter=max_iter_inner, + tol_rel=tol_inner, + tol_abs=0.0, + log=True, + armijo=False, + symmetric=symmetric, + **kwargs, + ) + current_loss = log["gw_dist"] if reg != 0: current_loss -= reg * np.sum(w**2) # 2. Solve linear unmixing problem over w with a fixed transport plan T w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( - C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, - starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + C=C, + Cdict=Cdict, + Cembedded=Cembedded, + w=w, + const_q=const_q, + T=T, + starting_loss=current_loss, + reg=reg, + tol=tol_inner, + max_iter=max_iter_inner, + **kwargs, ) if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + convergence_criterion = abs(previous_loss - current_loss) / abs( + previous_loss + ) else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + convergence_criterion = abs(previous_loss - current_loss) / 10 ** (-15) outer_count += 1 - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + return ( + nx.from_numpy(w), + nx.from_numpy(Cembedded), + nx.from_numpy(T), + nx.from_numpy(current_loss), + ) + + +def _cg_gromov_wasserstein_unmixing( + C, + Cdict, + Cembedded, + w, + const_q, + T, + starting_loss, + reg=0.0, + tol=10 ** (-5), + max_iter=200, + **kwargs, +): r""" Returns for a fixed admissible transport plan, the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` @@ -380,10 +480,13 @@ def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting const_TCT = np.transpose(C.dot(T)).dot(T) while (convergence_criterion > tol) and (count < max_iter): - previous_loss = current_loss # 1) Compute gradient at current point w - grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w = 2 * np.sum( + Cdict + * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), + axis=(1, 2), + ) grad_w -= 2 * reg * w # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w @@ -392,7 +495,9 @@ def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting x /= np.sum(x) # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing( + w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg + ) # 4) Updates: w <-- (1-gamma)*w + gamma*x w += gamma * (x - w) @@ -400,15 +505,19 @@ def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting current_loss += a * (gamma**2) + b * gamma if previous_loss != 0: # not that the loss can be negative if reg >0 - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + convergence_criterion = abs(previous_loss - current_loss) / abs( + previous_loss + ) else: # handle numerical issues around 0 - convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + convergence_criterion = abs(previous_loss - current_loss) / 10 ** (-15) count += 1 return w, Cembedded, current_loss -def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): +def _linesearch_gromov_wasserstein_unmixing( + w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs +): r""" Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing .. math:: @@ -459,11 +568,11 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons a = trace_diffx - trace_diffw b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) if reg != 0: - a -= reg * np.sum((x - w)**2) + a -= reg * np.sum((x - w) ** 2) b -= 2 * reg * np.sum(w * (x - w)) if a > 0: - gamma = min(1, max(0, - b / (2 * a))) + gamma = min(1, max(0, -b / (2 * a))) elif a + b < 0: gamma = 1 else: @@ -472,10 +581,32 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons return gamma, a, b, Cembedded_diff -def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., - Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, - tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, - random_state=None, **kwargs): +def fused_gromov_wasserstein_dictionary_learning( + Cs, + Ys, + D, + nt, + alpha, + reg=0.0, + ps=None, + q=None, + epochs=20, + batch_size=32, + learning_rate_C=1.0, + learning_rate_Y=1.0, + Cdict_init=None, + Ydict_init=None, + projection="nonnegative_symmetric", + use_log=False, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=20, + max_iter_inner=200, + use_adam_optimizer=True, + verbose=False, + random_state=None, + **kwargs, +): r""" Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` @@ -517,7 +648,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p reg : float, optional Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. ps : list of S array-like, shape (ns,), optional - Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + Distribution in each source space C of Cs. Default is None and corresponds to uniform distributions. q : array-like, shape (nt,), optional Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. epochs: int, optional @@ -597,31 +728,37 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p if Cdict_init is None: # Initialize randomly structures of dictionary atoms based on samples dataset_means = [C.mean() for C in Cs] - Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + Cdict = rng.normal( + loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt) + ) else: Cdict = nx.to_numpy(Cdict_init).copy() assert Cdict.shape == (D, nt, nt) if Ydict_init is None: # Initialize randomly features of dictionary atoms based on samples distribution by feature component dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) - Ydict = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + Ydict = rng.normal( + loc=dataset_feature_means.mean(axis=0), + scale=dataset_feature_means.std(axis=0), + size=(D, nt, d), + ) else: Ydict = nx.to_numpy(Ydict_init).copy() assert Ydict.shape == (D, nt, d) - if 'symmetric' in projection: + if "symmetric" in projection: Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) symmetric = True else: symmetric = False - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. + if "nonnegative" in projection: + Cdict[Cdict < 0.0] = 0.0 if use_adam_optimizer: adam_moments_C = _initialize_adam_optimizer(Cdict) adam_moments_Y = _initialize_adam_optimizer(Ydict) - log = {'loss_batches': [], 'loss_epochs': []} + log = {"loss_batches": [], "loss_epochs": []} const_q = q[:, None] * q[None, :] diag_q = np.diag(q) Cdict_best_state = Cdict.copy() @@ -632,13 +769,12 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) for epoch in range(epochs): - cumulated_loss_over_epoch = 0. + cumulated_loss_over_epoch = 0.0 for _ in range(iter_by_epoch): - # Batch iterations batch = rng.choice(range(dataset_size), size=batch_size, replace=False) - cumulated_loss_over_batch = 0. + cumulated_loss_over_batch = 0.0 unmixings = np.zeros((batch_size, D)) Cs_embedded = np.zeros((batch_size, nt, nt)) Ys_embedded = np.zeros((batch_size, nt, d)) @@ -646,53 +782,106 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p for batch_idx, C_idx in enumerate(batch): # BCD solver for Gromov-Wasserstein linear unmixing used independently on each structure of the sampled batch - unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( - Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, - tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs + ( + unmixings[batch_idx], + Cs_embedded[batch_idx], + Ys_embedded[batch_idx], + Ts[batch_idx], + current_loss, + ) = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], + Ys[C_idx], + Cdict, + Ydict, + alpha, + reg=reg, + p=ps[C_idx], + q=q, + tol_outer=tol_outer, + tol_inner=tol_inner, + max_iter_outer=max_iter_outer, + max_iter_inner=max_iter_inner, + symmetric=symmetric, + **kwargs, ) cumulated_loss_over_batch += current_loss cumulated_loss_over_epoch += cumulated_loss_over_batch if use_log: - log['loss_batches'].append(cumulated_loss_over_batch) + log["loss_batches"].append(cumulated_loss_over_batch) # Stochastic projected gradient step over dictionary atoms grad_Cdict = np.zeros_like(Cdict) grad_Ydict = np.zeros_like(Ydict) for batch_idx, C_idx in enumerate(batch): - shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) - shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) - grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] - grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + shared_term_structures = Cs_embedded[batch_idx] * const_q - ( + Cs[C_idx].dot(Ts[batch_idx]) + ).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[ + batch_idx + ].T.dot(Ys[C_idx]) + grad_Cdict += ( + alpha + * unmixings[batch_idx][:, None, None] + * shared_term_structures[None, :, :] + ) + grad_Ydict += ( + (1 - alpha) + * unmixings[batch_idx][:, None, None] + * shared_term_features[None, :, :] + ) grad_Cdict *= 2 / batch_size grad_Ydict *= 2 / batch_size if use_adam_optimizer: - Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) - Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + Cdict, adam_moments_C = _adam_stochastic_updates( + Cdict, grad_Cdict, learning_rate_C, adam_moments_C + ) + Ydict, adam_moments_Y = _adam_stochastic_updates( + Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y + ) else: Cdict -= learning_rate_C * grad_Cdict Ydict -= learning_rate_Y * grad_Ydict - if 'symmetric' in projection: + if "symmetric" in projection: Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) - if 'nonnegative' in projection: - Cdict[Cdict < 0.] = 0. + if "nonnegative" in projection: + Cdict[Cdict < 0.0] = 0.0 if use_log: - log['loss_epochs'].append(cumulated_loss_over_epoch) + log["loss_epochs"].append(cumulated_loss_over_epoch) if loss_best_state > cumulated_loss_over_epoch: loss_best_state = cumulated_loss_over_epoch Cdict_best_state = Cdict.copy() Ydict_best_state = Ydict.copy() if verbose: - print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + print( + "--- epoch: ", + epoch, + " cumulated reconstruction error: ", + cumulated_loss_over_epoch, + ) return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log -def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), - tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=True, **kwargs): +def fused_gromov_wasserstein_linear_unmixing( + C, + Y, + Cdict, + Ydict, + alpha, + reg=0.0, + p=None, + q=None, + tol_outer=10 ** (-5), + tol_inner=10 ** (-5), + max_iter_outer=20, + max_iter_inner=200, + symmetric=True, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` @@ -796,35 +985,97 @@ def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., convergence_criterion = np.inf current_loss = 10**15 outer_count = 0 - Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + Ys_constM = (Y**2).dot( + np.ones((d, nt)) + ) # constant in computing euclidean pairwise feature matrix while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): previous_loss = current_loss # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) - M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + M = ( + Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) + ) # euclidean distance matrix between features T, log = fused_gromov_wasserstein( - M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, - max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., armijo=False, G0=T, log=True, symmetric=symmetric, **kwargs) - current_loss = log['fgw_dist'] + M, + C, + Cembedded, + p, + q, + loss_fun="square_loss", + alpha=alpha, + max_iter=max_iter_inner, + tol_rel=tol_inner, + tol_abs=0.0, + armijo=False, + G0=T, + log=True, + symmetric=symmetric, + **kwargs, + ) + current_loss = log["fgw_dist"] if reg != 0: current_loss -= reg * np.sum(w**2) # 2. Solve linear unmixing problem over w with a fixed transport plan T - w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, - T, p, q, const_q, diag_q, current_loss, alpha, reg, - tol=tol_inner, max_iter=max_iter_inner, **kwargs) + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing( + C, + Y, + Cdict, + Ydict, + Cembedded, + Yembedded, + w, + T, + p, + q, + const_q, + diag_q, + current_loss, + alpha, + reg, + tol=tol_inner, + max_iter=max_iter_inner, + **kwargs, + ) if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + convergence_criterion = abs(previous_loss - current_loss) / abs( + previous_loss + ) else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + convergence_criterion = abs(previous_loss - current_loss) / 10 ** (-12) outer_count += 1 - return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) - - -def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + return ( + nx.from_numpy(w), + nx.from_numpy(Cembedded), + nx.from_numpy(Yembedded), + nx.from_numpy(T), + nx.from_numpy(current_loss), + ) + + +def _cg_fused_gromov_wasserstein_unmixing( + C, + Y, + Cdict, + Ydict, + Cembedded, + Yembedded, + w, + T, + p, + q, + const_q, + diag_q, + starting_loss, + alpha, + reg, + tol=10 ** (-6), + max_iter=200, + **kwargs, +): r""" Returns for a fixed admissible transport plan, the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` @@ -901,9 +1152,16 @@ def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedd # 1) Compute gradient at current point w # structure - grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w = alpha * np.sum( + Cdict + * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), + axis=(1, 2), + ) # feature - grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w += (1 - alpha) * np.sum( + Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), + axis=(1, 2), + ) grad_w -= reg * w grad_w *= 2 @@ -913,7 +1171,24 @@ def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedd x /= np.sum(x) # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c - gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + gamma, a, b, Cembedded_diff, Yembedded_diff = ( + _linesearch_fused_gromov_wasserstein_unmixing( + w, + grad_w, + x, + Y, + Cdict, + Ydict, + Cembedded, + Yembedded, + T, + const_q, + const_TCT, + ones_ns_d, + alpha, + reg, + ) + ) # 4) Updates: w <-- (1-gamma)*w + gamma*x w += gamma * (x - w) @@ -922,15 +1197,33 @@ def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedd current_loss += a * (gamma**2) + b * gamma if previous_loss != 0: - convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + convergence_criterion = abs(previous_loss - current_loss) / abs( + previous_loss + ) else: - convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + convergence_criterion = abs(previous_loss - current_loss) / 10 ** (-12) count += 1 return w, Cembedded, Yembedded, current_loss -def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): +def _linesearch_fused_gromov_wasserstein_unmixing( + w, + grad_w, + x, + Y, + Cdict, + Ydict, + Cembedded, + Yembedded, + T, + const_q, + const_TCT, + ones_ns_d, + alpha, + reg, + **kwargs, +): r""" Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing .. math:: @@ -1002,12 +1295,14 @@ def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Yembedded_diff = Yembedded_x - Yembedded # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) - b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + b_w = 2 * np.sum( + T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)) + ) a = alpha * a_gw + (1 - alpha) * a_w b = alpha * b_gw + (1 - alpha) * b_w if reg != 0: - a -= reg * np.sum((x - w)**2) + a -= reg * np.sum((x - w) ** 2) b -= 2 * reg * np.sum(w * (x - w)) if a > 0: gamma = min(1, max(0, -b / (2 * a))) diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py index 7e12ef930..14871bfe3 100644 --- a/ot/gromov/_estimators.py +++ b/ot/gromov/_estimators.py @@ -17,8 +17,18 @@ from ..backend import get_backend -def GW_distance_estimation(C1, C2, p, q, loss_fun, T, - nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): +def GW_distance_estimation( + C1, + C2, + p, + q, + loss_fun, + T, + nb_samples_p=None, + nb_samples_q=None, + std=True, + random_state=None, +): r""" Returns an approximation of the Gromov-Wasserstein loss between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` with a fixed transport plan :math:`\mathbf{T}`. To recover an approximation of the Gromov-Wasserstein distance as defined in [13] compute :math:`d_{GW} = \frac{1}{2} \sqrt{\mathbf{GW}}`. @@ -122,21 +132,24 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, len_q, size=nb_samples_q, p=nx.to_numpy(T_indexi / nx.sum(T_indexi)), - replace=True + replace=True, ) index_l[i] = generator.choice( len_q, size=nb_samples_q, p=nx.to_numpy(T_indexj / nx.sum(T_indexj)), - replace=True + replace=True, ) - list_value_sample = nx.stack([ - loss_fun( - C1[np.ix_(index_i, index_j)], - C2[np.ix_(index_k[:, n], index_l[:, n])] - ) for n in range(nb_samples_q) - ], axis=2) + list_value_sample = nx.stack( + [ + loss_fun( + C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])] + ) + for n in range(nb_samples_q) + ], + axis=2, + ) if std: std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5 @@ -145,8 +158,19 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, return nx.mean(list_value_sample) -def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, - alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): +def pointwise_gromov_wasserstein( + C1, + C2, + p, + q, + loss_fun, + alpha=1, + max_iter=100, + threshold_plan=0, + log=False, + verbose=False, + random_state=None, +): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe. This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations. @@ -232,21 +256,23 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, ) if alpha == 1: - T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) + T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) else: - new_T = nx.tocsr( - emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False) - ) + new_T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)) T = (1 - alpha) * T + alpha * new_T # To limit the number of non 0, the values below the threshold are set to 0. T = nx.eliminate_zeros(T, threshold=threshold_plan) if cpt % 10 == 0 or cpt == (max_iter - 1): gw_dist_estimated = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False, random_state=generator + C1=C1, + C2=C2, + loss_fun=loss_fun, + p=p, + q=q, + T=T, + std=False, + random_state=generator, ) if gw_dist_estimated < best_gw_dist_estimated: @@ -255,22 +281,35 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + print( + "{:5s}|{:12s}".format("It.", "Best gw estimated") + + "\n" + + "-" * 19 + ) + print("{:5d}|{:8e}|".format(cpt, best_gw_dist_estimated)) if log: log = {} log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T, random_state=generator + C1=C1, C2=C2, loss_fun=loss_fun, p=p, q=q, T=best_T, random_state=generator ) return best_T, log return best_T -def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, - nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, - random_state=None): +def sampled_gromov_wasserstein( + C1, + C2, + p, + q, + loss_fun, + nb_samples_grad=100, + epsilon=1, + max_iter=500, + log=False, + verbose=False, + random_state=None, +): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe. This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver. @@ -355,7 +394,9 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose( + C2, C2.T, rtol=1e-10, atol=1e-10 + ) for cpt in range(max_iter): index0 = generator.choice( @@ -364,21 +405,28 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, Lik = 0 for i, index0_i in enumerate(index0): index1 = generator.choice( - len_q, size=nb_samples_grad_q, + len_q, + size=nb_samples_grad_q, p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])), - replace=False + replace=False, ) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. if (not C_are_symmetric) and generator.rand(1) > 0.5: - Lik += nx.mean(loss_fun( - C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], - C2[:, index1][None, :, :] - ), axis=2) + Lik += nx.mean( + loss_fun( + C1[:, [index0[i]] * nb_samples_grad_q][:, None, :], + C2[:, index1][None, :, :], + ), + axis=2, + ) else: - Lik += nx.mean(loss_fun( - C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], - C2[index1, :][:, None, :] - ), axis=0) + Lik += nx.mean( + loss_fun( + C1[[index0[i]] * nb_samples_grad_q, :][:, :, None], + C2[index1, :][:, None, :], + ), + axis=0, + ) max_Lik = nx.max(Lik) if max_Lik == 0: @@ -395,7 +443,7 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, try: new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) except (RuntimeWarning, UserWarning): - print("Warning catched in Sinkhorn: Return last stable T") + print("Warning caught in Sinkhorn: Return last stable T") break else: new_T = emd(a=p, b=q, M=Lik) @@ -411,15 +459,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, if verbose and cpt % 10 == 0: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, change_T)) + print( + "{:5s}|{:12s}".format("It.", "||T_n - T_{n+1}||") + "\n" + "-" * 19 + ) + print("{:5d}|{:8e}|".format(cpt, change_T)) T = nx.copy(new_T) if log: log = {} log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation( - C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, random_state=generator + C1=C1, C2=C2, loss_fun=loss_fun, p=p, q=q, T=T, random_state=generator ) return T, log return T diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 806e691e1..1c8de7b20 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -24,8 +24,21 @@ from ._utils import update_barycenter_structure, update_barycenter_feature -def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def gromov_wasserstein( + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + symmetric=None, + log=False, + armijo=False, + G0=None, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + **kwargs, +): r""" Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`. @@ -140,7 +153,9 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: G0 = p[:, None] * q[None, :] @@ -158,20 +173,36 @@ def f(G): return gwloss(constC, hC1, hC2, G, np_) if symmetric: + def df(G): return gwggrad(constC, hC1, hC2, G, np_) else: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) def df(G): - return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + return 0.5 * ( + gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_) + ) if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, symmetric=symmetric, **kwargs) + return solve_gromov_linesearch( + G, + deltaG, + cost_G, + hC1, + hC2, + M=0.0, + reg=1.0, + nx=np_, + symmetric=symmetric, + **kwargs, + ) if not nx.is_floating_point(C10): warnings.warn( @@ -179,21 +210,65 @@ def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "structure matrix consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) if log: - res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) - log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) + res, log = cg( + p, + q, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=True, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) + log["gw_dist"] = nx.from_numpy(log["loss"][-1], type_as=C10) + log["u"] = nx.from_numpy(log["u"], type_as=C10) + log["v"] = nx.from_numpy(log["v"], type_as=C10) return nx.from_numpy(res, type_as=C10), log else: - return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + return nx.from_numpy( + cg( + p, + q, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=False, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ), + type_as=C10, + ) -def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def gromov_wasserstein2( + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + symmetric=None, + log=False, + armijo=False, + G0=None, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + **kwargs, +): r""" Returns the Gromov-Wasserstein loss :math:`\mathbf{GW}` between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`. To recover the Gromov-Wasserstein distance as defined in [13] compute :math:`d_{GW} = \frac{1}{2} \sqrt{\mathbf{GW}}`. @@ -304,22 +379,43 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri q = unif(C2.shape[0], type_as=C1) T, log_gw = gromov_wasserstein( - C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) - - log_gw['T'] = T - gw = log_gw['gw_dist'] - - if loss_fun == 'square_loss': + C1, + C2, + p, + q, + loss_fun, + symmetric, + log=True, + armijo=armijo, + G0=G0, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + **kwargs, + ) + + log_gw["T"] = T + gw = log_gw["gw_dist"] + + if loss_fun == "square_loss": gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - elif loss_fun == 'kl_loss': - gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) - - gw = nx.set_gradients(gw, (p, q, C1, C2), - (log_gw['u'] - nx.mean(log_gw['u']), - log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + + gw = nx.set_gradients( + gw, + (p, q, C1, C2), + ( + log_gw["u"] - nx.mean(log_gw["u"]), + log_gw["v"] - nx.mean(log_gw["v"]), + gC1, + gC2, + ), + ) if log: return gw, log_gw @@ -327,8 +423,23 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri return gw -def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def fused_gromov_wasserstein( + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + symmetric=None, + alpha=0.5, + armijo=False, + G0=None, + log=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}` (see :ref:`[24] `). @@ -448,7 +559,9 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', alpha = nx.to_numpy(alpha0) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: G0 = p[:, None] * q[None, :] @@ -466,20 +579,36 @@ def f(G): return gwloss(constC, hC1, hC2, G, np_) if symmetric: + def df(G): return gwggrad(constC, hC1, hC2, G, np_) else: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_) def df(G): - return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) + return 0.5 * ( + gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_) + ) if armijo: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) else: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): - return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, symmetric=symmetric, **kwargs) + return solve_gromov_linesearch( + G, + deltaG, + cost_G, + hC1, + hC2, + M=(1 - alpha) * M, + reg=alpha, + nx=np_, + symmetric=symmetric, + **kwargs, + ) if not nx.is_floating_point(M0): warnings.warn( @@ -487,20 +616,66 @@ def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "feature matrix consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) if log: - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) - log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0) - log['u'] = nx.from_numpy(log['u'], type_as=M0) - log['v'] = nx.from_numpy(log['v'], type_as=M0) + res, log = cg( + p, + q, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=True, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) + log["fgw_dist"] = nx.from_numpy(log["loss"][-1], type_as=M0) + log["u"] = nx.from_numpy(log["u"], type_as=M0) + log["v"] = nx.from_numpy(log["v"], type_as=M0) return nx.from_numpy(res, type_as=M0), log else: - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0) + return nx.from_numpy( + cg( + p, + q, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=False, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ), + type_as=M0, + ) -def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def fused_gromov_wasserstein2( + M, + C1, + C2, + p=None, + q=None, + loss_fun="square_loss", + symmetric=None, + alpha=0.5, + armijo=False, + G0=None, + log=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein distance between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}` (see :ref:`[24] `). @@ -611,37 +786,67 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', q = unif(C2.shape[0], type_as=M) T, log_fgw = fused_gromov_wasserstein( - M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) - - fgw_dist = log_fgw['fgw_dist'] - log_fgw['T'] = T + M, + C1, + C2, + p, + q, + loss_fun, + symmetric, + alpha, + armijo, + G0, + log=True, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + **kwargs, + ) + + fgw_dist = log_fgw["fgw_dist"] + log_fgw["T"] = T # compute separate terms for gradients and log lin_term = nx.sum(T * M) - log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term) - log_fgw['lin_loss'] = lin_term * (1 - alpha) - gw_term = log_fgw['quad_loss'] / alpha + log_fgw["quad_loss"] = fgw_dist - (1 - alpha) * lin_term + log_fgw["lin_loss"] = lin_term * (1 - alpha) + gw_term = log_fgw["quad_loss"] / alpha - if loss_fun == 'square_loss': + if loss_fun == "square_loss": gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - elif loss_fun == 'kl_loss': - gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T)) + fgw_dist = nx.set_gradients( + fgw_dist, + (p, q, C1, C2, M), + ( + log_fgw["u"] - nx.mean(log_fgw["u"]), + log_fgw["v"] - nx.mean(log_fgw["v"]), + alpha * gC1, + alpha * gC2, + (1 - alpha) * T, + ), + ) else: - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T, - gw_term - lin_term)) + fgw_dist = nx.set_gradients( + fgw_dist, + (p, q, C1, C2, M, alpha), + ( + log_fgw["u"] - nx.mean(log_fgw["u"]), + log_fgw["v"] - nx.mean(log_fgw["v"]), + alpha * gC1, + alpha * gC2, + (1 - alpha) * T, + gw_term - lin_term, + ), + ) if log: return fgw_dist, log_fgw @@ -649,8 +854,20 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', return fgw_dist -def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, - alpha_min=None, alpha_max=None, nx=None, symmetric=False, **kwargs): +def solve_gromov_linesearch( + G, + deltaG, + cost_G, + C1, + C2, + M, + reg, + alpha_min=None, + alpha_max=None, + nx=None, + symmetric=False, + **kwargs, +): """ Solve the linesearch in the FW iterations for any inner loss that decomposes as in Proposition 1 in :ref:`[12] `. @@ -712,27 +929,43 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, nx = get_backend(G, deltaG, C1, C2, M) dot = nx.dot(nx.dot(C1, deltaG), C2.T) - a = - reg * nx.sum(dot * deltaG) + a = -reg * nx.sum(dot * deltaG) if symmetric: b = nx.sum(M * deltaG) - 2 * reg * nx.sum(dot * G) else: - b = nx.sum(M * deltaG) - reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) + b = nx.sum(M * deltaG) - reg * ( + nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG) + ) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) # the new cost is deduced from the line search quadratic function - cost_G = cost_G + a * (alpha ** 2) + b * alpha + cost_G = cost_G + a * (alpha**2) + b * alpha return alpha, 1, cost_G def gromov_barycenters( - N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', - symmetric=True, armijo=False, max_iter=1000, tol=1e-9, - stop_criterion='barycenter', warmstartT=False, verbose=False, - log=False, init_C=None, random_state=None, **kwargs): + N, + Cs, + ps=None, + p=None, + lambdas=None, + loss_fun="square_loss", + symmetric=True, + armijo=False, + max_iter=1000, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + random_state=None, + **kwargs, +): r""" Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` @@ -808,22 +1041,27 @@ def gromov_barycenters( International Conference on Machine Learning (ICML). 2016. """ - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) if isinstance(Cs[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: structures Cs[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr = [*Cs] if ps is not None: if isinstance(ps[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr += [*ps] else: ps = [unif(C.shape[0], type_as=C) for C in Cs] if p is not None: - arr.append(list_to_array(p)) else: p = unif(N, type_as=Cs[0]) @@ -832,7 +1070,7 @@ def gromov_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = [1.0 / S] * S # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: @@ -847,7 +1085,7 @@ def gromov_barycenters( if warmstartT: T = [None] * S - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": inner_log = False else: inner_log = True @@ -855,62 +1093,91 @@ def gromov_barycenters( if log: log_ = {} - log_['err'] = [] - if stop_criterion == 'loss': - log_['loss'] = [] + log_["err"] = [] + if stop_criterion == "loss": + log_["loss"] = [] for cpt in range(max_iter): - - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C else: prev_loss = curr_loss # get transport plans if warmstartT: - res = [gromov_wasserstein( - C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s], - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) - for s in range(S)] + res = [ + gromov_wasserstein( + C, + Cs[s], + p, + ps[s], + loss_fun, + symmetric=symmetric, + armijo=armijo, + G0=T[s], + max_iter=max_iter, + tol_rel=1e-5, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] else: - res = [gromov_wasserstein( - C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=None, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) - for s in range(S)] - if stop_criterion == 'barycenter': + res = [ + gromov_wasserstein( + C, + Cs[s], + p, + ps[s], + loss_fun, + symmetric=symmetric, + armijo=armijo, + G0=None, + max_iter=max_iter, + tol_rel=1e-5, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['gw_dist'] for output in res]) + curr_loss = np.sum([output[1]["gw_dist"] for output in res]) # update barycenters C = update_barycenter_structure( - T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx + ) # update convergence criterion - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": err = nx.norm(C - Cprev) if log: - log_['err'].append(err) + log_["err"].append(err) else: - err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan if log: - log_['loss'].append(curr_loss) - log_['err'].append(err) + log_["loss"].append(curr_loss) + log_["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) if err <= tol: break if log: - log_['T'] = T - log_['p'] = p + log_["T"] = T + log_["p"] = p return C, log_ else: @@ -918,11 +1185,29 @@ def gromov_barycenters( def fgw_barycenters( - N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, - fixed_features=False, p=None, loss_fun='square_loss', armijo=False, - symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter', - warmstartT=False, verbose=False, log=False, init_C=None, init_X=None, - random_state=None, **kwargs): + N, + Ys, + Cs, + ps=None, + lambdas=None, + alpha=0.5, + fixed_structure=False, + fixed_features=False, + p=None, + loss_fun="square_loss", + armijo=False, + symmetric=True, + max_iter=100, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + init_X=None, + random_state=None, + **kwargs, +): r""" Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` (see eq (5) in :ref:`[24] `), estimated using Fused Gromov-Wasserstein transports from Conditional Gradient solvers. @@ -1016,16 +1301,22 @@ def fgw_barycenters( "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) if isinstance(Cs[0], list) or isinstance(Ys[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: structures Cs[i] and/or features Ys[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr = [*Cs, *Ys] if ps is not None: if isinstance(ps[0], list): - raise ValueError("Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy).") + raise ValueError( + "Deprecated feature in POT 0.9.4: weights ps[i] are lists and should be arrays from a supported backend (e.g numpy)." + ) arr += [*ps] else: @@ -1039,14 +1330,13 @@ def fgw_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = [1.0 / S] * S d = Ys[0].shape[1] # dimension on the node features if fixed_structure: if init_C is None: - raise UndefinedParameter( - 'If C is fixed it must be provided in init_C') + raise UndefinedParameter("If C is fixed it must be provided in init_C") else: C = init_C else: @@ -1060,8 +1350,7 @@ def fgw_barycenters( if fixed_features: if init_X is None: - raise UndefinedParameter( - 'If X is fixed it must be provided in init_X') + raise UndefinedParameter("If X is fixed it must be provided in init_X") else: X = init_X else: @@ -1076,7 +1365,7 @@ def fgw_barycenters( if warmstartT: T = [None] * S - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": inner_log = False else: @@ -1085,17 +1374,16 @@ def fgw_barycenters( if log: log_ = {} - if stop_criterion == 'barycenter': - log_['err_feature'] = [] - log_['err_structure'] = [] - log_['Ts_iter'] = [] + if stop_criterion == "barycenter": + log_["err_feature"] = [] + log_["err_structure"] = [] + log_["Ts_iter"] = [] else: - log_['loss'] = [] - log_['err_rel_loss'] = [] + log_["loss"] = [] + log_["err_rel_loss"] = [] for cpt in range(max_iter): # break if specified errors are below tol. - - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C Xprev = X else: @@ -1103,72 +1391,108 @@ def fgw_barycenters( # get transport plans if warmstartT: - res = [fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) - for s in range(S)] + res = [ + fused_gromov_wasserstein( + Ms[s], + C, + Cs[s], + p, + ps[s], + loss_fun=loss_fun, + alpha=alpha, + armijo=armijo, + symmetric=symmetric, + G0=T[s], + max_iter=max_iter, + tol_rel=1e-5, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] else: - res = [fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) - for s in range(S)] - if stop_criterion == 'barycenter': + res = [ + fused_gromov_wasserstein( + Ms[s], + C, + Cs[s], + p, + ps[s], + loss_fun=loss_fun, + alpha=alpha, + armijo=armijo, + symmetric=symmetric, + G0=None, + max_iter=max_iter, + tol_rel=1e-5, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['fgw_dist'] for output in res]) + curr_loss = np.sum([output[1]["fgw_dist"] for output in res]) # update barycenters if not fixed_features: X = update_barycenter_feature( - T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx) + T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx + ) Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: C = update_barycenter_structure( - T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx + ) # update convergence criterion - if stop_criterion == 'barycenter': - err_feature, err_structure = 0., 0. + if stop_criterion == "barycenter": + err_feature, err_structure = 0.0, 0.0 if not fixed_features: err_feature = nx.norm(X - Xprev) if not fixed_structure: err_structure = nx.norm(C - Cprev) if log: - log_['err_feature'].append(err_feature) - log_['err_structure'].append(err_structure) - log_['Ts_iter'].append(T) + log_["err_feature"].append(err_feature) + log_["err_structure"].append(err_structure) + log_["Ts_iter"].append(T) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_structure)) - print('{:5d}|{:8e}|'.format(cpt, err_feature)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_structure)) + print("{:5d}|{:8e}|".format(cpt, err_feature)) if (err_feature <= tol) or (err_structure <= tol): break else: - err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err_rel_loss = ( + abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan + ) if log: - log_['loss'].append(curr_loss) - log_['err_rel_loss'].append(err_rel_loss) + log_["loss"].append(curr_loss) + log_["err_rel_loss"].append(err_rel_loss) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_rel_loss)) if err_rel_loss <= tol: break if log: - log_['T'] = T - log_['p'] = p - log_['Ms'] = Ms + log_["T"] = T + log_["p"] = p + log_["Ms"] = Ms return X, C, log_ else: diff --git a/ot/gromov/_lowrank.py b/ot/gromov/_lowrank.py index 9aa3faab5..82ff98da4 100644 --- a/ot/gromov/_lowrank.py +++ b/ot/gromov/_lowrank.py @@ -6,7 +6,6 @@ # # License: MIT License - import warnings from ..utils import unif, get_lowrank_lazytensor from ..backend import get_backend @@ -58,9 +57,27 @@ def _flat_product_operator(X, nx=None): return X_flat -def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, gamma_init="rescale", - rescale_cost=True, cost_factorized_Xs=None, cost_factorized_Xt=None, stopThr=1e-4, numItermax=1000, - stopThr_dykstra=1e-3, numItermax_dykstra=10000, seed_init=49, warn=True, warn_dykstra=False, log=False): +def lowrank_gromov_wasserstein_samples( + X_s, + X_t, + a=None, + b=None, + reg=0, + rank=None, + alpha=1e-10, + gamma_init="rescale", + rescale_cost=True, + cost_factorized_Xs=None, + cost_factorized_Xt=None, + stopThr=1e-4, + numItermax=1000, + stopThr_dykstra=1e-3, + numItermax_dykstra=10000, + seed_init=49, + warn=True, + warn_dykstra=False, + log=False, +): r""" Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints on the couplings and cost matrices. @@ -180,8 +197,11 @@ def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=Non # Dykstra won't converge if 1/rank < alpha (see Section 3.2) if 1 / r < alpha: - raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( - a=alpha, r=1 / rank)) + raise ValueError( + "alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank + ) + ) if cost_factorized_Xs is not None: A1, A2 = cost_factorized_Xs @@ -204,7 +224,9 @@ def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=Non gamma = 1 / (2 * L) if gamma_init not in ["rescale", "theory"]: - raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))) + raise ( + NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init)) + ) # initial value of error err = 1 @@ -217,7 +239,7 @@ def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=Non if err > stopThr: # Compute cost matrices C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) - C1 = - 4 * nx.dot(A1, C1) + C1 = -4 * nx.dot(A1, C1) C2 = nx.dot(R.T, B1) C2 = nx.dot(C2, B2.T) diag_g = (1 / g)[None, :] @@ -248,7 +270,16 @@ def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=Non # Update couplings with LR Dykstra algorithm Q, R, g = _LR_Dysktra( - K1, K2, K3, a, b, alpha, stopThr_dykstra, numItermax_dykstra, warn_dykstra, nx + K1, + K2, + K3, + a, + b, + alpha, + stopThr_dykstra, + numItermax_dykstra, + warn_dykstra, + nx, ) # Update error with kullback-divergence @@ -274,7 +305,7 @@ def lowrank_gromov_wasserstein_samples(X_s, X_t, a=None, b=None, reg=0, rank=Non # Update low rank costs C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) - C1 = - 4 * nx.dot(A1, C1) + C1 = -4 * nx.dot(A1, C1) C2 = nx.dot(R.T, B1) C2 = nx.dot(C2, B2.T) diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py index f1840655c..e38eeff1c 100644 --- a/ot/gromov/_partial.py +++ b/ot/gromov/_partial.py @@ -9,7 +9,6 @@ # # License: MIT License - from ..utils import list_to_array, unif from ..backend import get_backend, NumpyBackend from ..partial import entropic_partial_wasserstein @@ -21,9 +20,23 @@ def partial_gromov_wasserstein( - C1, C2, p=None, q=None, m=None, loss_fun='square_loss', nb_dummies=1, - G0=None, thres=1, numItermax=1e4, tol=1e-8, symmetric=None, warn=True, - log=False, verbose=False, **kwargs): + C1, + C2, + p=None, + q=None, + m=None, + loss_fun="square_loss", + nb_dummies=1, + G0=None, + thres=1, + numItermax=1e4, + tol=1e-8, + symmetric=None, + warn=True, + log=False, + verbose=False, + **kwargs, +): r""" Returns the Partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`. @@ -165,19 +178,24 @@ def partial_gromov_wasserstein( C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) if m is None: m = min(np.sum(p), np.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") elif m > min(np.sum(p), np.sum(q)): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) if G0 is None: - G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. else: G0 = nx.to_numpy(G0_) @@ -207,6 +225,7 @@ def f(G): return gwloss(constC1 + constC2, hC1, hC2, G, np_) if symmetric: + def df(G): pG = G.sum(1) qG = G.sum(0) @@ -224,13 +243,15 @@ def df(G): constC2t = np.outer(ones_p, np.dot(qG, fC2)) return 0.5 * ( - gwggrad(constC1 + constC2, hC1, hC2, G, np_) + - gwggrad(constC1t + constC2t, hC1t, hC2t, G, np_)) + gwggrad(constC1 + constC2, hC1, hC2, G, np_) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, np_) + ) def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): df_Gc = df(deltaG + G) return solve_partial_gromov_linesearch( - G, deltaG, cost_G, df_G, df_Gc, M=0., reg=1., nx=np_, **kwargs) + G, deltaG, cost_G, df_G, df_Gc, M=0.0, reg=1.0, nx=np_, **kwargs + ) if not nx.is_floating_point(C10): warnings.warn( @@ -238,26 +259,71 @@ def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "structure matrix consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) if log: - res, log = partial_cg(p, q, p_extended, q_extended, 0., 1., f, df, G0, - line_search, log=True, numItermax=numItermax, - stopThr=tol, stopThr2=0., warn=warn, **kwargs) - log['partial_gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) + res, log = partial_cg( + p, + q, + p_extended, + q_extended, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=True, + numItermax=numItermax, + stopThr=tol, + stopThr2=0.0, + warn=warn, + **kwargs, + ) + log["partial_gw_dist"] = nx.from_numpy(log["loss"][-1], type_as=C10) return nx.from_numpy(res, type_as=C10), log else: return nx.from_numpy( - partial_cg(p, q, p_extended, q_extended, 0., 1., f, df, G0, - line_search, log=False, numItermax=numItermax, - stopThr=tol, stopThr2=0., **kwargs), type_as=C10) + partial_cg( + p, + q, + p_extended, + q_extended, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=False, + numItermax=numItermax, + stopThr=tol, + stopThr2=0.0, + **kwargs, + ), + type_as=C10, + ) def partial_gromov_wasserstein2( - C1, C2, p=None, q=None, m=None, loss_fun='square_loss', nb_dummies=1, G0=None, - thres=1, numItermax=1e4, tol=1e-7, symmetric=None, warn=False, log=False, - verbose=False, **kwargs): + C1, + C2, + p=None, + q=None, + m=None, + loss_fun="square_loss", + nb_dummies=1, + G0=None, + thres=1, + numItermax=1e4, + tol=1e-7, + symmetric=None, + warn=False, + log=False, + verbose=False, + **kwargs, +): r""" Returns the Partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`. @@ -391,18 +457,35 @@ def partial_gromov_wasserstein2( q = unif(C2.shape[0], type_as=C1) T, log_pgw = partial_gromov_wasserstein( - C1, C2, p, q, m, loss_fun, nb_dummies, G0, thres, - numItermax, tol, symmetric, warn, True, verbose, **kwargs) - - log_pgw['T'] = T - pgw = log_pgw['partial_gw_dist'] - - if loss_fun == 'square_loss': + C1, + C2, + p, + q, + m, + loss_fun, + nb_dummies, + G0, + thres, + numItermax, + tol, + symmetric, + warn, + True, + verbose, + **kwargs, + ) + + log_pgw["T"] = T + pgw = log_pgw["partial_gw_dist"] + + if loss_fun == "square_loss": gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - elif loss_fun == 'kl_loss': - gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) pgw = nx.set_gradients(pgw, (C1, C2), (gC1, gC2)) @@ -413,8 +496,18 @@ def partial_gromov_wasserstein2( def solve_partial_gromov_linesearch( - G, deltaG, cost_G, df_G, df_Gc, M, reg, alpha_min=None, alpha_max=None, - nx=None, **kwargs): + G, + deltaG, + cost_G, + df_G, + df_Gc, + M, + reg, + alpha_min=None, + alpha_max=None, + nx=None, + **kwargs, +): """ Solve the linesearch in the FW iterations of partial (F)GW following eq.5 of :ref:`[29]`. @@ -479,7 +572,7 @@ def solve_partial_gromov_linesearch( alpha = np.clip(alpha, alpha_min, alpha_max) # the new cost is deduced from the line search quadratic function - cost_G = cost_G + a * (alpha ** 2) + b * alpha + cost_G = cost_G + a * (alpha**2) + b * alpha # update the gradient for next cg iteration df_G = df_G + alpha * df_deltaG @@ -487,8 +580,20 @@ def solve_partial_gromov_linesearch( def entropic_partial_gromov_wasserstein( - C1, C2, p=None, q=None, reg=1., m=None, loss_fun='square_loss', G0=None, - numItermax=1000, tol=1e-7, symmetric=None, log=False, verbose=False): + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): r""" Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -621,14 +726,17 @@ def entropic_partial_gromov_wasserstein( if m is None: m = min(nx.sum(p), nx.sum(q)) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") elif m > min(nx.sum(p), nx.sum(q)): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) if G0 is None: - G0 = nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. else: # Check marginals of G0 @@ -636,7 +744,9 @@ def entropic_partial_gromov_wasserstein( assert nx.any(nx.sum(G0, 0) <= q) if symmetric is None: - symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) + symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose( + C2, C2.T, atol=1e-10 + ) # Setup gradient computation fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) @@ -655,6 +765,7 @@ def f(G): return gwloss(constC1 + constC2, hC1, hC2, G, nx) if symmetric: + def df(G): pG = nx.sum(G, 1) qG = nx.sum(G, 0) @@ -672,40 +783,56 @@ def df(G): constC2t = nx.outer(ones_p, nx.dot(qG, fC2)) return 0.5 * ( - gwggrad(constC1 + constC2, hC1, hC2, G, nx) + - gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx)) + gwggrad(constC1 + constC2, hC1, hC2, G, nx) + + gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx) + ) cpt = 0 err = 1 - loge = {'err': []} + loge = {"err": []} - while (err > tol and cpt < numItermax): + while err > tol and cpt < numItermax: Gprev = G0 M_entr = df(G0) G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m) if cpt % 10 == 0: # to speed up the computations err = np.linalg.norm(G0 - Gprev) if log: - loge['err'].append(err) + loge["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}|{:12s}'.format( - 'It.', 'Err', 'Loss') + '\n' + '-' * 31) - print('{:5d}|{:8e}|{:8e}'.format(cpt, err, f(G0))) + print( + "{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss") + + "\n" + + "-" * 31 + ) + print("{:5d}|{:8e}|{:8e}".format(cpt, err, f(G0))) cpt += 1 if log: - loge['partial_gw_dist'] = f(G0) + loge["partial_gw_dist"] = f(G0) return G0, loge else: return G0 def entropic_partial_gromov_wasserstein2( - C1, C2, p=None, q=None, reg=1., m=None, loss_fun='square_loss', G0=None, - numItermax=1000, tol=1e-7, symmetric=None, log=False, verbose=False): + C1, + C2, + p=None, + q=None, + reg=1.0, + m=None, + loss_fun="square_loss", + G0=None, + numItermax=1000, + tol=1e-7, + symmetric=None, + log=False, + verbose=False, +): r""" Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -810,12 +937,12 @@ def entropic_partial_gromov_wasserstein2( """ partial_gw, log_gw = entropic_partial_gromov_wasserstein( - C1, C2, p, q, reg, m, loss_fun, G0, numItermax, tol, - symmetric, True, verbose) + C1, C2, p, q, reg, m, loss_fun, G0, numItermax, tol, symmetric, True, verbose + ) - log_gw['T'] = partial_gw + log_gw["T"] = partial_gw if log: - return log_gw['partial_gw_dist'], log_gw + return log_gw["partial_gw_dist"], log_gw else: - return log_gw['partial_gw_dist'] + return log_gw["partial_gw_dist"] diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py index 4b952a965..ac2db5d2d 100644 --- a/ot/gromov/_quantized.py +++ b/ot/gromov/_quantized.py @@ -12,12 +12,14 @@ try: from networkx.algorithms.community import asyn_fluidc, louvain_communities from networkx import from_numpy_array, pagerank + networkx_import = True except ImportError: networkx_import = False try: from sklearn.cluster import SpectralClustering, KMeans + sklearn_import = True except ImportError: sklearn_import = False @@ -32,9 +34,23 @@ def quantized_fused_gromov_wasserstein_partitioned( - CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR=None, - alpha=1., build_OT=False, log=False, armijo=False, max_iter=1e4, - tol_rel=1e-9, tol_abs=1e-9, nx=None, **kwargs): + CR1, + CR2, + list_R1, + list_R2, + list_p1, + list_p2, + MR=None, + alpha=1.0, + build_OT=False, + log=False, + armijo=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + nx=None, + **kwargs, +): r""" Returns the quantized Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, @@ -155,24 +171,43 @@ def quantized_fused_gromov_wasserstein_partitioned( pR2 = nx.from_numpy(list_to_array([nx.sum(q) for q in list_p2])) # compute global alignment - if alpha == 1.: + if alpha == 1.0: res_global = gromov_wasserstein( - CR1, CR2, pR1, pR2, loss_fun='square_loss', log=log, - armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs) + CR1, + CR2, + pR1, + pR2, + loss_fun="square_loss", + log=log, + armijo=armijo, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + ) if log: - T_global, dist_global = res_global[0], res_global[1]['gw_dist'] + T_global, dist_global = res_global[0], res_global[1]["gw_dist"] else: T_global = res_global - elif (alpha < 1.) and (alpha > 0.): - + elif (alpha < 1.0) and (alpha > 0.0): res_global = fused_gromov_wasserstein( - MR, CR1, CR2, pR1, pR2, 'square_loss', alpha=alpha, log=log, - armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs) + MR, + CR1, + CR2, + pR1, + pR2, + "square_loss", + alpha=alpha, + log=log, + armijo=armijo, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + ) if log: - T_global, dist_global = res_global[0], res_global[1]['fgw_dist'] + T_global, dist_global = res_global[0], res_global[1]["fgw_dist"] else: T_global = res_global @@ -180,11 +215,12 @@ def quantized_fused_gromov_wasserstein_partitioned( raise ValueError( f""" `alpha='{alpha}'` should be in ]0, 1]. - """) + """ + ) if log: log_ = {} - log_['global dist'] = dist_global + log_["global dist"] = dist_global # compute local alignments Ts_local = {} @@ -193,13 +229,20 @@ def quantized_fused_gromov_wasserstein_partitioned( for i in range(npart1): for j in range(npart2): - if T_global[i, j] != 0.: - res_1d = emd_1d(list_R1[i], list_R2[j], list_p1_norm[i], list_p2_norm[j], - metric='sqeuclidean', p=1., log=log) + if T_global[i, j] != 0.0: + res_1d = emd_1d( + list_R1[i], + list_R2[j], + list_p1_norm[i], + list_p2_norm[j], + metric="sqeuclidean", + p=1.0, + log=log, + ) if log: T_local, log_local = res_1d Ts_local[(i, j)] = T_local - log_[f'local dist ({i},{j})'] = log_local['cost'] + log_[f"local dist ({i},{j})"] = log_local["cost"] else: Ts_local[(i, j)] = res_1d @@ -208,8 +251,10 @@ def quantized_fused_gromov_wasserstein_partitioned( for i in range(npart1): list_Ti = [] for j in range(npart2): - if T_global[i, j] == 0.: - T_local = nx.zeros((list_R1[i].shape[0], list_R2[j].shape[0]), type_as=T_global) + if T_global[i, j] == 0.0: + T_local = nx.zeros( + (list_R1[i].shape[0], list_R2[j].shape[0]), type_as=T_global + ) else: T_local = T_global[i, j] * Ts_local[(i, j)] list_Ti.append(T_local) @@ -228,8 +273,9 @@ def quantized_fused_gromov_wasserstein_partitioned( return T_global, Ts_local, T -def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., - random_state=0, nx=None): +def get_graph_partition( + C, npart, part_method="random", F=None, alpha=1.0, random_state=0, nx=None +): r""" Partitioning a given graph with structure matrix :math:`\mathbf{C} \in R^{n \times n}` into `npart` partitions either 'random', or using one of {'louvain', 'fluid'} @@ -277,14 +323,14 @@ def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., n = C.shape[0] C0 = C - if (alpha != 1.) and (F is None): + if (alpha != 1.0) and (F is None): raise ValueError("`alpha != 1` but node features are not provided.") if npart >= n: warnings.warn( "Requested number of partitions higher than the number of nodes" "hence we enforce each node to be a partition.", - stacklevel=2 + stacklevel=2, ) part = np.arange(n) @@ -292,12 +338,12 @@ def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., elif npart == 1: part = np.zeros(n) - elif part_method == 'random': + elif part_method == "random": # randomly partition the space random.seed(random_state) part = list_to_array(random.choices(np.arange(npart), k=C.shape[0])) - elif part_method == 'louvain': + elif part_method == "louvain": C = nx.to_numpy(C0) graph = from_numpy_array(C) part_sets = louvain_communities(graph, seed=random_state) @@ -306,7 +352,7 @@ def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., set_ = list(set_) part[set_] = iset_ - elif part_method == 'fluid': + elif part_method == "fluid": C = nx.to_numpy(C0) graph = from_numpy_array(C) part_sets = asyn_fluidc(graph, npart, seed=random_state) @@ -315,14 +361,14 @@ def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., set_ = list(set_) part[set_] = iset_ - elif part_method == 'spectral': + elif part_method == "spectral": C = nx.to_numpy(C0) - sc = SpectralClustering(n_clusters=npart, - random_state=random_state, - affinity='precomputed').fit(C) + sc = SpectralClustering( + n_clusters=npart, random_state=random_state, affinity="precomputed" + ).fit(C) part = sc.labels_ - elif part_method in ['GW', 'FGW']: + elif part_method in ["GW", "FGW"]: raise ValueError(f"`part_method == {part_method}` not implemented yet.") else: @@ -330,11 +376,12 @@ def get_graph_partition(C, npart, part_method='random', F=None, alpha=1., f""" Unknown `part_method='{part_method}'`. Use one of: {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}. - """) + """ + ) return nx.from_numpy(part, type_as=C0) -def get_graph_representants(C, part, rep_method='pagerank', random_state=0, nx=None): +def get_graph_representants(C, part, rep_method="pagerank", random_state=0, nx=None): r""" Get representative node for each partition given by :math:`\mathbf{part} \in R^{n}` of a graph with structure matrix :math:`\mathbf{C} \in R^{n \times n}`. @@ -376,13 +423,13 @@ def get_graph_representants(C, part, rep_method='pagerank', random_state=0, nx=N if n_part_ids == C.shape[0]: rep_indices = nx.arange(n_part_ids) - elif rep_method == 'random': + elif rep_method == "random": random.seed(random_state) for id_, part_id in enumerate(part_ids): indices = nx.where(part == part_id)[0] rep_indices.append(random.choice(indices)) - elif rep_method == 'pagerank': + elif rep_method == "pagerank": C0, part0 = C, part C = nx.to_numpy(C0) part = nx.to_numpy(part0) @@ -401,13 +448,15 @@ def get_graph_representants(C, part, rep_method='pagerank', random_state=0, nx=N f""" Unknown `rep_method='{rep_method}'`. Use one of: {'random', 'pagerank'}. - """) + """ + ) return rep_indices -def format_partitioned_graph(C, p, part, rep_indices, F=None, M=None, - alpha=1., nx=None): +def format_partitioned_graph( + C, p, part, rep_indices, F=None, M=None, alpha=1.0, nx=None +): r""" Format an attributed graph :math:`(\mathbf{C}, \mathbf{F}, \mathbf{p})` with structure matrix :math:`(\mathbf{C} \in R^{n \times n}`, feature matrix @@ -464,16 +513,17 @@ def format_partitioned_graph(C, p, part, rep_indices, F=None, M=None, nx = get_backend(*arr) - if alpha != 1.: + if alpha != 1.0: if (M is None) or (F is None): raise ValueError( f""" `alpha == {alpha} != 1` but features information is not properly provided. - """) + """ + ) CR = C[rep_indices, :][:, rep_indices] - if alpha != 1.: + if alpha != 1.0: C_new = alpha * C + (1 - alpha) * M else: C_new = C @@ -488,7 +538,6 @@ def format_partitioned_graph(C, p, part, rep_indices, F=None, M=None, list_p.append(p[indices]) if F is None: - return CR, list_R, list_p else: FR = F[rep_indices, :] @@ -497,10 +546,27 @@ def format_partitioned_graph(C, p, part, rep_indices, F=None, M=None, def quantized_fused_gromov_wasserstein( - C1, C2, npart1, npart2, p=None, q=None, C1_aux=None, C2_aux=None, - F1=None, F2=None, alpha=1., part_method='fluid', - rep_method='random', log=False, armijo=False, max_iter=1e4, - tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): + C1, + C2, + npart1, + npart2, + p=None, + q=None, + C1_aux=None, + C2_aux=None, + F1=None, + F2=None, + alpha=1.0, + part_method="fluid", + rep_method="random", + log=False, + armijo=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Returns the quantized Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, @@ -626,7 +692,9 @@ def quantized_fused_gromov_wasserstein( Quantized gromov-wasserstein. ECML PKDD 2021. Springer International Publishing. """ - if (part_method in ['fluid', 'louvain', 'fluid_fused', 'louvain_fused'] or (rep_method in ['pagerank', 'pagerank_fused'])): + if part_method in ["fluid", "louvain", "fluid_fused", "louvain_fused"] or ( + rep_method in ["pagerank", "pagerank_fused"] + ): if not networkx_import: warnings.warn( f""" @@ -635,10 +703,10 @@ def quantized_fused_gromov_wasserstein( default methods. Consider installing Networkx to fix this. """ ) - part_method = 'random' - rep_method = 'random' + part_method = "random" + rep_method = "random" - if (part_method in ['spectral', 'spectral_fused']) and (not sklearn_import): + if (part_method in ["spectral", "spectral_fused"]) and (not sklearn_import): warnings.warn( f""" Scikit-learn is not installed, so part_method={part_method} and/or @@ -646,16 +714,17 @@ def quantized_fused_gromov_wasserstein( default methods. Consider installing Scikit-learn to fix this. """ ) - part_method = 'random' - rep_method = 'random' + part_method = "random" + rep_method = "random" - if (('fused' in part_method) or ('fused' in rep_method) or (part_method == 'FGW')): + if ("fused" in part_method) or ("fused" in rep_method) or (part_method == "FGW"): if (F1 is None) or (F2 is None): raise ValueError( f""" `part_method='{part_method}'` and/or `rep_method='{rep_method}'` require feature matrices which are not provided as inputs. - """) + """ + ) arr = [C1, C2] if C1_aux is not None: @@ -684,22 +753,29 @@ def quantized_fused_gromov_wasserstein( DF1 = None DF2 = None # compute attributed graph partitions potentially using the auxiliary structure - if 'fused' in part_method: - + if "fused" in part_method: DF1 = dist(F1, F1) DF2 = dist(F2, F2) C1_new = alpha * C1_aux + (1 - alpha) * DF1 C2_new = alpha * C2_aux + (1 - alpha) * DF2 part_method_ = part_method[:-6] - part1 = get_graph_partition(C1_new, npart1, part_method_, random_state=random_state, nx=nx) - part2 = get_graph_partition(C2_new, npart2, part_method_, random_state=random_state, nx=nx) + part1 = get_graph_partition( + C1_new, npart1, part_method_, random_state=random_state, nx=nx + ) + part2 = get_graph_partition( + C2_new, npart2, part_method_, random_state=random_state, nx=nx + ) else: - part1 = get_graph_partition(C1_aux, npart1, part_method, F1, alpha, random_state, nx) - part2 = get_graph_partition(C2_aux, npart2, part_method, F2, alpha, random_state, nx) + part1 = get_graph_partition( + C1_aux, npart1, part_method, F1, alpha, random_state, nx + ) + part2 = get_graph_partition( + C2_aux, npart2, part_method, F2, alpha, random_state, nx + ) - if 'fused' in rep_method: + if "fused" in rep_method: if DF1 is None: DF1 = dist(F1, F1) DF2 = dist(F2, F2) @@ -708,17 +784,29 @@ def quantized_fused_gromov_wasserstein( rep_method_ = rep_method[:-6] - rep_indices1 = get_graph_representants(C1_new, part1, rep_method_, random_state, nx) - rep_indices2 = get_graph_representants(C2_new, part2, rep_method_, random_state, nx) + rep_indices1 = get_graph_representants( + C1_new, part1, rep_method_, random_state, nx + ) + rep_indices2 = get_graph_representants( + C2_new, part2, rep_method_, random_state, nx + ) else: - rep_indices1 = get_graph_representants(C1_aux, part1, rep_method, random_state, nx) - rep_indices2 = get_graph_representants(C2_aux, part2, rep_method, random_state, nx) + rep_indices1 = get_graph_representants( + C1_aux, part1, rep_method, random_state, nx + ) + rep_indices2 = get_graph_representants( + C2_aux, part2, rep_method, random_state, nx + ) # format partitions over (C1, F1) and (C2, F2) if (F1 is None) and (F2 is None): - CR1, list_R1, list_p1 = format_partitioned_graph(C1, p, part1, rep_indices1, nx=nx) - CR2, list_R2, list_p2 = format_partitioned_graph(C2, q, part2, rep_indices2, nx=nx) + CR1, list_R1, list_p1 = format_partitioned_graph( + C1, p, part1, rep_indices1, nx=nx + ) + CR2, list_R2, list_p2 = format_partitioned_graph( + C2, q, part2, rep_indices2, nx=nx + ) MR = None else: @@ -726,31 +814,49 @@ def quantized_fused_gromov_wasserstein( DF1 = dist(F1, F1) DF2 = dist(F2, F2) - CR1, list_R1, list_p1, FR1 = format_partitioned_graph(C1, p, part1, rep_indices1, F1, DF1, alpha, nx) - CR2, list_R2, list_p2, FR2 = format_partitioned_graph(C2, q, part2, rep_indices2, F2, DF2, alpha, nx) + CR1, list_R1, list_p1, FR1 = format_partitioned_graph( + C1, p, part1, rep_indices1, F1, DF1, alpha, nx + ) + CR2, list_R2, list_p2, FR2 = format_partitioned_graph( + C2, q, part2, rep_indices2, F2, DF2, alpha, nx + ) MR = dist(FR1, FR2) # call to partitioned quantized fused gromov-wasserstein solver res = quantized_fused_gromov_wasserstein_partitioned( - CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR, alpha, build_OT=True, - log=log, armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, - tol_abs=tol_abs, nx=nx, **kwargs) + CR1, + CR2, + list_R1, + list_R2, + list_p1, + list_p2, + MR, + alpha, + build_OT=True, + log=log, + armijo=armijo, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + nx=nx, + **kwargs, + ) if log: T_global, Ts_local, T, log_ = res # compute the transport cost on structures - constC, hC1, hC2 = init_matrix(C1, C2, p, q, 'square_loss', nx) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, "square_loss", nx) structure_cost = gwloss(constC, hC1, hC2, T, nx) - if alpha != 1.: + if alpha != 1.0: M = dist(F1, F2) feature_cost = nx.sum(M * T) else: - feature_cost = 0. + feature_cost = 0.0 - log_['qFGW_dist'] = alpha * structure_cost + (1 - alpha) * feature_cost + log_["qFGW_dist"] = alpha * structure_cost + (1 - alpha) * feature_cost return T_global, Ts_local, T, log_ else: @@ -760,7 +866,8 @@ def quantized_fused_gromov_wasserstein( def get_partition_and_representants_samples( - X, npart, method='kmeans', random_state=0, nx=None): + X, npart, method="kmeans", random_state=0, nx=None +): r""" Compute `npart` partitions and representants over samples :math:`\mathbf{X} \in R^{n \times d}` using either a random or a kmeans algorithm. @@ -807,7 +914,7 @@ def get_partition_and_representants_samples( warnings.warn( "Requested number of partitions higher than the number of nodes" "hence we enforce each node to be a partition.", - stacklevel=2 + stacklevel=2, ) part = nx.arange(n) @@ -818,7 +925,7 @@ def get_partition_and_representants_samples( part = nx.zeros(n) rep_indices = [random.choice(nx.arange(n))] - elif method == 'random': + elif method == "random": # randomly partition the space random.seed(random_state) part = list_to_array(random.choices(np.arange(npart), k=X.shape[0])) @@ -831,7 +938,7 @@ def get_partition_and_representants_samples( indices = nx.where(part == part_id)[0] rep_indices.append(random.choice(indices)) - elif method == 'kmeans': + elif method == "kmeans": X = nx.to_numpy(X0) km = KMeans(n_clusters=npart, random_state=random_state).fit(X) part = nx.from_numpy(km.labels_, type_as=X0) @@ -847,13 +954,13 @@ def get_partition_and_representants_samples( raise ValueError( f""" Unknown `method='{method}'`. Use one of: {'random', 'kmeans'} - """) + """ + ) return part, rep_indices -def format_partitioned_samples( - X, p, part, rep_indices, F=None, alpha=1., nx=None): +def format_partitioned_samples(X, p, part, rep_indices, F=None, alpha=1.0, nx=None): r""" Format an attributed graph :math:`(\mathbf{D}(\mathbf{X}), \mathbf{F}, \mathbf{p})` with euclidean structure matrix :math:`(\mathbf{D}(\mathbf{X}) \in R^{n \times n}`, @@ -906,12 +1013,13 @@ def format_partitioned_samples( nx = get_backend(*arr) - if alpha != 1.: + if alpha != 1.0: if F is None: raise ValueError( f""" `alpha == {alpha} != 1` but features information is not properly provided. - """) + """ + ) XR = X[rep_indices, :] CR = dist(XR, XR) @@ -927,13 +1035,12 @@ def format_partitioned_samples( if alpha != 1: features_R = dist(F[indices], F[rep_indices[id_]][None, :]) else: - features_R = 0. + features_R = 0.0 list_R.append(alpha * structure_R + (1 - alpha) * features_R) list_p.append(p[indices]) if F is None: - return CR, list_R, list_p else: FR = F[rep_indices, :] @@ -942,9 +1049,24 @@ def format_partitioned_samples( def quantized_fused_gromov_wasserstein_samples( - X1, X2, npart1, npart2, p=None, q=None, F1=None, F2=None, alpha=1., - method='kmeans', log=False, armijo=False, max_iter=1e4, - tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): + X1, + X2, + npart1, + npart2, + p=None, + q=None, + F1=None, + F2=None, + alpha=1.0, + method="kmeans", + log=False, + armijo=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Returns the quantized Fused Gromov-Wasserstein transport between samples endowed with their respective euclidean geometry :math:`(\mathbf{D}(\mathbf{X_1}), \mathbf{F_1}, \mathbf{p})` @@ -1054,7 +1176,7 @@ def quantized_fused_gromov_wasserstein_samples( """ - if (method in ['kmeans', 'kmeans_fused']) and (not sklearn_import): + if (method in ["kmeans", "kmeans_fused"]) and (not sklearn_import): warnings.warn( f""" Scikit-learn is not installed, so method={method} cannot be used @@ -1062,13 +1184,14 @@ def quantized_fused_gromov_wasserstein_samples( Scikit-learn to fix this. """ ) - method = 'random' + method = "random" - if ('fused' in method) and ((F1 is None) or (F2 is None)): + if ("fused" in method) and ((F1 is None) or (F2 is None)): raise ValueError( f""" `method='{method}'` requires feature matrices which are not provided as inputs. - """) + """ + ) arr = [X1, X2] if p is not None: @@ -1087,7 +1210,7 @@ def quantized_fused_gromov_wasserstein_samples( nx = get_backend(*arr) # compute attributed partitions and representants - if ('fused' in method) and (alpha != 1.): + if ("fused" in method) and (alpha != 1.0): X1_new = nx.concatenate([alpha * X1, (1 - alpha) * F1], axis=1) X2_new = nx.concatenate([alpha * X2, (1 - alpha) * F2], axis=1) method_ = method[:-6] @@ -1095,32 +1218,52 @@ def quantized_fused_gromov_wasserstein_samples( X1_new, X2_new = X1, X2 method_ = method part1, rep_indices1 = get_partition_and_representants_samples( - X1_new, npart1, method_, random_state, nx) + X1_new, npart1, method_, random_state, nx + ) part2, rep_indices2 = get_partition_and_representants_samples( - X2_new, npart2, method_, random_state, nx) + X2_new, npart2, method_, random_state, nx + ) # format partitions over (C1, F1) and (C2, F2) if (F1 is None) and (F2 is None): CR1, list_R1, list_p1 = format_partitioned_samples( - X1, p, part1, rep_indices1, nx=nx) + X1, p, part1, rep_indices1, nx=nx + ) CR2, list_R2, list_p2 = format_partitioned_samples( - X2, q, part2, rep_indices2, nx=nx) + X2, q, part2, rep_indices2, nx=nx + ) MR = None else: CR1, list_R1, list_p1, FR1 = format_partitioned_samples( - X1, p, part1, rep_indices1, F1, alpha, nx) + X1, p, part1, rep_indices1, F1, alpha, nx + ) CR2, list_R2, list_p2, FR2 = format_partitioned_samples( - X2, q, part2, rep_indices2, F2, alpha, nx) + X2, q, part2, rep_indices2, F2, alpha, nx + ) MR = dist(FR1, FR2) # call to partitioned quantized fused gromov-wasserstein solver res = quantized_fused_gromov_wasserstein_partitioned( - CR1, CR2, list_R1, list_R2, list_p1, list_p2, MR, alpha, build_OT=True, - log=log, armijo=armijo, max_iter=max_iter, tol_rel=tol_rel, - tol_abs=tol_abs, nx=nx, **kwargs) + CR1, + CR2, + list_R1, + list_R2, + list_p1, + list_p2, + MR, + alpha, + build_OT=True, + log=log, + armijo=armijo, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + nx=nx, + **kwargs, + ) if log: T_global, Ts_local, T, log_ = res @@ -1129,16 +1272,16 @@ def quantized_fused_gromov_wasserstein_samples( C2 = dist(X2, X2) # compute the transport cost on structures - constC, hC1, hC2 = init_matrix(C1, C2, p, q, 'square_loss', nx) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, "square_loss", nx) structure_cost = gwloss(constC, hC1, hC2, T, nx) - if alpha != 1.: + if alpha != 1.0: M = dist(F1, F2) feature_cost = nx.sum(M * T) else: - feature_cost = 0. + feature_cost = 0.0 - log_['qFGW_dist'] = alpha * structure_cost + (1 - alpha) * feature_cost + log_["qFGW_dist"] = alpha * structure_cost + (1 - alpha) * feature_cost return T_global, Ts_local, T, log_ else: diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index c509a1046..05ad8b25c 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -11,28 +11,41 @@ import numpy as np -from ..utils import ( - list_to_array, unif, dist, UndefinedParameter -) +from ..utils import list_to_array, unif, dist, UndefinedParameter from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend from ._utils import ( - init_matrix_semirelaxed, gwloss, gwggrad, - update_barycenter_structure, update_barycenter_feature, + init_matrix_semirelaxed, + gwloss, + gwggrad, + update_barycenter_structure, + update_barycenter_feature, semirelaxed_init_plan, ) try: from sklearn.cluster import KMeans + sklearn_import = True except ImportError: sklearn_import = False def semirelaxed_gromov_wasserstein( - C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + log=False, + G0=None, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -127,14 +140,17 @@ def semirelaxed_gromov_wasserstein( nx = get_backend(*arr) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, method=G0, random_state=random_state, nx=nx) + C1, C2, p, method=G0, random_state=random_state, nx=nx + ) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -151,6 +167,7 @@ def f(G): return gwloss(constC + marginal_product, hC1, hC2, G, nx) if symmetric: + def df(G): qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) @@ -162,23 +179,76 @@ def df(G): qG = nx.sum(G, 0) marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) - return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + return 0.5 * ( + gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx) + ) def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): - return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs) + return solve_semirelaxed_gromov_linesearch( + G, + deltaG, + cost_G, + hC1, + hC2, + ones_p, + M=0.0, + reg=1.0, + fC2t=fC2t, + nx=nx, + **kwargs, + ) if log: - res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) - log['srgw_dist'] = log['loss'][-1] + res, log = semirelaxed_cg( + p, + q, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=True, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) + log["srgw_dist"] = log["loss"][-1] return res, log else: - return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + return semirelaxed_cg( + p, + q, + 0.0, + 1.0, + f, + df, + G0, + line_search, + log=False, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) def semirelaxed_gromov_wasserstein2( - C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, - G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, - **kwargs): + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + log=False, + G0=None, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Returns the semi-relaxed Gromov-Wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -271,21 +341,33 @@ def semirelaxed_gromov_wasserstein2( p = unif(C1.shape[0], type_as=C1) T, log_srgw = semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, symmetric, log=True, G0=G0, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, - random_state=random_state, **kwargs) + C1, + C2, + p, + loss_fun, + symmetric, + log=True, + G0=G0, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + random_state=random_state, + **kwargs, + ) q = nx.sum(T, 0) - log_srgw['T'] = T - srgw = log_srgw['srgw_dist'] + log_srgw["T"] = T + srgw = log_srgw["srgw_dist"] - if loss_fun == 'square_loss': + if loss_fun == "square_loss": gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - elif loss_fun == 'kl_loss': - gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2)) @@ -296,9 +378,21 @@ def semirelaxed_gromov_wasserstein2( def semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, - random_state=0, **kwargs): + M, + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + alpha=0.5, + G0=None, + log=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Computes the semi-relaxed Fused Gromov-Wasserstein transport between two graphs (see [48]). @@ -399,14 +493,17 @@ def semirelaxed_fused_gromov_wasserstein( nx = get_backend(*arr) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx + ) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -423,6 +520,7 @@ def f(G): return gwloss(constC + marginal_product, hC1, hC2, G, nx) if symmetric: + def df(G): qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) @@ -434,24 +532,78 @@ def df(G): qG = nx.sum(G, 0) marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) - return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + return 0.5 * ( + gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx) + ) def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs) + G, + deltaG, + cost_G, + hC1, + hC2, + ones_p, + M=(1 - alpha) * M, + reg=alpha, + fC2t=fC2t, + nx=nx, + **kwargs, + ) if log: - res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) - log['srfgw_dist'] = log['loss'][-1] + res, log = semirelaxed_cg( + p, + q, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=True, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) + log["srfgw_dist"] = log["loss"][-1] return res, log else: - return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) + return semirelaxed_cg( + p, + q, + (1 - alpha) * M, + alpha, + f, + df, + G0, + line_search, + log=False, + numItermax=max_iter, + stopThr=tol_rel, + stopThr2=tol_abs, + **kwargs, + ) def semirelaxed_fused_gromov_wasserstein2( - M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, - random_state=0, **kwargs): + M, + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + alpha=0.5, + G0=None, + log=False, + max_iter=1e4, + tol_rel=1e-9, + tol_abs=1e-9, + random_state=0, + **kwargs, +): r""" Computes the semi-relaxed FGW divergence between two graphs (see [48]). @@ -551,32 +703,49 @@ def semirelaxed_fused_gromov_wasserstein2( p = unif(C1.shape[0], type_as=C1) T, log_fgw = semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, - random_state=random_state, **kwargs) + M, + C1, + C2, + p, + loss_fun, + symmetric, + alpha, + G0, + log=True, + max_iter=max_iter, + tol_rel=tol_rel, + tol_abs=tol_abs, + random_state=random_state, + **kwargs, + ) q = nx.sum(T, 0) - srfgw_dist = log_fgw['srfgw_dist'] - log_fgw['T'] = T - log_fgw['lin_loss'] = nx.sum(M * T) * (1 - alpha) - log_fgw['quad_loss'] = srfgw_dist - log_fgw['lin_loss'] + srfgw_dist = log_fgw["srfgw_dist"] + log_fgw["T"] = T + log_fgw["lin_loss"] = nx.sum(M * T) * (1 - alpha) + log_fgw["quad_loss"] = srfgw_dist - log_fgw["lin_loss"] - if loss_fun == 'square_loss': + if loss_fun == "square_loss": gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - elif loss_fun == 'kl_loss': - gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) - gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) + elif loss_fun == "kl_loss": + gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot( + T, nx.dot(nx.log(C2 + 1e-15), T.T) + ) + gC2 = -nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q) if isinstance(alpha, int) or isinstance(alpha, float): - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), - (alpha * gC1, alpha * gC2, (1 - alpha) * T)) + srfgw_dist = nx.set_gradients( + srfgw_dist, (C1, C2, M), (alpha * gC1, alpha * gC2, (1 - alpha) * T) + ) else: lin_term = nx.sum(T * M) srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha - srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), - (alpha * gC1, alpha * gC2, (1 - alpha) * T, - srgw_term - lin_term)) + srfgw_dist = nx.set_gradients( + srfgw_dist, + (C1, C2, M, alpha), + (alpha * gC1, alpha * gC2, (1 - alpha) * T, srgw_term - lin_term), + ) if log: return srfgw_dist, log_fgw @@ -584,8 +753,21 @@ def semirelaxed_fused_gromov_wasserstein2( return srfgw_dist -def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, - M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs): +def solve_semirelaxed_gromov_linesearch( + G, + deltaG, + cost_G, + C1, + C2, + ones_p, + M, + reg, + fC2t=None, + alpha_min=None, + alpha_max=None, + nx=None, + **kwargs, +): """ Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence. @@ -648,27 +830,40 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) dot = nx.dot(nx.dot(C1, deltaG), C2.T) if fC2t is None: - fC2t = C2.T ** 2 + fC2t = C2.T**2 dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t) dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t) a = reg * nx.sum((dot_qdeltaG - dot) * deltaG) - b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) + b = nx.sum(M * deltaG) + reg * ( + nx.sum((dot_qdeltaG - dot) * G) + + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG) + ) alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) # the new cost can be deduced from the line search quadratic function - cost_G = cost_G + a * (alpha ** 2) + b * alpha + cost_G = cost_G + a * (alpha**2) + b * alpha return alpha, 1, cost_G def entropic_semirelaxed_gromov_wasserstein( - C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, - random_state=0): + C1, + C2, + p=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1e4, + tol=1e-9, + log=False, + verbose=False, + random_state=0, +): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -711,7 +906,7 @@ def entropic_semirelaxed_gromov_wasserstein( symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). verbose : bool, optional Print information along iterations G0: array-like of shape (ns,nt) or string, optional @@ -760,14 +955,17 @@ def entropic_semirelaxed_gromov_wasserstein( nx = get_backend(*arr) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, method=G0, random_state=random_state, nx=nx) + C1, C2, p, method=G0, random_state=random_state, nx=nx + ) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -779,6 +977,7 @@ def entropic_semirelaxed_gromov_wasserstein( ones_p = nx.ones(p.shape[0], type_as=p) if symmetric: + def df(G): qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) @@ -790,20 +989,22 @@ def df(G): qG = nx.sum(G, 0) marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) - return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + return 0.5 * ( + gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx) + ) cpt = 0 err = 1e15 G = G0 if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Gprev = G # compute the kernel - K = G * nx.exp(- df(G) / epsilon) + K = G * nx.exp(-df(G) / epsilon) scaling = p / nx.sum(K, 1) G = nx.reshape(scaling, (-1, 1)) * K if cpt % 10 == 0: @@ -812,29 +1013,39 @@ def df(G): err = nx.norm(G - Gprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if log: qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) - log['srgw_dist'] = gwloss(constC + marginal_product, hC1, hC2, G, nx) + log["srgw_dist"] = gwloss(constC + marginal_product, hC1, hC2, G, nx) return G, log else: return G def entropic_semirelaxed_gromov_wasserstein2( - C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, - random_state=0, **kwargs): + C1, + C2, + p=None, + loss_fun="square_loss", + epsilon=0.1, + symmetric=None, + G0=None, + max_iter=1e4, + tol=1e-9, + log=False, + verbose=False, + random_state=0, + **kwargs, +): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -879,7 +1090,7 @@ def entropic_semirelaxed_gromov_wasserstein2( symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). verbose : bool, optional Print information along iterations G0: array-like of shape (ns,nt) or string, optional @@ -916,21 +1127,44 @@ def entropic_semirelaxed_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srgw = entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, epsilon, symmetric, G0, max_iter, tol, - log=True, verbose=verbose, random_state=random_state) - - log_srgw['T'] = T + C1, + C2, + p, + loss_fun, + epsilon, + symmetric, + G0, + max_iter, + tol, + log=True, + verbose=verbose, + random_state=random_state, + ) + + log_srgw["T"] = T if log: - return log_srgw['srgw_dist'], log_srgw + return log_srgw["srgw_dist"], log_srgw else: - return log_srgw['srgw_dist'] + return log_srgw["srgw_dist"] def entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, - random_state=0): + M, + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + epsilon=0.1, + alpha=0.5, + G0=None, + max_iter=1e4, + tol=1e-9, + log=False, + verbose=False, + random_state=0, +): r""" Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -976,7 +1210,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional @@ -1025,14 +1259,17 @@ def entropic_semirelaxed_fused_gromov_wasserstein( nx = get_backend(*arr) if symmetric is None: - symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose( + C2, C2.T, atol=1e-10 + ) if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx + ) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -1044,6 +1281,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( ones_p = nx.ones(p.shape[0], type_as=p) dM = (1 - alpha) * M if symmetric: + def df(G): qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) @@ -1055,20 +1293,27 @@ def df(G): qG = nx.sum(G, 0) marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) - return 0.5 * alpha * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + dM + return ( + 0.5 + * alpha + * ( + gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx) + ) + + dM + ) cpt = 0 err = 1e15 G = G0 if log: - log = {'err': []} - - while (err > tol and cpt < max_iter): + log = {"err": []} + while err > tol and cpt < max_iter: Gprev = G # compute the kernel - K = G * nx.exp(- df(G) / epsilon) + K = G * nx.exp(-df(G) / epsilon) scaling = p / nx.sum(K, 1) G = nx.reshape(scaling, (-1, 1)) * K if cpt % 10 == 0: @@ -1077,31 +1322,42 @@ def df(G): err = nx.norm(G - Gprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt += 1 if log: qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) - log['lin_loss'] = nx.sum(M * G) * (1 - alpha) - log['quad_loss'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) - log['srfgw_dist'] = log['lin_loss'] + log['quad_loss'] + log["lin_loss"] = nx.sum(M * G) * (1 - alpha) + log["quad_loss"] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + log["srfgw_dist"] = log["lin_loss"] + log["quad_loss"] return G, log else: return G def entropic_semirelaxed_fused_gromov_wasserstein2( - M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, - random_state=0): + M, + C1, + C2, + p=None, + loss_fun="square_loss", + symmetric=None, + epsilon=0.1, + alpha=0.5, + G0=None, + max_iter=1e4, + tol=1e-9, + log=False, + verbose=False, + random_state=0, +): r""" Computes the entropic-regularized semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -1147,7 +1403,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. - Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional @@ -1185,22 +1441,48 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, max_iter, tol, - log=True, verbose=verbose, random_state=random_state) - - log_srfgw['T'] = T + M, + C1, + C2, + p, + loss_fun, + symmetric, + epsilon, + alpha, + G0, + max_iter, + tol, + log=True, + verbose=verbose, + random_state=random_state, + ) + + log_srfgw["T"] = T if log: - return log_srfgw['srfgw_dist'], log_srfgw + return log_srfgw["srfgw_dist"], log_srfgw else: - return log_srfgw['srfgw_dist'] + return log_srfgw["srfgw_dist"] def semirelaxed_gromov_barycenters( - N, Cs, ps=None, lambdas=None, loss_fun='square_loss', - symmetric=True, max_iter=1000, tol=1e-9, - stop_criterion='barycenter', warmstartT=False, verbose=False, - log=False, init_C=None, G0='product', random_state=None, **kwargs): + N, + Cs, + ps=None, + lambdas=None, + loss_fun="square_loss", + symmetric=True, + max_iter=1000, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + G0="product", + random_state=None, + **kwargs, +): r""" Returns the Semi-relaxed Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` @@ -1278,8 +1560,10 @@ def semirelaxed_gromov_barycenters( International Conference on Learning Representations (ICLR), 2022. """ - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) arr = [*Cs] if ps is not None: @@ -1299,20 +1583,28 @@ def semirelaxed_gromov_barycenters( # Initialization of transport plans and C (if not provided by user) if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) - if G0 in ['product', 'random_product', 'random']: - T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in range(S)] - C = update_barycenter_structure( - T, Cs, lambdas, loss_fun=loss_fun, nx=nx) - - if G0 in ['product', 'random_product']: + if G0 in ["product", "random_product", "random"]: + T = [ + semirelaxed_init_plan( + Cs[i], + init_C, + ps[i], + method=G0, + use_target=False, + random_state=random_state, + nx=nx, + ) + for i in range(S) + ] + C = update_barycenter_structure(T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + + if G0 in ["product", "random_product"]: # initial structure is constant so we add a small random noise # to avoid getting stuck at init np.random.seed(random_state) noise = np.random.uniform(-0.01, 0.01, size=(N, N)) if symmetric: - noise = (noise + noise.T) / 2. + noise = (noise + noise.T) / 2.0 noise = nx.from_numpy(noise) C = C + noise @@ -1328,29 +1620,49 @@ def semirelaxed_gromov_barycenters( # then use it on graphs to expand for indices in [large_graphs_idx, small_graphs_idx]: if len(indices) > 0: - sub_T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in indices] + sub_T = [ + semirelaxed_init_plan( + Cs[i], + init_C, + ps[i], + method=G0, + use_target=False, + random_state=random_state, + nx=nx, + ) + for i in indices + ] sub_Cs = [Cs[i] for i in indices] sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) init_C = update_barycenter_structure( - sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx + ) for i, idx in enumerate(indices): T[idx] = sub_T[i] list_init_C.append(init_C) if len(list_init_C) == 2: init_C = update_barycenter_structure( - T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + T, Cs, lambdas, loss_fun=loss_fun, nx=nx + ) C = init_C else: C = init_C - T = [semirelaxed_init_plan( - Cs[i], C, ps[i], method=G0, use_target=True, - random_state=random_state, nx=nx) for i in range(S)] - - if stop_criterion == 'barycenter': + T = [ + semirelaxed_init_plan( + Cs[i], + C, + ps[i], + method=G0, + use_target=True, + random_state=random_state, + nx=nx, + ) + for i in range(S) + ] + + if stop_criterion == "barycenter": inner_log = False else: inner_log = True @@ -1358,67 +1670,88 @@ def semirelaxed_gromov_barycenters( if log: log_ = {} - log_['err'] = [] - if stop_criterion == 'loss': - log_['loss'] = [] + log_["err"] = [] + if stop_criterion == "loss": + log_["loss"] = [] for cpt in range(max_iter): - - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C else: prev_loss = curr_loss # get transport plans if warmstartT: - res = [semirelaxed_gromov_wasserstein( - Cs[s], C, ps[s], loss_fun, symmetric, G0=T[s], - max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, - verbose=verbose, **kwargs) - for s in range(S)] + res = [ + semirelaxed_gromov_wasserstein( + Cs[s], + C, + ps[s], + loss_fun, + symmetric, + G0=T[s], + max_iter=max_iter, + tol_rel=tol, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] else: - res = [semirelaxed_gromov_wasserstein( - Cs[s], C, ps[s], loss_fun, symmetric, G0=G0, - max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, - verbose=verbose, **kwargs) - for s in range(S)] + res = [ + semirelaxed_gromov_wasserstein( + Cs[s], + C, + ps[s], + loss_fun, + symmetric, + G0=G0, + max_iter=max_iter, + tol_rel=tol, + tol_abs=0.0, + log=inner_log, + verbose=verbose, + **kwargs, + ) + for s in range(S) + ] - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['srgw_dist'] for output in res]) + curr_loss = np.sum([output[1]["srgw_dist"] for output in res]) # update barycenters - p = nx.concatenate( - [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) + p = nx.concatenate([nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) # update convergence criterion - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": err = nx.norm(C - Cprev) if log: - log_['err'].append(err) + log_["err"].append(err) else: - err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan if log: - log_['loss'].append(curr_loss) - log_['err'].append(err) + log_["loss"].append(curr_loss) + log_["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) if err <= tol: break if log: - log_['T'] = T - log_['p'] = p + log_["T"] = T + log_["p"] = p return C, log_ else: @@ -1426,11 +1759,29 @@ def semirelaxed_gromov_barycenters( def semirelaxed_fgw_barycenters( - N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, - fixed_features=False, p=None, loss_fun='square_loss', - symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter', - warmstartT=False, verbose=False, log=False, init_C=None, init_X=None, - G0='product', random_state=None, **kwargs): + N, + Ys, + Cs, + ps=None, + lambdas=None, + alpha=0.5, + fixed_structure=False, + fixed_features=False, + p=None, + loss_fun="square_loss", + symmetric=True, + max_iter=100, + tol=1e-9, + stop_criterion="barycenter", + warmstartT=False, + verbose=False, + log=False, + init_C=None, + init_X=None, + G0="product", + random_state=None, + **kwargs, +): r""" Returns the Semi-relaxed Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` @@ -1523,8 +1874,10 @@ def semirelaxed_fgw_barycenters( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - if stop_criterion not in ['barycenter', 'loss']: - raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + if stop_criterion not in ["barycenter", "loss"]: + raise ValueError( + f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}." + ) arr = [*Cs, *Ys] if ps is not None: @@ -1543,51 +1896,65 @@ def semirelaxed_fgw_barycenters( if fixed_structure: if init_C is None: - raise UndefinedParameter( - 'If C is fixed it must be provided in init_C') + raise UndefinedParameter("If C is fixed it must be provided in init_C") else: C = init_C if fixed_features: if init_X is None: - raise UndefinedParameter( - 'If X is fixed it must be provided in init_X') + raise UndefinedParameter("If X is fixed it must be provided in init_X") else: X = init_X # Initialization of transport plans, C and X (if not provided by user) - if G0 in ['product', 'random_product', 'random']: + if G0 in ["product", "random_product", "random"]: # both init_X and init_C are simply deduced from transport plans # if not initialized if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) # to know the barycenter shape - T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in range(S)] + T = [ + semirelaxed_init_plan( + Cs[i], + init_C, + ps[i], + method=G0, + use_target=False, + random_state=random_state, + nx=nx, + ) + for i in range(S) + ] - C = update_barycenter_structure( - T, Cs, lambdas, loss_fun=loss_fun, nx=nx) - if G0 in ['product', 'random_product']: + C = update_barycenter_structure(T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + if G0 in ["product", "random_product"]: # initial structure is constant so we add a small random noise # to avoid getting stuck at init np.random.seed(random_state) noise = np.random.uniform(-0.01, 0.01, size=(N, N)) if symmetric: - noise = (noise + noise.T) / 2. + noise = (noise + noise.T) / 2.0 noise = nx.from_numpy(noise) C = C + noise else: - T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in range(S)] + T = [ + semirelaxed_init_plan( + Cs[i], + init_C, + ps[i], + method=G0, + use_target=False, + random_state=random_state, + nx=nx, + ) + for i in range(S) + ] C = init_C if init_X is None: - X = update_barycenter_feature( - T, Ys, lambdas, loss_fun=loss_fun, nx=nx) + X = update_barycenter_feature(T, Ys, lambdas, loss_fun=loss_fun, nx=nx) else: X = init_X @@ -1602,8 +1969,9 @@ def semirelaxed_fgw_barycenters( stacked_features = nx.concatenate(Ys, axis=0) if sklearn_import: stacked_features = nx.to_numpy(stacked_features) - km = KMeans(n_clusters=N, random_state=random_state, - n_init=1).fit(stacked_features) + km = KMeans(n_clusters=N, random_state=random_state, n_init=1).fit( + stacked_features + ) X = nx.from_numpy(km.cluster_centers_) else: raise ValueError( @@ -1614,7 +1982,7 @@ def semirelaxed_fgw_barycenters( Ms = [dist(Ys[s], X) for s in range(len(Ys))] - if (init_C is None): + if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) # relies on partitioning of inputs @@ -1629,14 +1997,26 @@ def semirelaxed_fgw_barycenters( # then use it on graphs to expand for indices in [large_graphs_idx, small_graphs_idx]: if len(indices) > 0: - sub_T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], Ms[i], alpha, method=G0, use_target=False, - random_state=random_state, nx=nx) for i in indices] + sub_T = [ + semirelaxed_init_plan( + Cs[i], + init_C, + ps[i], + Ms[i], + alpha, + method=G0, + use_target=False, + random_state=random_state, + nx=nx, + ) + for i in indices + ] sub_Cs = [Cs[i] for i in indices] sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) init_C = update_barycenter_structure( - sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx + ) for i, idx in enumerate(indices): T[idx] = sub_T[i] @@ -1645,15 +2025,27 @@ def semirelaxed_fgw_barycenters( if len(list_init_C) == 2: init_C = update_barycenter_structure( - T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + T, Cs, lambdas, loss_fun=loss_fun, nx=nx + ) C = init_C else: C = init_C - T = [semirelaxed_init_plan( - Cs[i], C, ps[i], Ms[i], alpha, method=G0, use_target=True, - random_state=random_state, nx=nx) for i in range(S)] + T = [ + semirelaxed_init_plan( + Cs[i], + C, + ps[i], + Ms[i], + alpha, + method=G0, + use_target=True, + random_state=random_state, + nx=nx, + ) + for i in range(S) + ] - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": inner_log = False else: @@ -1662,16 +2054,15 @@ def semirelaxed_fgw_barycenters( if log: log_ = {} - if stop_criterion == 'barycenter': - log_['err_feature'] = [] - log_['err_structure'] = [] + if stop_criterion == "barycenter": + log_["err_feature"] = [] + log_["err_structure"] = [] else: - log_['loss'] = [] - log_['err_rel_loss'] = [] + log_["loss"] = [] + log_["err_rel_loss"] = [] for cpt in range(max_iter): # break if specified errors are below tol. - - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": Cprev = C Xprev = X else: @@ -1679,25 +2070,52 @@ def semirelaxed_fgw_barycenters( # get transport plans if warmstartT: - res = [semirelaxed_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, T[s], - inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) - for s in range(S)] + res = [ + semirelaxed_fused_gromov_wasserstein( + Ms[s], + Cs[s], + C, + ps[s], + loss_fun, + symmetric, + alpha, + T[s], + inner_log, + max_iter, + tol_rel=tol, + tol_abs=0.0, + **kwargs, + ) + for s in range(S) + ] else: - res = [semirelaxed_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, G0, - inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) - for s in range(S)] + res = [ + semirelaxed_fused_gromov_wasserstein( + Ms[s], + Cs[s], + C, + ps[s], + loss_fun, + symmetric, + alpha, + G0, + inner_log, + max_iter, + tol_rel=tol, + tol_abs=0.0, + **kwargs, + ) + for s in range(S) + ] - if stop_criterion == 'barycenter': + if stop_criterion == "barycenter": T = res else: T = [output[0] for output in res] - curr_loss = np.sum([output[1]['srfgw_dist'] for output in res]) + curr_loss = np.sum([output[1]["srfgw_dist"] for output in res]) # update barycenters - p = nx.concatenate( - [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) + p = nx.concatenate([nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) if not fixed_features: X = update_barycenter_feature(T, Ys, lambdas, p, nx=nx) @@ -1707,44 +2125,44 @@ def semirelaxed_fgw_barycenters( C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) # update convergence criterion - if stop_criterion == 'barycenter': - err_feature, err_structure = 0., 0. + if stop_criterion == "barycenter": + err_feature, err_structure = 0.0, 0.0 if not fixed_features: err_feature = nx.norm(X - Xprev) if not fixed_structure: err_structure = nx.norm(C - Cprev) if log: - log_['err_feature'].append(err_feature) - log_['err_structure'].append(err_structure) + log_["err_feature"].append(err_feature) + log_["err_structure"].append(err_structure) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_structure)) - print('{:5d}|{:8e}|'.format(cpt, err_feature)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_structure)) + print("{:5d}|{:8e}|".format(cpt, err_feature)) if (err_feature <= tol) or (err_structure <= tol): break else: - err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + err_rel_loss = ( + abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0.0 else np.nan + ) if log: - log_['loss'].append(curr_loss) - log_['err_rel_loss'].append(err_rel_loss) + log_["loss"].append(curr_loss) + log_["err_rel_loss"].append(err_rel_loss) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_rel_loss)) if err_rel_loss <= tol: break if log: - log_['T'] = T - log_['p'] = p - log_['Ms'] = Ms + log_["T"] = T + log_["p"] = p + log_["Ms"] = Ms return X, C, log_ else: diff --git a/ot/gromov/_unbalanced.py b/ot/gromov/_unbalanced.py index cc7b9e53c..6019c20c8 100644 --- a/ot/gromov/_unbalanced.py +++ b/ot/gromov/_unbalanced.py @@ -13,17 +13,39 @@ import ot from ot.backend import get_backend from ot.utils import list_to_array, get_parameter_pair -from ._utils import fused_unbalanced_across_spaces_cost, uot_cost_matrix, uot_parameters_and_measures +from ._utils import ( + fused_unbalanced_across_spaces_cost, + uot_cost_matrix, + uot_parameters_and_measures, +) def fused_unbalanced_across_spaces_divergence( - X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - reg_marginals=10, epsilon=0, reg_type="joint", divergence="kl", - unbalanced_solver="sinkhorn", alpha=0, M_samp=None, M_feat=None, - rescale_plan=True, init_pi=None, init_duals=None, max_iter=100, - tol=1e-7, max_iter_ot=500, tol_ot=1e-7, log=False, verbose=False, - **kwargs_solver): - + X, + Y, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + reg_marginals=10, + epsilon=0, + reg_type="joint", + divergence="kl", + unbalanced_solver="sinkhorn", + alpha=0, + M_samp=None, + M_feat=None, + rescale_plan=True, + init_pi=None, + init_duals=None, + max_iter=100, + tol=1e-7, + max_iter_ot=500, + tol_ot=1e-7, + log=False, + verbose=False, + **kwargs_solver, +): r"""Compute the fused unbalanced cross-spaces divergence between two matrices equipped with the distributions on rows and columns. We consider two cases of matrix: @@ -109,7 +131,7 @@ def fused_unbalanced_across_spaces_divergence( - If `reg_type` = "joint": then use joint regularization for couplings. - - If `reg_type` = "indepedent": then use independent regularization for couplings. + - If `reg_type` = "independent": then use independent regularization for couplings. divergence : string, optional (default = "kl") - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. @@ -195,12 +217,18 @@ def fused_unbalanced_across_spaces_divergence( if reg_type == "joint": # same regularization eps_feat = eps_samp if unbalanced_solver in ["sinkhorn", "sinkhorn_log"] and divergence == "l2": - warnings.warn("Sinkhorn algorithm does not support L2 norm. \ - Divergence is set to 'kl'.") + warnings.warn( + "Sinkhorn algorithm does not support L2 norm. \ + Divergence is set to 'kl'." + ) divergence = "kl" - if unbalanced_solver in ["sinkhorn", "sinkhorn_log"] and (eps_samp == 0 or eps_feat == 0): - warnings.warn("Sinkhorn algorithm does not support unregularized problem. \ - Solver is set to 'mm'.") + if unbalanced_solver in ["sinkhorn", "sinkhorn_log"] and ( + eps_samp == 0 or eps_feat == 0 + ): + warnings.warn( + "Sinkhorn algorithm does not support unregularized problem. \ + Solver is set to 'mm'." + ) unbalanced_solver = "mm" if init_pi is None: @@ -222,20 +250,26 @@ def fused_unbalanced_across_spaces_divergence( if d2 is not None: arr.append(list_to_array(d2)) - nx = get_backend(*arr, wx_samp, wx_feat, wy_samp, wy_feat, M_samp, M_feat, pi_samp, pi_feat) + nx = get_backend( + *arr, wx_samp, wx_feat, wy_samp, wy_feat, M_samp, M_feat, pi_samp, pi_feat + ) # constant input variables if M_samp is None: if alpha_samp > 0: - warnings.warn("M_samp is None but alpha_samp = {} > 0. \ - The algo will treat as if alpha_samp = 0.".format(alpha_samp)) + warnings.warn( + "M_samp is None but alpha_samp = {} > 0. \ + The algo will treat as if alpha_samp = 0.".format(alpha_samp) + ) else: M_samp = alpha_samp * M_samp if M_feat is None: if alpha_feat > 0: - warnings.warn("M_feat is None but alpha_feat = {} > 0. \ - The algo will treat as if alpha_feat = 0.".format(alpha_feat)) + warnings.warn( + "M_feat is None but alpha_feat = {} > 0. \ + The algo will treat as if alpha_feat = 0.".format(alpha_feat) + ) else: M_feat = alpha_feat * M_feat @@ -260,29 +294,31 @@ def fused_unbalanced_across_spaces_divergence( if unbalanced_solver in ["sinkhorn", "sinkhorn_log"]: if duals_samp is None: - duals_samp = (nx.zeros(nx_samp, type_as=X), - nx.zeros(ny_samp, type_as=Y)) + duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros(ny_samp, type_as=Y)) if duals_feat is None: - duals_feat = (nx.zeros(nx_feat, type_as=X), - nx.zeros(ny_feat, type_as=Y)) + duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros(ny_feat, type_as=Y)) # shortcut functions X_sqr, Y_sqr = X**2, Y**2 - local_cost_samp = partial(uot_cost_matrix, - data=(X_sqr, Y_sqr, X, Y, M_samp), - tuple_p=(wx_feat, wy_feat), - hyperparams=(rho_x, rho_y, eps_feat), - divergence=divergence, - reg_type=reg_type, - nx=nx) - - local_cost_feat = partial(uot_cost_matrix, - data=(X_sqr.T, Y_sqr.T, X.T, Y.T, M_feat), - tuple_p=(wx_samp, wy_samp), - hyperparams=(rho_x, rho_y, eps_samp), - divergence=divergence, - reg_type=reg_type, - nx=nx) + local_cost_samp = partial( + uot_cost_matrix, + data=(X_sqr, Y_sqr, X, Y, M_samp), + tuple_p=(wx_feat, wy_feat), + hyperparams=(rho_x, rho_y, eps_feat), + divergence=divergence, + reg_type=reg_type, + nx=nx, + ) + + local_cost_feat = partial( + uot_cost_matrix, + data=(X_sqr.T, Y_sqr.T, X.T, Y.T, M_feat), + tuple_p=(wx_samp, wy_samp), + hyperparams=(rho_x, rho_y, eps_samp), + divergence=divergence, + reg_type=reg_type, + nx=nx, + ) parameters_uot_l2_samp = partial( uot_parameters_and_measures, @@ -290,7 +326,7 @@ def fused_unbalanced_across_spaces_divergence( hyperparams=(rho_x, rho_y, eps_samp), reg_type=reg_type, divergence=divergence, - nx=nx + nx=nx, ) parameters_uot_l2_feat = partial( @@ -299,7 +335,7 @@ def fused_unbalanced_across_spaces_divergence( hyperparams=(rho_x, rho_y, eps_feat), reg_type=reg_type, divergence=divergence, - nx=nx + nx=nx, ) solver = partial( @@ -309,13 +345,12 @@ def fused_unbalanced_across_spaces_divergence( method=unbalanced_solver, max_iter=max_iter_ot, tol=tol_ot, - verbose=False + verbose=False, ) # initialize log if log: - dict_log = {"backend": nx, - "error": []} + dict_log = {"backend": nx, "error": []} for idx in range(max_iter): pi_samp_prev = nx.copy(pi_samp) @@ -332,9 +367,16 @@ def fused_unbalanced_across_spaces_divergence( new_w, new_rho, new_eps = parameters_uot_l2_feat(pi_feat) new_wx, new_wy, new_wxy = new_w - res = solver(M=uot_cost, a=new_wx, b=new_wy, - reg=new_eps, c=new_wxy, unbalanced=new_rho, - plan_init=pi_feat, potentials_init=duals_feat) + res = solver( + M=uot_cost, + a=new_wx, + b=new_wy, + reg=new_eps, + c=new_wxy, + unbalanced=new_rho, + plan_init=pi_feat, + potentials_init=duals_feat, + ) pi_feat, duals_feat = res.plan, res.potentials if rescale_plan: @@ -352,9 +394,16 @@ def fused_unbalanced_across_spaces_divergence( new_w, new_rho, new_eps = parameters_uot_l2_samp(pi_samp) new_wx, new_wy, new_wxy = new_w - res = solver(M=uot_cost, a=new_wx, b=new_wy, - reg=new_eps, c=new_wxy, unbalanced=new_rho, - plan_init=pi_samp, potentials_init=duals_samp) + res = solver( + M=uot_cost, + a=new_wx, + b=new_wy, + reg=new_eps, + c=new_wxy, + unbalanced=new_rho, + plan_init=pi_samp, + potentials_init=duals_samp, + ) pi_samp, duals_samp = res.plan, res.potentials if rescale_plan: @@ -365,14 +414,18 @@ def fused_unbalanced_across_spaces_divergence( if log: dict_log["error"].append(err) if verbose: - print('{:5d}|{:8e}|'.format(idx + 1, err)) + print("{:5d}|{:8e}|".format(idx + 1, err)) if err < tol: break # sanity check if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0: - raise (ValueError("There is NaN in coupling. \ - Adjust the relaxation or regularization parameters.")) + raise ( + ValueError( + "There is NaN in coupling. \ + Adjust the relaxation or regularization parameters." + ) + ) if log: linear_cost, ucoot_cost = fused_unbalanced_across_spaces_cost( @@ -380,11 +433,12 @@ def fused_unbalanced_across_spaces_divergence( data=(X_sqr, Y_sqr, X, Y), tuple_pxy_samp=(wx_samp, wy_samp, wxy_samp), tuple_pxy_feat=(wx_feat, wy_feat, wxy_feat), - pi_samp=pi_samp, pi_feat=pi_feat, + pi_samp=pi_samp, + pi_feat=pi_feat, hyperparams=(rho_x, rho_y, eps_samp, eps_feat), divergence=divergence, reg_type=reg_type, - nx=nx + nx=nx, ) dict_log["duals_sample"] = duals_samp @@ -399,13 +453,30 @@ def fused_unbalanced_across_spaces_divergence( def unbalanced_co_optimal_transport( - X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - reg_marginals=10, epsilon=0, divergence="kl", - unbalanced_solver="mm", alpha=0, M_samp=None, M_feat=None, - rescale_plan=True, init_pi=None, init_duals=None, - max_iter=100, tol=1e-7, max_iter_ot=500, tol_ot=1e-7, - log=False, verbose=False, **kwargs_solve): - + X, + Y, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + reg_marginals=10, + epsilon=0, + divergence="kl", + unbalanced_solver="mm", + alpha=0, + M_samp=None, + M_feat=None, + rescale_plan=True, + init_pi=None, + init_duals=None, + max_iter=100, + tol=1e-7, + max_iter_ot=500, + tol_ot=1e-7, + log=False, + verbose=False, + **kwargs_solve, +): r"""Compute the unbalanced Co-Optimal Transport between two Euclidean point clouds (represented as matrices whose rows are samples and columns are the features/dimensions). @@ -537,24 +608,58 @@ def unbalanced_co_optimal_transport( """ return fused_unbalanced_across_spaces_divergence( - X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, - wy_samp=wy_samp, wy_feat=wy_feat, reg_marginals=reg_marginals, - epsilon=epsilon, reg_type="independent", - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M_samp=M_samp, M_feat=M_feat, rescale_plan=rescale_plan, - init_pi=init_pi, init_duals=init_duals, max_iter=max_iter, tol=tol, - max_iter_ot=max_iter_ot, tol_ot=tol_ot, log=log, verbose=verbose, - **kwargs_solve) + X=X, + Y=Y, + wx_samp=wx_samp, + wx_feat=wx_feat, + wy_samp=wy_samp, + wy_feat=wy_feat, + reg_marginals=reg_marginals, + epsilon=epsilon, + reg_type="independent", + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + rescale_plan=rescale_plan, + init_pi=init_pi, + init_duals=init_duals, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=log, + verbose=verbose, + **kwargs_solve, + ) def unbalanced_co_optimal_transport2( - X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - reg_marginals=10, epsilon=0, divergence="kl", - unbalanced_solver="sinkhorn", alpha=0, M_samp=None, M_feat=None, - rescale_plan=True, init_pi=None, init_duals=None, - max_iter=100, tol=1e-7, max_iter_ot=500, tol_ot=1e-7, - log=False, verbose=False, **kwargs_solve): - + X, + Y, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + reg_marginals=10, + epsilon=0, + divergence="kl", + unbalanced_solver="sinkhorn", + alpha=0, + M_samp=None, + M_feat=None, + rescale_plan=True, + init_pi=None, + init_duals=None, + max_iter=100, + tol=1e-7, + max_iter_ot=500, + tol_ot=1e-7, + log=False, + verbose=False, + **kwargs_solve, +): r"""Compute the unbalanced Co-Optimal Transport between two Euclidean point clouds (represented as matrices whose rows are samples and columns are the features/dimensions). @@ -687,16 +792,36 @@ def unbalanced_co_optimal_transport2( """ if divergence != "kl": - warnings.warn("The computation of gradients is only supported for KL divergence, not \ - for {} divergence".format(divergence)) + warnings.warn( + "The computation of gradients is only supported for KL divergence, not \ + for {} divergence".format(divergence) + ) pi_samp, pi_feat, log_ucoot = unbalanced_co_optimal_transport( - X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp, wy_feat=wy_feat, - reg_marginals=reg_marginals, epsilon=epsilon, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, M_samp=M_samp, M_feat=M_feat, - rescale_plan=rescale_plan, init_pi=init_pi, init_duals=init_duals, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=verbose, **kwargs_solve) + X=X, + Y=Y, + wx_samp=wx_samp, + wx_feat=wx_feat, + wy_samp=wy_samp, + wy_feat=wy_feat, + reg_marginals=reg_marginals, + epsilon=epsilon, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + rescale_plan=rescale_plan, + init_pi=init_pi, + init_duals=init_duals, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=verbose, + **kwargs_solve, + ) nx = log_ucoot["backend"] @@ -725,26 +850,33 @@ def unbalanced_co_optimal_transport2( m_wy_feat, m_wy_samp = nx.sum(wy_feat), nx.sum(wy_samp) # calculate subgradients - gradX = 2 * X * (pi1_samp[:, None] * pi1_feat[None, :]) - \ - 2 * nx.dot(nx.dot(pi_samp, Y), pi_feat.T) # shape (nx_samp, nx_feat) - gradY = 2 * Y * (pi2_samp[:, None] * pi2_feat[None, :]) - \ - 2 * nx.dot(nx.dot(pi_samp.T, X), pi_feat) # shape (ny_samp, ny_feat) - - grad_wx_samp = rho_x * (m_wx_feat - m_feat * pi1_samp / wx_samp) + \ - eps_samp * (m_wy_samp - pi1_samp / wx_samp) - grad_wx_feat = rho_x * (m_wx_samp - m_samp * pi1_feat / wx_feat) + \ - eps_feat * (m_wy_feat - pi1_feat / wx_feat) - grad_wy_samp = rho_y * (m_wy_feat - m_feat * pi2_samp / wy_samp) + \ - eps_samp * (m_wx_samp - pi2_samp / wy_samp) - grad_wy_feat = rho_y * (m_wy_samp - m_samp * pi2_feat / wy_feat) + \ - eps_feat * (m_wx_feat - pi2_feat / wy_feat) + gradX = 2 * X * (pi1_samp[:, None] * pi1_feat[None, :]) - 2 * nx.dot( + nx.dot(pi_samp, Y), pi_feat.T + ) # shape (nx_samp, nx_feat) + gradY = 2 * Y * (pi2_samp[:, None] * pi2_feat[None, :]) - 2 * nx.dot( + nx.dot(pi_samp.T, X), pi_feat + ) # shape (ny_samp, ny_feat) + + grad_wx_samp = rho_x * (m_wx_feat - m_feat * pi1_samp / wx_samp) + eps_samp * ( + m_wy_samp - pi1_samp / wx_samp + ) + grad_wx_feat = rho_x * (m_wx_samp - m_samp * pi1_feat / wx_feat) + eps_feat * ( + m_wy_feat - pi1_feat / wx_feat + ) + grad_wy_samp = rho_y * (m_wy_feat - m_feat * pi2_samp / wy_samp) + eps_samp * ( + m_wx_samp - pi2_samp / wy_samp + ) + grad_wy_feat = rho_y * (m_wy_samp - m_samp * pi2_feat / wy_feat) + eps_feat * ( + m_wx_feat - pi2_feat / wy_feat + ) # set gradients ucoot = log_ucoot["ucoot_cost"] - ucoot = nx.set_gradients(ucoot, - (X, Y, wx_samp, wx_feat, wy_samp, wy_feat), - (gradX, gradY, grad_wx_samp, grad_wx_feat, grad_wy_samp, grad_wy_feat) - ) + ucoot = nx.set_gradients( + ucoot, + (X, Y, wx_samp, wx_feat, wy_samp, wy_feat), + (gradX, gradY, grad_wx_samp, grad_wx_feat, grad_wy_samp, grad_wy_feat), + ) if log: return ucoot, log_ucoot @@ -754,12 +886,26 @@ def unbalanced_co_optimal_transport2( def fused_unbalanced_gromov_wasserstein( - Cx, Cy, wx=None, wy=None, reg_marginals=10, epsilon=0, - divergence="kl", unbalanced_solver="mm", alpha=0, - M=None, init_duals=None, init_pi=None, max_iter=100, - tol=1e-7, max_iter_ot=500, tol_ot=1e-7, - log=False, verbose=False, **kwargs_solve): - + Cx, + Cy, + wx=None, + wy=None, + reg_marginals=10, + epsilon=0, + divergence="kl", + unbalanced_solver="mm", + alpha=0, + M=None, + init_duals=None, + init_pi=None, + max_iter=100, + tol=1e-7, + max_iter_ot=500, + tol_ot=1e-7, + log=False, + verbose=False, + **kwargs_solve, +): r"""Compute the lower bound of the fused unbalanced Gromov-Wasserstein (FUGW) between two similarity matrices. In practice, this lower bound is used interchangeably with the true FUGW. @@ -881,22 +1027,40 @@ def fused_unbalanced_gromov_wasserstein( alpha = (alpha / 2, alpha / 2) pi_samp, pi_feat, dict_log = fused_unbalanced_across_spaces_divergence( - X=Cx, Y=Cy, wx_samp=wx, wx_feat=wx, wy_samp=wy, wy_feat=wy, - reg_marginals=reg_marginals, epsilon=epsilon, reg_type="joint", - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M_samp=M, M_feat=M, rescale_plan=True, + X=Cx, + Y=Cy, + wx_samp=wx, + wx_feat=wx, + wy_samp=wy, + wy_feat=wy, + reg_marginals=reg_marginals, + epsilon=epsilon, + reg_type="joint", + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M, + M_feat=M, + rescale_plan=True, init_pi=(init_pi, init_pi), - init_duals=(init_duals, init_duals), max_iter=max_iter, tol=tol, - max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=verbose, **kwargs_solve + init_duals=(init_duals, init_duals), + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=verbose, + **kwargs_solve, ) if log: - log_fugw = {"error": dict_log["error"], - "duals": dict_log["duals_sample"], - "linear_cost": dict_log["linear_cost"], - "fugw_cost": dict_log["ucoot_cost"], - "backend": dict_log["backend"]} + log_fugw = { + "error": dict_log["error"], + "duals": dict_log["duals_sample"], + "linear_cost": dict_log["linear_cost"], + "fugw_cost": dict_log["ucoot_cost"], + "backend": dict_log["backend"], + } return pi_samp, pi_feat, log_fugw @@ -905,12 +1069,26 @@ def fused_unbalanced_gromov_wasserstein( def fused_unbalanced_gromov_wasserstein2( - Cx, Cy, wx=None, wy=None, reg_marginals=10, epsilon=0, - divergence="kl", unbalanced_solver="mm", alpha=0, - M=None, init_duals=None, init_pi=None, max_iter=100, - tol=1e-7, max_iter_ot=500, tol_ot=1e-7, - log=False, verbose=False, **kwargs_solve): - + Cx, + Cy, + wx=None, + wy=None, + reg_marginals=10, + epsilon=0, + divergence="kl", + unbalanced_solver="mm", + alpha=0, + M=None, + init_duals=None, + init_pi=None, + max_iter=100, + tol=1e-7, + max_iter_ot=500, + tol_ot=1e-7, + log=False, + verbose=False, + **kwargs_solve, +): r"""Compute the lower bound of the fused unbalanced Gromov-Wasserstein (FUGW) between two similarity matrices. In practice, this lower bound is used interchangeably with the true FUGW. @@ -1027,16 +1205,33 @@ def fused_unbalanced_gromov_wasserstein2( """ if divergence != "kl": - warnings.warn("The computation of gradients is only supported for KL divergence, \ - but not for {} divergence. The gradient of the KL case will be used.".format(divergence)) + warnings.warn( + "The computation of gradients is only supported for KL divergence, \ + but not for {} divergence. The gradient of the KL case will be used.".format( + divergence + ) + ) pi_samp, pi_feat, log_fugw = fused_unbalanced_gromov_wasserstein( - Cx=Cx, Cy=Cy, wx=wx, wy=wy, reg_marginals=reg_marginals, - epsilon=epsilon, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, M=M, - init_duals=init_duals, init_pi=init_pi, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, - tol_ot=tol_ot, log=True, verbose=verbose, **kwargs_solve + Cx=Cx, + Cy=Cy, + wx=wx, + wy=wy, + reg_marginals=reg_marginals, + epsilon=epsilon, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M, + init_duals=init_duals, + init_pi=init_pi, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=verbose, + **kwargs_solve, ) nx = log_fugw["backend"] @@ -1055,23 +1250,30 @@ def fused_unbalanced_gromov_wasserstein2( m_wx, m_wy = nx.sum(wx), nx.sum(wy) # calculate subgradients - gradX = 2 * Cx * (pi1_samp[:, None] * pi1_feat[None, :]) - \ - 2 * nx.dot(nx.dot(pi_samp, Cy), pi_feat.T) # shape (nx_samp, nx_feat) - gradY = 2 * Cy * (pi2_samp[:, None] * pi2_feat[None, :]) - \ - 2 * nx.dot(nx.dot(pi_samp.T, Cx), pi_feat) # shape (ny_samp, ny_feat) + gradX = 2 * Cx * (pi1_samp[:, None] * pi1_feat[None, :]) - 2 * nx.dot( + nx.dot(pi_samp, Cy), pi_feat.T + ) # shape (nx_samp, nx_feat) + gradY = 2 * Cy * (pi2_samp[:, None] * pi2_feat[None, :]) - 2 * nx.dot( + nx.dot(pi_samp.T, Cx), pi_feat + ) # shape (ny_samp, ny_feat) gradM = alpha / 2 * (pi_samp + pi_feat) rho_x, rho_y = get_parameter_pair(reg_marginals) - grad_wx = 2 * m_wx * (rho_x + epsilon * m_wy**2) - \ - (rho_x + epsilon) * (m_feat * pi1_samp + m_samp * pi1_feat) / wx - grad_wy = 2 * m_wy * (rho_y + epsilon * m_wx**2) - \ - (rho_y + epsilon) * (m_feat * pi2_samp + m_samp * pi2_feat) / wy + grad_wx = ( + 2 * m_wx * (rho_x + epsilon * m_wy**2) + - (rho_x + epsilon) * (m_feat * pi1_samp + m_samp * pi1_feat) / wx + ) + grad_wy = ( + 2 * m_wy * (rho_y + epsilon * m_wx**2) + - (rho_y + epsilon) * (m_feat * pi2_samp + m_samp * pi2_feat) / wy + ) # set gradients fugw = log_fugw["fugw_cost"] - fugw = nx.set_gradients(fugw, (Cx, Cy, M, wx, wy), - (gradX, gradY, gradM, grad_wx, grad_wy)) + fugw = nx.set_gradients( + fugw, (Cx, Cy, M, wx, wy), (gradX, gradY, gradM, grad_wx, grad_wy) + ) if log: return fugw, log_fugw diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 31cd0fd90..79afaed36 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -12,7 +12,6 @@ # # License: MIT License - from ..utils import list_to_array, euclidean_distances from ..backend import get_backend from ..lp import emd @@ -20,12 +19,14 @@ try: from networkx.algorithms.community import asyn_fluidc from networkx import from_numpy_array + networkx_import = True except ImportError: networkx_import = False try: from sklearn.cluster import SpectralClustering, KMeans + sklearn_import = True except ImportError: sklearn_import = False @@ -34,7 +35,7 @@ import warnings -def _transform_matrix(C1, C2, loss_fun='square_loss', nx=None): +def _transform_matrix(C1, C2, loss_fun="square_loss", nx=None): r"""Return transformed structure matrices for Gromov-Wasserstein fast computation Returns the matrices involved in the computation of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2})` @@ -110,19 +111,21 @@ def _transform_matrix(C1, C2, loss_fun='square_loss', nx=None): C1, C2 = list_to_array(C1, C2) nx = get_backend(C1, C2) - if loss_fun == 'square_loss': + if loss_fun == "square_loss": + def f1(a): - return (a**2) + return a**2 def f2(b): - return (b**2) + return b**2 def h1(a): return a def h2(b): return 2 * b - elif loss_fun == 'kl_loss': + elif loss_fun == "kl_loss": + def f1(a): return a * nx.log(a + 1e-18) - a @@ -135,7 +138,9 @@ def h1(a): def h2(b): return nx.log(b + 1e-18) else: - raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") + raise ValueError( + f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}." + ) fC1 = f1(C1) fC2 = f2(C2) @@ -145,7 +150,7 @@ def h2(b): return fC1, fC2, hC1, hC2 -def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): +def init_matrix(C1, C2, p, q, loss_fun="square_loss", nx=None): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the @@ -226,12 +231,10 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) constC1 = nx.dot( - nx.dot(fC1, nx.reshape(p, (-1, 1))), - nx.ones((1, len(q)), type_as=q) + nx.dot(fC1, nx.reshape(p, (-1, 1))), nx.ones((1, len(q)), type_as=q) ) constC2 = nx.dot( - nx.ones((len(p), 1), type_as=p), - nx.dot(nx.reshape(q, (1, -1)), fC2.T) + nx.ones((len(p), 1), type_as=p), nx.dot(nx.reshape(q, (1, -1)), fC2.T) ) constC = constC1 + constC2 @@ -271,9 +274,7 @@ def tensor_product(constC, hC1, hC2, T, nx=None): constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T) nx = get_backend(constC, hC1, hC2, T) - A = - nx.dot( - nx.dot(hC1, T), hC2.T - ) + A = -nx.dot(nx.dot(hC1, T), hC2.T) tens = constC + A # tens -= tens.min() return tens @@ -350,11 +351,10 @@ def gwggrad(constC, hC1, hC2, T, nx=None): International Conference on Machine Learning (ICML). 2016. """ - return 2 * tensor_product(constC, hC1, hC2, - T, nx) # [12] Prop. 2 misses a 2 factor + return 2 * tensor_product(constC, hC1, hC2, T, nx) # [12] Prop. 2 misses a 2 factor -def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): +def init_matrix_semirelaxed(C1, C2, p, loss_fun="square_loss", nx=None): r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the @@ -438,15 +438,25 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx) - constC = nx.dot(nx.dot(fC1, nx.reshape(p, (-1, 1))), - nx.ones((1, C2.shape[0]), type_as=p)) + constC = nx.dot( + nx.dot(fC1, nx.reshape(p, (-1, 1))), nx.ones((1, C2.shape[0]), type_as=p) + ) fC2t = fC2.T return constC, hC1, hC2, fC2t -def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', - use_target=True, random_state=0, nx=None): +def semirelaxed_init_plan( + C1, + C2, + p, + M=None, + alpha=1.0, + method="product", + use_target=True, + random_state=0, + nx=None, +): r""" Heuristics to initialize the semi-relaxed (F)GW transport plan :math:`\mathbf{T} \in \mathcal{U}_{nt}(\mathbf{p})`, between a graph @@ -469,7 +479,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', If a metric cost matrix between features across domains :math:`\mathbf{M}` is a provided, it will be used as cost matrix in a semi-relaxed Wasserstein problem providing :math:`\mathbf{T}_M \in \mathcal{U}_{nt}(\mathbf{p})`. Then - the outputed transport plan is :math:`\alpha \mathbf{T} + (1 - \alpha ) \mathbf{T}_{M}`. + the outputted transport plan is :math:`\alpha \mathbf{T} + (1 - \alpha ) \mathbf{T}_{M}`. Parameters ---------- @@ -507,17 +517,26 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', """ list_partitioning_methods = [ - 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', - 'kmeans_soft'] - - if method not in list_partitioning_methods + ['product', 'random_product', 'random']: - raise ValueError(f'Unsupported initialization method = {method}.') - - if (method in ['kmeans', 'kmeans_soft']) and (not sklearn_import): - raise ValueError(f'Scikit-learn must be installed to use method = {method}') - - if (method in ['fluid', 'fluid_soft']) and (not networkx_import): - raise ValueError(f'Networkx must be installed to use method = {method}') + "fluid", + "spectral", + "kmeans", + "fluid_soft", + "spectral_soft", + "kmeans_soft", + ] + + if method not in list_partitioning_methods + [ + "product", + "random_product", + "random", + ]: + raise ValueError(f"Unsupported initialization method = {method}.") + + if (method in ["kmeans", "kmeans_soft"]) and (not sklearn_import): + raise ValueError(f"Scikit-learn must be installed to use method = {method}") + + if (method in ["fluid", "fluid_soft"]) and (not networkx_import): + raise ValueError(f"Networkx must be installed to use method = {method}") if nx is None: nx = get_backend(C1, C2, p, M) @@ -537,7 +556,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', "Both structures have the same size so no partitioning is" "performed to initialize the transport plan even though" f"initialization method is {method}", - stacklevel=2 + stacklevel=2, ) def get_transport_from_partition(part): @@ -552,7 +571,7 @@ def get_transport_from_partition(part): if use_target: M_structure = euclidean_distances(factored_C1, C2) T_emd = emd(q, q, M_structure) - inv_q = 1. / q + inv_q = 1.0 / q T = nx.dot(T_, inv_q[:, None] * T_emd) else: @@ -567,7 +586,7 @@ def get_transport_from_partition(part): # alignment of both structure seen as feature matrices M_structure = euclidean_distances(factored_C2, C1) T_emd = emd(q, p, M_structure) - inv_q = 1. / q + inv_q = 1.0 / q T = nx.dot(T_, inv_q[:, None] * T_emd).T q = nx.sum(T, 0) # uniform one @@ -581,24 +600,24 @@ def get_transport_from_partition(part): # Handle initialization via structure information - if method == 'product': + if method == "product": q = nx.ones(m, type_as=C1) / m T = nx.outer(p, q) - elif method == 'random_product': + elif method == "random_product": np.random.seed(random_state) q = np.random.uniform(0, m, size=(m,)) q = q / q.sum() q = nx.from_numpy(q, type_as=p) T = nx.outer(p, q) - elif method == 'random': + elif method == "random": np.random.seed(random_state) U = np.random.uniform(0, n * m, size=(n, m)) U = (p / U.sum(1))[:, None] * U T = nx.from_numpy(U, type_as=C1) - elif method in ['fluid', 'fluid_soft']: + elif method in ["fluid", "fluid_soft"]: # compute fluid partitioning on the biggest graph if C_to_partition is None: T, q = get_transport_from_partition(None) @@ -613,38 +632,39 @@ def get_transport_from_partition(part): T, q = get_transport_from_partition(part) - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. + if "soft" in method: + T = (T + nx.outer(p, q)) / 2.0 - elif method in ['spectral', 'spectral_soft']: + elif method in ["spectral", "spectral_soft"]: # compute spectral partitioning on the biggest graph if C_to_partition is None: T, q = get_transport_from_partition(None) else: - sc = SpectralClustering(n_clusters=min_size, - random_state=random_state, - affinity='precomputed').fit(C_to_partition) + sc = SpectralClustering( + n_clusters=min_size, random_state=random_state, affinity="precomputed" + ).fit(C_to_partition) part = sc.labels_ T, q = get_transport_from_partition(part) - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. + if "soft" in method: + T = (T + nx.outer(p, q)) / 2.0 - elif method in ['kmeans', 'kmeans_soft']: + elif method in ["kmeans", "kmeans_soft"]: # compute spectral partitioning on the biggest graph if C_to_partition is None: T, q = get_transport_from_partition(None) else: - km = KMeans(n_clusters=min_size, random_state=random_state, - n_init=1).fit(C_to_partition) + km = KMeans(n_clusters=min_size, random_state=random_state, n_init=1).fit( + C_to_partition + ) part = km.labels_ T, q = get_transport_from_partition(part) - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. + if "soft" in method: + T = (T + nx.outer(p, q)) / 2.0 - if (M is not None): + if M is not None: # Add feature information solving a semi-relaxed Wasserstein problem # get minimum by rows as binary mask TM = nx.ones(1, type_as=p) * (M == nx.reshape(nx.min(M, axis=1), (-1, 1))) @@ -656,8 +676,15 @@ def get_transport_from_partition(part): def update_barycenter_structure( - Ts, Cs, lambdas, p=None, loss_fun='square_loss', target=True, - check_zeros=True, nx=None): + Ts, + Cs, + lambdas, + p=None, + loss_fun="square_loss", + target=True, + check_zeros=True, + nx=None, +): r""" Updates :math:`\mathbf{C}` according to the inner loss L with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration of variants of @@ -729,49 +756,53 @@ def update_barycenter_structure( if p is None: p = nx.concatenate( - [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], - axis=0) + [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], axis=0 + ) # compute coefficients for the barycenter coming from marginals if len(p.shape) == 1: # shared target masses potentially with zeros if check_zeros: - inv_p = nx.nan_to_num(1. / p, nan=1., posinf=1., neginf=1.) + inv_p = nx.nan_to_num(1.0 / p, nan=1.0, posinf=1.0, neginf=1.0) else: - inv_p = 1. / p + inv_p = 1.0 / p prod = nx.outer(inv_p, inv_p) else: quotient = sum([lambdas[s] * nx.outer(p[s], p[s]) for s in range(S)]) if check_zeros: - prod = nx.nan_to_num(1. / quotient, nan=1., posinf=1., neginf=1.) + prod = nx.nan_to_num(1.0 / quotient, nan=1.0, posinf=1.0, neginf=1.0) else: - prod = 1. / quotient + prod = 1.0 / quotient # compute coefficients for the barycenter coming from Ts and Cs - if loss_fun == 'square_loss': + if loss_fun == "square_loss": if target: - list_structures = [lambdas[s] * nx.dot( - nx.dot(Ts[s].T, Cs[s]), Ts[s]) for s in range(S)] + list_structures = [ + lambdas[s] * nx.dot(nx.dot(Ts[s].T, Cs[s]), Ts[s]) for s in range(S) + ] else: - list_structures = [lambdas[s] * nx.dot( - nx.dot(Ts[s], Cs[s]), Ts[s].T) for s in range(S)] + list_structures = [ + lambdas[s] * nx.dot(nx.dot(Ts[s], Cs[s]), Ts[s].T) for s in range(S) + ] return sum(list_structures) * prod - elif loss_fun == 'kl_loss': + elif loss_fun == "kl_loss": if target: - list_structures = [lambdas[s] * nx.dot( - nx.dot(Ts[s].T, Cs[s]), Ts[s]) - for s in range(S)] + list_structures = [ + lambdas[s] * nx.dot(nx.dot(Ts[s].T, Cs[s]), Ts[s]) for s in range(S) + ] return sum(list_structures) * prod else: - list_structures = [lambdas[s] * nx.dot( - nx.dot(Ts[s], nx.log(nx.maximum(Cs[s], 1e-16))), Ts[s].T) - for s in range(S)] + list_structures = [ + lambdas[s] + * nx.dot(nx.dot(Ts[s], nx.log(nx.maximum(Cs[s], 1e-16))), Ts[s].T) + for s in range(S) + ] return nx.exp(sum(list_structures) * prod) @@ -780,8 +811,15 @@ def update_barycenter_structure( def update_barycenter_feature( - Ts, Ys, lambdas, p=None, loss_fun='square_loss', target=True, - check_zeros=True, nx=None): + Ts, + Ys, + lambdas, + p=None, + loss_fun="square_loss", + target=True, + check_zeros=True, + nx=None, +): r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration of variants of the FGW barycenter problem with inner wasserstein loss `loss_fun` @@ -826,7 +864,7 @@ def update_barycenter_feature( arr = [*Ts, *Ys, p] nx = get_backend(*arr) - if loss_fun != 'square_loss': + if loss_fun != "square_loss": raise ValueError(f"not supported loss_fun = {loss_fun}") S = len(Ts) @@ -838,20 +876,20 @@ def update_barycenter_feature( if p is None: p = nx.concatenate( - [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], - axis=0) + [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], axis=0 + ) if len(p.shape) == 1: # shared target masses potentially with zeros if check_zeros: - inv_p = nx.nan_to_num(1. / p, nan=1., posinf=1., neginf=1.) + inv_p = nx.nan_to_num(1.0 / p, nan=1.0, posinf=1.0, neginf=1.0) else: - inv_p = 1. / p + inv_p = 1.0 / p else: p_sum = sum([lambdas[s] * p[s] for s in range(S)]) if check_zeros: - inv_p = nx.nan_to_num(1. / p_sum, nan=1., posinf=1., neginf=1.) + inv_p = nx.nan_to_num(1.0 / p_sum, nan=1.0, posinf=1.0, neginf=1.0) else: - inv_p = 1. / p_sum + inv_p = 1.0 / p_sum return sum(list_features) * inv_p[:, None] @@ -860,6 +898,7 @@ def update_barycenter_feature( # Methods related to fused unbalanced GW and unbalanced Co-Optimal Transport. ############################################################################ + def div_to_product(pi, a, b, pi1=None, pi2=None, divergence="kl", mass=True, nx=None): r"""Fast computation of the Bregman divergence between an arbitrary measure and a product measure. Only support for Kullback-Leibler and half-squared L2 divergences. @@ -918,20 +957,23 @@ def div_to_product(pi, a, b, pi1=None, pi2=None, divergence="kl", mass=True, nx= nx = get_backend(*arr, pi1, pi2) if divergence == "kl": - if pi1 is None: pi1 = nx.sum(pi, 1) if pi2 is None: pi2 = nx.sum(pi, 0) - res = nx.sum(pi * nx.log(pi + 1.0 * (pi == 0))) \ - - nx.sum(pi1 * nx.log(a)) - nx.sum(pi2 * nx.log(b)) + res = ( + nx.sum(pi * nx.log(pi + 1.0 * (pi == 0))) + - nx.sum(pi1 * nx.log(a)) + - nx.sum(pi2 * nx.log(b)) + ) if mass: res = res - nx.sum(pi1) + nx.sum(a) * nx.sum(b) elif divergence == "l2": - res = (nx.sum(pi**2) + nx.sum(a**2) * nx.sum(b**2) - - 2 * nx.dot(a, nx.dot(pi, b))) / 2 + res = ( + nx.sum(pi**2) + nx.sum(a**2) * nx.sum(b**2) - 2 * nx.dot(a, nx.dot(pi, b)) + ) / 2 return res @@ -985,11 +1027,18 @@ def div_between_product(mu, nu, alpha, beta, divergence, nx=None): m_mu, m_nu = nx.sum(mu), nx.sum(nu) m_alpha, m_beta = nx.sum(alpha), nx.sum(beta) const = (m_mu - m_alpha) * (m_nu - m_beta) - res = m_nu * nx.kl_div(mu, alpha, mass=True) + m_mu * nx.kl_div(nu, beta, mass=True) + const + res = ( + m_nu * nx.kl_div(mu, alpha, mass=True) + + m_mu * nx.kl_div(nu, beta, mass=True) + + const + ) elif divergence == "l2": - res = (nx.sum(alpha**2) * nx.sum(beta**2) - 2 * nx.sum(alpha * mu) * nx.sum(beta * nu) - + nx.sum(mu**2) * nx.sum(nu**2)) / 2 + res = ( + nx.sum(alpha**2) * nx.sum(beta**2) + - 2 * nx.sum(alpha * mu) * nx.sum(beta * nu) + + nx.sum(mu**2) * nx.sum(nu**2) + ) / 2 return res @@ -1056,13 +1105,16 @@ def uot_cost_matrix(data, pi, tuple_p, hyperparams, divergence, reg_type, nx=Non if rho_y != float("inf") and rho_y != 0: uot_cost = uot_cost + rho_y * nx.kl_div(pi2, b, mass=False) if reg_type == "joint" and eps > 0: - uot_cost = uot_cost + eps * div_to_product(pi, a, b, pi1, pi2, - divergence, mass=False, nx=nx) + uot_cost = uot_cost + eps * div_to_product( + pi, a, b, pi1, pi2, divergence, mass=False, nx=nx + ) return uot_cost -def uot_parameters_and_measures(pi, tuple_weights, hyperparams, reg_type, divergence, nx): +def uot_parameters_and_measures( + pi, tuple_weights, hyperparams, reg_type, divergence, nx +): r"""The Block Coordinate Descent algorithm for FUGW and UCOOT requires solving an UOT problem in each iteration. In particular, we need to specify the following inputs: @@ -1126,8 +1178,18 @@ def uot_parameters_and_measures(pi, tuple_weights, hyperparams, reg_type, diverg return weighted_w, new_rho, new_eps -def fused_unbalanced_across_spaces_cost(M_linear, data, tuple_pxy_samp, tuple_pxy_feat, - pi_samp, pi_feat, hyperparams, divergence, reg_type, nx): +def fused_unbalanced_across_spaces_cost( + M_linear, + data, + tuple_pxy_samp, + tuple_pxy_feat, + pi_samp, + pi_feat, + hyperparams, + divergence, + reg_type, + nx, +): r"""Return the fused unbalanced across-space divergence between two spaces Parameters @@ -1186,26 +1248,43 @@ def fused_unbalanced_across_spaces_cost(M_linear, data, tuple_pxy_samp, tuple_px ucoot_cost = ucoot_cost + nx.sum(pi_feat * M_feat) if rho_x != float("inf") and rho_x != 0: - ucoot_cost = ucoot_cost + \ - rho_x * div_between_product(pi1_samp, pi1_feat, - px_samp, px_feat, divergence, nx) + ucoot_cost = ucoot_cost + rho_x * div_between_product( + pi1_samp, pi1_feat, px_samp, px_feat, divergence, nx + ) if rho_y != float("inf") and rho_y != 0: - ucoot_cost = ucoot_cost + \ - rho_y * div_between_product(pi2_samp, pi2_feat, - py_samp, py_feat, divergence, nx) + ucoot_cost = ucoot_cost + rho_y * div_between_product( + pi2_samp, pi2_feat, py_samp, py_feat, divergence, nx + ) if reg_type == "joint" and eps_samp != 0: - div_cost = div_between_product(pi_samp, pi_feat, - pxy_samp, pxy_feat, divergence, nx) + div_cost = div_between_product( + pi_samp, pi_feat, pxy_samp, pxy_feat, divergence, nx + ) ucoot_cost = ucoot_cost + eps_samp * div_cost elif reg_type == "independent": if eps_samp != 0: - div_samp = div_to_product(pi_samp, pi1_samp, pi2_samp, - px_samp, py_samp, divergence, mass=True, nx=nx) + div_samp = div_to_product( + pi_samp, + pi1_samp, + pi2_samp, + px_samp, + py_samp, + divergence, + mass=True, + nx=nx, + ) ucoot_cost = ucoot_cost + eps_samp * div_samp if eps_feat != 0: - div_feat = div_to_product(pi_feat, pi1_feat, pi2_feat, - px_feat, py_feat, divergence, mass=True, nx=nx) + div_feat = div_to_product( + pi_feat, + pi1_feat, + pi2_feat, + px_feat, + py_feat, + divergence, + mass=True, + nx=nx, + ) ucoot_cost = ucoot_cost + eps_feat * div_feat return linear_cost, ucoot_cost diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py index 90a2918da..1b66aa0df 100644 --- a/ot/helpers/openmp_helpers.py +++ b/ot/helpers/openmp_helpers.py @@ -3,7 +3,6 @@ # This code is adapted for a large part from the astropy openmp helpers, which # can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa - import os import sys import textwrap @@ -17,22 +16,22 @@ def get_openmp_flag(compiler): """Get openmp flags for a given compiler""" - if hasattr(compiler, 'compiler'): + if hasattr(compiler, "compiler"): compiler = compiler.compiler[0] else: compiler = compiler.__class__.__name__ - if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): - omp_flag = ['/Qopenmp'] + if sys.platform == "win32" and ("icc" in compiler or "icl" in compiler): + omp_flag = ["/Qopenmp"] elif sys.platform == "win32": - omp_flag = ['/openmp'] + omp_flag = ["/openmp"] elif sys.platform in ("darwin", "linux") and "icc" in compiler: - omp_flag = ['-qopenmp'] - elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + omp_flag = ["-qopenmp"] + elif sys.platform == "darwin" and "openmp" in os.getenv("CPPFLAGS", ""): omp_flag = [] else: # Default flag for GCC and clang: - omp_flag = ['-fopenmp'] + omp_flag = ["-fopenmp"] if sys.platform.startswith("darwin"): omp_flag += ["-Xpreprocessor", "-lomp"] return omp_flag @@ -50,26 +49,27 @@ def check_openmp_support(): printf("nthreads=%d\\n", omp_get_num_threads()); return 0; } - """) + """ + ) - extra_preargs = os.getenv('LDFLAGS', None) + extra_preargs = os.getenv("LDFLAGS", None) if extra_preargs is not None: extra_preargs = extra_preargs.strip().split(" ") extra_preargs = [ - flag for flag in extra_preargs - if flag.startswith(('-L', '-Wl,-rpath', '-l'))] + flag + for flag in extra_preargs + if flag.startswith(("-L", "-Wl,-rpath", "-l")) + ] extra_postargs = get_openmp_flag try: output, compile_flags = compile_test_program( - code, - extra_preargs=extra_preargs, - extra_postargs=extra_postargs + code, extra_preargs=extra_preargs, extra_postargs=extra_postargs ) - if output and 'nthreads=' in output[0]: - nthreads = int(output[0].strip().split('=')[1]) + if output and "nthreads=" in output[0]: + nthreads = int(output[0].strip().split("=")[1]) openmp_supported = len(output) == nthreads elif "PYTHON_CROSSENV" in os.environ: # Since we can't run the test program when cross-compiling diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py index 2930036b7..51b231c42 100644 --- a/ot/helpers/pre_build_helpers.py +++ b/ot/helpers/pre_build_helpers.py @@ -26,35 +26,37 @@ def compile_test_program(code, extra_preargs=[], extra_postargs=[]): if callable(extra_postargs): extra_postargs = extra_postargs(ccompiler) - start_dir = os.path.abspath('.') + start_dir = os.path.abspath(".") with tempfile.TemporaryDirectory() as tmp_dir: try: os.chdir(tmp_dir) # Write test program - with open('test_program.c', 'w') as f: + with open("test_program.c", "w") as f: f.write(code) - os.mkdir('objects') + os.mkdir("objects") # Compile, test program - ccompiler.compile(['test_program.c'], output_dir='objects', - extra_postargs=extra_postargs) + ccompiler.compile( + ["test_program.c"], output_dir="objects", extra_postargs=extra_postargs + ) # Link test program - objects = glob.glob( - os.path.join('objects', '*' + ccompiler.obj_extension)) - ccompiler.link_executable(objects, 'test_program', - extra_preargs=extra_preargs, - extra_postargs=extra_postargs) + objects = glob.glob(os.path.join("objects", "*" + ccompiler.obj_extension)) + ccompiler.link_executable( + objects, + "test_program", + extra_preargs=extra_preargs, + extra_postargs=extra_postargs, + ) if "PYTHON_CROSSENV" not in os.environ: # Run test program if not cross compiling # will raise a CalledProcessError if return code was non-zero - output = subprocess.check_output('./test_program') - output = output.decode( - sys.stdout.encoding or 'utf-8').splitlines() + output = subprocess.check_output("./test_program") + output = output.decode(sys.stdout.encoding or "utf-8").splitlines() else: # Return an empty output if we are cross compiling # as we cannot run the test_program diff --git a/ot/lowrank.py b/ot/lowrank.py index a06c1aaa1..14a10f163 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -6,7 +6,6 @@ # # License: MIT License - import warnings from .utils import unif, dist, get_lowrank_lazytensor from .backend import get_backend @@ -15,6 +14,7 @@ # test if sklearn is installed for linux-minimal-deps try: import sklearn.cluster + sklearn_import = True except ImportError: sklearn_import = False @@ -119,7 +119,9 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=Non g = nx.ones(rank, type_as=X_s) / rank # Init Q - kmeans_Xs = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto") + kmeans_Xs = sklearn.cluster.KMeans( + n_clusters=rank, random_state=random_state, n_init="auto" + ) kmeans_Xs.fit(X_s) Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_) C_Xs = dist(X_s, Z_Xs) # shape (ns, rank) @@ -127,7 +129,9 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=Non Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3) # Init R - kmeans_Xt = sklearn.cluster.KMeans(n_clusters=rank, random_state=random_state, n_init="auto") + kmeans_Xt = sklearn.cluster.KMeans( + n_clusters=rank, random_state=random_state, n_init="auto" + ) kmeans_Xt.fit(X_t) Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_) C_Xt = dist(X_t, Z_Xt) # shape (nt, rank) @@ -135,7 +139,9 @@ def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=Non R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3) else: - raise ImportError("Scikit-learn should be installed to use the 'kmeans' init.") + raise ImportError( + "Scikit-learn should be installed to use the 'kmeans' init." + ) return Q, R, g @@ -250,7 +256,10 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N r = len(eps3) # rank g_ = nx.copy(eps3) # \tilde{g} q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 - v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} + v1_, v2_ = ( + nx.ones(r, type_as=p1), + nx.ones(r, type_as=p1), + ) # \tilde{v}^{(1)}, \tilde{v}^{(2)} q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} err = 1 # initial error @@ -309,9 +318,24 @@ def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=N return Q, R, g -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, rescale_cost=True, - init="random", reg_init=1e-1, seed_init=49, gamma_init="rescale", - numItermax=2000, stopThr=1e-7, warn=True, log=False): +def lowrank_sinkhorn( + X_s, + X_t, + a=None, + b=None, + reg=0, + rank=None, + alpha=1e-10, + rescale_cost=True, + init="random", + reg_init=1e-1, + seed_init=49, + gamma_init="rescale", + numItermax=2000, + stopThr=1e-7, + warn=True, + log=False, +): r""" Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the couplings. @@ -412,8 +436,11 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re # Dykstra won't converge if 1/rank < alpha (see Section 3.2) if 1 / r < alpha: - raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( - a=alpha, r=1 / rank)) + raise ValueError( + "alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( + a=alpha, r=1 / rank + ) + ) # Low rank decomposition of the sqeuclidean cost matrix M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx) @@ -424,13 +451,15 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re # Gamma initialization if gamma_init == "theory": L = nx.sqrt( - 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + - (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 + 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 ) gamma = 1 / (2 * L) if gamma_init not in ["rescale", "theory"]: - raise (NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))) + raise ( + NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init)) + ) # -------------------------- Low rank algorithm ------------------------------ # see "Section 3.3, Algorithm 3 LOT" diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 752c5d2d7..2b93e84f3 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,10 +8,6 @@ # # License: MIT License -import os -import multiprocessing -import sys - import numpy as np import warnings @@ -21,18 +17,35 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein_circle, - semidiscrete_wasserstein2_unif_circle) +from .solver_1d import ( + emd_1d, + emd2_1d, + wasserstein_1d, + binary_search_circle, + wasserstein_circle, + semidiscrete_wasserstein2_unif_circle, +) from ..utils import dist, list_to_array -from ..utils import parmap from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', - 'dmmot_monge_1dgrid_loss', 'dmmot_monge_1dgrid_optimize'] +__all__ = [ + "emd", + "emd2", + "barycenter", + "free_support_barycenter", + "cvx", + "emd_1d_sorted", + "emd_1d", + "emd2_1d", + "wasserstein_1d", + "generalized_free_support_barycenter", + "binary_search_circle", + "wasserstein_circle", + "semidiscrete_wasserstein2_unif_circle", + "dmmot_monge_1dgrid_loss", + "dmmot_monge_1dgrid_optimize", +] def check_number_threads(numThreads): @@ -48,10 +61,14 @@ def check_number_threads(numThreads): numThreads : int Corrected number of threads """ - if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'): + if (numThreads is None) or ( + isinstance(numThreads, str) and numThreads.lower() == "max" + ): return -1 if (not isinstance(numThreads, int)) or numThreads < 1: - raise ValueError('numThreads should either be "max" or a strictly positive integer') + raise ValueError( + 'numThreads should either be "max" or a strictly positive integer' + ) return numThreads @@ -202,7 +219,16 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): return center_ot_dual(alpha, beta, a, b) -def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, check_marginals=True): +def emd( + a, + b, + M, + numItermax=100000, + log=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -318,16 +344,13 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c if len(b) == 0: b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - # store original tensors - a0, b0, M0 = a, b, M - # convert to numpy M, a, b = nx.to_numpy(M, a, b) # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order='C') + M = np.asarray(M, dtype=np.float64, order="C") # if empty array given then use uniform distributions if len(a) == 0: @@ -335,14 +358,18 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ - "Dimension mismatch, check dimensions of M with a and b" + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass if check_marginals: - np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum', - decimal=6) + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0), + err_msg="a and b vector must have the same sum", + decimal=6, + ) b = b * a.sum() / b.sum() asel = a != 0 @@ -365,22 +392,31 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "histogram consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) if log: log = {} - log['cost'] = cost - log['u'] = nx.from_numpy(u, type_as=type_as) - log['v'] = nx.from_numpy(v, type_as=type_as) - log['warning'] = result_code_string - log['result_code'] = result_code + log["cost"] = cost + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code return nx.from_numpy(G, type_as=type_as), log return nx.from_numpy(G, type_as=type_as) -def emd2(a, b, M, processes=1, - numItermax=100000, log=False, return_matrix=False, - center_dual=True, numThreads=1, check_marginals=True): +def emd2( + a, + b, + M, + processes=1, + numItermax=100000, + log=False, + return_matrix=False, + center_dual=True, + numThreads=1, + check_marginals=True, +): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -504,16 +540,20 @@ def emd2(a, b, M, processes=1, a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order='C') + M = np.asarray(M, dtype=np.float64, order="C") - assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ - "Dimension mismatch, check dimensions of M with a and b" + assert ( + a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass if check_marginals: - np.testing.assert_almost_equal(a.sum(0), - b.sum(0, keepdims=True), err_msg='a and b vector must have the same sum', - decimal=6) + np.testing.assert_almost_equal( + a.sum(0), + b.sum(0, keepdims=True), + err_msg="a and b vector must have the same sum", + decimal=6, + ) b = b * a.sum(0) / b.sum(0, keepdims=True) asel = a != 0 @@ -521,6 +561,7 @@ def emd2(a, b, M, processes=1, numThreads = check_number_threads(numThreads) if log or return_matrix: + def f(b): bsel = b != 0 @@ -540,20 +581,23 @@ def f(b): "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "histogram consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) G = nx.from_numpy(G, type_as=type_as) if return_matrix: - log['G'] = G - log['u'] = nx.from_numpy(u, type_as=type_as) - log['v'] = nx.from_numpy(v, type_as=type_as) - log['warning'] = result_code_string - log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'] - nx.mean(log['u']), - log['v'] - nx.mean(log['v']), G)) + log["G"] = G + log["u"] = nx.from_numpy(u, type_as=type_as) + log["v"] = nx.from_numpy(v, type_as=type_as) + log["warning"] = result_code_string + log["result_code"] = result_code + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + ) return [cost, log] else: + def f(b): bsel = b != 0 G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) @@ -570,12 +614,18 @@ def f(b): "casted accordingly, possibly resulting in a loss of precision. " "If this behaviour is unwanted, please make sure your input " "histogram consists of floating point elements.", - stacklevel=2 + stacklevel=2, ) G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), G)) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) check_result(result_code) return cost @@ -594,8 +644,18 @@ def f(b): return res -def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, - stopThr=1e-7, verbose=False, log=None, numThreads=1): +def free_support_barycenter( + measures_locations, + measures_weights, + X_init, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, +): r""" Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: @@ -680,16 +740,19 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None log_dict = {} displacement_square_norms = [] - displacement_square_norm = stopThr + 1. - - while (displacement_square_norm > stopThr and iter_count < numItermax): + displacement_square_norm = stopThr + 1.0 + while displacement_square_norm > stopThr and iter_count < numItermax: T_sum = nx.zeros((k, d), type_as=X_init) - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + for measure_locations_i, measure_weights_i, weight_i in zip( + measures_locations, measures_weights, weights + ): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( + T_i, measure_locations_i + ) displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: @@ -698,19 +761,36 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None X = T_sum if verbose: - print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm) + print( + "iteration %d, displacement_square_norm=%f\n", + iter_count, + displacement_square_norm, + ) iter_count += 1 if log: - log_dict['displacement_square_norms'] = displacement_square_norms + log_dict["displacement_square_norms"] = displacement_square_norms return X, log_dict else: return X -def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None, - numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0): +def generalized_free_support_barycenter( + X_list, + a_list, + P_list, + n_samples_bary, + Y_init=None, + b=None, + weights=None, + numItermax=100, + stopThr=1e-7, + verbose=False, + log=None, + numThreads=1, + eps=0, +): r""" Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with a fixed amount of points of uniform weights) whose respective projections fit the input measures. @@ -789,12 +869,16 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, weights = nx.ones(p, type_as=X_list[0]) / p # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB) - A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A - for (P_i, lambda_i) in zip(P_list, weights): + A = eps * nx.eye( + d, type_as=X_list[0] + ) # if eps nonzero: will force the invertibility of A + for P_i, lambda_i in zip(P_list, weights): A = A + lambda_i * P_i.T @ P_i B = nx.inv(nx.sqrtm(A)) - Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z + Z_list = [ + x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list) + ] # change of variables -> (WB) problem on Z if Y_init is None: Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0]) @@ -802,8 +886,17 @@ def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, if b is None: b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimized - out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads) + out = free_support_barycenter( + Z_list, + a_list, + Y_init, + b, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + numThreads=numThreads, + ) if log: # unpack Y, log_dict = out diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index f9572962a..01f5e5d87 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -25,7 +25,7 @@ def scipy_sparse_to_spmatrix(A): return SP -def barycenter(A, M, weights=None, verbose=False, log=False, solver='highs-ipm'): +def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"): r"""Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -95,7 +95,9 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='highs-ipm') lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)] # row constraints - A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))) + A_eq1 = sps.hstack( + (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))) + ) # columns constraints lst_idiag2 = [] @@ -115,28 +117,33 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='highs-ipm') A_eq = sps.vstack((A_eq1, A_eq2)) b_eq = np.concatenate((b_eq1, b_eq2)) - if not cvxopt or solver in ['interior-point', 'highs', 'highs-ipm', 'highs-ds']: + if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]: # cvxopt not installed or interior point if solver is None: - solver = 'interior-point' + solver = "interior-point" - options = {'disp': verbose} - sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, - options=options) + options = {"disp": verbose} + sol = sp.optimize.linprog( + c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options + ) x = sol.x b = x[-n:] else: - h = np.zeros((n_distributions * n2 + n)) G = -sps.eye(n_distributions * n2 + n) - sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), - A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), - solver=solver) + sol = solvers.lp( + matrix(c), + scipy_sparse_to_spmatrix(G), + matrix(h), + A=scipy_sparse_to_spmatrix(A_eq), + b=matrix(b_eq), + solver=solver, + ) - x = np.array(sol['x']) + x = np.array(sol["x"]) b = x[-n:].ravel() if log: diff --git a/ot/lp/dmmot.py b/ot/lp/dmmot.py index 8576c3c61..f40a99b6f 100644 --- a/ot/lp/dmmot.py +++ b/ot/lp/dmmot.py @@ -51,7 +51,7 @@ def dist_monge_max_min(i): Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi: https://doi.org/10.1016/0166-218X(93)E0121-E. URL https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E. - Workshop on Discrete Algoritms. + Workshop on Discrete Algorithms. """ return max(i) - min(i) @@ -146,11 +146,13 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): xx = {} dual = [np.zeros(d) for d in dims] - idx = [0, ] * len(AA) + idx = [ + 0, + ] * len(AA) obj = 0 if verbose: - print('i minval oldidx\t\tobj\t\tvals') + print("i minval oldidx\t\tobj\t\tvals") while all([i < _ for _, i in zip(dims, idx)]): vals = [v[i] for v, i in zip(AA, idx)] @@ -164,12 +166,14 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): oldidx = idx.copy() idx[i] += 1 if idx[i] < dims[i]: - temp = (dist_monge_max_min(idx) - - dist_monge_max_min(oldidx) + - dual[i][idx[i] - 1]) + temp = ( + dist_monge_max_min(idx) + - dist_monge_max_min(oldidx) + + dual[i][idx[i] - 1] + ) dual[i][idx[i]] += temp if verbose: - print(i, minval, oldidx, obj, '\t', vals) + print(i, minval, oldidx, obj, "\t", vals) # the above terminates when any entry in idx equals the corresponding # value in dims this leaves other dimensions incomplete; the remaining @@ -183,10 +187,12 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): dualobj = sum([np.dot(A[:, i], arr) for i, arr in enumerate(dual)]) obj = nx.from_numpy(obj) - log_dict = {'A': xx, - 'primal objective': obj, - 'dual': dual, - 'dual objective': dualobj} + log_dict = { + "A": xx, + "primal objective": obj, + "dual": dual, + "dual objective": dualobj, + } # define forward/backward relations for pytorch obj = nx.set_gradients(obj, (A_copy), (dual)) @@ -198,13 +204,14 @@ def dmmot_monge_1dgrid_loss(A, verbose=False, log=False): def dmmot_monge_1dgrid_optimize( - A, - niters=100, - lr_init=1e-5, - lr_decay=0.995, - print_rate=100, - verbose=False, - log=False): + A, + niters=100, + lr_init=1e-5, + lr_decay=0.995, + print_rate=100, + verbose=False, + log=False, +): r"""Minimize the d-dimensional EMD using gradient descent. Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots, @@ -300,9 +307,8 @@ def dmmot_monge_1dgrid_optimize( n, d = A.shape # n is dim, d is n_hists def dualIter(A, lr): - funcval, log_dict = dmmot_monge_1dgrid_loss( - A, verbose=verbose, log=True) - grad = np.column_stack(log_dict['dual']) + funcval, log_dict = dmmot_monge_1dgrid_loss(A, verbose=verbose, log=True) + grad = np.column_stack(log_dict["dual"]) A_new = np.reshape(A, (n, d)) - grad * lr return funcval, A_new, grad, log_dict @@ -322,16 +328,15 @@ def listify(A): funcval, _, grad, log_dict = dualIter(A, lr) gn = np.linalg.norm(grad) - print(f'Inital:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + print(f"Initial:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}") for i in range(niters): - A = renormalize(A) funcval, A, grad, log_dict = dualIter(A, lr) gn = np.linalg.norm(grad) if i % print_rate == 0: - print(f'Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}') + print(f"Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}") lr *= lr_decay diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 6d97303e2..e8af20c3c 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -17,7 +17,7 @@ def quantile_function(qs, cws, xs): - r""" Computes the quantile function of an empirical distribution + r"""Computes the quantile function of an empirical distribution Parameters ---------- @@ -35,7 +35,7 @@ def quantile_function(qs, cws, xs): """ nx = get_backend(qs, cws) n = xs.shape[0] - if nx.__name__ == 'torch': + if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays cws = cws.T.contiguous() @@ -47,7 +47,9 @@ def quantile_function(qs, cws, xs): return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) -def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): +def wasserstein_1d( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True +): r""" Computes the 1 dimensional OT loss [15] between two (batched) empirical distributions @@ -100,11 +102,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ m = v_values.shape[0] if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) @@ -133,8 +135,17 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False, check_marginals=True): +def emd_1d( + x_a, + x_b, + a=None, + b=None, + metric="sqeuclidean", + p=1.0, + dense=True, + log=False, + check_marginals=True, +): r"""Solves the Earth Movers distance problem between 1d measures and returns the OT matrix @@ -230,15 +241,17 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b is not None: b = list_to_array(b, nx=nx) - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - if metric not in ['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']: + assert ( + x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1 + ), "emd_1d should only be used with monodimensional data" + assert ( + x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1 + ), "emd_1d should only be used with monodimensional data" + if metric not in ["sqeuclidean", "minkowski", "cityblock", "euclidean"]: raise ValueError( - "Solver for EMD in 1d only supports metrics " + - "from the following list: " + - "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + "Solver for EMD in 1d only supports metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) # if empty array given then use uniform distributions @@ -252,8 +265,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, np.testing.assert_almost_equal( nx.to_numpy(nx.sum(a, axis=0)), nx.to_numpy(nx.sum(b, axis=0)), - err_msg='a and b vector must have the same sum', - decimal=6 + err_msg="a and b vector must have the same sum", + decimal=6, ) b = b * nx.sum(a) / nx.sum(b) @@ -267,7 +280,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, nx.to_numpy(b[perm_b]).astype(np.float64), nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), - metric=metric, p=p + metric=metric, + p=p, ) G = nx.coo_matrix( @@ -275,20 +289,21 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_a[indices[:, 0]], perm_b[indices[:, 1]], shape=(a.shape[0], b.shape[0]), - type_as=x_a + type_as=x_a, ) if dense: G = nx.todense(G) elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': nx.from_numpy(cost, type_as=x_a)} + log = {"cost": nx.from_numpy(cost, type_as=x_a)} return G, log return G -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): +def emd2_1d( + x_a, x_b, a=None, b=None, metric="sqeuclidean", p=1.0, dense=True, log=False +): r"""Solves the Earth Movers distance problem between 1d measures and returns the loss @@ -374,11 +389,12 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, """ # If we do not return G (log==False), then we should not to cast it to dense # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] + G, log_emd = emd_1d( + x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, dense=dense and log, log=True + ) + cost = log_emd["cost"] if log: - log_emd = {'G': G} + log_emd = {"G": G} return cost, log_emd return cost @@ -417,14 +433,16 @@ def roll_cols(M, shifts): n_rows, n_cols = M.shape - arange1 = nx.tile(nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1)) + arange1 = nx.tile( + nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) + ) arange2 = (arange1 - shifts) % n_cols return nx.take_along_axis(M, arange2, 1) def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): - r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) Parameters ---------- @@ -472,13 +490,15 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf - shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + shift = -nx.argmin(v_cdf_theta2, axis=-1) v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) - if nx.__name__ == 'torch': + if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() @@ -490,9 +510,11 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): # Deal with 1 u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) - u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate( + [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 + ) - if nx.__name__ == 'torch': + if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays u_cdfm = u_cdfm.contiguous() @@ -501,17 +523,23 @@ def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) - dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1) + dCp = nx.sum( + nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), + axis=-1, + ) - dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1) + dCm = nx.sum( + nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), + axis=-1, + ) return dCp.reshape(-1, 1), dCm.reshape(-1, 1) def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): - r""" Computes the the cost (Equation (6.2) of [1]) + r"""Computes the the cost (Equation (6.2) of [1]) Parameters ---------- @@ -558,11 +586,13 @@ def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): # Put negative values at the end v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf - shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + shift = -nx.argmin(v_cdf_theta2, axis=-1) v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) # Compute absciss cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) @@ -570,9 +600,9 @@ def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] - if nx.__name__ == 'torch': + if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted - # and avoid a warninng related to non-contiguous arrays + # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() v_cdf_theta = v_cdf_theta.contiguous() cdf_axis = cdf_axis.contiguous() @@ -581,7 +611,9 @@ def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): u_index = nx.searchsorted(u_cdf, cdf_axis) u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) - v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) v_index = nx.searchsorted(v_cdf_theta, cdf_axis) v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) @@ -593,9 +625,20 @@ def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): return ot_cost -def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, - Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, - log=False): +def binary_search_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, + log=False, +): r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. @@ -681,18 +724,20 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batches {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) u_values = u_values % 1 v_values = v_values % 1 if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) @@ -733,18 +778,30 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if nx.any(mask): # can probably be improved by computing only relevant values - dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p) - dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p) - Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) - Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + dCptp, dCmtp = derivative_cost_on_circle( + tp, u_values, v_values, u_cdf, v_cdf, p + ) + dCptm, dCmtm = derivative_cost_on_circle( + tm, u_values, v_values, u_cdf, v_cdf, p + ) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) - tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] + tc[mask_end > 0] = ( + (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) + )[mask_end > 0] done[nx.prod(mask, axis=-1) > 0] = 1 elif nx.any(1 - done): tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] - tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 + tc[((1 - mask) * (1 - done)) > 0] = ( + tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] + ) / 2 w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) @@ -753,7 +810,9 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 return w -def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): +def wasserstein1_circle( + u_values, v_values, u_weights=None, v_weights=None, require_sort=True +): r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. @@ -810,18 +869,20 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) u_values = u_values % 1 v_values = v_values % 1 if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) @@ -838,7 +899,12 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) - cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0) + cdf_diff = nx.cumsum( + nx.take_along_axis( + nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 + ), + 0, + ) cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) @@ -854,8 +920,19 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) -def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, - Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): +def wasserstein_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, +): r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, @@ -927,11 +1004,23 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if p == 1: - return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort) + return wasserstein1_circle( + u_values, v_values, u_weights, v_weights, require_sort + ) - return binary_search_circle(u_values, v_values, u_weights, v_weights, - p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, - require_sort=require_sort) + return binary_search_circle( + u_values, + v_values, + u_weights, + v_weights, + p=p, + Lm=Lm, + Lp=Lp, + tm=tm, + tp=tp, + eps=eps, + require_sort=require_sort, + ) def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): @@ -991,7 +1080,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): u_values = nx.reshape(u_values, (n, 1)) if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) diff --git a/ot/mapping.py b/ot/mapping.py index 13ed55deb..dae059edb 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -20,8 +20,18 @@ from .utils import dist, unif, list_to_array, kernel, dots -def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, - gradient_lipschitz_constant=1.4, its=100, log=False, init_method='barycentric'): +def nearest_brenier_potential_fit( + X, + V, + X_classes=None, + a=None, + b=None, + strongly_convex_constant=0.6, + gradient_lipschitz_constant=1.4, + its=100, + log=False, + init_method="barycentric", +): r""" Computes optimal values and gradients at X for a strongly convex potential :math:`\varphi` with Lipschitz gradients on the partitions defined by `X_classes`, where :math:`\varphi` is optimal such that @@ -105,9 +115,11 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly try: import cvxpy as cvx except ImportError: - print('Please install CVXPY to use this function') + print("Please install CVXPY to use this function") return - assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" + assert ( + X.shape == V.shape + ), f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" nx = get_backend(X, V, X_classes, a, b) X, V = to_numpy(X), to_numpy(V) n, d = X.shape @@ -118,19 +130,18 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly X_classes = np.zeros(n) a = unif(n) if a is None else nx.to_numpy(a) b = unif(n) if b is None else nx.to_numpy(b) - assert a.shape[-1] == b.shape[-1] == n, 'incorrect measure weight sizes' + assert a.shape[-1] == b.shape[-1] == n, "incorrect measure weight sizes" - assert init_method in ['target', 'barycentric'], f"Unsupported initialization method '{init_method}'" - if init_method == 'target': + assert init_method in [ + "target", + "barycentric", + ], f"Unsupported initialization method '{init_method}'" + if init_method == "target": G_val = V else: # Init G_val with barycentric projection G_val = emd(a, b, dist(X, V)) @ V / a.reshape(n, 1) phi_val = None - log_dict = { - 'G_list': [], - 'phi_list': [], - 'its': [] - } + log_dict = {"G_list": [], "phi_list": [], "its": []} for _ in range(its): # alternate optimisation iterations cost_matrix = dist(G_val, V) @@ -146,29 +157,35 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly for j in range(n): cost += cvx.sum_squares(G[i, :] - V[j, :]) * plan[i, j] objective = cvx.Minimize(cost) # OT cost - c1, c2, c3 = _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + c1, c2, c3 = _ssnb_qcqp_constants( + strongly_convex_constant, gradient_lipschitz_constant + ) for k in np.unique(X_classes): # constraints for the convex interpolation for i in np.where(X_classes == k)[0]: for j in np.where(X_classes == k)[0]: constraints += [ - phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) - + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) + phi[i] + >= phi[j] + + G[j].T @ (X[i] - X[j]) + + c1 * cvx.sum_squares(G[i] - G[j]) + + c2 * cvx.sum_squares(X[i] - X[j]) + - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) phi_val, G_val = phi.value, G.value it_log_dict = { - 'solve_time': problem.solver_stats.solve_time, - 'setup_time': problem.solver_stats.setup_time, - 'num_iters': problem.solver_stats.num_iters, - 'status': problem.status, - 'value': problem.value + "solve_time": problem.solver_stats.solve_time, + "setup_time": problem.solver_stats.setup_time, + "num_iters": problem.solver_stats.num_iters, + "status": problem.status, + "value": problem.value, } if log: - log_dict['its'].append(it_log_dict) - log_dict['G_list'].append(G_val) - log_dict['phi_list'].append(phi_val) + log_dict["its"].append(it_log_dict) + log_dict["G_list"].append(G_val) + log_dict["phi_list"].append(phi_val) # convert back to backend phi_val = nx.from_numpy(phi_val) @@ -194,7 +211,9 @@ def _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): c3 : float """ - assert 0 < strongly_convex_constant < gradient_lipschitz_constant, "incompatible regularity assumption" + assert ( + 0 < strongly_convex_constant < gradient_lipschitz_constant + ), "incompatible regularity assumption" c = 1 / (2 * (1 - strongly_convex_constant / gradient_lipschitz_constant)) c1 = c / gradient_lipschitz_constant c2 = strongly_convex_constant * c @@ -202,8 +221,17 @@ def _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): return c1, c2, c3 -def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_classes=None, - strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False): +def nearest_brenier_potential_predict_bounds( + X, + phi, + G, + Y, + X_classes=None, + Y_classes=None, + strongly_convex_constant=0.6, + gradient_lipschitz_constant=1.4, + log=False, +): r""" Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal values phi at X and their gradients G at X. The 'lower' potential corresponds to the method from :ref:`[58]`, @@ -292,7 +320,7 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla try: import cvxpy as cvx except ImportError: - print('Please install CVXPY to use this function') + print("Please install CVXPY to use this function") return nx = get_backend(X, phi, G, Y) X = to_numpy(X) @@ -302,18 +330,22 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla m, d = Y.shape if Y_classes is not None: Y_classes = to_numpy(Y_classes) - assert Y_classes.size == m, 'wrong number of class items for Y' + assert Y_classes.size == m, "wrong number of class items for Y" else: Y_classes = np.zeros(m) - assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' + assert ( + X.shape[1] == d + ), f"incompatible dimensions between X: {X.shape} and Y: {Y.shape}" n, _ = X.shape if X_classes is not None: X_classes = to_numpy(X_classes) assert X_classes.size == n, "incorrect number of class items" else: X_classes = np.zeros(n) - assert X_classes.size == n, 'wrong number of class items for X' - c1, c2, c3 = _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + assert X_classes.size == n, "wrong number of class items for X" + c1, c2, c3 = _ssnb_qcqp_constants( + strongly_convex_constant, gradient_lipschitz_constant + ) phi_lu = np.zeros((2, m)) G_lu = np.zeros((2, m, d)) log_dict = {} @@ -328,20 +360,24 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla k = Y_classes[y_idx] for j in np.where(X_classes == k)[0]: constraints += [ - phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) - + c2 * cvx.sum_squares(Y[y_idx] - X[j]) - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) + phi_l_y + >= phi[j] + + G[j].T @ (Y[y_idx] - X[j]) + + c1 * cvx.sum_squares(G_l_y - G[j]) + + c2 * cvx.sum_squares(Y[y_idx] - X[j]) + - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) phi_lu[0, y_idx] = phi_l_y.value G_lu[0, y_idx] = G_l_y.value if log: - log_item['l'] = { - 'solve_time': problem.solver_stats.solve_time, - 'setup_time': problem.solver_stats.setup_time, - 'num_iters': problem.solver_stats.num_iters, - 'status': problem.status, - 'value': problem.value + log_item["l"] = { + "solve_time": problem.solver_stats.solve_time, + "setup_time": problem.solver_stats.setup_time, + "num_iters": problem.solver_stats.num_iters, + "status": problem.status, + "value": problem.value, } # upper bound @@ -351,20 +387,24 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla constraints = [] for i in np.where(X_classes == k)[0]: constraints += [ - phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) - + c2 * cvx.sum_squares(X[i] - Y[y_idx]) - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) + phi[i] + >= phi_u_y + + G_u_y.T @ (X[i] - Y[y_idx]) + + c1 * cvx.sum_squares(G[i] - G_u_y) + + c2 * cvx.sum_squares(X[i] - Y[y_idx]) + - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) phi_lu[1, y_idx] = phi_u_y.value G_lu[1, y_idx] = G_u_y.value if log: - log_item['u'] = { - 'solve_time': problem.solver_stats.solve_time, - 'setup_time': problem.solver_stats.setup_time, - 'num_iters': problem.solver_stats.num_iters, - 'status': problem.status, - 'value': problem.value + log_item["u"] = { + "solve_time": problem.solver_stats.solve_time, + "setup_time": problem.solver_stats.setup_time, + "num_iters": problem.solver_stats.num_iters, + "status": problem.status, + "value": problem.value, } log_dict[y_idx] = log_item @@ -374,10 +414,21 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla return phi_lu, G_lu, log_dict -def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, - verbose2=False, numItermax=100, numInnerItermax=10, - stopInnerThr=1e-6, stopThr=1e-5, log=False, - **kwargs): +def joint_OT_mapping_linear( + xs, + xt, + mu=1, + eta=0.001, + bias=False, + verbose=False, + verbose2=False, + numItermax=100, + numInnerItermax=10, + stopInnerThr=1e-6, + stopThr=1e-5, + log=False, + **kwargs, +): r"""Joint OT and linear mapping estimation as proposed in :ref:`[8] `. @@ -487,7 +538,7 @@ def sel(x): return x if log: - log = {'err': []} + log = {"err": []} a = unif(ns, type_as=xs) b = unif(nt, type_as=xt) @@ -505,7 +556,7 @@ def loss(L, G): ) def solve_L(G): - """ solve L problem with fixed G (least square)""" + """solve L problem with fixed G (least square)""" xst = ns * nx.dot(G, xt) return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) @@ -519,8 +570,17 @@ def f(G): def df(G): return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) - G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, - numItermax=numInnerItermax, stopThr=stopInnerThr) + G = cg( + a, + b, + M, + 1.0 / mu, + f, + df, + G0=G0, + numItermax=numInnerItermax, + stopThr=stopInnerThr, + ) return G L = solve_L(G) @@ -528,9 +588,10 @@ def df(G): vloss.append(loss(L, G)) if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + print( + "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss") + "\n" + "-" * 32 + ) + print("{:5d}|{:8e}|{:8e}".format(0, vloss[-1], 0)) # init loop if numItermax > 0: @@ -540,7 +601,6 @@ def df(G): it = 0 while loop: - it += 1 # update G @@ -559,22 +619,40 @@ def df(G): if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format( - it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + print( + "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss") + + "\n" + + "-" * 32 + ) + print( + "{:5d}|{:8e}|{:8e}".format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]) + ) + ) if log: - log['loss'] = vloss + log["loss"] = vloss return G, L, log else: return G, L -def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', - sigma=1, bias=False, verbose=False, verbose2=False, - numItermax=100, numInnerItermax=10, - stopInnerThr=1e-6, stopThr=1e-5, log=False, - **kwargs): +def joint_OT_mapping_kernel( + xs, + xt, + mu=1, + eta=0.001, + kerneltype="gaussian", + sigma=1, + bias=False, + verbose=False, + verbose2=False, + numItermax=100, + numInnerItermax=10, + stopInnerThr=1e-6, + stopThr=1e-5, + log=False, + **kwargs, +): r"""Joint OT and nonlinear mapping estimation with kernels as proposed in :ref:`[8] `. @@ -701,7 +779,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Kreg = K if log: - log = {'err': []} + log = {"err": []} a = unif(ns, type_as=xs) b = unif(nt, type_as=xt) @@ -719,12 +797,12 @@ def loss(L, G): ) def solve_L_nobias(G): - """ solve L problem with fixed G (least square)""" + """solve L problem with fixed G (least square)""" xst = ns * nx.dot(G, xt) return nx.solve(K0, xst) def solve_L_bias(G): - """ solve L problem with fixed G (least square)""" + """solve L problem with fixed G (least square)""" xst = ns * nx.dot(G, xt) return nx.solve(K0, nx.dot(K1.T, xst)) @@ -738,8 +816,17 @@ def f(G): def df(G): return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) - G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, - numItermax=numInnerItermax, stopThr=stopInnerThr) + G = cg( + a, + b, + M, + 1.0 / mu, + f, + df, + G0=G0, + numItermax=numInnerItermax, + stopThr=stopInnerThr, + ) return G if bias: @@ -752,9 +839,10 @@ def df(G): vloss.append(loss(L, G)) if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + print( + "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss") + "\n" + "-" * 32 + ) + print("{:5d}|{:8e}|{:8e}".format(0, vloss[-1], 0)) # init loop if numItermax > 0: @@ -764,7 +852,6 @@ def df(G): it = 0 while loop: - it += 1 # update G @@ -783,12 +870,18 @@ def df(G): if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format( - it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + print( + "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss") + + "\n" + + "-" * 32 + ) + print( + "{:5d}|{:8e}|{:8e}".format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]) + ) + ) if log: - log['loss'] = vloss + log["loss"] = vloss return G, L, log else: return G, L diff --git a/ot/optim.py b/ot/optim.py index d4db59b68..ae8b0ba58 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -26,8 +26,18 @@ def line_search_armijo( - f, xk, pk, gfk, old_fval, args=(), c1=1e-4, - alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs + f, + xk, + pk, + gfk, + old_fval, + args=(), + c1=1e-4, + alpha0=0.99, + alpha_min=0.0, + alpha_max=None, + nx=None, + **kwargs, ): r""" Armijo linesearch function that works with matrices @@ -107,7 +117,7 @@ def phi(alpha1): return nx.to_numpy(fval) if old_fval is None: - phi0 = phi(0.) + phi0 = phi(0.0) elif isinstance(old_fval, float): # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval @@ -116,18 +126,39 @@ def phi(alpha1): derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( - phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min) + phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min + ) if alpha is None: - return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) + return 0.0, fc[0], nx.from_numpy(phi0, type_as=xk0) else: alpha = np.clip(alpha, alpha_min, alpha_max) - return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0) - - -def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, - numItermax=200, stopThr=1e-9, - stopThr2=1e-9, verbose=False, log=False, nx=None, **kwargs): + return ( + nx.from_numpy(alpha, type_as=xk0), + fc[0], + nx.from_numpy(phi1, type_as=xk0), + ) + + +def generic_conditional_gradient( + a, + b, + M, + f, + df, + reg1, + reg2, + lp_solver, + line_search, + G0=None, + numItermax=200, + stopThr=1e-9, + stopThr2=1e-9, + verbose=False, + log=False, + nx=None, + **kwargs, +): r""" Solve the general regularized OT problem or its semi-relaxed version with conditional gradient or generalized conditional gradient depending on the @@ -278,7 +309,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): loop = 1 if log: - log = {'loss': []} + log = {"loss": []} if G0 is None: G = nx.outer(a, b) @@ -287,25 +318,32 @@ def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): G = nx.copy(G0) if reg2 is None: + def cost(G): return nx.sum(M * G) + reg1 * f(G) else: + def cost(G): return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) + cost_G = cost(G) if log: - log['loss'].append(cost_G) + log["loss"].append(cost_G) df_G = None it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) + print( + "{:5s}|{:12s}|{:8s}|{:8s}".format( + "It.", "Loss", "Relative loss", "Absolute loss" + ) + + "\n" + + "-" * 48 + ) + print("{:5d}|{:8e}|{:8e}|{:8e}".format(it, cost_G, 0, 0)) while loop: - it += 1 old_cost_G = cost_G # problem linearization @@ -313,7 +351,7 @@ def cost(G): df_G = df(G) Mi = M + reg1 * df_G - if not (reg2 is None): + if reg2 is not None: Mi = Mi + reg2 * (1 + nx.log(G)) # solve linear program @@ -338,18 +376,29 @@ def cost(G): loop = 0 abs_delta_cost_G = abs(cost_G - old_cost_G) - relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan + relative_delta_cost_G = ( + abs_delta_cost_G / abs(cost_G) if cost_G != 0.0 else np.nan + ) if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: - log['loss'].append(cost_G) + log["loss"].append(cost_G) if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}|{:8s}'.format( - 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) - print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) + print( + "{:5s}|{:12s}|{:8s}|{:8s}".format( + "It.", "Loss", "Relative loss", "Absolute loss" + ) + + "\n" + + "-" * 48 + ) + print( + "{:5d}|{:8e}|{:8e}|{:8e}".format( + it, cost_G, relative_delta_cost_G, abs_delta_cost_G + ) + ) if log: log.update(innerlog_) @@ -358,9 +407,24 @@ def cost(G): return G -def cg(a, b, M, reg, f, df, G0=None, line_search=None, - numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, - verbose=False, log=False, nx=None, **kwargs): +def cg( + a, + b, + M, + reg, + f, + df, + G0=None, + line_search=None, + numItermax=200, + numItermaxEmd=100000, + stopThr=1e-9, + stopThr2=1e-9, + verbose=False, + log=False, + nx=None, + **kwargs, +): r""" Solve the general regularized OT problem with conditional gradient @@ -444,19 +508,51 @@ def cg(a, b, M, reg, f, df, G0=None, line_search=None, nx = get_backend(a, b, M) if line_search is None: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs) def lp_solver(a, b, M, **kwargs): return emd(a, b, M, numItermaxEmd, log=True) - return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, - numItermax=numItermax, stopThr=stopThr, - stopThr2=stopThr2, verbose=verbose, log=log, nx=nx, **kwargs) - - -def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=None, - numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, nx=None, **kwargs): + return generic_conditional_gradient( + a, + b, + M, + f, + df, + reg, + None, + lp_solver, + line_search, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + stopThr2=stopThr2, + verbose=verbose, + log=log, + nx=nx, + **kwargs, + ) + + +def semirelaxed_cg( + a, + b, + M, + reg, + f, + df, + G0=None, + line_search=None, + numItermax=200, + stopThr=1e-9, + stopThr2=1e-9, + verbose=False, + log=False, + nx=None, + **kwargs, +): r""" Solve the general regularized and semi-relaxed OT problem with conditional gradient @@ -533,6 +629,7 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=None, nx = get_backend(a, b, M) if line_search is None: + def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs) @@ -540,20 +637,53 @@ def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask min_ = nx.reshape(nx.min(Mi, axis=1), (-1, 1)) # instead of exact elements equal to min_ we consider a small margin (1e-15) - # for float precision issues. Then the mass is splitted uniformly + # for float precision issues. Then the mass is split uniformly # between these elements. Gc = nx.ones(1, type_as=a) * (Mi <= min_ + 1e-15) Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) # return by default an empty inner_log return Gc, {} - return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, - numItermax=numItermax, stopThr=stopThr, - stopThr2=stopThr2, verbose=verbose, log=log, nx=nx, **kwargs) - - -def partial_cg(a, b, a_extended, b_extended, M, reg, f, df, G0=None, line_search=line_search_armijo, - numItermax=200, stopThr=1e-9, stopThr2=1e-9, warn=True, verbose=False, log=False, **kwargs): + return generic_conditional_gradient( + a, + b, + M, + f, + df, + reg, + None, + lp_solver, + line_search, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + stopThr2=stopThr2, + verbose=verbose, + log=log, + nx=nx, + **kwargs, + ) + + +def partial_cg( + a, + b, + a_extended, + b_extended, + M, + reg, + f, + df, + G0=None, + line_search=line_search_armijo, + numItermax=200, + stopThr=1e-9, + stopThr2=1e-9, + warn=True, + verbose=False, + log=False, + **kwargs, +): r""" Solve the general regularized partial OT problem with conditional gradient @@ -641,23 +771,57 @@ def lp_solver(a, b, Mi, **kwargs): Mi_extended[:n, :m] = Mi Mi_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 - G_extended, log_ = emd(a_extended, b_extended, Mi_extended, numItermax, log=True) + G_extended, log_ = emd( + a_extended, b_extended, Mi_extended, numItermax, log=True + ) Gc = G_extended[:n, :m] if warn: - if log_['warning'] is not None: - raise ValueError("Error in the EMD resolution: try to increase the" - " number of dummy points") + if log_["warning"] is not None: + raise ValueError( + "Error in the EMD resolution: try to increase the" + " number of dummy points" + ) return Gc, log_ - return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, - numItermax=numItermax, stopThr=stopThr, - stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) - - -def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): + return generic_conditional_gradient( + a, + b, + M, + f, + df, + reg, + None, + lp_solver, + line_search, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + stopThr2=stopThr2, + verbose=verbose, + log=log, + **kwargs, + ) + + +def gcg( + a, + b, + M, + reg1, + reg2, + f, + df, + G0=None, + numItermax=10, + numInnerItermax=200, + stopThr=1e-9, + stopThr2=1e-9, + verbose=False, + log=False, + **kwargs, +): r""" Solve the general regularized OT problem with the generalized conditional gradient @@ -738,8 +902,24 @@ def lp_solver(a, b, Mi, **kwargs): def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) - return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, - numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs) + return generic_conditional_gradient( + a, + b, + M, + f, + df, + reg2, + reg1, + lp_solver, + line_search, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + stopThr2=stopThr2, + verbose=verbose, + log=log, + **kwargs, + ) def solve_1d_linesearch_quad(a, b): @@ -761,10 +941,10 @@ def solve_1d_linesearch_quad(a, b): The optimal value which leads to the minimal cost """ if a > 0: # convex - minimum = min(1., max(0., -b / (2.0 * a))) + minimum = min(1.0, max(0.0, -b / (2.0 * a))) return minimum else: # non convex if a + b < 0: - return 1. + return 1.0 else: - return 0. + return 0.0 diff --git a/ot/partial.py b/ot/partial.py index a409a74d1..c11ab228a 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -16,8 +16,9 @@ # License: MIT License -def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, - **kwargs): +def partial_wasserstein_lagrange( + a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs +): r""" Solves the partial optimal transport problem for the quadratic cost and returns the OT plan @@ -125,8 +126,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, nx = get_backend(a, b, M) if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors - raise ValueError("Problem infeasible. Check that a and b are in the " - "simplex") + raise ValueError("Problem infeasible. Check that a and b are in the " "simplex") if reg_m is None: reg_m = float(nx.max(M)) + 1 @@ -151,27 +151,31 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, # extend a, b, M with "reservoir" or "dummy" points M_extended = np.zeros((len(idx_x) + nb_dummies, len(idx_y) + nb_dummies)) - M_extended[:len(idx_x), :len(idx_y)] = M_star[np.ix_(idx_x, idx_y)] + M_extended[: len(idx_x), : len(idx_y)] = M_star[np.ix_(idx_x, idx_y)] - a_extended = np.append(a[idx_x], [(np.sum(a) - np.sum(a[idx_x]) + - np.sum(b)) / nb_dummies] * nb_dummies) - b_extended = np.append(b[idx_y], [(np.sum(b) - np.sum(b[idx_y]) + - np.sum(a)) / nb_dummies] * nb_dummies) + a_extended = np.append( + a[idx_x], [(np.sum(a) - np.sum(a[idx_x]) + np.sum(b)) / nb_dummies] * nb_dummies + ) + b_extended = np.append( + b[idx_y], [(np.sum(b) - np.sum(b[idx_y]) + np.sum(a)) / nb_dummies] * nb_dummies + ) - gamma_extended, log_emd = emd(a_extended, b_extended, M_extended, log=True, - **kwargs) + gamma_extended, log_emd = emd( + a_extended, b_extended, M_extended, log=True, **kwargs + ) gamma = np.zeros((len(a), len(b))) gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies] # convert back to backend gamma = nx.from_numpy(gamma, type_as=M0) - if log_emd['warning'] is not None: - raise ValueError("Error in the EMD resolution: try to increase the" - " number of dummy points") - log_emd['cost'] = nx.sum(gamma * M0) - log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0) - log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0) + if log_emd["warning"] is not None: + raise ValueError( + "Error in the EMD resolution: try to increase the" " number of dummy points" + ) + log_emd["cost"] = nx.sum(gamma * M0) + log_emd["u"] = nx.from_numpy(log_emd["u"], type_as=a0) + log_emd["v"] = nx.from_numpy(log_emd["v"], type_as=b0) if log: return gamma, log_emd @@ -283,11 +287,12 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): if m is None: return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies b_extended = nx.concatenate((b, b_extension)) @@ -295,22 +300,26 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): a_extended = nx.concatenate((a, a_extension)) M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2 M_extended = nx.concatenate( - (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1), - nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)), - axis=0 + ( + nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1), + nx.concatenate( + (nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1 + ), + ), + axis=0, ) - gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, - **kwargs) + gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, **kwargs) - gamma = gamma[:len(a), :len(b)] + gamma = gamma[: len(a), : len(b)] - if log_emd['warning'] is not None: - raise ValueError("Error in the EMD resolution: try to increase the" - " number of dummy points") - log_emd['partial_w_dist'] = nx.sum(M * gamma) - log_emd['u'] = log_emd['u'][:len(a)] - log_emd['v'] = log_emd['v'][:len(b)] + if log_emd["warning"] is not None: + raise ValueError( + "Error in the EMD resolution: try to increase the" " number of dummy points" + ) + log_emd["partial_w_dist"] = nx.sum(M * gamma) + log_emd["u"] = log_emd["u"][: len(a)] + log_emd["v"] = log_emd["v"][: len(b)] if log: return gamma, log_emd @@ -404,9 +413,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): nx = get_backend(a, b, M) - partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, - **kwargs) - log_w['T'] = partial_gw + partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) + log_w["T"] = partial_gw if log: return nx.sum(partial_gw * M), log_w @@ -414,8 +422,9 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): return nx.sum(partial_gw * M) -def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, - stopThr=1e-100, verbose=False, log=False): +def entropic_partial_wasserstein( + a, b, M, reg, m=None, numItermax=1000, stopThr=1e-100, verbose=False, log=False +): r""" Solves the partial optimal transport problem and returns the OT plan @@ -513,13 +522,14 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, if m is None: m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 if m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) - log_e = {'err': []} + log_e = {"err": []} if nx.__name__ == "numpy": # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute @@ -536,7 +546,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, q2 = nx.ones(K.shape, type_as=K) q3 = nx.ones(K.shape, type_as=K) - while (err > stopThr and cpt < numItermax): + while err > stopThr and cpt < numItermax: Kprev = K K = K * q1 K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) @@ -551,20 +561,19 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, q3 = q3 * K2prev / K if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)): - print('Warning: numerical errors at iteration', cpt) + print("Warning: numerical errors at iteration", cpt) break if cpt % 10 == 0: err = nx.norm(Kprev - K) if log: - log_e['err'].append(err) + log_e["err"].append(err) if verbose: if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 11) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 11) + print("{:5d}|{:8e}|".format(cpt, err)) cpt = cpt + 1 - log_e['partial_w_dist'] = nx.sum(M * K) + log_e["partial_w_dist"] = nx.sum(M * K) if log: return K, log_e else: @@ -605,10 +614,10 @@ def gwgrad_partial(C1, C2, T): warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.gwggrad` instead.", - stacklevel=2 + stacklevel=2, ) - cC1 = np.dot(C1 ** 2 / 2, np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1))) - cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 / 2) + cC1 = np.dot(C1**2 / 2, np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1))) + cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2**2 / 2) constC = cC1 + cC2 A = -np.dot(C1, T).dot(C2.T) tens = constC + A @@ -639,16 +648,28 @@ def gwloss_partial(C1, C2, T): warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.gwloss` instead.", - stacklevel=2 + stacklevel=2, ) g = gwgrad_partial(C1, C2, T) * 0.5 return np.sum(g * T) -def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, - thres=1, numItermax=1000, tol=1e-7, - log=False, verbose=False, **kwargs): +def partial_gromov_wasserstein( + C1, + C2, + p, + q, + m=None, + nb_dummies=1, + G0=None, + thres=1, + numItermax=1000, + tol=1e-7, + log=False, + verbose=False, + **kwargs, +): r""" Solves the partial optimal transport problem and returns the OT plan @@ -753,20 +774,23 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.partial_gromov_wasserstein` instead.", - stacklevel=2 + stacklevel=2, ) if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") elif m > np.min((np.sum(p), np.sum(q))): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) if G0 is None: - G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. dim_G_extended = (len(p) + nb_dummies, len(q) + nb_dummies) q_extended = np.append(q, [(np.sum(p) - m) / nb_dummies] * nb_dummies) @@ -776,36 +800,41 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, err = 1 if log: - log = {'err': []} - - while (err > tol and cpt < numItermax): + log = {"err": []} + while err > tol and cpt < numItermax: Gprev = np.copy(G0) - M = 0.5 * gwgrad_partial(C1, C2, G0) # rescaling the gradient with 0.5 for line-search while not changing Gc + M = 0.5 * gwgrad_partial( + C1, C2, G0 + ) # rescaling the gradient with 0.5 for line-search while not changing Gc M_emd = np.zeros(dim_G_extended) - M_emd[:len(p), :len(q)] = M + M_emd[: len(p), : len(q)] = M M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 M_emd = np.asarray(M_emd, dtype=np.float64) Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs) - if logemd['warning'] is not None: - raise ValueError("Error in the EMD resolution: try to increase the" - " number of dummy points") + if logemd["warning"] is not None: + raise ValueError( + "Error in the EMD resolution: try to increase the" + " number of dummy points" + ) - G0 = Gc[:len(p), :len(q)] + G0 = Gc[: len(p), : len(q)] if cpt % 10 == 0: # to speed up the computations err = np.linalg.norm(G0 - Gprev) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}|{:12s}'.format( - 'It.', 'Err', 'Loss') + '\n' + '-' * 31) - print('{:5d}|{:8e}|{:8e}'.format(cpt, err, - gwloss_partial(C1, C2, G0))) + print( + "{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss") + + "\n" + + "-" * 31 + ) + print("{:5d}|{:8e}|{:8e}".format(cpt, err, gwloss_partial(C1, C2, G0))) deltaG = G0 - Gprev a = gwloss_partial(C1, C2, deltaG) @@ -826,15 +855,27 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, cpt += 1 if log: - log['partial_gw_dist'] = gwloss_partial(C1, C2, G0) - return G0[:len(p), :len(q)], log + log["partial_gw_dist"] = gwloss_partial(C1, C2, G0) + return G0[: len(p), : len(q)], log else: - return G0[:len(p), :len(q)] - - -def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, - thres=1, numItermax=1000, tol=1e-7, - log=False, verbose=False, **kwargs): + return G0[: len(p), : len(q)] + + +def partial_gromov_wasserstein2( + C1, + C2, + p, + q, + m=None, + nb_dummies=1, + G0=None, + thres=1, + numItermax=1000, + tol=1e-7, + log=False, + verbose=False, + **kwargs, +): r""" Solves the partial optimal transport problem and returns the partial Gromov-Wasserstein discrepancy @@ -942,25 +983,34 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.partial_gromov_wasserstein2` instead.", - stacklevel=2 + stacklevel=2, ) - partial_gw, log_gw = partial_gromov_wasserstein(C1, C2, p, q, m, - nb_dummies, G0, thres, - numItermax, tol, True, - verbose, **kwargs) + partial_gw, log_gw = partial_gromov_wasserstein( + C1, C2, p, q, m, nb_dummies, G0, thres, numItermax, tol, True, verbose, **kwargs + ) - log_gw['T'] = partial_gw + log_gw["T"] = partial_gw if log: - return log_gw['partial_gw_dist'], log_gw + return log_gw["partial_gw_dist"], log_gw else: - return log_gw['partial_gw_dist'] - - -def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - numItermax=1000, tol=1e-7, log=False, - verbose=False): + return log_gw["partial_gw_dist"] + + +def entropic_partial_gromov_wasserstein( + C1, + C2, + p, + q, + reg, + m=None, + G0=None, + numItermax=1000, + tol=1e-7, + log=False, + verbose=False, +): r""" Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -1073,7 +1123,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.entropic_partial_gromov_wasserstein` instead.", - stacklevel=2 + stacklevel=2, ) if G0 is None: @@ -1082,44 +1132,57 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, if m is None: m = np.min((np.sum(p), np.sum(q))) elif m < 0: - raise ValueError("Problem infeasible. Parameter m should be greater" - " than 0.") + raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.") elif m > np.min((np.sum(p), np.sum(q))): - raise ValueError("Problem infeasible. Parameter m should lower or" - " equal than min(|a|_1, |b|_1).") + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) cpt = 0 err = 1 - loge = {'err': []} + loge = {"err": []} - while (err > tol and cpt < numItermax): + while err > tol and cpt < numItermax: Gprev = G0 M_entr = gwgrad_partial(C1, C2, G0) G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m) if cpt % 10 == 0: # to speed up the computations err = np.linalg.norm(G0 - Gprev) if log: - loge['err'].append(err) + loge["err"].append(err) if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}|{:12s}'.format( - 'It.', 'Err', 'Loss') + '\n' + '-' * 31) - print('{:5d}|{:8e}|{:8e}'.format(cpt, err, - gwloss_partial(C1, C2, G0))) + print( + "{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss") + + "\n" + + "-" * 31 + ) + print("{:5d}|{:8e}|{:8e}".format(cpt, err, gwloss_partial(C1, C2, G0))) cpt += 1 if log: - loge['partial_gw_dist'] = gwloss_partial(C1, C2, G0) + loge["partial_gw_dist"] = gwloss_partial(C1, C2, G0) return G0, loge else: return G0 -def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - numItermax=1000, tol=1e-7, log=False, - verbose=False): +def entropic_partial_gromov_wasserstein2( + C1, + C2, + p, + q, + reg, + m=None, + G0=None, + numItermax=1000, + tol=1e-7, + log=False, + verbose=False, +): r""" Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -1219,17 +1282,16 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, warnings.warn( "This function will be deprecated in a near future, please use " "ot.gromov.entropic_partial_gromov_wasserstein2` instead.", - stacklevel=2 + stacklevel=2, ) - partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, - m, G0, numItermax, - tol, True, - verbose) + partial_gw, log_gw = entropic_partial_gromov_wasserstein( + C1, C2, p, q, reg, m, G0, numItermax, tol, True, verbose + ) - log_gw['T'] = partial_gw + log_gw["T"] = partial_gw if log: - return log_gw['partial_gw_dist'], log_gw + return log_gw["partial_gw_dist"], log_gw else: - return log_gw['partial_gw_dist'] + return log_gw["partial_gw_dist"] diff --git a/ot/plot.py b/ot/plot.py index 70995633e..88fbc0856 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -17,10 +17,19 @@ from matplotlib import gridspec -def plot1D_mat(a, b, M, title='', plot_style='yx', - a_label='', b_label='', color_source='b', - color_target='r', coupling_cmap='gray_r'): - r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions. +def plot1D_mat( + a, + b, + M, + title="", + plot_style="yx", + a_label="", + b_label="", + color_source="b", + color_target="r", + coupling_cmap="gray_r", +): + r"""Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions. Creates a subplot with the source distribution :math:`\mathbf{a}` and target distribution :math:`\mathbf{b}`t. @@ -63,12 +72,12 @@ def plot1D_mat(a, b, M, title='', plot_style='yx', .. seealso:: :func:`rescale_for_imshow_plot` """ - assert plot_style in ['yx', 'xy'], "plot_style should be 'yx' or 'xy'" + assert plot_style in ["yx", "xy"], "plot_style should be 'yx' or 'xy'" na, nb = M.shape - gs = gridspec.GridSpec(3, 3, height_ratios=[1, 1, 1], - width_ratios=[1, 1, 1], - hspace=0, wspace=0) + gs = gridspec.GridSpec( + 3, 3, height_ratios=[1, 1, 1], width_ratios=[1, 1, 1], hspace=0, wspace=0 + ) xa = np.arange(na) xb = np.arange(nb) @@ -79,23 +88,28 @@ def _set_ticks_and_spines(ax, empty_ticks=True, visible_spines=False): ax.set_xticks(()) ax.set_yticks(()) - ax.spines['top'].set_visible(visible_spines) - ax.spines['right'].set_visible(visible_spines) - ax.spines['bottom'].set_visible(visible_spines) - ax.spines['left'].set_visible(visible_spines) + ax.spines["top"].set_visible(visible_spines) + ax.spines["right"].set_visible(visible_spines) + ax.spines["bottom"].set_visible(visible_spines) + ax.spines["left"].set_visible(visible_spines) - if plot_style == 'xy': + if plot_style == "xy": # horizontal source on the bottom, flipped vertically ax1 = pl.subplot(gs[2, 1:]) ax1.plot(xa, np.max(a) - a, color=color_source, linewidth=2) - ax1.fill(xa, np.max(a) - a, np.max(a) * np.ones_like(a), - color=color_source, alpha=.5) - ax1.set_title(a_label, y=-.15) + ax1.fill( + xa, + np.max(a) - a, + np.max(a) * np.ones_like(a), + color=color_source, + alpha=0.5, + ) + ax1.set_title(a_label, y=-0.15) # vertical target on the left ax2 = pl.subplot(gs[0:2, 0]) ax2.plot(b, xb, color=color_target, linewidth=2) - ax2.fill(b, xb, color=color_target, alpha=.5) + ax2.fill(b, xb, color=color_target, alpha=0.5) ax2.invert_xaxis() ax2.invert_yaxis() ax2.set_title(b_label) @@ -105,8 +119,7 @@ def _set_ticks_and_spines(ax, empty_ticks=True, visible_spines=False): # coupling matrix in the middle ax3 = pl.subplot(gs[0:2, 1:], sharey=ax2, sharex=ax1) - ax3.imshow(M.T, interpolation='nearest', origin='lower', - cmap=coupling_cmap) + ax3.imshow(M.T, interpolation="nearest", origin="lower", cmap=coupling_cmap) ax3.set_title(title) _set_ticks_and_spines(ax3, empty_ticks=False, visible_spines=True) @@ -117,14 +130,14 @@ def _set_ticks_and_spines(ax, empty_ticks=True, visible_spines=False): # vertical source on the left ax1 = pl.subplot(gs[1:, 0]) ax1.plot(a, xa, color=color_source, linewidth=2) - ax1.fill(a, xa, color=color_source, alpha=.5) + ax1.fill(a, xa, color=color_source, alpha=0.5) ax1.invert_xaxis() ax1.set_title(a_label) # horizontal target on the top ax2 = pl.subplot(gs[0, 1:]) ax2.plot(xb, b, color=color_target, linewidth=2) - ax2.fill(xb, b, color=color_target, alpha=.5) + ax2.fill(xb, b, color=color_target, alpha=0.5) ax2.set_title(b_label) _set_ticks_and_spines(ax1, empty_ticks=True, visible_spines=False) @@ -132,12 +145,17 @@ def _set_ticks_and_spines(ax, empty_ticks=True, visible_spines=False): # coupling matrix in the middle ax3 = pl.subplot(gs[1:, 1:], sharey=ax1, sharex=ax2) - ax3.imshow(M, interpolation='nearest', cmap=coupling_cmap) + ax3.imshow(M, interpolation="nearest", cmap=coupling_cmap) # Set title below matrix plot - ax3.text(0.5, -0.025, title, - ha='center', va='top', - transform=ax3.transAxes, - fontsize='large') + ax3.text( + 0.5, + -0.025, + title, + ha="center", + va="top", + transform=ax3.transAxes, + fontsize="large", + ) _set_ticks_and_spines(ax3, empty_ticks=False, visible_spines=True) pl.subplots_adjust(hspace=0, wspace=0) @@ -190,7 +208,7 @@ def rescale_for_imshow_plot(x, y, n, m=None, a_y=None, b_y=None): def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): - r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values + r"""Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. @@ -211,16 +229,20 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): nothing given) """ - if ('color' not in kwargs) and ('c' not in kwargs): - kwargs['color'] = 'k' + if ("color" not in kwargs) and ("c" not in kwargs): + kwargs["color"] = "k" mx = G.max() - if 'alpha' in kwargs: - scale = kwargs['alpha'] - del kwargs['alpha'] + if "alpha" in kwargs: + scale = kwargs["alpha"] + del kwargs["alpha"] else: scale = 1 for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx * scale, **kwargs) + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx * scale, + **kwargs, + ) diff --git a/ot/regpath.py b/ot/regpath.py index 5e32e4fd1..e64ca7c77 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -107,8 +107,9 @@ def recast_ot_as_lasso(a, b, C): iHb = np.tile(np.arange(dim_b), dim_a) + dim_a j = np.concatenate((jHa, jHb)) i = np.concatenate((iHa, iHb)) - H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)), - shape=(dim_a + dim_b, dim_a * dim_b)) + H = sp.csc_matrix( + (np.ones(dim_a * dim_b * 2), (i, j)), shape=(dim_a + dim_b, dim_a * dim_b) + ) return H, y, c @@ -201,10 +202,12 @@ def recast_semi_relaxed_as_lasso(a, b, C): jHc = np.arange(dim_a * dim_b) iHc = np.tile(np.arange(dim_b), dim_a) - Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)), - shape=(dim_a, dim_a * dim_b)) - Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)), - shape=(dim_b, dim_a * dim_b)) + Hr = sp.csc_matrix( + (np.ones(dim_a * dim_b), (iHr, jHr)), shape=(dim_a, dim_a * dim_b) + ) + Hc = sp.csc_matrix( + (np.ones(dim_a * dim_b), (iHc, jHc)), shape=(dim_b, dim_a * dim_b) + ) return Hr, Hc, c @@ -266,15 +269,17 @@ def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. """ - M = (HtH[:, active_index].dot(phi) - Hty) / \ - (HtH[:, active_index].dot(delta) - c + 1e-16) + M = (HtH[:, active_index].dot(phi) - Hty) / ( + HtH[:, active_index].dot(delta) - c + 1e-16 + ) M[active_index] = 0 M[M > (current_gamma - 1e-10 * current_gamma)] = 0 return np.max(M), np.argmax(M) -def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, - c, active_index, current_gamma): +def semi_relaxed_next_gamma( + phi, delta, phi_u, delta_u, HrHr, Hc, Hra, c, active_index, current_gamma +): r""" This function computes the next value of gamma when a variable is active in the regularization path of semi-relaxed UOT. @@ -341,8 +346,9 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, linear regression. NeurIPS. """ - M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ - (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16) + M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / ( + HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16 + ) M[active_index] = 0 M[M > (current_gamma - 1e-10 * current_gamma)] = 0 return np.max(M), np.argmax(M) @@ -533,8 +539,7 @@ def construct_augmented_H(active_index, m, Hc, HrHr): return H_augmented -def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, - itmax=50000): +def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, itmax=50000): r"""This function gives the regularization path of l2-penalized UOT problem The problem to optimize is the Lasso reformulation of the l2-penalized UOT: @@ -621,15 +626,14 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, id_pop = -1 while n_iter < itmax and gamma_list[-1] > reg: - H_inv = complement_schur(H_inv, add_col, 2., id_pop) + H_inv = complement_schur(H_inv, add_col, 2.0, id_pop) current_gamma = gamma_list[-1] # compute the intercept and slope of solutions in current iteration # t = phi - gamma * delta phi = H_inv.dot(Hty[active_index]) delta = H_inv.dot(c[active_index]) - gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index, - current_gamma) + gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma) # compute the next lambda when removing a point from the active set alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) @@ -643,7 +647,7 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, # compute the solution of current segment tA = phi - gamma * delta - sol = np.zeros((n * m, )) + sol = np.zeros((n * m,)) sol[active_index] = tA if id_pop != -1: @@ -658,23 +662,23 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, n_iter += 1 if itmax <= n_iter: - print('maximum iteration has been reached !') + print("maximum iteration has been reached !") # correct the last solution and gamma if len(t_list) > 1: - t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * - (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_final = t_list[-2] + (t_list[-1] - t_list[-2]) * (reg - gamma_list[-2]) / ( + gamma_list[-1] - gamma_list[-2] + ) t_list[-1] = t_final gamma_list[-1] = reg else: gamma_list[-1] = reg - print('Regularization path does not exist !') + print("Regularization path does not exist !") return t_list[-1], t_list, gamma_list -def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, - itmax=50000): +def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, itmax=50000): r"""This function gives the regularization path of semi-relaxed l2-UOT problem. @@ -771,7 +775,7 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, if n_iter == 1: H_inv = np.linalg.inv(augmented_H0) else: - H_inv = complement_schur(H_inv, add_col, 1., id_pop + m) + H_inv = complement_schur(H_inv, add_col, 1.0, id_pop + m) # compute the intercept and slope of solutions in current iteration augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index]))) augmented_delta = H_inv[:, m:].dot(c[active_index]) @@ -779,9 +783,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, delta = augmented_delta[m:] phi_u = augmented_phi[0:m] delta_u = augmented_delta[0:m] - gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, - HrHr, Hc, Hra, c, active_index, - current_gamma) + gamma, ik = semi_relaxed_next_gamma( + phi, delta, phi_u, delta_u, HrHr, Hc, Hra, c, active_index, current_gamma + ) # compute the next lambda when removing a point from the active set alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) @@ -795,15 +799,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, # compute the solution of current segment tA = phi - gamma * delta - sol = np.zeros((n * m, )) + sol = np.zeros((n * m,)) sol[active_index] = tA if id_pop != -1: active_index.pop(id_pop) add_col = None else: active_index.append(ik) - add_col = np.concatenate((Hc.toarray()[:, ik], - HrHr.toarray()[active_index[:-1], ik])) + add_col = np.concatenate( + (Hc.toarray()[:, ik], HrHr.toarray()[active_index[:-1], ik]) + ) add_col = add_col[:, np.newaxis] gamma_list.append(gamma) @@ -812,23 +817,25 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, n_iter += 1 if itmax <= n_iter: - print('maximum iteration has been reached !') + print("maximum iteration has been reached !") # correct the last solution and gamma if len(t_list) > 1: - t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * - (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_final = t_list[-2] + (t_list[-1] - t_list[-2]) * (reg - gamma_list[-2]) / ( + gamma_list[-1] - gamma_list[-2] + ) t_list[-1] = t_final gamma_list[-1] = reg else: gamma_list[-1] = reg - print('Regularization path does not exist !') + print("Regularization path does not exist !") return t_list[-1], t_list, gamma_list -def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, - semi_relaxed=False, itmax=50000): +def regularization_path( + a: np.array, b: np.array, C: np.array, reg=1e-4, semi_relaxed=False, itmax=50000 +): r"""This function provides all the solutions of the regularization path \ of the l2-UOT problem :ref:`[41] `. @@ -899,16 +906,14 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, linear regression. NeurIPS. """ if semi_relaxed: - t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, - itmax=itmax) + t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, itmax=itmax) else: - t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg, - itmax=itmax) + t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg, itmax=itmax) return t, t_list, gamma_list def compute_transport_plan(gamma, gamma_list, Pi_list): - r""" Given the regularization path, this function computes the transport + r"""Given the regularization path, this function computes the transport plan for any value of gamma thanks to the piecewise linearity of the path. .. math:: @@ -973,6 +978,5 @@ def compute_transport_plan(gamma, gamma_list, Pi_list): gamma_k1 = gamma_list[idx + 1] pi_k0 = Pi_list[idx] pi_k1 = Pi_list[idx + 1] - Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \ - / (gamma_k1 - gamma_k0) + Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) / (gamma_k1 - gamma_k0) return Pi diff --git a/ot/sliced.py b/ot/sliced.py index d5bb0ee08..cd095ed6d 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -9,7 +9,6 @@ # # License: MIT License - import numpy as np from .backend import get_backend, NumpyBackend from .utils import list_to_array, get_coordinate_circle @@ -51,7 +50,7 @@ def get_random_projections(d, n_projections, seed=None, backend=None, type_as=No else: nx = backend - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": projections = seed.randn(d, n_projections) else: if seed is not None: @@ -62,8 +61,17 @@ def get_random_projections(d, n_projections, seed=None, backend=None, type_as=No return projections -def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, - projections=None, seed=None, log=False): +def sliced_wasserstein_distance( + X_s, + X_t, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + log=False, +): r""" Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance @@ -135,8 +143,10 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, if X_s.shape[1] != X_t.shape[1]: raise ValueError( - "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], - X_t.shape[1])) + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) if a is None: a = nx.full(n, 1 / n, type_as=X_s) @@ -146,7 +156,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, d = X_s.shape[1] if projections is None: - projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) else: n_projections = projections.shape[1] @@ -161,8 +173,17 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, return res -def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, - projections=None, seed=None, log=False): +def max_sliced_wasserstein_distance( + X_s, + X_t, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + log=False, +): r""" Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance @@ -235,8 +256,10 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if X_s.shape[1] != X_t.shape[1]: raise ValueError( - "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], - X_t.shape[1])) + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) if a is None: a = nx.full(n, 1 / n, type_as=X_s) @@ -246,7 +269,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, d = X_s.shape[1] if projections is None: - projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) X_s_projections = nx.dot(X_s, projections) X_t_projections = nx.dot(X_t, projections) @@ -259,8 +284,17 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, return res -def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - p=2, projections=None, seed=None, log=False): +def sliced_wasserstein_sphere( + X_s, + X_t, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + log=False, +): r""" Compute the spherical sliced-Wasserstein discrepancy. @@ -323,16 +357,18 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, if X_s.shape[1] != X_t.shape[1]: raise ValueError( - "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], - X_t.shape[1])) - if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_s is not on the sphere.") - if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_t is not on the sphere.") if projections is None: # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": Z = seed.randn(n_projections, d, 2) else: if seed is not None: @@ -353,10 +389,16 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) # Get coordinates on [0,1[ - Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) - Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) - - projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + Xps_coords = nx.reshape( + get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n) + ) + Xpt_coords = nx.reshape( + get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m) + ) + + projected_emd = wasserstein_circle( + Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p + ) res = nx.mean(projected_emd) ** (1 / p) if log: @@ -415,11 +457,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log n, d = X_s.shape - if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)): raise ValueError("X_s is not on the sphere.") # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + if isinstance(seed, np.random.RandomState) and str(nx) == "numpy": Z = seed.randn(n_projections, d, 2) else: if seed is not None: @@ -436,7 +478,9 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) # Get coordinates on [0,1[ - Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xps_coords = nx.reshape( + get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n) + ) projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) res = nx.mean(projected_emd) ** (1 / 2) diff --git a/ot/smooth.py b/ot/smooth.py index 331cfc04e..266fddd53 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -1,26 +1,26 @@ -#Copyright (c) 2018, Mathieu Blondel -#All rights reserved. +# Copyright (c) 2018, Mathieu Blondel +# All rights reserved. # -#Redistribution and use in source and binary forms, with or without -#modification, are permitted provided that the following conditions are met: +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: # -#1. Redistributions of source code must retain the above copyright notice, this -#list of conditions and the following disclaimer. +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. # -#2. Redistributions in binary form must reproduce the above copyright notice, -#this list of conditions and the following disclaimer in the documentation and/or -#other materials provided with the distribution. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation and/or +# other materials provided with the distribution. # -#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -#ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -#WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -#IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -#INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT -#NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -#OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -#LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -#OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -#THE POSSIBILITY OF SUCH DAMAGE. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. # Author: Mathieu Blondel # Remi Flamary @@ -63,7 +63,7 @@ def projection_simplex(V, z=1, axis=None): - r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` + r"""Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` .. math:: P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2 @@ -104,10 +104,10 @@ def projection_simplex(V, z=1, axis=None): class Regularization(object): r"""Base class for Regularization objects - Notes - ----- - This class is not intended for direct use but as apparent for true - regularization implementation. + Notes + ----- + This class is not intended for direct use but as apparent for true + regularization implementation. """ def __init__(self, gamma=1.0): @@ -186,7 +186,7 @@ def Omega(T): class NegEntropy(Regularization): - """ NegEntropy regularization """ + """NegEntropy regularization""" def delta_Omega(self, X): G = np.exp(X / self.gamma - 1) @@ -206,11 +206,11 @@ def Omega(self, T): class SquaredL2(Regularization): - """ Squared L2 regularization """ + """Squared L2 regularization""" def delta_Omega(self, X): max_X = np.maximum(X, 0) - val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + val = np.sum(max_X**2, axis=0) / (2 * self.gamma) G = max_X / self.gamma return val, G @@ -221,11 +221,11 @@ def max_Omega(self, X, b): return val, G def Omega(self, T): - return 0.5 * self.gamma * np.sum(T ** 2) + return 0.5 * self.gamma * np.sum(T**2) class SparsityConstrained(Regularization): - """ Squared L2 regularization with sparsity constraints """ + """Squared L2 regularization with sparsity constraints""" def __init__(self, max_nz, gamma=1.0): self.max_nz = max_nz @@ -233,28 +233,28 @@ def __init__(self, max_nz, gamma=1.0): def delta_Omega(self, X): # For each column of X, find entries that are not among the top max_nz. - non_top_indices = np.argpartition( - -X, self.max_nz, axis=0)[self.max_nz:] + non_top_indices = np.argpartition(-X, self.max_nz, axis=0)[self.max_nz :] # Set these entries to -inf. if X.ndim == 1: X[non_top_indices] = 0.0 else: X[non_top_indices, np.arange(X.shape[1])] = 0.0 max_X = np.maximum(X, 0) - val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + val = np.sum(max_X**2, axis=0) / (2 * self.gamma) G = max_X / self.gamma return val, G def max_Omega(self, X, b): # Project the scaled X onto the simplex with sparsity constraint. G = ot.utils.projection_sparse_simplex( - X / (b * self.gamma), self.max_nz, axis=0) + X / (b * self.gamma), self.max_nz, axis=0 + ) val = np.sum(X * G, axis=0) val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) return val, G def Omega(self, T): - return 0.5 * self.gamma * np.sum(T ** 2) + return 0.5 * self.gamma * np.sum(T**2) def dual_obj_grad(alpha, beta, a, b, C, regul): @@ -301,8 +301,9 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): return obj, grad_alpha, grad_beta -def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, - verbose=False): +def solve_dual( + a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, verbose=False +): """ Solve the "smoothed" dual objective. @@ -331,8 +332,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, def _func(params): # Unpack alpha and beta. - alpha = params[:len(a)] - beta = params[len(a):] + alpha = params[: len(a)] + beta = params[len(a) :] obj, grad_alpha, grad_beta = dual_obj_grad(alpha, beta, a, b, C, regul) @@ -348,11 +349,17 @@ def _func(params): beta_init = np.zeros(len(b)) params_init = np.concatenate((alpha_init, beta_init)) - res = minimize(_func, params_init, method=method, jac=True, - tol=tol, options=dict(maxiter=max_iter, disp=verbose)) + res = minimize( + _func, + params_init, + method=method, + jac=True, + tol=tol, + options=dict(maxiter=max_iter, disp=verbose), + ) - alpha = res.x[:len(a)] - beta = res.x[len(a):] + alpha = res.x[: len(a)] + beta = res.x[len(a) :] return alpha, beta, res @@ -396,8 +403,9 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): return obj, grad -def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, - verbose=False): +def solve_semi_dual( + a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, verbose=False +): """ Solve the "smoothed" semi-dual objective. @@ -430,8 +438,14 @@ def _func(alpha): alpha_init = np.zeros(len(a)) - res = minimize(_func, alpha_init, method=method, jac=True, - tol=tol, options=dict(maxiter=max_iter, disp=verbose)) + res = minimize( + _func, + alpha_init, + method=method, + jac=True, + tol=tol, + options=dict(maxiter=max_iter, disp=verbose), + ) return res.x, res @@ -483,9 +497,19 @@ def get_plan_from_semi_dual(alpha, b, C, regul): return regul.max_Omega(X, b)[1] * b -def smooth_ot_dual(a, b, M, reg, reg_type='l2', - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False, max_nz=None): +def smooth_ot_dual( + a, + b, + M, + reg, + reg_type="l2", + method="L-BFGS-B", + stopThr=1e-9, + numItermax=500, + verbose=False, + log=False, + max_nz=None, +): r""" Solve the regularized OT problem in the dual and return the OT matrix @@ -568,39 +592,53 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', nx = get_backend(a, b, M) - if reg_type.lower() in ['l2', 'squaredl2']: + if reg_type.lower() in ["l2", "squaredl2"]: regul = SquaredL2(gamma=reg) - elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + elif reg_type.lower() in ["entropic", "negentropy", "kl"]: regul = NegEntropy(gamma=reg) - elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + elif reg_type.lower() in ["sparsity_constrained", "sparsity-constrained"]: if not isinstance(max_nz, int): - raise ValueError( - f'max_nz {max_nz} must be an integer') + raise ValueError(f"max_nz {max_nz} must be an integer") regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: - raise NotImplementedError('Unknown regularization') + raise NotImplementedError("Unknown regularization") a0, b0, M0 = a, b, M # convert to humpy a, b, M = nx.to_numpy(a, b, M) # solve dual - alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, - tol=stopThr, verbose=verbose) + alpha, beta, res = solve_dual( + a, b, M, regul, max_iter=numItermax, tol=stopThr, verbose=verbose + ) # reconstruct transport matrix G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0) if log: - log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} + log = { + "alpha": nx.from_numpy(alpha, type_as=a0), + "beta": nx.from_numpy(beta, type_as=b0), + "res": res, + } return G, log else: return G -def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): +def smooth_ot_semi_dual( + a, + b, + M, + reg, + reg_type="l2", + max_nz=None, + method="L-BFGS-B", + stopThr=1e-9, + numItermax=500, + verbose=False, + log=False, +): r""" Solve the regularized OT problem in the semi-dual and return the OT matrix @@ -682,27 +720,27 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, ot.optim.cg : General regularized OT """ - if reg_type.lower() in ['l2', 'squaredl2']: + if reg_type.lower() in ["l2", "squaredl2"]: regul = SquaredL2(gamma=reg) - elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: + elif reg_type.lower() in ["entropic", "negentropy", "kl"]: regul = NegEntropy(gamma=reg) - elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + elif reg_type.lower() in ["sparsity_constrained", "sparsity-constrained"]: if not isinstance(max_nz, int): - raise ValueError( - f'max_nz {max_nz} must be an integer') + raise ValueError(f"max_nz {max_nz} must be an integer") regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: - raise NotImplementedError('Unknown regularization') + raise NotImplementedError("Unknown regularization") # solve dual - alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, - tol=stopThr, verbose=verbose) + alpha, res = solve_semi_dual( + a, b, M, regul, max_iter=numItermax, tol=stopThr, verbose=verbose + ) # reconstruct transport matrix G = get_plan_from_semi_dual(alpha, b, M, regul) if log: - log = {'alpha': alpha, 'res': res} + log = {"alpha": alpha, "res": res} return G, log else: return G diff --git a/ot/solvers.py b/ot/solvers.py index 37b1b93df..ec56d1330 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -13,24 +13,54 @@ from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss from .smooth import smooth_ot_dual -from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, - entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, - semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_gromov_wasserstein2, - partial_gromov_wasserstein2, - entropic_partial_gromov_wasserstein2) +from .gromov import ( + gromov_wasserstein2, + fused_gromov_wasserstein2, + entropic_gromov_wasserstein2, + entropic_fused_gromov_wasserstein2, + semirelaxed_gromov_wasserstein2, + semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_gromov_wasserstein2, + partial_gromov_wasserstein2, + entropic_partial_gromov_wasserstein2, +) from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport from .lowrank import lowrank_sinkhorn from .optim import cg -lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] - - -def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False, grad='autodiff'): +lst_method_lazy = [ + "1d", + "gaussian", + "lowrank", + "factored", + "geomloss", + "geomloss_auto", + "geomloss_tensorized", + "geomloss_online", + "geomloss_multiscale", +] + + +def solve( + M, + a=None, + b=None, + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + method=None, + n_threads=1, + max_iter=None, + plan_init=None, + potentials_init=None, + tol=None, + verbose=False, + grad="autodiff", +): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem @@ -273,38 +303,52 @@ def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, status = None if reg == 0: # exact OT - if unbalanced is None: # Exact balanced OT - # default values for EMD solver if max_iter is None: max_iter = 1000000 - value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + value_linear, log = emd2( + a, + b, + M, + numItermax=max_iter, + log=True, + return_matrix=True, + numThreads=n_threads, + ) value = value_linear - potentials = (log['u'], log['v']) - plan = log['G'] - status = log["warning"] if log["warning"] is not None else 'Converged' - - elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + potentials = (log["u"], log["v"]) + plan = log["G"] + status = log["warning"] if log["warning"] is not None else "Converged" + elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT # default values for exact unbalanced OT if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 - plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, c=c, reg=reg, - div=unbalanced_type, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose, G0=plan_init) - - value_linear = log['cost'] - value = log['total_cost'] + plan, log = mm_unbalanced( + a, + b, + M, + reg_m=unbalanced, + c=c, + reg=reg, + div=unbalanced_type, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + G0=plan_init, + ) - elif unbalanced_type.lower() == 'tv': + value_linear = log["cost"] + value = log["total_cost"] + elif unbalanced_type.lower() == "tv": if max_iter is None: max_iter = 1000 if tol is None: @@ -313,23 +357,34 @@ def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, reg_type = reg_type.lower() plan, log = lbfgsb_unbalanced( - a, b, M, reg=reg, reg_m=unbalanced, c=c, reg_div=reg_type, - regm_div=unbalanced_type, numItermax=max_iter, - stopThr=tol, verbose=verbose, log=True, G0=plan_init + a, + b, + M, + reg=reg, + reg_m=unbalanced, + c=c, + reg_div=reg_type, + regm_div=unbalanced_type, + numItermax=max_iter, + stopThr=tol, + verbose=verbose, + log=True, + G0=plan_init, ) - value_linear = log['cost'] - value = log['total_cost'] + value_linear = log["cost"] + value = log["total_cost"] else: - raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + raise ( + NotImplementedError( + 'Unknown unbalanced_type="{}"'.format(unbalanced_type) + ) + ) else: # regularized OT - if unbalanced is None: # Balanced regularized OT - if isinstance(reg_type, tuple): # general solver - f, df = reg_type if max_iter is None: @@ -337,16 +392,26 @@ def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, if tol is None: tol = 1e-9 - plan, log = cg(a, b, M, reg=reg, f=f, df=df, numItermax=max_iter, - stopThr=tol, log=True, verbose=verbose, G0=plan_init) + plan, log = cg( + a, + b, + M, + reg=reg, + f=f, + df=df, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + G0=plan_init, + ) value_linear = nx.sum(M * plan) - value = log['loss'][-1] - potentials = (log['u'], log['v']) - - elif reg_type.lower() in ['entropy', 'kl']: + value = log["loss"][-1] + potentials = (log["u"], log["v"]) - if grad == 'envelope': # if envelope then detach the input + elif reg_type.lower() in ["entropy", "kl"]: + if grad == "envelope": # if envelope then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -356,64 +421,103 @@ def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, if tol is None: tol = 1e-9 - plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose) + plan, log = sinkhorn_log( + a, + b, + M, + reg=reg, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + ) value_linear = nx.sum(M * plan) - if reg_type.lower() == 'entropy': + if reg_type.lower() == "entropy": value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: - value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) - - potentials = (log['log_u'], log['log_v']) - - if grad == 'envelope': # set the gradient at convergence - - value = nx.set_gradients(value, (M0, a0, b0), - (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) - - elif reg_type.lower() == 'l2': - + value = value_linear + reg * nx.kl_div( + plan, a[:, None] * b[None, :] + ) + + potentials = (log["log_u"], log["log_v"]) + + if grad == "envelope": # set the gradient at convergence + value = nx.set_gradients( + value, + (M0, a0, b0), + ( + plan, + reg * (potentials[0] - potentials[0].mean()), + reg * (potentials[1] - potentials[1].mean()), + ), + ) + + elif reg_type.lower() == "l2": if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 - plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) + plan, log = smooth_ot_dual( + a, + b, + M, + reg=reg, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + ) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan**2) - potentials = (log['alpha'], log['beta']) + potentials = (log["alpha"], log["beta"]) else: - raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + raise ( + NotImplementedError( + 'Not implemented reg_type="{}"'.format(reg_type) + ) + ) else: # unbalanced AND regularized OT - - if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': - + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_knopp_unbalanced( - a, b, M, reg=reg, reg_m=unbalanced, - method=method, reg_type=reg_type, c=c, + a, + b, + M, + reg=reg, + reg_m=unbalanced, + method=method, + reg_type=reg_type, + c=c, warmstart=potentials_init, - numItermax=max_iter, stopThr=tol, - verbose=verbose, log=True + numItermax=max_iter, + stopThr=tol, + verbose=verbose, + log=True, ) - value_linear = log['cost'] - value = log['total_cost'] - - potentials = (log['logu'], log['logv']) + value_linear = log["cost"] + value = log["total_cost"] - elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: + potentials = (log["logu"], log["logv"]) + elif ( + isinstance(reg_type, tuple) + or reg_type.lower() in ["kl", "l2", "entropy"] + ) and unbalanced_type.lower() in ["kl", "l2", "tv"]: if max_iter is None: max_iter = 1000 if tol is None: @@ -422,29 +526,66 @@ def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, reg_type = reg_type.lower() plan, log = lbfgsb_unbalanced( - a, b, M, reg=reg, reg_m=unbalanced, c=c, reg_div=reg_type, - regm_div=unbalanced_type, numItermax=max_iter, - stopThr=tol, verbose=verbose, log=True, G0=plan_init + a, + b, + M, + reg=reg, + reg_m=unbalanced, + c=c, + reg_div=reg_type, + regm_div=unbalanced_type, + numItermax=max_iter, + stopThr=tol, + verbose=verbose, + log=True, + G0=plan_init, ) - value_linear = log['cost'] - value = log['total_cost'] + value_linear = log["cost"] + value = log["total_cost"] else: - raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + raise ( + NotImplementedError( + 'Not implemented reg_type="{}" and unbalanced_type="{}"'.format( + reg_type, unbalanced_type + ) + ) + ) - res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, plan=plan, status=status, backend=nx) + res = OTResult( + potentials=potentials, + value=value, + value_linear=value_linear, + plan=plan, + status=status, + backend=nx, + ) return res -def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, - alpha=0.5, reg=None, - reg_type="entropy", unbalanced=None, unbalanced_type='KL', - n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, - verbose=False): - r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object +def solve_gromov( + Ca, + Cb, + M=None, + a=None, + b=None, + loss="L2", + symmetric=None, + alpha=0.5, + reg=None, + reg_type="entropy", + unbalanced=None, + unbalanced_type="KL", + n_threads=1, + method=None, + max_iter=None, + plan_init=None, + tol=None, + verbose=False, +): + r"""Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object The function solves the following optimization problem: @@ -686,100 +827,154 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, status = None log = None - loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} + loss_dict = {"l2": "square_loss", "kl": "kl_loss"} if loss.lower() not in loss_dict.keys(): raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) loss_fun = loss_dict[loss.lower()] if reg is None or reg == 0: # exact OT - - if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT - + if unbalanced is None and unbalanced_type.lower() not in [ + "semirelaxed" + ]: # Exact balanced OT if M is None or alpha == 1: # Gromov-Wasserstein problem - # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 - value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = gromov_wasserstein2( + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 - plan = log['T'] - potentials = (log['u'], log['v']) + plan = log["T"] + potentials = (log["u"], log["v"]) elif alpha == 0: # Wasserstein problem - # default values for EMD solver if max_iter is None: max_iter = 1000000 - value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + value_linear, log = emd2( + a, + b, + M, + numItermax=max_iter, + log=True, + return_matrix=True, + numThreads=n_threads, + ) value = value_linear - potentials = (log['u'], log['v']) - plan = log['G'] - status = log["warning"] if log["warning"] is not None else 'Converged' + potentials = (log["u"], log["v"]) + plan = log["G"] + status = log["warning"] if log["warning"] is not None else "Converged" value_quad = 0 else: # Fused Gromov-Wasserstein problem - # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 - value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) - - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] - potentials = (log['u'], log['v']) + value, log = fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) - elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + potentials = (log["u"], log["v"]) + elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem - # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 - value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = semirelaxed_gromov_wasserstein2( + Ca, + Cb, + a, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 - plan = log['T'] + plan = log["T"] # potentials = (log['u'], log['v']) TODO else: # Semi relaxed Fused Gromov-Wasserstein problem - # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 - value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = semirelaxed_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] # potentials = (log['u'], log['v']) TODO - elif unbalanced_type.lower() in ['partial']: # Partial OT - + elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError('Partial GW mass given in reg is too large')) + raise (ValueError("Partial GW mass given in reg is too large")) # default values for solver if max_iter is None: @@ -787,118 +982,200 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-7 - value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, loss_fun=loss_fun, log=True, numItermax=max_iter, G0=plan_init, tol=tol, symmetric=symmetric, verbose=verbose) + value, log = partial_gromov_wasserstein2( + Ca, + Cb, + a, + b, + m=unbalanced, + loss_fun=loss_fun, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) value_quad = value - plan = log['T'] + plan = log["T"] # potentials = (log['u'], log['v']) TODO else: # partial FGW + raise (NotImplementedError("Partial FGW not implemented yet")) - raise (NotImplementedError('Partial FGW not implemented yet')) - - elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT - + elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) else: - raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + raise ( + NotImplementedError( + 'Unknown unbalanced_type="{}"'.format(unbalanced_type) + ) + ) else: # regularized OT - - if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT - - if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem - + if unbalanced is None and unbalanced_type.lower() not in [ + "semirelaxed" + ]: # Balanced regularized OT + if reg_type.lower() in ["entropy"] and ( + M is None or alpha == 1 + ): # Entropic Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: - method = 'PGD' - - value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + method = "PGD" + + value_quad, log = entropic_gromov_wasserstein2( + Ca, + Cb, + a, + b, + epsilon=reg, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + solver=method, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) - plan = log['T'] + plan = log["T"] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) # potentials = (log['log_u'], log['log_v']) #TODO - elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem - + elif ( + reg_type.lower() in ["entropy"] and M is not None and alpha == 0 + ): # Entropic Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 - plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose) + plan, log = sinkhorn_log( + a, + b, + M, + reg=reg, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + ) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - potentials = (log['log_u'], log['log_v']) - - elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem + potentials = (log["log_u"], log["log_v"]) + elif ( + reg_type.lower() in ["entropy"] and M is not None + ): # Entropic Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: - method = 'PGD' - - value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + method = "PGD" + + value_noreg, log = entropic_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + solver=method, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] # potentials = (log['u'], log['v']) value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: - raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) - - elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT - - if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem + raise ( + NotImplementedError( + 'Not implemented reg_type="{}"'.format(reg_type) + ) + ) + elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT + if reg_type.lower() in ["entropy"] and ( + M is None or alpha == 1 + ): # Entropic Semi-relaxed Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 - value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) + value_quad, log = entropic_semirelaxed_gromov_wasserstein2( + Ca, + Cb, + a, + epsilon=reg, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol=tol, + verbose=verbose, + ) - plan = log['T'] + plan = log["T"] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # Entropic Semi-relaxed FGW problem - # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol=tol, + verbose=verbose, + ) - value_linear = log['lin_loss'] - value_quad = log['quad_loss'] - plan = log['T'] + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - elif unbalanced_type.lower() in ['partial']: # Partial OT - + elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError('Partial GW mass given in reg is too large')) + raise (ValueError("Partial GW mass given in reg is too large")) # default values for solver if max_iter is None: @@ -906,31 +1183,77 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-7 - value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, loss_fun=loss_fun, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, symmetric=symmetric, verbose=verbose) + value_quad, log = entropic_partial_gromov_wasserstein2( + Ca, + Cb, + a, + b, + reg=reg, + loss_fun=loss_fun, + m=unbalanced, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) value_quad = value - plan = log['T'] + plan = log["T"] # potentials = (log['u'], log['v']) TODO else: # partial FGW - - raise (NotImplementedError('Partial entropic FGW not implemented yet')) + raise (NotImplementedError("Partial entropic FGW not implemented yet")) else: # unbalanced AND regularized OT + raise ( + NotImplementedError( + 'Not implemented reg_type="{}" and unbalanced_type="{}"'.format( + reg_type, unbalanced_type + ) + ) + ) - raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) - - res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) + res = OTResult( + potentials=potentials, + value=value, + value_linear=value_linear, + value_quad=value_quad, + plan=plan, + status=status, + backend=nx, + log=log, + ) return res -def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=None, reg_type="KL", - unbalanced=None, - unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, - potentials_init=None, X_init=None, tol=None, verbose=False, - grad='autodiff'): +def solve_sample( + X_a, + X_b, + a=None, + b=None, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + batch_size=None, + method=None, + n_threads=1, + max_iter=None, + plan_init=None, + rank=100, + scaling=0.95, + potentials_init=None, + X_init=None, + tol=None, + verbose=False, + grad="autodiff", +): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -1193,7 +1516,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): This method computes the Gaussian Bures-Wasserstein distance between two - Gaussian distributions estimated from teh empirical distributions + Gaussian distributions estimated from the empirical distributions .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -1273,16 +1596,31 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non lazy = True if not lazy: # default non lazy solver calls ot.solve - # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) - res = solve(M, a, b, reg, c, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) + res = solve( + M, + a, + b, + reg, + c, + reg_type, + unbalanced, + unbalanced_type, + method, + n_threads, + max_iter, + plan_init, + potentials_init, + tol, + verbose, + grad, + ) return res else: - # Detect backend nx = get_backend(X_a, X_b, a, b) @@ -1295,35 +1633,41 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non status = None log = None - method = method.lower() if method is not None else '' + method = method.lower() if method is not None else "" - if method == '1d': # Wasserstein 1d (parallel on all dimensions) - if metric == 'sqeuclidean': + if method == "1d": # Wasserstein 1d (parallel on all dimensions) + if metric == "sqeuclidean": p = 2 - elif metric in ['euclidean', 'cityblock']: + elif metric in ["euclidean", "cityblock"]: p = 1 else: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + raise ( + NotImplementedError('Not implemented metric="{}"'.format(metric)) + ) value = wasserstein_1d(X_a, X_b, a, b, p=p) value_linear = value - elif method == 'gaussian': # Gaussian Bures-Wasserstein - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + elif method == "gaussian": # Gaussian Bures-Wasserstein + if metric.lower() not in ["sqeuclidean"]: + raise ( + NotImplementedError('Not implemented metric="{}"'.format(metric)) + ) if reg is None: reg = 1e-6 - value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) + value, log = empirical_bures_wasserstein_distance( + X_a, X_b, reg=reg, log=True + ) value = value**2 # return the value (squared bures distance) value_linear = value # return the value - elif method == 'factored': # Factored OT - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + elif method == "factored": # Factored OT + if metric.lower() not in ["sqeuclidean"]: + raise ( + NotImplementedError('Not implemented metric="{}"'.format(metric)) + ) if max_iter is None: max_iter = 100 @@ -1332,19 +1676,29 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non if reg is None: reg = 0 - Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) - log['X'] = X + Q, R, X, log = factored_optimal_transport( + X_a, + X_b, + reg=reg, + r=rank, + log=True, + stopThr=tol, + numItermax=max_iter, + verbose=verbose, + ) + log["X"] = X - value_linear = log['costa'] + log['costb'] + value_linear = log["costa"] + log["costb"] value = value_linear # TODO add reg term - lazy_plan = log['lazy_plan'] + lazy_plan = log["lazy_plan"] if not lazy0: # store plan if not lazy plan = lazy_plan[:] elif method == "lowrank": - - if not metric.lower() in ['sqeuclidean']: - raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + if metric.lower() not in ["sqeuclidean"]: + raise ( + NotImplementedError('Not implemented metric="{}"'.format(metric)) + ) if max_iter is None: max_iter = 2000 @@ -1353,46 +1707,77 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non if reg is None: reg = 0 - Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) - value = log['value'] - value_linear = log['value_linear'] - lazy_plan = log['lazy_plan'] + Q, R, g, log = lowrank_sinkhorn( + X_a, + X_b, + rank=rank, + reg=reg, + a=a, + b=b, + numItermax=max_iter, + stopThr=tol, + log=True, + ) + value = log["value"] + value_linear = log["value_linear"] + lazy_plan = log["lazy_plan"] if not lazy0: # store plan if not lazy plan = lazy_plan[:] - elif method.startswith('geomloss'): # Geomloss solver for entropic OT - - split_method = method.split('_') + elif method.startswith("geomloss"): # Geomloss solver for entropic OT + split_method = method.split("_") if len(split_method) == 2: backend = split_method[1] else: if lazy0 is None: - backend = 'auto' + backend = "auto" elif lazy0: - backend = 'online' + backend = "online" else: - backend = 'tensorized' - - value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) + backend = "tensorized" + + value, log = empirical_sinkhorn2_geomloss( + X_a, + X_b, + reg=reg, + a=a, + b=b, + metric=metric, + log=True, + verbose=verbose, + scaling=scaling, + backend=backend, + ) - lazy_plan = log['lazy_plan'] + lazy_plan = log["lazy_plan"] if not lazy0: # store plan if not lazy plan = lazy_plan[:] # return scaled potentials (to be consistent with other solvers) - potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) + potentials = ( + log["f"] / (lazy_plan.blur**2), + log["g"] / (lazy_plan.blur**2), + ) elif reg is None or reg == 0: # exact OT - if unbalanced is None: # balanced EMD solver not available for lazy - raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) + raise ( + NotImplementedError( + "Exact OT solver with lazy=True not implemented" + ) + ) else: - raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) + raise ( + NotImplementedError( + 'Non regularized solver with unbalanced_type="{}" not implemented'.format( + unbalanced_type + ) + ) + ) else: if unbalanced is None: - if max_iter is None: max_iter = 1000 if tol is None: @@ -1400,15 +1785,41 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=Non if batch_size is None: batch_size = 100 - value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + value_linear, log = empirical_sinkhorn2( + X_a, + X_b, + reg, + a, + b, + metric=metric, + numIterMax=max_iter, + stopThr=tol, + isLazy=True, + batchSize=batch_size, + verbose=verbose, + log=True, + ) # compute potentials potentials = (log["u"], log["v"]) - lazy_plan = log['lazy_plan'] + lazy_plan = log["lazy_plan"] else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + raise ( + NotImplementedError( + 'Not implemented unbalanced_type="{}" with regularization'.format( + unbalanced_type + ) + ) + ) - res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, - value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) + res = OTResult( + potentials=potentials, + value=value, + lazy_plan=lazy_plan, + value_linear=value_linear, + plan=plan, + status=status, + backend=nx, + log=log, + ) return res diff --git a/ot/stochastic.py b/ot/stochastic.py index fec512ccc..da0639f73 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -19,7 +19,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): - r''' + r""" Compute the coordinate gradient update for regularized discrete distributions for :math:`(i, :)` The function computes the gradient of the semi dual problem: @@ -62,7 +62,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): References ---------- .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). - ''' + """ r = M[i, :] - beta exp_beta = np.exp(-r / reg) * b khi = exp_beta / (np.sum(exp_beta)) @@ -126,7 +126,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state """ if lr is None: - lr = 1. / max(a / reg) + lr = 1.0 / max(a / reg) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -135,17 +135,18 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state rng = check_random_state(random_state) for _ in range(numItermax): i = rng.randint(n_source) - cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg, - cur_beta, i) - sum_stored_gradient += (cur_coord_grad - stored_gradient[i]) + cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg, cur_beta, i) + sum_stored_gradient += cur_coord_grad - stored_gradient[i] stored_gradient[i] = cur_coord_grad - cur_beta += lr * (1. / n_source) * sum_stored_gradient + cur_beta += lr * (1.0 / n_source) * sum_stored_gradient return cur_beta -def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, random_state=None): - r''' - Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem +def averaged_sgd_entropic_transport( + a, b, M, reg, numItermax=300000, lr=None, random_state=None +): + r""" + Compute the ASGD algorithm to solve the regularized semi continuous measures optimal transport max problem The function solves the following optimization problem: @@ -194,10 +195,10 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, ra References ---------- .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). - ''' + """ if lr is None: - lr = 1. / max(a / reg) + lr = 1.0 / max(a / reg) n_source = np.shape(M)[0] n_target = np.shape(M)[1] cur_beta = np.zeros(n_target) @@ -208,12 +209,12 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, ra i = rng.randint(n_source) cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i) cur_beta += (lr / np.sqrt(k)) * cur_coord_grad - ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta + ave_beta = (1.0 / k) * cur_beta + (1 - 1.0 / k) * ave_beta return ave_beta def c_transform_entropic(b, M, reg, beta): - r''' + r""" The goal is to recover u from the c-transform. The function computes the c-transform of a dual variable from the other @@ -253,7 +254,7 @@ def c_transform_entropic(b, M, reg, beta): References ---------- .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). - ''' + """ n_source = np.shape(M)[0] alpha = np.zeros(n_source) @@ -265,9 +266,10 @@ def c_transform_entropic(b, M, reg, beta): return alpha -def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, - log=False): - r''' +def solve_semi_dual_entropic( + a, b, M, reg, method, numItermax=10000, lr=None, log=False +): + r""" Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem The function solves the following optimization problem: @@ -303,7 +305,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, cost matrix reg : float Regularization term > 0 - methode : str + method : str used method (SAG or ASGD) numItermax : int number of iteration @@ -327,7 +329,7 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, References ---------- .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). - ''' + """ if method.lower() == "sag": opt_beta = sag_entropic_transport(a, b, M, reg, numItermax, lr) @@ -338,13 +340,16 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, return None opt_alpha = c_transform_entropic(b, M, reg, opt_beta) - pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * - a[:, None] * b[None, :]) + pi = ( + np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) + * a[:, None] + * b[None, :] + ) if log: log = {} - log['alpha'] = opt_alpha - log['beta'] = opt_beta + log["alpha"] = opt_alpha + log["beta"] = opt_beta return pi, log else: return pi @@ -355,9 +360,8 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, ############################################################################## -def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, - batch_beta): - r''' +def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, batch_beta): + r""" Computes the partial gradient of the dual optimal transport problem. For each :math:`(i,j)` in a batch of coordinates, the partial gradients are : @@ -416,22 +420,33 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, References ---------- .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) - ''' - G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] - - M[batch_alpha, :][:, batch_beta]) / reg) * - a[batch_alpha, None] * b[None, batch_beta]) + """ + G = -( + np.exp( + ( + alpha[batch_alpha, None] + + beta[None, batch_beta] + - M[batch_alpha, :][:, batch_beta] + ) + / reg + ) + * a[batch_alpha, None] + * b[None, batch_beta] + ) grad_beta = np.zeros(np.shape(M)[1]) grad_alpha = np.zeros(np.shape(M)[0]) - grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] - + G.sum(0)) - grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) - / np.shape(M)[1] + G.sum(1)) + grad_beta[batch_beta] = b[batch_beta] * len(batch_alpha) / np.shape(M)[0] + G.sum(0) + grad_alpha[batch_alpha] = a[batch_alpha] * len(batch_beta) / np.shape(M)[1] + G.sum( + 1 + ) return grad_alpha, grad_beta -def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random_state=None): - r''' +def sgd_entropic_regularization( + a, b, M, reg, batch_size, numItermax, lr, random_state=None +): + r""" Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: @@ -482,7 +497,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random References ---------- .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) - ''' + """ n_source = np.shape(M)[0] n_target = np.shape(M)[1] @@ -493,18 +508,17 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random k = np.sqrt(cur_iter + 1) batch_alpha = rng.choice(n_source, batch_size, replace=False) batch_beta = rng.choice(n_target, batch_size, replace=False) - update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha, - cur_beta, batch_size, - batch_alpha, batch_beta) + update_alpha, update_beta = batch_grad_dual( + a, b, M, reg, cur_alpha, cur_beta, batch_size, batch_alpha, batch_beta + ) cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha] cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta] return cur_alpha, cur_beta -def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, - log=False): - r''' +def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, log=False): + r""" Compute the transportation matrix to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: @@ -554,16 +568,20 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, References ---------- .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) - ''' + """ - opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, - numItermax, lr) - pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) * - a[:, None] * b[None, :]) + opt_alpha, opt_beta = sgd_entropic_regularization( + a, b, M, reg, batch_size, numItermax, lr + ) + pi = ( + np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) + * a[:, None] + * b[None, :] + ) if log: log = {} - log['alpha'] = opt_alpha - log['beta'] = opt_beta + log["alpha"] = opt_alpha + log["beta"] = opt_beta return pi, log else: return pi @@ -573,7 +591,8 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, # Losses for stochastic optimization ################################################################################ -def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + +def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"): r""" Compute the dual loss of the entropic OT as in equation (6)-(7) of [19] @@ -631,7 +650,7 @@ def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidea return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) -def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): +def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"): r""" Compute the primal OT plan the entropic OT as in equation (8) of [19] @@ -689,7 +708,7 @@ def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidea return ws[:, None] * H * wt[None, :] -def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): +def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"): r""" Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19] @@ -742,12 +761,12 @@ def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclide else: M = dist(xs, xt, metric=metric) - F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2 + F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0) ** 2 return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) -def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): +def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric="sqeuclidean"): r""" Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19] diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 7d6294424..771452954 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -9,21 +9,33 @@ # License: MIT License # All submodules and packages -from ._sinkhorn import (sinkhorn_knopp_unbalanced, - sinkhorn_unbalanced, - sinkhorn_stabilized_unbalanced, - sinkhorn_unbalanced_translation_invariant, - sinkhorn_unbalanced2, - barycenter_unbalanced_sinkhorn, - barycenter_unbalanced_stabilized, - barycenter_unbalanced) +from ._sinkhorn import ( + sinkhorn_knopp_unbalanced, + sinkhorn_unbalanced, + sinkhorn_stabilized_unbalanced, + sinkhorn_unbalanced_translation_invariant, + sinkhorn_unbalanced2, + barycenter_unbalanced_sinkhorn, + barycenter_unbalanced_stabilized, + barycenter_unbalanced, +) -from ._mm import (mm_unbalanced, mm_unbalanced2) +from ._mm import mm_unbalanced, mm_unbalanced2 -from ._lbfgs import (lbfgsb_unbalanced, lbfgsb_unbalanced2) +from ._lbfgs import lbfgsb_unbalanced, lbfgsb_unbalanced2 -__all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced', - 'sinkhorn_unbalanced_translation_invariant', 'sinkhorn_unbalanced2', - 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized', - 'barycenter_unbalanced', 'mm_unbalanced', 'mm_unbalanced2', '_get_loss_unbalanced', - 'lbfgsb_unbalanced', 'lbfgsb_unbalanced2'] +__all__ = [ + "sinkhorn_knopp_unbalanced", + "sinkhorn_unbalanced", + "sinkhorn_stabilized_unbalanced", + "sinkhorn_unbalanced_translation_invariant", + "sinkhorn_unbalanced2", + "barycenter_unbalanced_sinkhorn", + "barycenter_unbalanced_stabilized", + "barycenter_unbalanced", + "mm_unbalanced", + "mm_unbalanced2", + "_get_loss_unbalanced", + "lbfgsb_unbalanced", + "lbfgsb_unbalanced2", +] diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index a38c00a5e..6ec173ad8 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -17,7 +17,7 @@ from ..utils import list_to_array, get_parameter_pair -def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): +def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div="kl"): """ Return loss function for the L-BFGS-B solver @@ -62,7 +62,7 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div nx_numpy = get_backend(M, a, b) def reg_l2(G): - return np.sum((G - c)**2) / 2 + return np.sum((G - c) ** 2) / 2 def grad_l2(G): return G - c @@ -79,10 +79,10 @@ def reg_entropy(G): def grad_entropy(G): return np.log(G + 1e-16) - if reg_div == 'kl': + if reg_div == "kl": reg_fun = reg_kl grad_reg_fun = grad_kl - elif reg_div == 'entropy': + elif reg_div == "entropy": reg_fun = reg_entropy grad_reg_fun = grad_entropy elif isinstance(reg_div, tuple): @@ -93,32 +93,39 @@ def grad_entropy(G): grad_reg_fun = grad_l2 def marg_l2(G): - return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \ - reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2) + return reg_m1 * 0.5 * np.sum((G.sum(1) - a) ** 2) + reg_m2 * 0.5 * np.sum( + (G.sum(0) - b) ** 2 + ) def grad_marg_l2(G): - return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) + return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + reg_m2 * np.outer( + np.ones(m), (G.sum(0) - b) + ) def marg_kl(G): - return reg_m1 * nx_numpy.kl_div(G.sum(1), a, mass=True) + reg_m2 * nx_numpy.kl_div(G.sum(0), b, mass=True) + return reg_m1 * nx_numpy.kl_div( + G.sum(1), a, mass=True + ) + reg_m2 * nx_numpy.kl_div(G.sum(0), b, mass=True) def grad_marg_kl(G): - return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) + return reg_m1 * np.outer( + np.log(G.sum(1) / a + 1e-16), np.ones(n) + ) + reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) def marg_tv(G): - return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ - reg_m2 * np.sum(np.abs(G.sum(0) - b)) + return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + reg_m2 * np.sum( + np.abs(G.sum(0) - b) + ) def grad_marg_tv(G): - return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ - reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) + return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + reg_m2 * np.outer( + np.ones(m), np.sign(G.sum(0) - b) + ) - if regm_div == 'kl': + if regm_div == "kl": regm_fun = marg_kl grad_regm_fun = grad_marg_kl - elif regm_div == 'tv': + elif regm_div == "tv": regm_fun = marg_tv grad_regm_fun = grad_marg_tv else: @@ -142,8 +149,22 @@ def _func(G): return _func -def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, - stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): +def lbfgsb_unbalanced( + a, + b, + M, + reg, + reg_m, + c=None, + reg_div="kl", + regm_div="kl", + G0=None, + numItermax=1000, + stopThr=1e-15, + method="L-BFGS-B", + verbose=False, + log=False, +): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B algorithm. The function solves the following optimization problem: @@ -253,7 +274,9 @@ def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', f0(G0) df0(G0) except BaseException: - warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") + warnings.warn( + "The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead" + ) def f(x): return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) @@ -266,11 +289,17 @@ def df(x): else: reg_div = reg_div.lower() if reg_div not in ["entropy", "kl", "l2"]: - raise ValueError("Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2', or a tuple".format(reg_div)) + raise ValueError( + "Unknown reg_div = {}. Must be either 'entropy', 'kl' or 'l2', or a tuple".format( + reg_div + ) + ) regm_div = regm_div.lower() if regm_div not in ["kl", "l2", "tv"]: - raise ValueError("Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div)) + raise ValueError( + "Unknown regm_div = {}. Must be either 'kl', 'l2' or 'tv'".format(regm_div) + ) reg_m1, reg_m2 = get_parameter_pair(reg_m) @@ -292,22 +321,43 @@ def df(x): _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) - res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), - tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) + res = minimize( + _func, + G0.ravel(), + method=method, + jac=True, + bounds=Bounds(0, np.inf), + tol=stopThr, + options=dict(maxiter=numItermax, disp=verbose), + ) G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) if log: - log = {'cost': nx.sum(G * M), 'res': res} - log['total_cost'] = nx.from_numpy(res.fun, type_as=M0) + log = {"cost": nx.sum(G * M), "res": res} + log["total_cost"] = nx.from_numpy(res.fun, type_as=M0) return G, log else: return G -def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', - G0=None, returnCost="linear", numItermax=1000, stopThr=1e-15, - method='L-BFGS-B', verbose=False, log=False): +def lbfgsb_unbalanced2( + a, + b, + M, + reg, + reg_m, + c=None, + reg_div="kl", + regm_div="kl", + G0=None, + returnCost="linear", + numItermax=1000, + stopThr=1e-15, + method="L-BFGS-B", + verbose=False, + log=False, +): r""" Solve the unbalanced optimal transport problem and return the OT cost using L-BFGS-B. The function solves the following optimization problem: @@ -411,15 +461,27 @@ def lbfgsb_unbalanced2(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_lbfgs = lbfgsb_unbalanced(a=a, b=b, M=M, reg=reg, reg_m=reg_m, c=c, - reg_div=reg_div, regm_div=regm_div, G0=G0, - numItermax=numItermax, stopThr=stopThr, - method=method, verbose=verbose, log=True) + _, log_lbfgs = lbfgsb_unbalanced( + a=a, + b=b, + M=M, + reg=reg, + reg_m=reg_m, + c=c, + reg_div=reg_div, + regm_div=regm_div, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + method=method, + verbose=verbose, + log=True, + ) if returnCost == "linear": - cost = log_lbfgs['cost'] + cost = log_lbfgs["cost"] elif returnCost == "total": - cost = log_lbfgs['total_cost'] + cost = log_lbfgs["total_cost"] else: raise ValueError("Unknown returnCost = {}".format(returnCost)) diff --git a/ot/unbalanced/_mm.py b/ot/unbalanced/_mm.py index b22f234d1..47fb1ca7c 100644 --- a/ot/unbalanced/_mm.py +++ b/ot/unbalanced/_mm.py @@ -13,8 +13,20 @@ from ..utils import list_to_array, get_parameter_pair -def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, - stopThr=1e-15, verbose=False, log=False): +def mm_unbalanced( + a, + b, + M, + reg_m, + c=None, + reg=0, + div="kl", + G0=None, + numItermax=1000, + stopThr=1e-15, + verbose=False, + log=False, +): r""" Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: @@ -129,14 +141,14 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - log = {'err': [], 'G': []} + log = {"err": [], "G": []} div = div.lower() - if div == 'kl': + if div == "kl": sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r - K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) - elif div == 'l2': + K = (a[:, None] ** r1) * (b[None, :] ** r2) * (c**r) * nx.exp(-M / sum_r) + elif div == "l2": K = (reg_m1 * a[:, None]) + (reg_m2 * b[None, :]) + reg * c - M K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) else: @@ -145,36 +157,50 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 for i in range(numItermax): Gprev = G - if div == 'kl': - Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 - G = K * G**(r1 + r2) / Gd - elif div == 'l2': - Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ - reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 + if div == "kl": + Gd = (nx.sum(G, 1, keepdims=True) ** r1) * ( + nx.sum(G, 0, keepdims=True) ** r2 + ) + 1e-16 + G = K * G ** (r1 + r2) / Gd + elif div == "l2": + Gd = ( + reg_m1 * nx.sum(G, 1, keepdims=True) + + reg_m2 * nx.sum(G, 0, keepdims=True) + + reg * G + + 1e-16 + ) G = K * G / Gd err = nx.sqrt(nx.sum((G - Gprev) ** 2)) if log: - log['err'].append(err) - log['G'].append(G) + log["err"].append(err) + log["G"].append(G) if verbose: - print('{:5d}|{:8e}|'.format(i, err)) + print("{:5d}|{:8e}|".format(i, err)) if err < stopThr: break if log: linear_cost = nx.sum(G * M) - log['cost'] = linear_cost + log["cost"] = linear_cost m1, m2 = nx.sum(G, 1), nx.sum(G, 0) if div == "kl": - cost = linear_cost + reg_m1 * nx.kl_div(m1, a, mass=True) + reg_m2 * nx.kl_div(m2, b, mass=True) + cost = ( + linear_cost + + reg_m1 * nx.kl_div(m1, a, mass=True) + + reg_m2 * nx.kl_div(m2, b, mass=True) + ) if reg > 0: cost = cost + reg * nx.kl_div(G, c, mass=True) else: - cost = linear_cost + reg_m1 * 0.5 * nx.sum((m1 - a)**2) + reg_m2 * 0.5 * nx.sum((m2 - b)**2) + cost = ( + linear_cost + + reg_m1 * 0.5 * nx.sum((m1 - a) ** 2) + + reg_m2 * 0.5 * nx.sum((m2 - b) ** 2) + ) if reg > 0: - cost = cost + reg * 0.5 * nx.sum((G - c)**2) + cost = cost + reg * 0.5 * nx.sum((G - c) ** 2) log["total_cost"] = cost @@ -183,8 +209,21 @@ def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1 return G -def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost="linear", - numItermax=1000, stopThr=1e-15, verbose=False, log=False): +def mm_unbalanced2( + a, + b, + M, + reg_m, + c=None, + reg=0, + div="kl", + G0=None, + returnCost="linear", + numItermax=1000, + stopThr=1e-15, + verbose=False, + log=False, +): r""" Solve the unbalanced optimal transport problem and return the OT cost. The function solves the following optimization problem: @@ -280,14 +319,25 @@ def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, returnCost= ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ - _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, - numItermax=numItermax, stopThr=stopThr, - verbose=verbose, log=True) + _, log_mm = mm_unbalanced( + a, + b, + M, + reg_m, + c=c, + reg=reg, + div=div, + G0=G0, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + ) if returnCost == "linear": - cost = log_mm['cost'] + cost = log_mm["cost"] elif returnCost == "total": - cost = log_mm['total_cost'] + cost = log_mm["total_cost"] else: raise ValueError("Unknown returnCost = {}".format(returnCost)) diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 6a7cbb028..810c66ce3 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -16,9 +16,22 @@ from ..utils import list_to_array, get_parameter_pair -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', - reg_type="kl", c=None, warmstart=None, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced( + a, + b, + M, + reg, + reg_m, + method="sinkhorn", + reg_type="kl", + c=None, + warmstart=None, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r""" Solve the unbalanced entropic regularization optimal transport problem and return the OT plan @@ -157,40 +170,95 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', """ - if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - - elif method.lower() == 'sinkhorn_translation_invariant': - return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if method.lower() == "sinkhorn": + return sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + + elif method.lower() == "sinkhorn_stabilized": + return sinkhorn_stabilized_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + + elif method.lower() == "sinkhorn_translation_invariant": + return sinkhorn_unbalanced_translation_invariant( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + + elif method.lower() in ["sinkhorn_reg_scaling"]: + warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp") + return sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', - reg_type="kl", c=None, warmstart=None, - returnCost="linear", numItermax=1000, - stopThr=1e-6, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2( + a, + b, + M, + reg, + reg_m, + method="sinkhorn", + reg_type="kl", + c=None, + warmstart=None, + returnCost="linear", + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r""" Solve the entropic regularization unbalanced optimal transport problem and return the cost @@ -322,37 +390,81 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', M, a, b = list_to_array(M, a, b) if len(b.shape) < 2: - if method.lower() == 'sinkhorn': - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=True, **kwargs) - - elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=True, **kwargs) - - elif method.lower() == 'sinkhorn_translation_invariant': - res = sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=True, **kwargs) - - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=True, **kwargs) + if method.lower() == "sinkhorn": + res = sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + **kwargs, + ) + + elif method.lower() == "sinkhorn_stabilized": + res = sinkhorn_stabilized_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + **kwargs, + ) + + elif method.lower() == "sinkhorn_translation_invariant": + res = sinkhorn_unbalanced_translation_invariant( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + **kwargs, + ) + + elif method.lower() in ["sinkhorn_reg_scaling"]: + warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp") + res = sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=True, + **kwargs, + ) else: - raise ValueError('Unknown method %s.' % method) + raise ValueError("Unknown method %s." % method) if returnCost == "linear": - cost = res[1]['cost'] + cost = res[1]["cost"] elif returnCost == "total": - cost = res[1]['total_cost'] + cost = res[1]["total_cost"] else: raise ValueError("Unknown returnCost = {}".format(returnCost)) @@ -363,39 +475,95 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', else: if reg_type == "kl": - warnings.warn('Reg_type not implemented yet. Use entropy.') - - if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() == 'sinkhorn_translation_invariant': - return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_reg_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, c, - warmstart, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - else: - raise ValueError('Unknown method %s.' % method) + warnings.warn("Reg_type not implemented yet. Use entropy.") + + if method.lower() == "sinkhorn": + return sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + + elif method.lower() == "sinkhorn_stabilized": + return sinkhorn_stabilized_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + elif method.lower() == "sinkhorn_translation_invariant": + return sinkhorn_unbalanced_translation_invariant( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, - warmstart=None, numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): + elif method.lower() in ["sinkhorn_reg_scaling"]: + warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp") + return sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type, + c, + warmstart, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + else: + raise ValueError("Unknown method %s." % method) + + +def sinkhorn_knopp_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type="kl", + c=None, + warmstart=None, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -530,7 +698,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - dict_log = {'err': []} + dict_log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -546,7 +714,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type.lower() == "entropy": - warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + warnings.warn( + "If reg_type = entropy, then the matrix c is overwritten by the one matrix." + ) c = nx.ones((dim_a, dim_b), type_as=M) if n_hists: @@ -558,7 +728,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 - err = 1. + err = 1.0 for i in range(numItermax): uprev = u @@ -569,39 +739,42 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi_2 - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % i) + warnings.warn("Numerical errors at iteration %s" % i) u = uprev v = vprev break err_u = nx.max(nx.abs(u - uprev)) / max( - nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0 ) err_v = nx.max(nx.abs(v - vprev)) / max( - nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0 ) err = 0.5 * (err_u + err_v) if log: - dict_log['err'].append(err) + dict_log["err"].append(err) if verbose: if i % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(i, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(i, err)) if err < stopThr: break if log: - dict_log['logu'] = nx.log(u + 1e-300) - dict_log['logv'] = nx.log(v + 1e-300) + dict_log["logu"] = nx.log(u + 1e-300) + dict_log["logv"] = nx.log(v + 1e-300) if n_hists: # return only loss - res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M) if log: return res, dict_log else: @@ -626,10 +799,22 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, return plan -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, - warmstart=None, tau=1e5, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_stabilized_unbalanced( + a, + b, + M, + reg, + reg_m, + reg_type="kl", + c=None, + warmstart=None, + tau=1e5, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -769,7 +954,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - dict_log = {'err': []} + dict_log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -785,7 +970,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "entropy": - warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + warnings.warn( + "If reg_type = entropy, then the matrix c is overwritten by the one matrix." + ) c = nx.ones((dim_a, dim_b), type_as=M) if n_hists: @@ -799,19 +986,19 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 cpt = 0 - err = 1. + err = 1.0 alpha = nx.zeros(dim_a, type_as=M) beta = nx.zeros(dim_b, type_as=M) ones_a = nx.ones(dim_a, type_as=M) ones_b = nx.ones(dim_b, type_as=M) - while (err > stopThr and cpt < numItermax): + while err > stopThr and cpt < numItermax: uprev = u vprev = v Kv = nx.dot(K, v) - f_alpha = nx.exp(- alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a - f_beta = nx.exp(- beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b + f_alpha = nx.exp(-alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a + f_beta = nx.exp(-beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b if n_hists: f_alpha = f_alpha[:, None] @@ -832,12 +1019,16 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn("Numerical errors at iteration %s" % cpt) u = uprev v = vprev break @@ -845,21 +1036,22 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u - uprev)) / max( - nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0 ) if log: - dict_log['err'].append(err) + dict_log["err"].append(err) if verbose: if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err)) cpt = cpt + 1 if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg` or a lower mass `reg_m`." + - "Or a larger absorption threshold `tau`.") + warnings.warn( + "Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`." + ) if n_hists: logu = alpha[:, None] / reg + nx.log(u) logv = beta[:, None] / reg + nx.log(v) @@ -867,15 +1059,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) if log: - dict_log['logu'] = logu - dict_log['logv'] = logv + dict_log["logu"] = logu + dict_log["logv"] = logv if n_hists: # return only loss res = nx.logsumexp( nx.log(M + 1e-100)[:, :, None] + logu[:, None, :] + logv[None, :, :] - M0[:, :, None] / reg, - axis=(0, 1) + axis=(0, 1), ) res = nx.exp(res) if log: @@ -901,9 +1093,21 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, return plan -def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl", c=None, - warmstart=None, numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced_translation_invariant( + a, + b, + M, + reg, + reg_m, + reg_type="kl", + c=None, + warmstart=None, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan @@ -1021,7 +1225,7 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: - dict_log = {'err': []} + dict_log = {"err": []} # we assume that no distances are null except those of the diagonal of # distances @@ -1039,7 +1243,9 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" u_, v_ = u, v if reg_type == "entropy": - warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') + warnings.warn( + "If reg_type = entropy, then the matrix c is overwritten by the one matrix." + ) c = nx.ones((dim_a, dim_b), type_as=M) if n_hists: @@ -1052,8 +1258,16 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 - k1 = reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2)) if reg_m1 != float("inf") else 0 - k2 = reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2)) if reg_m2 != float("inf") else 0 + k1 = ( + reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2)) + if reg_m1 != float("inf") + else 0 + ) + k2 = ( + reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2)) + if reg_m2 != float("inf") + else 0 + ) k_rho1 = k1 * reg_m1 / reg if reg_m1 != float("inf") else 0 k_rho2 = k2 * reg_m2 / reg if reg_m2 != float("inf") else 0 @@ -1080,59 +1294,64 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" reg_ratio1 = reg / reg_m1 if reg_m1 != float("inf") else 0 reg_ratio2 = reg / reg_m2 if reg_m2 != float("inf") else 0 - err = 1. + err = 1.0 for i in range(numItermax): uprev = u vprev = v Kv = nx.dot(K, v_) - u_hat = (a / Kv) ** fi_1 * nx.sum(b * v_**reg_ratio2)**k_rho2 - u_ = u_hat * nx.sum(a * u_hat**(-reg_ratio1))**(-xi_rho1) + u_hat = (a / Kv) ** fi_1 * nx.sum(b * v_**reg_ratio2) ** k_rho2 + u_ = u_hat * nx.sum(a * u_hat ** (-reg_ratio1)) ** (-xi_rho1) Ktu = nx.dot(K.T, u_) - v_hat = (b / Ktu) ** fi_2 * nx.sum(a * u_**(-reg_ratio1))**k_rho1 - v_ = v_hat * nx.sum(b * v_hat**(-reg_ratio2))**(-xi_rho2) - - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u_)) or nx.any(nx.isnan(v_)) - or nx.any(nx.isinf(u_)) or nx.any(nx.isinf(v_))): + v_hat = (b / Ktu) ** fi_2 * nx.sum(a * u_ ** (-reg_ratio1)) ** k_rho1 + v_ = v_hat * nx.sum(b * v_hat ** (-reg_ratio2)) ** (-xi_rho2) + + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u_)) + or nx.any(nx.isnan(v_)) + or nx.any(nx.isinf(u_)) + or nx.any(nx.isinf(v_)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % i) + warnings.warn("Numerical errors at iteration %s" % i) u = uprev v = vprev break - t = (nx.sum(a * u_**(-reg_ratio1)) / nx.sum(b * v_**(-reg_ratio2)))**(fi_12 / reg) + t = (nx.sum(a * u_ ** (-reg_ratio1)) / nx.sum(b * v_ ** (-reg_ratio2))) ** ( + fi_12 / reg + ) u = u_ * t v = v_ / t err_u = nx.max(nx.abs(u - uprev)) / max( - nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0 ) err_v = nx.max(nx.abs(v - vprev)) / max( - nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0 ) err = 0.5 * (err_u + err_v) if log: - dict_log['err'].append(err) + dict_log["err"].append(err) if verbose: if i % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(i, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(i, err)) if err < stopThr: break if log: - dict_log['logu'] = nx.log(u + 1e-300) - dict_log['logv'] = nx.log(v + 1e-300) + dict_log["logu"] = nx.log(u + 1e-300) + dict_log["logv"] = nx.log(v + 1e-300) if n_hists: # return only loss - res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M) if log: return res, dict_log else: @@ -1157,9 +1376,18 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" return plan -def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_stabilized( + A, + M, + reg, + reg_m, + weights=None, + tau=1e3, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, +): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -1229,10 +1457,10 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} fi = reg_m / (reg_m + reg) @@ -1245,15 +1473,15 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) cpt = 0 - err = 1. + err = 1.0 alpha = nx.zeros(dim, type_as=A) beta = nx.zeros(dim, type_as=A) q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): qprev = nx.copy(q) Kv = nx.dot(K, v) - f_alpha = nx.exp(- alpha / (reg + reg_m)) - f_beta = nx.exp(- beta / (reg + reg_m)) + f_alpha = nx.exp(-alpha / (reg + reg_m)) + f_beta = nx.exp(-beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha @@ -1270,46 +1498,59 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn("Numerical errors at iteration %s" % cpt) q = qprev break if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(q - qprev)) / max( - nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 ) if log: - log['err'].append(err) + log["err"].append(err) if verbose: if i % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(i, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(i, err)) if err < stopThr: break if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg` or a lower mass `reg_m`." + - "Or a larger absorption threshold `tau`.") + warnings.warn( + "Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`." + ) if log: - log['niter'] = i - log['logu'] = nx.log(u + 1e-300) - log['logv'] = nx.log(v + 1e-300) + log["niter"] = i + log["logu"] = nx.log(u + 1e-300) + log["logv"] = nx.log(v + 1e-300) return q, log else: return q -def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False): +def barycenter_unbalanced_sinkhorn( + A, + M, + reg, + reg_m, + weights=None, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, +): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. The function solves the following optimization problem with :math:`\mathbf{a}` @@ -1338,7 +1579,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, reg_m: float Marginal relaxation term > 0 weights : array-like (n_hists,) optional - Weight of each distribution (barycentric coodinates) + Weight of each distribution (barycentric coordinates) If None, uniform weights are used. numItermax : int, optional Max number of iterations @@ -1377,10 +1618,10 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: - assert (len(weights) == A.shape[1]) + assert len(weights) == A.shape[1] if log: - log = {'err': []} + log = {"err": []} K = nx.exp(-M / reg) @@ -1389,7 +1630,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, v = nx.ones((dim, n_hists), type_as=A) u = nx.ones((dim, 1), type_as=A) q = nx.ones(dim, type_as=A) - err = 1. + err = 1.0 for i in range(numItermax): uprev = nx.copy(u) @@ -1404,12 +1645,16 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Q = q[:, None] v = (Q / Ktu) ** fi - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + if ( + nx.any(Ktu == 0.0) + or nx.any(nx.isnan(u)) + or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) + or nx.any(nx.isinf(v)) + ): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % i) + warnings.warn("Numerical errors at iteration %s" % i) u = uprev v = vprev q = qprev @@ -1419,29 +1664,38 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 ) if log: - log['err'].append(err) + log["err"].append(err) # if barycenter did not change + at least 10 iterations - stop if err < stopThr and i > 10: break if verbose: if i % 10 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(i, err)) + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(i, err)) if log: - log['niter'] = i - log['logu'] = nx.log(u + 1e-300) - log['logv'] = nx.log(v + 1e-300) + log["niter"] = i + log["logu"] = nx.log(u + 1e-300) + log["logv"] = nx.log(v + 1e-300) return q, log else: return q -def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, - numItermax=1000, stopThr=1e-6, - verbose=False, log=False, **kwargs): +def barycenter_unbalanced( + A, + M, + reg, + reg_m, + method="sinkhorn", + weights=None, + numItermax=1000, + stopThr=1e-6, + verbose=False, + log=False, + **kwargs, +): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. The function solves the following optimization problem with :math:`\mathbf{a}` @@ -1470,7 +1724,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, reg_m: float Marginal relaxation term > 0 weights : array-like (n_hists,) optional - Weight of each distribution (barycentric coodinates) + Weight of each distribution (barycentric coordinates) If None, uniform weights are used. numItermax : int, optional Max number of iterations @@ -1502,26 +1756,46 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, """ - if method.lower() == 'sinkhorn': - return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, - weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() == 'sinkhorn_stabilized': - return barycenter_unbalanced_stabilized(A, M, reg, reg_m, - weights=weights, - numItermax=numItermax, - stopThr=stopThr, - verbose=verbose, - log=log, **kwargs) - elif method.lower() in ['sinkhorn_reg_scaling', 'sinkhorn_translation_invariant']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return barycenter_unbalanced(A, M, reg, reg_m, - weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if method.lower() == "sinkhorn": + return barycenter_unbalanced_sinkhorn( + A, + M, + reg, + reg_m, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + + elif method.lower() == "sinkhorn_stabilized": + return barycenter_unbalanced_stabilized( + A, + M, + reg, + reg_m, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) + elif method.lower() in ["sinkhorn_reg_scaling", "sinkhorn_translation_invariant"]: + warnings.warn("Method not implemented yet. Using classic Sinkhorn Knopp") + return barycenter_unbalanced( + A, + M, + reg, + reg_m, + weights=weights, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, + **kwargs, + ) else: raise ValueError("Unknown method '%s'." % method) diff --git a/ot/utils.py b/ot/utils.py index 12910c479..a2d328484 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -21,30 +21,30 @@ def tic(): - r""" Python implementation of Matlab tic() function """ + r"""Python implementation of Matlab tic() function""" global __time_tic_toc __time_tic_toc = time.time() -def toc(message='Elapsed time : {} s'): - r""" Python implementation of Matlab toc() function """ +def toc(message="Elapsed time : {} s"): + r"""Python implementation of Matlab toc() function""" t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): - r""" Python implementation of Julia toc() function """ + r"""Python implementation of Julia toc() function""" t = time.time() return t - __time_tic_toc -def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): +def kernel(x1, x2, method="gaussian", sigma=1, **kwargs): r"""Compute kernel matrix""" nx = get_backend(x1, x2) - if method.lower() in ['gaussian', 'gauss', 'rbf']: + if method.lower() in ["gaussian", "gauss", "rbf"]: K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -57,10 +57,9 @@ def laplacian(x): def list_to_array(*lst, nx=None): - r""" Convert a list if in numpy format """ + r"""Convert a list if in numpy format""" lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] if nx is None: # find backend - if len(lst_not_empty) == 0: type_as = np.zeros(0) nx = get_backend(type_as) @@ -73,8 +72,10 @@ def list_to_array(*lst, nx=None): else: type_as = lst_not_empty[0] if len(lst) > 1: - return [nx.from_numpy(np.array(a), type_as=type_as) - if isinstance(a, list) else a for a in lst] + return [ + nx.from_numpy(np.array(a), type_as=type_as) if isinstance(a, list) else a + for a in lst + ] else: if isinstance(lst[0], list): return nx.from_numpy(np.array(lst[0]), type_as=type_as) @@ -166,10 +167,14 @@ def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None): if V.ndim == 1: return projection_sparse_simplex( # V[nx.newaxis, :], max_nz, z, axis=1).ravel() - V[None, :], max_nz, z, axis=1).ravel() + V[None, :], + max_nz, + z, + axis=1, + ).ravel() if V.ndim > 2: - raise ValueError('V.ndim must be <= 2') + raise ValueError("V.ndim must be <= 2") if axis == 1: # For each row of V, find top max_nz values; arrange the @@ -199,9 +204,10 @@ def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None): if isinstance(nx, JaxBackend): # in Jax, we need to use the `at` property of `jax.numpy.ndarray` - # to do in-place array modificatons. - sparse_projection = sparse_projection.at[ - row_indices, max_nz_indices].set(nz_projection) + # to do in-place array modifications. + sparse_projection = sparse_projection.at[row_indices, max_nz_indices].set( + nz_projection + ) else: sparse_projection[row_indices, max_nz_indices] = nz_projection return sparse_projection @@ -238,8 +244,7 @@ def unif(n, type_as=None): def clean_zeros(a, b, M): - r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}` - """ + r"""Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}`""" M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] b2 = b[b > 0] @@ -268,8 +273,8 @@ def euclidean_distances(X, Y, squared=False): nx = get_backend(X, Y) - a2 = nx.einsum('ij,ij->i', X, X) - b2 = nx.einsum('ij,ij->i', Y, Y) + a2 = nx.einsum("ij,ij->i", X, X) + b2 = nx.einsum("ij,ij->i", Y, Y) c = -2 * nx.dot(X, Y.T) c += a2[:, None] @@ -286,7 +291,7 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None): +def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None): r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays @@ -326,7 +331,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None): elif metric == "euclidean": return euclidean_distances(x1, x2, squared=False) else: - if not get_backend(x1, x2).__name__ == 'numpy': + if not get_backend(x1, x2).__name__ == "numpy": raise NotImplementedError() else: if isinstance(metric, str) and metric.endswith("minkowski"): @@ -336,7 +341,7 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None): return cdist(x1, x2, metric=metric) -def dist0(n, method='lin_square'): +def dist0(n, method="lin_square"): r"""Compute standard cost matrices of size (`n`, `n`) for OT problems Parameters @@ -354,14 +359,14 @@ def dist0(n, method='lin_square'): Distance matrix computed with given metric. """ res = 0 - if method == 'lin_square': + if method == "lin_square": x = np.arange(n, dtype=np.float64).reshape((n, 1)) res = dist(x, x) return res def cost_normalization(C, norm=None, return_value=False, value=None): - r""" Apply normalization to the loss matrix + r"""Apply normalization to the loss matrix Parameters ---------- @@ -394,9 +399,11 @@ def cost_normalization(C, norm=None, return_value=False, value=None): elif norm == "loglog": C = nx.log(1 + nx.log(1 + C)) else: - raise ValueError('Norm %s is not a valid option.\n' - 'Valid options are:\n' - 'median, max, log, loglog' % norm) + raise ValueError( + "Norm %s is not a valid option.\n" + "Valid options are:\n" + "median, max, log, loglog" % norm + ) if return_value: return C, value else: @@ -404,7 +411,7 @@ def cost_normalization(C, norm=None, return_value=False, value=None): def dots(*args): - r""" dots function for multiple matrix multiply """ + r"""dots function for multiple matrix multiply""" nx = get_backend(*args) return reduce(nx.dot, args) @@ -416,7 +423,7 @@ def is_all_finite(*args): def label_normalization(y, start=0, nx=None): - r""" Transform labels to start at a given value + r"""Transform labels to start at a given value Parameters ---------- @@ -466,15 +473,14 @@ def labels_to_masks(y, type_as=None, nx=None): def parmap(f, X, nprocs="default"): - r""" parallel map for multiprocessing. + r"""parallel map for multiprocessing. The function has been deprecated and only performs a regular map. """ return list(map(f, X)) def check_params(**kwargs): - r"""check_params: check whether some parameters are missing - """ + r"""check_params: check whether some parameters are missing""" missing_params = [] check = True @@ -510,8 +516,9 @@ def check_random_state(seed): return np.random.RandomState(seed) if isinstance(seed, np.random.RandomState): return seed - raise ValueError('{} cannot be used to seed a numpy.random.RandomState' - ' instance'.format(seed)) + raise ValueError( + "{} cannot be used to seed a numpy.random.RandomState" " instance".format(seed) + ) def get_coordinate_circle(x): @@ -545,7 +552,7 @@ def get_coordinate_circle(x): def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): - """ Reduce a LazyTensor along an axis with function fun using batches. + """Reduce a LazyTensor along an axis with function fun using batches. When axis=None, reduce the LazyTensor to a scalar as a sum of fun over batches taken along dim. @@ -584,33 +591,33 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): if axis is None: res = 0.0 for i in range(0, a.shape[0], batch_size): - res += func(a[i:i + batch_size]) + res += func(a[i : i + batch_size]) return res elif axis == 0: res = nx.zeros(a.shape[1:], type_as=a[0]) if nx.__name__ in ["jax", "tf"]: lst = [] for j in range(0, a.shape[1], batch_size): - lst.append(func(a[:, j:j + batch_size], 0)) + lst.append(func(a[:, j : j + batch_size], 0)) return nx.concatenate(lst, axis=0) else: for j in range(0, a.shape[1], batch_size): - res[j:j + batch_size] = func(a[:, j:j + batch_size], axis=0) + res[j : j + batch_size] = func(a[:, j : j + batch_size], axis=0) return res elif axis == 1: if len(a.shape) == 2: - shape = (a.shape[0]) + shape = a.shape[0] else: shape = (a.shape[0], *a.shape[2:]) res = nx.zeros(shape, type_as=a[0]) if nx.__name__ in ["jax", "tf"]: lst = [] for i in range(0, a.shape[0], batch_size): - lst.append(func(a[i:i + batch_size], 1)) + lst.append(func(a[i : i + batch_size], 1)) return nx.concatenate(lst, axis=0) else: for i in range(0, a.shape[0], batch_size): - res[i:i + batch_size] = func(a[i:i + batch_size], axis=1) + res[i : i + batch_size] = func(a[i : i + batch_size], axis=1) return res else: @@ -618,7 +625,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): def get_lowrank_lazytensor(Q, R, d=None, nx=None): - """ Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T + """Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T Parameters ---------- @@ -681,8 +688,10 @@ def get_parameter_pair(parameter): param_1, param_2 = parameter[0], parameter[0] else: if len(parameter) > 2: - raise ValueError("Parameter must be either a scalar, \ - or an indexable object of length 1 or 2.") + raise ValueError( + "Parameter must be either a scalar, \ + or an indexable object of length 1 or 2." + ) else: param_1, param_2 = parameter[0], parameter[1] @@ -715,7 +724,7 @@ class deprecated(object): # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, # but with many changes. - def __init__(self, extra=''): + def __init__(self, extra=""): self.extra = extra def __call__(self, obj): @@ -743,7 +752,7 @@ def wrapped(*args, **kwargs): cls.__init__ = wrapped - wrapped.__name__ = '__init__' + wrapped.__name__ = "__init__" wrapped.__doc__ = self._update_doc(init.__doc__) wrapped.deprecated_original = init @@ -776,16 +785,15 @@ def _update_doc(self, olddoc): def _is_deprecated(func): - r"""Helper to check if func is wraped by our deprecated decorator""" + r"""Helper to check if func is wrapped by our deprecated decorator""" if sys.version_info < (3, 5): - raise NotImplementedError("This is only available for python3.5 " - "or above") - closures = getattr(func, '__closure__', []) + raise NotImplementedError("This is only available for python3.5 " "or above") + closures = getattr(func, "__closure__", []) if closures is None: closures = [] - is_deprecated = ('deprecated' in ''.join([c.cell_contents - for c in closures - if isinstance(c.cell_contents, str)])) + is_deprecated = "deprecated" in "".join( + [c.cell_contents for c in closures if isinstance(c.cell_contents, str)] + ) return is_deprecated @@ -804,9 +812,7 @@ class BaseEstimator(object): nx: Backend = None def _get_backend(self, *arrays): - nx = get_backend( - *[input_ for input_ in arrays if input_ is not None] - ) + nx = get_backend(*[input_ for input_ in arrays if input_ is not None]) if nx.__name__ in ("tf",): raise TypeError("Domain adaptation does not support TF backend.") self.nx = nx @@ -818,7 +824,7 @@ def _get_param_names(cls): # fetch the constructor or the original constructor before # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + init = getattr(cls.__init__, "deprecated_original", cls.__init__) if init is object.__init__: # No explicit constructor to introspect return [] @@ -827,16 +833,20 @@ def _get_param_names(cls): # to represent init_signature = signature(init) # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] + parameters = [ + p + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] for p in parameters: if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("POT estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) + raise RuntimeError( + "POT estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." % (cls, init_signature) + ) # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) @@ -864,16 +874,16 @@ def get_params(self, deep=True): try: with warnings.catch_warnings(record=True) as w: value = getattr(self, key, None) - if len(w) and w[0].category == DeprecationWarning: + if len(w) and isinstance(w[0].category, DeprecationWarning): # if the parameter is deprecated, don't show it continue finally: warnings.filters.pop(0) # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): + if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) + out.update((key + "__" + k, val) for k, val in deep_items) out[key] = value return out @@ -895,24 +905,27 @@ def set_params(self, **params): valid_params = self.get_params(deep=True) # for key, value in iteritems(params): for key, value in params.items(): - split = key.split('__', 1) + split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." % (name, self) + ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." + % (key, self.__class__.__name__) + ) setattr(self, key, value) return self @@ -922,11 +935,12 @@ class UndefinedParameter(Exception): Aim at raising an Exception when a undefined parameter is called """ + pass class OTResult: - """ Base class for OT results. + """Base class for OT results. Parameters ---------- @@ -995,8 +1009,20 @@ class OTResult: """ - def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100): - + def __init__( + self, + potentials=None, + value=None, + value_linear=None, + value_quad=None, + plan=None, + log=None, + backend=None, + sparse_plan=None, + lazy_plan=None, + status=None, + batch_size=100, + ): self._potentials = potentials self._value = value self._value_linear = value_linear @@ -1021,19 +1047,23 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No # Dual potentials -------------------------------------------- def __repr__(self): - s = 'OTResult(' + s = "OTResult(" if self._value is not None: - s += 'value={},'.format(self._value) + s += "value={},".format(self._value) if self._value_linear is not None: - s += 'value_linear={},'.format(self._value_linear) + s += "value_linear={},".format(self._value_linear) if self._plan is not None: - s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape) + s += "plan={}(shape={}),".format( + self._plan.__class__.__name__, self._plan.shape + ) if self._lazy_plan is not None: - s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) - if s[-1] != '(': - s = s[:-1] + ')' + s += "lazy_plan={}(shape={}),".format( + self._lazy_plan.__class__.__name__, self._lazy_plan.shape + ) + if s[-1] != "(": + s = s[:-1] + ")" else: - s = s + ')' + s = s + ")" return s @property @@ -1204,7 +1234,7 @@ def citation(self): class LazyTensor(object): - """ A lazy tensor is a tensor that is not stored in memory. Instead, it is + """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. Parameters @@ -1241,7 +1271,6 @@ class LazyTensor(object): """ def __init__(self, shape, getitem, **kwargs): - self._getitem = getitem self.shape = shape self.ndim = len(shape) @@ -1262,15 +1291,19 @@ def __getitem__(self, key): for i in range(self.ndim - len(key)): k.append(slice(None)) else: - raise NotImplementedError("Only integer, slice, and tuple indexing is supported") + raise NotImplementedError( + "Only integer, slice, and tuple indexing is supported" + ) return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + return "LazyTensor(shape={},attributes=({}))".format( + self.shape, ",".join(self.kwargs.keys()) + ) -def proj_SDP(S, nx=None, vmin=0.): +def proj_SDP(S, nx=None, vmin=0.0): """ Project a symmetric matrix onto the space of symmetric matrices with eigenvalues larger or equal to `vmin`. @@ -1305,6 +1338,6 @@ def proj_SDP(S, nx=None, vmin=0.): return P @ nx.diag(w) @ P.T else: # input was (n, d, d): broadcasting - Q = nx.einsum('ijk,ik->ijk', P, w) # Q[i] = P[i] @ diag(w[i]) + Q = nx.einsum("ijk,ik->ijk", P, w) # Q[i] = P[i] @ diag(w[i]) # R[i] = Q[i] @ P[i].T - return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1))) + return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1))) diff --git a/ot/weak.py b/ot/weak.py index 7364e68ab..aa504f7ac 100644 --- a/ot/weak.py +++ b/ot/weak.py @@ -10,10 +10,12 @@ from .optim import cg import numpy as np -__all__ = ['weak_optimal_transport'] +__all__ = ["weak_optimal_transport"] -def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): +def weak_optimal_transport( + Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs +): r"""Solves the weak optimal transport problem between two empirical distributions @@ -110,7 +112,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0= # weak OT loss def f(T): - return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None]) ** 2, 1)) # weak OT gradient def df(T): @@ -119,8 +121,10 @@ def df(T): # solve with conditional gradient and return solution if log: res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) - log['u'] = nx.from_numpy(log['u'], type_as=Xa) - log['v'] = nx.from_numpy(log['v'], type_as=Xb) + log["u"] = nx.from_numpy(log["u"], type_as=Xa) + log["v"] = nx.from_numpy(log["v"], type_as=Xb) return nx.from_numpy(res, type_as=Xa), log else: - return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) + return nx.from_numpy( + cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa + ) diff --git a/setup.py b/setup.py index c0f75ad0f..9417ee33e 100644 --- a/setup.py +++ b/setup.py @@ -21,19 +21,20 @@ # dirty but working __version__ = re.search( r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # It excludes inline comment too - open('ot/__init__.py').read()).group(1) + open("ot/__init__.py").read(), +).group(1) # The beautiful part is, I don't even need to check exceptions here. # If something messes up, let the build process fail noisy, BEFORE my release! # thanks PyPI for handling markdown now ROOT = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(ROOT, 'README.md'), encoding="utf-8") as f: +with open(os.path.join(ROOT, "README.md"), encoding="utf-8") as f: README = f.read() # clean cython output is clean is called -if 'clean' in sys.argv[1:]: - if os.path.isfile('ot/lp/emd_wrap.cpp'): - os.remove('ot/lp/emd_wrap.cpp') +if "clean" in sys.argv[1:]: + if os.path.isfile("ot/lp/emd_wrap.cpp"): + os.remove("ot/lp/emd_wrap.cpp") # add platform dependant optional compilation argument openmp_supported, flags = check_openmp_support() @@ -44,75 +45,93 @@ compile_args += flags link_args += flags -if sys.platform.startswith('darwin'): +if sys.platform.startswith("darwin"): compile_args.append("-stdlib=libc++") - sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path']) - os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) + sdk_path = subprocess.check_output(["xcrun", "--show-sdk-path"]) + os.environ["CFLAGS"] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) setup( - name='POT', + name="POT", version=__version__, - description='Python Optimal Transport Library', + description="Python Optimal Transport Library", long_description=README, - long_description_content_type='text/markdown', - author=u'Remi Flamary, Nicolas Courty, POT Contributors', - author_email='remi.flamary@gmail.com, ncourty@gmail.com', - url='https://github.com/PythonOT/POT', + long_description_content_type="text/markdown", + author="Remi Flamary, Nicolas Courty, POT Contributors", + author_email="remi.flamary@gmail.com, ncourty@gmail.com", + url="https://github.com/PythonOT/POT", packages=find_packages(exclude=["benchmarks"]), - ext_modules=cythonize(Extension( - name="ot.lp.emd_wrap", - sources=["ot/lp/emd_wrap.pyx", "ot/lp/EMD_wrapper.cpp"], # cython/c++ src files - language="c++", - include_dirs=[numpy.get_include(), os.path.join(ROOT, 'ot/lp')], - extra_compile_args=compile_args, - extra_link_args=link_args - )), - platforms=['linux', 'macosx', 'windows'], - download_url='https://github.com/PythonOT/POT/archive/{}.tar.gz'.format(__version__), - license='MIT', + ext_modules=cythonize( + Extension( + name="ot.lp.emd_wrap", + sources=[ + "ot/lp/emd_wrap.pyx", + "ot/lp/EMD_wrapper.cpp", + ], # cython/c++ src files + language="c++", + include_dirs=[numpy.get_include(), os.path.join(ROOT, "ot/lp")], + extra_compile_args=compile_args, + extra_link_args=link_args, + ) + ), + platforms=["linux", "macosx", "windows"], + download_url="https://github.com/PythonOT/POT/archive/{}.tar.gz".format( + __version__ + ), + license="MIT", scripts=[], data_files=[], install_requires=["numpy>=1.16", "scipy>=1.6"], extras_require={ - 'backend-numpy': [], # in requirements. - 'backend-jax': ['jax', 'jaxlib'], - 'backend-cupy': [], # should be installed with conda, not pip - 'backend-tf': ['tensorflow'], - 'backend-torch': ['torch'], - 'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations - 'dr': ['scikit-learn', 'pymanopt', 'autograd'], - 'gnn': ['torch', 'torch_geometric'], - 'plot': ['matplotlib'], - 'all': ['jax', 'jaxlib', 'tensorflow', 'torch', 'cvxopt', 'scikit-learn', 'pymanopt', 'autograd', 'torch_geometric', 'matplotlib'] + "backend-numpy": [], # in requirements. + "backend-jax": ["jax", "jaxlib"], + "backend-cupy": [], # should be installed with conda, not pip + "backend-tf": ["tensorflow"], + "backend-torch": ["torch"], + "cvxopt": ["cvxopt"], # on it's own to prevent accidental GPL violations + "dr": ["scikit-learn", "pymanopt", "autograd"], + "gnn": ["torch", "torch_geometric"], + "plot": ["matplotlib"], + "all": [ + "jax", + "jaxlib", + "tensorflow", + "torch", + "cvxopt", + "scikit-learn", + "pymanopt", + "autograd", + "torch_geometric", + "matplotlib", + ], }, python_requires=">=3.7", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Environment :: Console', - 'Operating System :: OS Independent', - 'Operating System :: POSIX :: Linux', - 'Operating System :: MacOS', - 'Operating System :: POSIX', - 'Operating System :: Microsoft :: Windows', - 'Programming Language :: Python', - 'Programming Language :: C++', - 'Programming Language :: C', - 'Programming Language :: Cython', - 'Topic :: Utilities', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Information Analysis', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - ] + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Environment :: Console", + "Operating System :: OS Independent", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: POSIX", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python", + "Programming Language :: C++", + "Programming Language :: C", + "Programming Language :: Cython", + "Topic :: Utilities", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Information Analysis", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], ) diff --git a/test/conftest.py b/test/conftest.py index 043c8ca70..732fb7e27 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,19 +7,27 @@ import functools import os import pytest +import numpy as np +from sys import platform + +# set numpy print options : TODO update tests when all release use modern numpy +if platform == "linux": + np.set_printoptions(legacy="1.25") from ot.backend import get_backend_list, jax, tf if jax: - os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" from jax import config + config.update("jax_enable_x64", True) if tf: # make sure TF doesn't allocate entire GPU import tensorflow as tf - physical_devices = tf.config.list_physical_devices('GPU') + + physical_devices = tf.config.list_physical_devices("GPU") for device in physical_devices: try: tf.config.experimental.set_memory_growth(device, True) @@ -28,6 +36,7 @@ # allow numpy API for TF from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() @@ -45,12 +54,12 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): if isinstance(arg, (tuple, list)): n = len(arg) else: - arg = (arg, ) + arg = (arg,) n = 1 if n != 1 and isinstance(value, (tuple, list)): pass else: - value = (value, ) + value = (value,) if isinstance(getter, (tuple, list)): pass else: @@ -60,7 +69,6 @@ def skip_arg(arg, value, reason=None, getter=lambda x: x): reason = f"Param {arg} should be skipped for value {value}" def wrapper(function): - @functools.wraps(function) def wrapped(*args, **kwargs): if all( diff --git a/test/gromov/test_bregman.py b/test/gromov/test_bregman.py index 71e55b1ce..9cb4a629f 100644 --- a/test/gromov/test_bregman.py +++ b/test/gromov/test_bregman.py @@ -1,4 +1,4 @@ -""" Tests for gromov._bregman.py """ +"""Tests for gromov._bregman.py""" # Author: Rémi Flamary # Titouan Vayer @@ -14,11 +14,14 @@ @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) +@pytest.mark.parametrize( + "loss_fun", + [ + "square_loss", + "kl_loss", + pytest.param("unknown_loss", marks=pytest.mark.xfail(raises=ValueError)), + ], +) def test_entropic_gromov(nx, loss_fun): n_samples = 10 # nb samples @@ -41,28 +44,50 @@ def test_entropic_gromov(nx, loss_fun): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-2, max_iter=10, verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, None, loss_fun, symmetric=True, G0=None, - epsilon=1e-2, max_iter=10, verbose=True, log=False - )) + C1, + C2, + None, + q, + loss_fun, + symmetric=None, + G0=G0, + epsilon=1e-2, + max_iter=10, + verbose=True, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_gromov_wasserstein( + C1b, + C2b, + pb, + None, + loss_fun, + symmetric=True, + G0=None, + epsilon=1e-2, + max_iter=10, + verbose=True, + log=False, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) +@pytest.mark.parametrize( + "loss_fun", + [ + "square_loss", + "kl_loss", + pytest.param("unknown_loss", marks=pytest.mark.xfail(raises=ValueError)), + ], +) def test_entropic_gromov2(nx, loss_fun): n_samples = 10 # nb samples @@ -85,25 +110,41 @@ def test_entropic_gromov2(nx, loss_fun): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, None, loss_fun, symmetric=True, G0=None, - max_iter=10, epsilon=1e-2, log=True) + C1, + C2, + p, + None, + loss_fun, + symmetric=True, + G0=None, + max_iter=10, + epsilon=1e-2, + log=True, + ) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, None, qb, loss_fun, symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-2, log=True) + C1b, + C2b, + None, + qb, + loss_fun, + symmetric=None, + G0=G0b, + max_iter=10, + epsilon=1e-2, + log=True, + ) gwb = nx.to_numpy(gwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -129,53 +170,108 @@ def test_entropic_proximal_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + C1, + C2, + None, + q, + loss_fun, + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + solver="PPA", + verbose=True, + log=True, + numItermax=1, + ) G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 - )) + C1, + C2, + None, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + solver="PPA", + verbose=True, + log=True, + numItermax=1, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_gromov_wasserstein( + C1b, + C2b, + pb, + None, + "square_loss", + symmetric=True, + G0=None, + epsilon=1e-1, + max_iter=10, + solver="PPA", + verbose=True, + log=False, + numItermax=1, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, - max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + C1, + C2, + p, + q, + "kl_loss", + symmetric=True, + G0=None, + max_iter=10, + epsilon=1e-1, + solver="PPA", + warmstart=True, + log=True, + ) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + C1b, + C2b, + pb, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=10, + epsilon=1e-1, + solver="PPA", + warmstart=True, + log=True, + ) gwb = nx.to_numpy(gwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_asymmetric_entropic_gromov(nx): n_samples = 10 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) rng.shuffle(idx) C2 = C1[idx, :][:, idx] @@ -186,25 +282,62 @@ def test_asymmetric_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=5, verbose=True, log=False) - Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, max_iter=5, verbose=True, log=False - )) + C1, + C2, + p, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=5, + verbose=True, + log=False, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_gromov_wasserstein( + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=False, + G0=None, + epsilon=1e-1, + max_iter=5, + verbose=True, + log=False, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, None, None, 'kl_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, log=False) + C1, + C2, + None, + None, + "kl_loss", + symmetric=False, + G0=None, + max_iter=5, + epsilon=1e-1, + log=False, + ) gwb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, log=False) + C1b, + C2b, + pb, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=5, + epsilon=1e-1, + log=False, + ) gwb = nx.to_numpy(gwb) np.testing.assert_allclose(gw, gwb, atol=1e-06) @@ -238,17 +371,21 @@ def test_entropic_gromov_dtype_device(nx): C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) - for solver in ['PGD', 'PPA', 'BAPG']: - if solver == 'BAPG': + for solver in ["PGD", "PPA", "BAPG"]: + if solver == "BAPG": Gb = ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, qb, max_iter=2, verbose=True) + C1b, C2b, pb, qb, max_iter=2, verbose=True + ) gw_valb = ot.gromov.BAPG_gromov_wasserstein2( - C1b, C2b, pb, qb, max_iter=2, verbose=True) + C1b, C2b, pb, qb, max_iter=2, verbose=True + ) else: Gb = ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) + C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True + ) gw_valb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True) + C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True + ) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -278,66 +415,143 @@ def test_BAPG_gromov(nx): # complete test with marginal loss = True marginal_loss = True with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) + C1, + C2, + None, + q, + loss_fun, + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + log=True, + ) G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, - log=False - )) + C1, + C2, + None, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.BAPG_gromov_wasserstein( + C1b, + C2b, + pb, + None, + "square_loss", + symmetric=True, + G0=None, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + log=False, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov with pytest.warns(UserWarning): - gw = ot.gromov.BAPG_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-2, marginal_loss=marginal_loss, log=False) + C1, + C2, + p, + q, + "kl_loss", + symmetric=False, + G0=None, + max_iter=10, + epsilon=1e-2, + marginal_loss=marginal_loss, + log=False, + ) gw, log = ot.gromov.BAPG_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) + C1, + C2, + p, + q, + "kl_loss", + symmetric=False, + G0=None, + max_iter=10, + epsilon=1.0, + marginal_loss=marginal_loss, + log=True, + ) gwb, logb = ot.gromov.BAPG_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True) + C1b, + C2b, + pb, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=10, + epsilon=1.0, + marginal_loss=marginal_loss, + log=True, + ) gwb = nx.to_numpy(gwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov marginal_loss = False G, log = ot.gromov.BAPG_gromov_wasserstein( - C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, - verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True, - log=False - )) + C1, + C2, + None, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.BAPG_gromov_wasserstein( + C1b, + C2b, + pb, + None, + "square_loss", + symmetric=False, + G0=None, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + log=False, + ) + ) @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -370,46 +584,96 @@ def test_entropic_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, None, None, loss_fun, symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, verbose=True, log=True) + M, + C1, + C2, + None, + None, + loss_fun, + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + verbose=True, + log=True, + ) G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, verbose=True, log=True) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, verbose=True, log=False - )) + M, + C1, + C2, + None, + None, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + verbose=True, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=True, + G0=None, + epsilon=1e-1, + max_iter=10, + verbose=True, + log=False, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, - max_iter=10, epsilon=1e-1, log=True) + M, + C1, + C2, + p, + q, + "kl_loss", + symmetric=True, + G0=None, + max_iter=10, + epsilon=1e-1, + log=True, + ) fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, log=True) + Mb, + C1b, + C2b, + pb, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=10, + epsilon=1e-1, + log=True, + ) fgwb = nx.to_numpy(fgwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(fgw, fgwb, atol=1e-06) np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -442,40 +706,87 @@ def test_entropic_proximal_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G, log = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 - )) + M, + C1, + C2, + p, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + solver="PPA", + verbose=True, + log=True, + numItermax=1, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=True, + G0=None, + epsilon=1e-1, + max_iter=10, + solver="PPA", + verbose=True, + log=False, + numItermax=1, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, - max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + M, + C1, + C2, + p, + None, + "kl_loss", + symmetric=True, + G0=None, + max_iter=5, + epsilon=1e-1, + solver="PPA", + warmstart=True, + log=True, + ) fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + Mb, + C1b, + C2b, + None, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=5, + epsilon=1e-1, + solver="PPA", + warmstart=True, + log=True, + ) fgwb = nx.to_numpy(fgwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(fgw, fgwb, atol=1e-06) np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov def test_BAPG_fgw(nx): @@ -507,74 +818,149 @@ def test_BAPG_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, loss_fun=loss_fun, max_iter=1, log=True) + M, C1, C2, p, q, loss_fun=loss_fun, max_iter=1, log=True + ) # complete test with marginal loss = True marginal_loss = True G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) + M, + C1, + C2, + p, + q, + "square_loss", + symmetric=None, + G0=G0, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.BAPG_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=True, + G0=None, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov with pytest.warns(UserWarning): - fgw = ot.gromov.BAPG_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-3, marginal_loss=marginal_loss, log=False) + M, + C1, + C2, + p, + q, + "kl_loss", + symmetric=False, + G0=None, + max_iter=10, + epsilon=1e-3, + marginal_loss=marginal_loss, + log=False, + ) fgw, log = ot.gromov.BAPG_fused_gromov_wasserstein2( - M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, - max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) + M, + C1, + C2, + p, + None, + "kl_loss", + symmetric=True, + G0=None, + max_iter=5, + epsilon=1, + marginal_loss=marginal_loss, + log=True, + ) fgwb, logb = ot.gromov.BAPG_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True) + Mb, + C1b, + C2b, + None, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=5, + epsilon=1, + marginal_loss=marginal_loss, + log=True, + ) fgwb = nx.to_numpy(fgwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(fgw, fgwb, atol=1e-06) np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov # Tests with marginal_loss = False marginal_loss = False G, log = ot.gromov.BAPG_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=False, G0=G0, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True) - Gb = nx.to_numpy(ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=None, G0=None, - epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True)) + M, + C1, + C2, + p, + q, + "square_loss", + symmetric=False, + G0=G0, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + log=True, + ) + Gb = nx.to_numpy( + ot.gromov.BAPG_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=None, + G0=None, + epsilon=1e-1, + max_iter=10, + marginal_loss=marginal_loss, + verbose=True, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-02) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-02) # cf convergence gromov def test_asymmetric_entropic_fgw(nx): n_samples = 5 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) rng.shuffle(idx) C2 = C1[idx, :][:, idx] @@ -589,25 +975,66 @@ def test_asymmetric_entropic_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G = ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - max_iter=5, epsilon=1e-1, verbose=True, log=False) - Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, verbose=True, log=False - )) + M, + C1, + C2, + p, + q, + "square_loss", + symmetric=None, + G0=G0, + max_iter=5, + epsilon=1e-1, + verbose=True, + log=False, + ) + Gb = nx.to_numpy( + ot.gromov.entropic_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + symmetric=False, + G0=None, + max_iter=5, + epsilon=1e-1, + verbose=True, + log=False, + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov fgw = ot.gromov.entropic_fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=5, epsilon=1e-1, log=False) + M, + C1, + C2, + p, + q, + "kl_loss", + symmetric=False, + G0=None, + max_iter=5, + epsilon=1e-1, + log=False, + ) fgwb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=5, epsilon=1e-1, log=False) + Mb, + C1b, + C2b, + pb, + qb, + "kl_loss", + symmetric=None, + G0=G0b, + max_iter=5, + epsilon=1e-1, + log=False, + ) fgwb = nx.to_numpy(fgwb) np.testing.assert_allclose(fgw, fgwb, atol=1e-06) @@ -646,18 +1073,22 @@ def test_entropic_fgw_dtype_device(nx): Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp) - for solver in ['PGD', 'PPA', 'BAPG']: - if solver == 'BAPG': + for solver in ["PGD", "PPA", "BAPG"]: + if solver == "BAPG": Gb = ot.gromov.BAPG_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, max_iter=2) + Mb, C1b, C2b, pb, qb, max_iter=2 + ) fgw_valb = ot.gromov.BAPG_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, max_iter=2) + Mb, C1b, C2b, pb, qb, max_iter=2 + ) else: Gb = ot.gromov.entropic_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) + Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver + ) fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver) + Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver + ) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, fgw_valb) @@ -668,8 +1099,8 @@ def test_entropic_fgw_barycenter(nx): nt = 10 rng = np.random.RandomState(42) - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) ys = rng.randn(Xs.shape[0], 2) yt = rng.randn(Xt.shape[0], 2) @@ -684,33 +1115,90 @@ def test_entropic_fgw_barycenter(nx): ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], loss_fun, 0.1, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42, - solver='PPA', numItermax=10, log=True, symmetric=True, + n_samples, + [ys, yt], + [C1, C2], + None, + p, + [0.5, 0.5], + loss_fun, + 0.1, + max_iter=10, + tol=1e-3, + verbose=True, + warmstartT=True, + random_state=42, + solver="PPA", + numItermax=10, + log=True, + symmetric=True, ) with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', - 0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42, - solver='PPA', numItermax=10, log=True, symmetric=True, + n_samples, + [ys, yt], + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + 0.1, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=True, + warmstartT=True, + random_state=42, + solver="PPA", + numItermax=10, + log=True, + symmetric=True, ) - for stop_criterion in ['barycenter', 'loss']: + for stop_criterion in ["barycenter", "loss"]: X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', - epsilon=0.1, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42, solver='PPA', - numItermax=10, log=True, symmetric=True + n_samples, + [ys, yt], + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + epsilon=0.1, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=True, + warmstartT=True, + random_state=42, + solver="PPA", + numItermax=10, + log=True, + symmetric=True, ) Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], - 'square_loss', epsilon=0.1, max_iter=10, tol=1e-3, - stop_criterion=stop_criterion, verbose=False, warmstartT=True, - random_state=42, solver='PPA', numItermax=10, log=False, symmetric=True) + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + None, + [0.5, 0.5], + "square_loss", + epsilon=0.1, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + warmstartT=True, + random_state=42, + solver="PPA", + numItermax=10, + log=False, + symmetric=True, + ) Xb, Cb = nx.to_numpy(Xb, Cb) np.testing.assert_allclose(C, Cb, atol=1e-06) @@ -730,23 +1218,59 @@ def test_entropic_fgw_barycenter(nx): init_Yb = nx.from_numpy(init_Y) X, C, log = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1, True, - max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, - solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + p, + None, + "kl_loss", + 0.1, + True, + max_iter=10, + tol=1e-3, + verbose=False, + warmstartT=False, + random_state=42, + solver="PPA", + numItermax=1, + init_C=init_C, + init_Y=init_Y, + log=True, ) Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', - 0.1, True, max_iter=10, tol=1e-3, verbose=False, warmstartT=False, - random_state=42, solver='PPA', numItermax=1, init_C=init_Cb, - init_Y=init_Yb, log=True) + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "kl_loss", + 0.1, + True, + max_iter=10, + tol=1e-3, + verbose=False, + warmstartT=False, + random_state=42, + solver="PPA", + numItermax=1, + init_C=init_Cb, + init_Y=init_Yb, + log=True, + ) Xb, Cb = nx.to_numpy(Xb, Cb) np.testing.assert_allclose(C, Cb, atol=1e-06) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(X, Xb, atol=1e-06) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) - np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature'])) - np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure'])) + np.testing.assert_array_almost_equal( + log["err_feature"], nx.to_numpy(*logb["err_feature"]) + ) + np.testing.assert_array_almost_equal( + log["err_structure"], nx.to_numpy(*logb["err_structure"]) + ) # add tests with fixed_structures or fixed_features init_C = ot.utils.dist(xalea, xalea) @@ -757,35 +1281,76 @@ def test_entropic_fgw_barycenter(nx): init_Yb = nx.from_numpy(init_Y) fixed_structure, fixed_features = True, False - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_structure=True`and `init_C=None` Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - fixed_structure=fixed_structure, init_C=None, - fixed_features=fixed_features, p=None, max_iter=10, tol=1e-3 + n_samples, + [ysb, ytb], + [C1b, C2b], + ps=[p1b, p2b], + lambdas=None, + fixed_structure=fixed_structure, + init_C=None, + fixed_features=fixed_features, + p=None, + max_iter=10, + tol=1e-3, ) Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - fixed_structure=fixed_structure, init_C=init_Cb, - fixed_features=fixed_features, max_iter=10, tol=1e-3 + n_samples, + [ysb, ytb], + [C1b, C2b], + ps=[p1b, p2b], + lambdas=None, + fixed_structure=fixed_structure, + init_C=init_Cb, + fixed_features=fixed_features, + max_iter=10, + tol=1e-3, ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb, init_Cb) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) fixed_structure, fixed_features = False, True - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_features=True`and `init_X=None` Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], - fixed_structure=fixed_structure, fixed_features=fixed_features, - init_Y=None, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + lambdas=[0.5, 0.5], + fixed_structure=fixed_structure, + fixed_features=fixed_features, + init_Y=None, + p=pb, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], lambdas=[.5, .5], - fixed_structure=fixed_structure, fixed_features=fixed_features, - init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + lambdas=[0.5, 0.5], + fixed_structure=fixed_structure, + fixed_features=fixed_features, + init_Y=init_Yb, + p=pb, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) @@ -797,10 +1362,21 @@ def test_entropic_fgw_barycenter(nx): with pytest.raises(ValueError): C1_list = [list(c) for c in C1b] _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb], [C1_list], [p1b], lambdas=None, - fixed_structure=False, fixed_features=False, - init_Y=None, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb], + [C1_list], + [p1b], + lambdas=None, + fixed_structure=False, + fixed_features=False, + init_Y=None, + p=pb, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) # p1, p2 as lists @@ -808,25 +1384,58 @@ def test_entropic_fgw_barycenter(nx): p1_list = list(p1b) p2_list = list(p2b) _, _, _ = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], lambdas=[0.5, 0.5], - fixed_structure=False, fixed_features=False, - init_Y=None, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1_list, p2_list], + lambdas=[0.5, 0.5], + fixed_structure=False, + fixed_features=False, + init_Y=None, + p=pb, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) # unique input structure X, C = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ys], [C1], [p1], lambdas=None, - fixed_structure=False, fixed_features=False, - init_Y=init_Y, p=p, max_iter=10, tol=1e-3, - warmstartT=True, log=False, random_state=98765, verbose=True + n_samples, + [ys], + [C1], + [p1], + lambdas=None, + fixed_structure=False, + fixed_features=False, + init_Y=init_Y, + p=p, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=False, + random_state=98765, + verbose=True, ) Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( - n_samples, [ysb], [C1b], [p1b], lambdas=None, - fixed_structure=False, fixed_features=False, - init_Y=init_Yb, p=pb, max_iter=10, tol=1e-3, - warmstartT=True, log=False, random_state=98765, verbose=True + n_samples, + [ysb], + [C1b], + [p1b], + lambdas=None, + fixed_structure=False, + fixed_features=False, + init_Y=init_Yb, + p=pb, + max_iter=10, + tol=1e-3, + warmstartT=True, + log=False, + random_state=98765, + verbose=True, ) np.testing.assert_allclose(C, Cb, atol=1e-06) @@ -838,8 +1447,8 @@ def test_gromov_entropic_barycenter(nx): ns = 5 nt = 10 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -851,55 +1460,135 @@ def test_gromov_entropic_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) with pytest.raises(ValueError): - loss_fun = 'weird_loss_fun' + loss_fun = "weird_loss_fun" Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], loss_fun, 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 + n_samples, + [C1, C2], + None, + p, + [0.5, 0.5], + loss_fun, + 1e-3, + max_iter=10, + tol=1e-3, + verbose=True, + warmstartT=True, + random_state=42, ) with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, warmstartT=True, random_state=42 + n_samples, + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=True, + warmstartT=True, + random_state=42, ) Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 + n_samples, + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + verbose=True, + warmstartT=True, + random_state=42, + ) + Cbb = nx.to_numpy( + ot.gromov.entropic_gromov_barycenters( + n_samples, + [C1b, C2b], + [p1b, p2b], + None, + [0.5, 0.5], + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + verbose=True, + warmstartT=True, + random_state=42, + ) ) - Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 - )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) # test of entropic_gromov_barycenters with `log` on - for stop_criterion in ['barycenter', 'loss']: + for stop_criterion in ["barycenter", "loss"]: Cb_, err_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', 1e-3, - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, verbose=True, - random_state=42, log=True + n_samples, + [C1, C2], + [p1, p2], + p, + None, + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=True, + random_state=42, + log=True, ) Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', - 1e-3, max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=True, random_state=42, log=True + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=True, + random_state=42, + log=True, ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_array_almost_equal(err_["err"], nx.to_numpy(*errb_["err"])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 + n_samples, + [C1, C2], + [p1, p2], + p, + [0.5, 0.5], + "kl_loss", + 1e-3, + max_iter=10, + tol=1e-3, + random_state=42, + ) + Cb2b = nx.to_numpy( + ot.gromov.entropic_gromov_barycenters( + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "kl_loss", + 1e-3, + max_iter=10, + tol=1e-3, + random_state=42, + ) ) - Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 - )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @@ -912,18 +1601,40 @@ def test_gromov_entropic_barycenter(nx): init_Cb = nx.from_numpy(init_C) Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, - max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, - init_C=init_C, log=True + n_samples, + [C1, C2], + [p1, p2], + p, + [0.5, 0.5], + "kl_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_C=init_C, + log=True, ) Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True, - random_state=42, init_Cb=init_Cb, log=True + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "kl_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_Cb=init_Cb, + log=True, ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_array_almost_equal(err2_["err"], nx.to_numpy(*err2b_["err"])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) # test edge cases for gw barycenters: @@ -931,9 +1642,20 @@ def test_gromov_entropic_barycenter(nx): with pytest.raises(ValueError): C1_list = [list(c) for c in C1b] _, _ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1_list], [p1b], pb, None, 'square_loss', 1e-3, - max_iter=10, tol=1e-3, warmstartT=True, verbose=True, - random_state=42, init_C=None, log=True + n_samples, + [C1_list], + [p1b], + pb, + None, + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_C=None, + log=True, ) # p1, p2 as lists @@ -941,21 +1663,55 @@ def test_gromov_entropic_barycenter(nx): p1_list = list(p1b) p2_list = list(p2b) _, _ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1_list, p2_list], pb, None, - 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, - verbose=True, random_state=42, init_Cb=None, log=True + n_samples, + [C1b, C2b], + [p1_list, p2_list], + pb, + None, + "kl_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_Cb=None, + log=True, ) # unique input structure Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1], [p1], p, None, 'square_loss', 1e-3, - max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, - init_C=None, log=False) + n_samples, + [C1], + [p1], + p, + None, + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_C=None, + log=False, + ) Cbb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b], [p1b], pb, [1.], 'square_loss', 1e-3, - max_iter=10, tol=1e-3, warmstartT=True, verbose=True, - random_state=42, init_Cb=None, log=False + n_samples, + [C1b], + [p1b], + pb, + [1.0], + "square_loss", + 1e-3, + max_iter=10, + tol=1e-3, + warmstartT=True, + verbose=True, + random_state=42, + init_Cb=None, + log=False, ) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) @@ -984,11 +1740,13 @@ def test_not_implemented_solver(): C2 /= C2.max() M = ot.dist(ys, yt) - solver = 'not_implemented' + solver = "not_implemented" # entropic gw and fgw with pytest.raises(ValueError): ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + C1, C2, p, q, "square_loss", epsilon=1e-1, solver=solver + ) with pytest.raises(ValueError): ot.gromov.entropic_fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + M, C1, C2, p, q, "square_loss", epsilon=1e-1, solver=solver + ) diff --git a/test/gromov/test_dictionary.py b/test/gromov/test_dictionary.py index 5b73b0d07..305810dbf 100644 --- a/test/gromov/test_dictionary.py +++ b/test/gromov/test_dictionary.py @@ -1,4 +1,4 @@ -""" Tests for gromov._dictionary.py """ +"""Tests for gromov._dictionary.py""" # Author: Cédric Vincent-Cuaz # @@ -12,8 +12,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): n = 4 - X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + X1, y1 = ot.datasets.make_data_classif("3gauss", n, random_state=42) + X2, y2 = ot.datasets.make_data_classif("3gauss2", n, random_state=42) C1 = ot.dist(X1) C2 = ot.dist(X2) @@ -22,87 +22,168 @@ def test_gromov_wasserstein_linear_unmixing(nx): C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p) - tol = 10**(-5) + tol = 10 ** (-5) # Tests without regularization - reg = 0. - unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( - C1, Cdict, reg=reg, p=p, q=p, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + reg = 0.0 + unmixing1, C1_emb, OT, reconstruction1 = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C1, + Cdict, + reg=reg, + p=p, + q=p, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( - C1b, Cdictb, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing1b, C1b_emb, OTb, reconstruction1b = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, + Cdictb, + reg=reg, + p=None, + q=None, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( - C2, Cdict, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing2, C2_emb, OT, reconstruction2 = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C2, + Cdict, + reg=reg, + p=None, + q=None, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( - C2b, Cdictb, reg=reg, p=pb, q=pb, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing2b, C2b_emb, OTb, reconstruction2b = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, + Cdictb, + reg=reg, + p=pb, + q=pb, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01) + np.testing.assert_allclose(unmixing1, [1.0, 0.0], atol=5e-01) np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01) + np.testing.assert_allclose(unmixing2, [0.0, 1.0], atol=5e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose( + reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06 + ) + np.testing.assert_allclose( + reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06 + ) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) # Tests with regularization reg = 0.001 - unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( - C1, Cdict, reg=reg, p=p, q=p, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing1, C1_emb, OT, reconstruction1 = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C1, + Cdict, + reg=reg, + p=p, + q=p, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( - C1b, Cdictb, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing1b, C1b_emb, OTb, reconstruction1b = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, + Cdictb, + reg=reg, + p=None, + q=None, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( - C2, Cdict, reg=reg, p=None, q=None, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing2, C2_emb, OT, reconstruction2 = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C2, + Cdict, + reg=reg, + p=None, + q=None, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) - unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( - C2b, Cdictb, reg=reg, p=pb, q=pb, - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + unmixing2b, C2b_emb, OTb, reconstruction2b = ( + ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, + Cdictb, + reg=reg, + p=pb, + q=pb, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=20, + max_iter_inner=200, + ) ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing1, [1.0, 0.0], atol=1e-01) np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing2, [0.0, 1.0], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose( + reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06 + ) + np.testing.assert_allclose( + reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06 + ) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) def test_gromov_wasserstein_dictionary_learning(nx): - # create dataset composed from 2 structures which are repeated 5 times shape = 4 n_samples = 2 n_atoms = 2 - projection = 'nonnegative_symmetric' - X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + projection = "nonnegative_symmetric" + X1, y1 = ot.datasets.make_data_classif("3gauss", shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif("3gauss2", shape, random_state=42) C1 = ot.dist(X1) C2 = ot.dist(X2) - Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Cs = [C1.copy() for _ in range(n_samples // 2)] + [ + C2.copy() for _ in range(n_samples // 2) + ] ps = [ot.unif(shape) for _ in range(n_samples)] q = ot.unif(shape) @@ -110,11 +191,15 @@ def test_gromov_wasserstein_dictionary_learning(nx): # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. dataset_means = [C.mean() for C in Cs] rng = np.random.RandomState(0) - Cdict_init = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + Cdict_init = rng.normal( + loc=np.mean(dataset_means), + scale=np.std(dataset_means), + size=(n_atoms, shape, shape), + ) - if projection == 'nonnegative_symmetric': + if projection == "nonnegative_symmetric": Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) - Cdict_init[Cdict_init < 0.] = 0. + Cdict_init[Cdict_init < 0.0] = 0.0 Csb = nx.from_numpy(*Cs) psb = nx.from_numpy(*ps) @@ -124,30 +209,58 @@ def test_gromov_wasserstein_dictionary_learning(nx): # > Compute initial reconstruction of samples on this random dictionary without backend use_adam_optimizer = True verbose = False - tol = 10**(-5) + tol = 10 ** (-5) epochs = 1 initial_total_reconstruction = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_init, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Cdict_init, + p=ps[i], + q=q, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) initial_total_reconstruction += reconstruction # > Learn the dictionary using this init Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, - epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + Cs, + D=n_atoms, + nt=shape, + ps=ps, + q=q, + Cdict_init=Cdict_init, + epochs=epochs, + batch_size=2 * n_samples, + learning_rate=1.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, ) # > Compute reconstruction of samples on learned dictionary without backend total_reconstruction = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Cdict, + p=None, + q=None, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction += reconstruction @@ -156,17 +269,38 @@ def test_gromov_wasserstein_dictionary_learning(nx): # Test: Perform same experiments after going through backend Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + Csb, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=Cdict_initb, + epochs=epochs, + batch_size=n_samples, + learning_rate=1.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, ) # Compute reconstruction of samples on learned dictionary total_reconstruction_b = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Cdictb, + p=psb[i], + q=qb, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b += reconstruction @@ -179,42 +313,88 @@ def test_gromov_wasserstein_dictionary_learning(nx): # Test: Perform same comparison without providing the initial dictionary being an optional input # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 + Cs, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=None, + epochs=epochs, + batch_size=n_samples, + learning_rate=1.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_bis = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Cdict_bis, + p=ps[i], + q=q, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_bis += reconstruction - np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_bis, total_reconstruction, atol=1e-05 + ) # Test: Same after going through backend Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0 + Csb, + D=n_atoms, + nt=shape, + ps=psb, + q=qb, + Cdict_init=None, + epochs=epochs, + batch_size=n_samples, + learning_rate=1.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_b_bis = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb_bis, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Cdictb_bis, + p=None, + q=None, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b_bis += reconstruction total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) - np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05 + ) np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) # Test: Perform same comparison without providing the initial dictionary being an optional input @@ -225,18 +405,39 @@ def test_gromov_wasserstein_dictionary_learning(nx): use_log = True Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, - epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, + Cs, + D=n_atoms, + nt=shape, + ps=ps, + q=q, + Cdict_init=Cdict, + epochs=epochs, + batch_size=n_samples, + learning_rate=10.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=use_log, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_bis2 = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Cdict_bis2, + p=ps[i], + q=q, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_bis2 += reconstruction @@ -244,31 +445,53 @@ def test_gromov_wasserstein_dictionary_learning(nx): # Test: Same after going through backend Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( - Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, - epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, + Csb, + D=n_atoms, + nt=shape, + ps=psb, + q=qb, + Cdict_init=Cdictb, + epochs=epochs, + batch_size=n_samples, + learning_rate=10.0, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=use_log, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_b_bis2 = 0 for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( - Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Cdictb_bis2, + p=psb[i], + q=qb, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b_bis2 += reconstruction total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) - np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05 + ) def test_fused_gromov_wasserstein_linear_unmixing(nx): - n = 4 - X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) - F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X1, y1 = ot.datasets.make_data_classif("3gauss", n, random_state=42) + X2, y2 = ot.datasets.make_data_classif("3gauss2", n, random_state=42) + F, y = ot.datasets.make_data_classif("3gauss", n, random_state=42) C1 = ot.dist(X1) C2 = ot.dist(X2) @@ -279,92 +502,197 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p) # Tests without regularization - reg = 0. - - unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + reg = 0.0 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, + F, + Cdict, + Ydict, + p=p, + q=p, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, + Fb, + Cdictb, + Ydictb, + p=None, + q=None, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, + F, + Cdict, + Ydict, + p=None, + q=None, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, + Fb, + Cdictb, + Ydictb, + p=pb, + q=pb, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01) + np.testing.assert_allclose(unmixing1, [1.0, 0.0], atol=4e-01) np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01) + np.testing.assert_allclose(unmixing2, [0.0, 1.0], atol=4e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose( + reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06 + ) + np.testing.assert_allclose( + reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06 + ) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) # Tests with regularization reg = 0.001 - unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, + F, + Cdict, + Ydict, + p=p, + q=p, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, + Fb, + Cdictb, + Ydictb, + p=None, + q=None, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, + F, + Cdict, + Ydict, + p=None, + q=None, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) - unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ( + ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, + Fb, + Cdictb, + Ydictb, + p=pb, + q=pb, + alpha=0.5, + reg=reg, + tol_outer=10 ** (-6), + tol_inner=10 ** (-6), + max_iter_outer=10, + max_iter_inner=50, + ) ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing1, [1.0, 0.0], atol=1e-01) np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing2, [0.0, 1.0], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) - np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) + np.testing.assert_allclose( + reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06 + ) + np.testing.assert_allclose( + reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06 + ) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) def test_fused_gromov_wasserstein_dictionary_learning(nx): - # create dataset composed from 2 structures which are repeated 5 times shape = 4 n_samples = 2 n_atoms = 2 - projection = 'nonnegative_symmetric' - X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) - X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) - F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + projection = "nonnegative_symmetric" + X1, y1 = ot.datasets.make_data_classif("3gauss", shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif("3gauss2", shape, random_state=42) + F, y = ot.datasets.make_data_classif("3gauss", shape, random_state=42) C1 = ot.dist(X1) C2 = ot.dist(X2) - Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Cs = [C1.copy() for _ in range(n_samples // 2)] + [ + C2.copy() for _ in range(n_samples // 2) + ] Ys = [F.copy() for _ in range(n_samples)] ps = [ot.unif(shape) for _ in range(n_samples)] q = ot.unif(shape) @@ -373,12 +701,20 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. dataset_structure_means = [C.mean() for C in Cs] rng = np.random.RandomState(0) - Cdict_init = rng.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) - if projection == 'nonnegative_symmetric': + Cdict_init = rng.normal( + loc=np.mean(dataset_structure_means), + scale=np.std(dataset_structure_means), + size=(n_atoms, shape, shape), + ) + if projection == "nonnegative_symmetric": Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) - Cdict_init[Cdict_init < 0.] = 0. + Cdict_init[Cdict_init < 0.0] = 0.0 dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) - Ydict_init = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + Ydict_init = rng.normal( + loc=dataset_feature_means.mean(axis=0), + scale=dataset_feature_means.std(axis=0), + size=(n_atoms, shape, 2), + ) Csb = nx.from_numpy(*Cs) Ysb = nx.from_numpy(*Ys) @@ -395,25 +731,63 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): initial_total_reconstruction = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, - alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Ys[i], + Cdict_init, + Ydict_init, + p=ps[i], + q=q, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) initial_total_reconstruction += reconstruction # > Learn a dictionary using this given initialization and check that the reconstruction loss # on the learned dictionary is lower than the one using its initialization. Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + Cs, + Ys, + D=n_atoms, + nt=shape, + ps=ps, + q=q, + Cdict_init=Cdict_init, + Ydict_init=Ydict_init, + epochs=epochs, + batch_size=n_samples, + learning_rate_C=1.0, + learning_rate_Y=1.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Ys[i], + Cdict, + Ydict, + p=None, + q=None, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction += reconstruction # Compare both @@ -421,18 +795,46 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # Test: Perform same experiments after going through backend Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, - epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 + Csb, + Ysb, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=Cdict_initb, + Ydict_init=Ydict_initb, + epochs=epochs, + batch_size=2 * n_samples, + learning_rate_C=1.0, + learning_rate_Y=1.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_b = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Ysb[i], + Cdictb, + Ydictb, + p=psb[i], + q=qb, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b += reconstruction @@ -444,43 +846,105 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): # Test: Perform similar experiment without providing the initial dictionary being an optional input Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0 + Cs, + Ys, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=None, + Ydict_init=None, + epochs=epochs, + batch_size=n_samples, + learning_rate_C=1.0, + learning_rate_Y=1.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_bis = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Ys[i], + Cdict_bis, + Ydict_bis, + p=ps[i], + q=q, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_bis += reconstruction - np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_bis, total_reconstruction, atol=1e-05 + ) # > Same after going through backend - Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, - epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0, + Cdictb_bis, Ydictb_bis, log = ( + ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, + Ysb, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=None, + Ydict_init=None, + epochs=epochs, + batch_size=n_samples, + learning_rate_C=1.0, + learning_rate_Y=1.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=False, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, + ) ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_b_bis = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Ysb[i], + Cdictb_bis, + Ydictb_bis, + p=psb[i], + q=qb, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b_bis += reconstruction total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) - np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05 + ) # Test: without using adam optimizer, with log and verbose set to True use_adam_optimizer = False @@ -488,42 +952,104 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): use_log = True # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. - Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, - epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, - verbose=verbose, random_state=0, + Cdict_bis2, Ydict_bis2, log = ( + ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, + Ys, + D=n_atoms, + nt=shape, + ps=ps, + q=q, + Cdict_init=Cdict, + Ydict_init=Ydict, + epochs=epochs, + batch_size=n_samples, + learning_rate_C=10.0, + learning_rate_Y=10.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=use_log, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, + ) ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_bis2 = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Cs[i], + Ys[i], + Cdict_bis2, + Ydict_bis2, + p=ps[i], + q=q, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_bis2 += reconstruction np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) # > Same after going through backend - Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( - Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, - epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, - projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose, - random_state=0, + Cdictb_bis2, Ydictb_bis2, log = ( + ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, + Ysb, + D=n_atoms, + nt=shape, + ps=None, + q=None, + Cdict_init=Cdictb, + Ydict_init=Ydictb, + epochs=epochs, + batch_size=n_samples, + learning_rate_C=10.0, + learning_rate_Y=10.0, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, + projection=projection, + use_log=use_log, + use_adam_optimizer=use_adam_optimizer, + verbose=verbose, + random_state=0, + ) ) # > Compute reconstruction of samples on learned dictionary total_reconstruction_b_bis2 = 0 for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( - Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 + Csb[i], + Ysb[i], + Cdictb_bis2, + Ydictb_bis2, + p=None, + q=None, + alpha=alpha, + reg=0.0, + tol_outer=tol, + tol_inner=tol, + max_iter_outer=10, + max_iter_inner=50, ) total_reconstruction_b_bis2 += reconstruction # > Compare results with/without backend total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) - np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) + np.testing.assert_allclose( + total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05 + ) diff --git a/test/gromov/test_estimators.py b/test/gromov/test_estimators.py index ead427204..2bf8afe9f 100644 --- a/test/gromov/test_estimators.py +++ b/test/gromov/test_estimators.py @@ -1,4 +1,4 @@ -""" Tests for gromov._estimators.py """ +"""Tests for gromov._estimators.py""" # Author: Rémi Flamary # Tanguy Kerdoncuff @@ -41,27 +41,47 @@ def lossb(x, y): return nx.abs(x - y) G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42 + ) G = NumpyBackend().todense(G) Gb, logb = ot.gromov.pointwise_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42 + ) Gb = nx.to_numpy(nx.todense(Gb)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08) + np.testing.assert_allclose(float(logb["gw_dist_estimated"]), 0.0, atol=1e-08) + np.testing.assert_allclose(float(logb["gw_dist_std"]), 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + C1, + C2, + p, + q, + loss, + max_iter=100, + alpha=0.1, + log=True, + verbose=True, + random_state=42, + ) G = NumpyBackend().todense(G) Gb, logb = ot.gromov.pointwise_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + C1b, + C2b, + pb, + qb, + lossb, + max_iter=100, + alpha=0.1, + log=True, + verbose=True, + random_state=42, + ) Gb = nx.to_numpy(nx.todense(Gb)) np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -97,14 +117,34 @@ def lossb(x, y): return nx.abs(x - y) G, log = ot.gromov.sampled_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) + C1, + C2, + p, + q, + loss, + max_iter=20, + nb_samples_grad=2, + epsilon=1, + log=True, + verbose=True, + random_state=42, + ) Gb, logb = ot.gromov.sampled_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) + C1b, + C2b, + pb, + qb, + lossb, + max_iter=20, + nb_samples_grad=2, + epsilon=1, + log=True, + verbose=True, + random_state=42, + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov diff --git a/test/gromov/test_fugw.py b/test/gromov/test_fugw.py index 894da1c3b..0485715cf 100644 --- a/test/gromov/test_fugw.py +++ b/test/gromov/test_fugw.py @@ -4,25 +4,29 @@ # # License: MIT License - import itertools import numpy as np import ot import pytest -from ot.gromov._unbalanced import fused_unbalanced_gromov_wasserstein, fused_unbalanced_gromov_wasserstein2, fused_unbalanced_across_spaces_divergence +from ot.gromov._unbalanced import ( + fused_unbalanced_gromov_wasserstein, + fused_unbalanced_gromov_wasserstein2, + fused_unbalanced_across_spaces_divergence, +) @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"]) +) def test_sanity(nx, unbalanced_solver, divergence): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -53,19 +57,45 @@ def test_sanity(nx, unbalanced_solver, divergence): anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( - C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp_nx, init_duals=None, init_pi=G0b, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1b, + C2b, + wx=pb, + wy=qb, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp_nx, + init_duals=None, + init_pi=G0b, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -77,19 +107,45 @@ def test_sanity(nx, unbalanced_solver, divergence): # test divergence fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = fused_unbalanced_gromov_wasserstein2( - C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp_nx, init_duals=None, init_pi=G0b, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1b, + C2b, + wx=pb, + wy=qb, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp_nx, + init_duals=None, + init_pi=G0b, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -99,15 +155,19 @@ def test_sanity(nx, unbalanced_solver, divergence): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_init_plans(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -131,19 +191,45 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): tol_ot = 1e-5 pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -154,19 +240,45 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): # test divergence fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -175,15 +287,19 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_init_duals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -210,19 +326,45 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): tol_ot = 1e-5 pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=init_duals, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=init_duals, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -232,19 +374,45 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): # test divergence fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=init_duals, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=init_duals, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -253,15 +421,19 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_reg_marginals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -291,28 +463,67 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): list_options = [full_tuple_reg_m, tuple_reg_m, full_list_reg_m, list_reg_m] pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) for opt in list_options: pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=opt, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=opt, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -322,11 +533,24 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): # test divergence fugw_nx = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=opt, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=opt, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -335,15 +559,19 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_log(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -366,19 +594,45 @@ def test_log(nx, unbalanced_solver, divergence, eps): tol_ot = 1e-5 pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx, log = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -389,19 +643,45 @@ def test_log(nx, unbalanced_solver, divergence, eps): # test divergence fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx, log = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -410,15 +690,19 @@ def test_log(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_marginals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() p = ot.unif(n_samples) @@ -441,19 +725,45 @@ def test_marginals(nx, unbalanced_solver, divergence, eps): tol_ot = 1e-5 pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( - C1, C2, wx=None, wy=None, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=None, + wy=None, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -464,19 +774,45 @@ def test_marginals(nx, unbalanced_solver, divergence, eps): # test divergence fugw = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = fused_unbalanced_gromov_wasserstein2( - C1, C2, wx=None, wy=None, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver=unbalanced_solver, - alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=None, + wy=None, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M=M_samp, + init_duals=None, + init_pi=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) fugw_nx = nx.to_numpy(fugw_nx) @@ -515,20 +851,46 @@ def test_raise_value_error(nx): # raise error of divergence def fugw_div(divergence): return fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver="mm", - alpha=0, M=None, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver="mm", + alpha=0, + M=None, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) def fugw_div_nx(divergence): return fused_unbalanced_gromov_wasserstein( - C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver="mm", - alpha=0, M=None, init_duals=None, init_pi=G0b, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1b, + C2b, + wx=pb, + wy=qb, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver="mm", + alpha=0, + M=None, + init_duals=None, + init_pi=G0b, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) np.testing.assert_raises(NotImplementedError, fugw_div, "div_not_existed") @@ -537,29 +899,61 @@ def fugw_div_nx(divergence): # raise error of solver def fugw_solver(unbalanced_solver): return fused_unbalanced_gromov_wasserstein( - C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, - divergence="kl", unbalanced_solver=unbalanced_solver, - alpha=0, M=None, init_duals=None, init_pi=G0, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1, + C2, + wx=p, + wy=q, + reg_marginals=reg_m, + epsilon=eps, + divergence="kl", + unbalanced_solver=unbalanced_solver, + alpha=0, + M=None, + init_duals=None, + init_pi=G0, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) def fugw_solver_nx(unbalanced_solver): return fused_unbalanced_gromov_wasserstein( - C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, - divergence="kl", unbalanced_solver=unbalanced_solver, - alpha=0, M=None, init_duals=None, init_pi=G0b, max_iter=max_iter, - tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + C1b, + C2b, + wx=pb, + wy=qb, + reg_marginals=reg_m, + epsilon=eps, + divergence="kl", + unbalanced_solver=unbalanced_solver, + alpha=0, + M=None, + init_duals=None, + init_pi=G0b, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) np.testing.assert_raises(NotImplementedError, fugw_solver, "solver_not_existed") np.testing.assert_raises(NotImplementedError, fugw_solver_nx, "solver_not_existed") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) -def test_fused_unbalanced_across_spaces_divergence_wrong_reg_type(nx, unbalanced_solver, divergence, eps): - +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) +def test_fused_unbalanced_across_spaces_divergence_wrong_reg_type( + nx, unbalanced_solver, divergence, eps +): n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -571,9 +965,13 @@ def test_fused_unbalanced_across_spaces_divergence_wrong_reg_type(nx, unbalanced def reg_type(reg_type): return fused_unbalanced_across_spaces_divergence( - X=x, Y=y, reg_marginals=reg_m, - epsilon=eps, reg_type=reg_type, - divergence=divergence, unbalanced_solver=unbalanced_solver + X=x, + Y=y, + reg_marginals=reg_m, + epsilon=eps, + reg_type=reg_type, + divergence=divergence, + unbalanced_solver=unbalanced_solver, ) np.testing.assert_raises(NotImplementedError, reg_type, "reg_type_not_existed") @@ -581,15 +979,24 @@ def reg_type(reg_type): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps, reg_type", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2], ["independent", "joint"])) -def test_fused_unbalanced_across_spaces_divergence_log(nx, unbalanced_solver, divergence, eps, reg_type): +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps, reg_type", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], + ["kl", "l2"], + [0, 1e-2], + ["independent", "joint"], + ), +) +def test_fused_unbalanced_across_spaces_divergence_log( + nx, unbalanced_solver, divergence, eps, reg_type +): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -614,21 +1021,53 @@ def test_fused_unbalanced_across_spaces_divergence_log(nx, unbalanced_solver, di # test couplings pi_sample, pi_feature = fused_unbalanced_across_spaces_divergence( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + reg_type=reg_type, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx, log = fused_unbalanced_across_spaces_divergence( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + reg_type=reg_type, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=False, ) np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) @@ -644,8 +1083,7 @@ def test_fused_unbalanced_across_spaces_divergence_warning(nx, reg_type): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -674,12 +1112,28 @@ def test_fused_unbalanced_across_spaces_divergence_warning(nx, reg_type): def raise_warning(): return fused_unbalanced_across_spaces_divergence( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + reg_type=reg_type, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) np.testing.assert_raises(ValueError, raise_warning) diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 4f3dff14b..d71500d20 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -1,4 +1,4 @@ -""" Tests for gromov._gw.py """ +"""Tests for gromov._gw.py""" # Author: Erwan Vautier # Nicolas Courty @@ -36,33 +36,50 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein( - C1, C2, None, q, 'square_loss', G0=G0, verbose=True, - alpha_min=0., alpha_max=1.) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + C1, + C2, + None, + q, + "square_loss", + G0=G0, + verbose=True, + alpha_min=0.0, + alpha_max=1.0, + ) + Gb = nx.to_numpy( + ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, "square_loss", symmetric=True, G0=G0b, verbose=True + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) for armijo in [False, True]: - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=armijo, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=armijo, log=True) + gw, log = ot.gromov.gromov_wasserstein2( + C1, C2, None, q, "kl_loss", armijo=armijo, log=True + ) + gwb, logb = ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, None, "kl_loss", armijo=armijo, log=True + ) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=armijo, G0=G0, log=False) + gw_val = ot.gromov.gromov_wasserstein2( + C1, C2, p, q, "kl_loss", armijo=armijo, G0=G0, log=False + ) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=armijo, G0=G0b, log=False) + ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, qb, "kl_loss", armijo=armijo, G0=G0b, log=False + ) ) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) @@ -72,16 +89,14 @@ def test_gromov(nx): # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov def test_asymmetric_gromov(nx): n_samples = 20 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) rng.shuffle(idx) C2 = C1[idx, :][:, idx] @@ -92,33 +107,37 @@ def test_asymmetric_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) - Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + G, log = ot.gromov.gromov_wasserstein( + C1, C2, p, q, "square_loss", G0=G0, log=True, symmetric=False, verbose=True + ) + Gb, logb = ot.gromov.gromov_wasserstein( + C1b, C2b, pb, qb, "square_loss", log=True, symmetric=False, G0=G0b, verbose=True + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["gw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["gw_dist"], 0.0, atol=1e-04) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True) + gw, log = ot.gromov.gromov_wasserstein2( + C1, C2, p, q, "square_loss", G0=G0, log=True, symmetric=False, verbose=True + ) + gwb, logb = ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, qb, "square_loss", log=True, symmetric=False, G0=G0b, verbose=True + ) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["gw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["gw_dist"], 0.0, atol=1e-04) def test_gromov_integer_warnings(nx): @@ -142,14 +161,25 @@ def test_gromov_integer_warnings(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein( - C1, C2, None, q, 'square_loss', G0=G0, verbose=True, - alpha_min=0., alpha_max=1.) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( - C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + C1, + C2, + None, + q, + "square_loss", + G0=G0, + verbose=True, + alpha_min=0.0, + alpha_max=1.0, + ) + Gb = nx.to_numpy( + ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, "square_loss", symmetric=True, G0=G0b, verbose=True + ) + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(G, 0., atol=1e-09) + np.testing.assert_allclose(G, 0.0, atol=1e-09) def test_gromov_dtype_device(nx): @@ -179,9 +209,13 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) with warnings.catch_warnings(): - warnings.filterwarnings('error') - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + warnings.filterwarnings("error") + Gb = ot.gromov.gromov_wasserstein( + C1b, C2b, pb, qb, "square_loss", G0=G0b, verbose=True + ) + gw_valb = ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, qb, "kl_loss", armijo=True, G0=G0b, log=False + ) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -206,16 +240,22 @@ def test_gromov_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + Gb = ot.gromov.gromov_wasserstein( + C1b, C2b, pb, qb, "square_loss", G0=G0b, verbose=True + ) + gw_valb = ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, qb, "kl_loss", armijo=True, G0=G0b, log=False + ) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, "square_loss", verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2( + C1b, C2b, pb, qb, "kl_loss", armijo=True, log=False + ) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) assert nx.dtype_device(Gb)[1].startswith("GPU") @@ -241,12 +281,10 @@ def test_gromov2_gradients(): C2 /= C2.max() if torch: - devices = [torch.device("cpu")] if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: - # classical gradients p1 = torch.tensor(p, requires_grad=True, device=device) q1 = torch.tensor(q, requires_grad=True, device=device) @@ -306,10 +344,14 @@ def test_gw_helper_backend(nx): C2 /= C2.max() C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True) + Gb, logb = ot.gromov.gromov_wasserstein( + C1b, C2b, pb, qb, "square_loss", armijo=False, symmetric=True, G0=G0b, log=True + ) # calls with nx=None - constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + constCb, hC1b, hC2b = ot.gromov.init_matrix( + C1b, C2b, pb, qb, loss_fun="square_loss" + ) def f(G): return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) @@ -318,18 +360,37 @@ def df(G): return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) def line_search(cost, G, deltaG, Mi, cost_G, df_G): - return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None) + return ot.gromov.solve_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, M=0.0, reg=1.0, nx=None + ) + # feed the precomputed local optimum Gb to cg - res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + res, log = ot.optim.cg( + pb, + qb, + 0.0, + 1.0, + f, + df, + Gb, + line_search, + log=True, + numItermax=1e4, + stopThr=1e-9, + stopThr2=1e-9, + ) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', - 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) +@pytest.mark.parametrize( + "loss_fun", + [ + "square_loss", + "kl_loss", + pytest.param("unknown_loss", marks=pytest.mark.xfail(raises=ValueError)), + ], +) def test_gw_helper_validation(loss_fun): n_samples = 10 # nb samples mu = np.array([0, 0]) @@ -347,8 +408,8 @@ def test_gromov_barycenter(nx): ns = 5 nt = 8 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -359,51 +420,115 @@ def test_gromov_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, ) - for stop_criterion in ['barycenter', 'loss']: + for stop_criterion in ["barycenter", "loss"]: Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1, C2], + None, + p, + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) + Cbb = nx.to_numpy( + ot.gromov.gromov_barycenters( + n_samples, + [C1b, C2b], + [p1b, p2b], + None, + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) ) - Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42 - )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - warmstartT=True, random_state=42, log=True + n_samples, + [C1, C2], + [p1, p2], + p, + None, + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + warmstartT=True, + random_state=42, + log=True, ) Cbb_, errb_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, warmstartT=True, random_state=42, log=True + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + warmstartT=True, + random_state=42, + log=True, ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_array_almost_equal(err_["err"], nx.to_numpy(*errb_["err"])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + n_samples, + [C1, C2], + [p1, p2], + p, + [0.5, 0.5], + "kl_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + random_state=42, + ) + Cb2b = nx.to_numpy( + ot.gromov.gromov_barycenters( + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "kl_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + random_state=42, + ) ) - Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 - )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @@ -416,17 +541,36 @@ def test_gromov_barycenter(nx): init_Cb = nx.from_numpy(init_C) Cb2_, err2_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=10, - tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C + n_samples, + [C1, C2], + [p1, p2], + p, + [0.5, 0.5], + "kl_loss", + max_iter=10, + tol=1e-3, + verbose=False, + random_state=42, + log=True, + init_C=init_C, ) Cb2b_, err2b_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', - max_iter=10, tol=1e-3, verbose=True, random_state=42, - init_C=init_Cb, log=True + n_samples, + [C1b, C2b], + [p1b, p2b], + pb, + [0.5, 0.5], + "kl_loss", + max_iter=10, + tol=1e-3, + verbose=True, + random_state=42, + init_C=init_Cb, + log=True, ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_array_almost_equal(err2_["err"], nx.to_numpy(*err2b_["err"])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) # test edge cases for gw barycenters: @@ -434,9 +578,17 @@ def test_gromov_barycenter(nx): with pytest.raises(ValueError): C1_list = [list(c) for c in C1] _ = ot.gromov.gromov_barycenters( - n_samples, [C1_list], None, p, None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1_list], + None, + p, + None, + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, ) # p1, p2 as lists @@ -444,22 +596,48 @@ def test_gromov_barycenter(nx): p1_list = list(p1) p2_list = list(p2) _ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1_list, p2_list], p, None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1, C2], + [p1_list, p2_list], + p, + None, + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, ) # unique input structure Cb = ot.gromov.gromov_barycenters( - n_samples, [C1], None, p, None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 - ) - Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b], None, None, [1.], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42 - )) + n_samples, + [C1], + None, + p, + None, + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) + Cbb = nx.to_numpy( + ot.gromov.gromov_barycenters( + n_samples, + [C1b], + None, + None, + [1.0], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) + ) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) @@ -493,50 +671,93 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) + G, log = ot.gromov.fused_gromov_wasserstein( + M, + C1, + C2, + None, + q, + "square_loss", + alpha=0.5, + armijo=True, + symmetric=None, + G0=G0, + log=True, + ) + Gb, logb = ot.gromov.fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + None, + "square_loss", + alpha=0.5, + armijo=True, + symmetric=True, + G0=G0b, + log=True, + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence fgw - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence fgw Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - np.testing.assert_allclose( - Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2( + M, + C1, + C2, + p, + None, + "square_loss", + armijo=True, + symmetric=True, + G0=None, + alpha=0.5, + log=True, + ) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2( + Mb, + C1b, + C2b, + None, + qb, + "square_loss", + armijo=True, + symmetric=None, + G0=G0b, + alpha=0.5, + log=True, + ) fgwb = nx.to_numpy(fgwb) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) np.testing.assert_allclose(fgw, fgwb, atol=1e-08) np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov def test_asymmetric_fgw(nx): n_samples = 20 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) rng.shuffle(idx) C2 = C1[idx, :][:, idx] # add features - F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F1 = rng.uniform(low=0.0, high=10, size=(n_samples, 1)) F2 = F1[idx, :] p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -546,90 +767,164 @@ def test_asymmetric_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=False, verbose=True) + M, + C1, + C2, + p, + q, + "square_loss", + alpha=0.5, + G0=G0, + log=True, + symmetric=False, + verbose=True, + ) Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=None, G0=G0b, verbose=True) + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + alpha=0.5, + log=True, + symmetric=None, + G0=G0b, + verbose=True, + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["fgw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["fgw_dist"], 0.0, atol=1e-04) fgw, log = ot.gromov.fused_gromov_wasserstein2( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=None, verbose=True) + M, + C1, + C2, + p, + q, + "square_loss", + alpha=0.5, + G0=G0, + log=True, + symmetric=None, + verbose=True, + ) fgwb, logb = ot.gromov.fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=False, G0=G0b, verbose=True) + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + alpha=0.5, + log=True, + symmetric=False, + G0=G0b, + verbose=True, + ) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["fgw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["fgw_dist"], 0.0, atol=1e-04) # Tests with kl-loss: for armijo in [False, True]: G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'kl_loss', alpha=0.5, armijo=armijo, G0=G0, - log=True, symmetric=False, verbose=True) + M, + C1, + C2, + p, + q, + "kl_loss", + alpha=0.5, + armijo=armijo, + G0=G0, + log=True, + symmetric=False, + verbose=True, + ) Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, armijo=armijo, - log=True, symmetric=None, G0=G0b, verbose=True) + Mb, + C1b, + C2b, + pb, + qb, + "kl_loss", + alpha=0.5, + armijo=armijo, + log=True, + symmetric=None, + G0=G0b, + verbose=True, + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["fgw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["fgw_dist"], 0.0, atol=1e-04) fgw, log = ot.gromov.fused_gromov_wasserstein2( - M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, - symmetric=None, verbose=True) + M, + C1, + C2, + p, + q, + "kl_loss", + alpha=0.5, + G0=G0, + log=True, + symmetric=None, + verbose=True, + ) fgwb, logb = ot.gromov.fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, - symmetric=False, G0=G0b, verbose=True) + Mb, + C1b, + C2b, + pb, + qb, + "kl_loss", + alpha=0.5, + log=True, + symmetric=False, + G0=G0b, + verbose=True, + ) - G = log['T'] - Gb = nx.to_numpy(logb['T']) + G = log["T"] + Gb = nx.to_numpy(logb["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose( - p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose( - q, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04) - np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) + np.testing.assert_allclose(log["fgw_dist"], 0.0, atol=1e-04) + np.testing.assert_allclose(logb["fgw_dist"], 0.0, atol=1e-04) def test_fgw_integer_warnings(nx): n_samples = 20 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) rng.shuffle(idx) C2 = C1[idx, :][:, idx] # add features - F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F1 = rng.uniform(low=0.0, high=10, size=(n_samples, 1)) F2 = F1[idx, :] p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -639,15 +934,35 @@ def test_fgw_integer_warnings(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G, log = ot.gromov.fused_gromov_wasserstein( - M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, - symmetric=False, verbose=True) + M, + C1, + C2, + p, + q, + "square_loss", + alpha=0.5, + G0=G0, + log=True, + symmetric=False, + verbose=True, + ) Gb, logb = ot.gromov.fused_gromov_wasserstein( - Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, - symmetric=None, G0=G0b, verbose=True) + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + alpha=0.5, + log=True, + symmetric=None, + G0=G0b, + verbose=True, + ) Gb = nx.to_numpy(Gb) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(G, 0., atol=1e-06) + np.testing.assert_allclose(G, 0.0, atol=1e-06) def test_fgw2_gradients(): @@ -671,7 +986,6 @@ def test_fgw2_gradients(): C2 /= C2.max() if torch: - devices = [torch.device("cpu")] if torch.cuda.is_available(): devices.append(torch.device("cuda")) @@ -740,10 +1054,24 @@ def test_fgw_helper_backend(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) alpha = 0.5 - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + qb, + "square_loss", + alpha=0.5, + armijo=False, + symmetric=True, + G0=G0b, + log=True, + ) # calls with nx=None - constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss') + constCb, hC1b, hC2b = ot.gromov.init_matrix( + C1b, C2b, pb, qb, loss_fun="square_loss" + ) def f(G): return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None) @@ -752,14 +1080,44 @@ def df(G): return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None) def line_search(cost, G, deltaG, Mi, cost_G, df_G): - return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None) + return ot.gromov.solve_gromov_linesearch( + G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None + ) + # feed the precomputed local optimum Gb to cg - res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + res, log = ot.optim.cg( + pb, + qb, + (1 - alpha) * Mb, + alpha, + f, + df, + Gb, + line_search, + log=True, + numItermax=1e4, + stopThr=1e-9, + stopThr2=1e-9, + ) def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None) + # feed the precomputed local optimum Gb to cg - res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + res_armijo, log_armijo = ot.optim.cg( + pb, + qb, + (1 - alpha) * Mb, + alpha, + f, + df, + Gb, + line_search, + log=True, + numItermax=1e4, + stopThr=1e-9, + stopThr2=1e-9, + ) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) np.testing.assert_allclose(res_armijo, Gb, atol=1e-06) @@ -769,8 +1127,8 @@ def test_fgw_barycenter(nx): ns = 10 nt = 20 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) rng = np.random.RandomState(42) ys = rng.randn(Xs.shape[0], 2) @@ -786,19 +1144,32 @@ def test_fgw_barycenter(nx): p = ot.unif(n_samples) ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) - lambdas = [.5, .5] + lambdas = [0.5, 0.5] Csb = [C1b, C2b] Ysb = [ysb, ytb] Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, - fixed_features=False, p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - random_state=12345, log=True + n_samples, + Ysb, + Csb, + None, + lambdas, + 0.5, + fixed_structure=False, + fixed_features=False, + p=pb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + random_state=12345, + log=True, ) # test correspondance with utils function recovered_Cb = ot.gromov.update_barycenter_structure( - logb['Ts_iter'][-1], Csb, lambdas, pb, target=False, check_zeros=True) + logb["Ts_iter"][-1], Csb, lambdas, pb, target=False, check_zeros=True + ) recovered_Xb = ot.gromov.update_barycenter_feature( - logb['Ts_iter'][-1], Ysb, lambdas, pb, target=False, check_zeros=True) + logb["Ts_iter"][-1], Ysb, lambdas, pb, target=False, check_zeros=True + ) np.testing.assert_allclose(Cb, recovered_Cb) np.testing.assert_allclose(Xb, recovered_Xb) @@ -808,17 +1179,39 @@ def test_fgw_barycenter(nx): init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_structure=True`and `init_C=None` Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 + n_samples, + Ysb, + Csb, + ps=[p1b, p2b], + lambdas=None, + alpha=0.5, + fixed_structure=True, + init_C=None, + fixed_features=False, + p=None, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, ) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=10, tol=1e-3 + n_samples, + [ysb, ytb], + [C1b, C2b], + ps=[p1b, p2b], + lambdas=None, + alpha=0.5, + fixed_structure=True, + init_C=init_Cb, + fixed_features=False, + p=None, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) @@ -827,18 +1220,46 @@ def test_fgw_barycenter(nx): init_X = rng.randn(n_samples, ys.shape[1]) init_Xb = nx.from_numpy(init_X) - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_features=True`and `init_X=None` Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=None, - p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=True, + init_X=None, + p=pb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=True, + init_X=init_Xb, + p=pb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) @@ -847,28 +1268,63 @@ def test_fgw_barycenter(nx): # add test with 'kl_loss' with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" X, C, log = ot.gromov.fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + p=p, + loss_fun="kl_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=C, + init_X=X, + warmstartT=True, + random_state=12345, + log=True, ) - for stop_criterion in ['barycenter', 'loss']: + for stop_criterion in ["barycenter", "loss"]: X, C, log = ot.gromov.fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + p=p, + loss_fun="kl_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=C, + init_X=X, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) # test correspondance with utils function recovered_C = ot.gromov.update_barycenter_structure( - log['T'], [C1, C2], lambdas, p, loss_fun='kl_loss', - target=False, check_zeros=False) + log["T"], + [C1, C2], + lambdas, + p, + loss_fun="kl_loss", + target=False, + check_zeros=False, + ) np.testing.assert_allclose(C, recovered_C) @@ -877,10 +1333,25 @@ def test_fgw_barycenter(nx): with pytest.raises(ValueError): C1b_list = [list(c) for c in C1b] _, _, _ = ot.gromov.fgw_barycenters( - n_samples, [ysb], [C1b_list], [p1b], None, 0.5, - fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, - init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ysb], + [C1b_list], + [p1b], + None, + 0.5, + fixed_structure=False, + fixed_features=False, + p=pb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=Cb, + init_X=Xb, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) # p1, p2 as lists @@ -888,24 +1359,65 @@ def test_fgw_barycenter(nx): p1_list = list(p1) p2_list = list(p2) _, _, _ = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1_list, p2_list], None, 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=Cb, - init_X=Xb, warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1_list, p2_list], + None, + 0.5, + fixed_structure=False, + fixed_features=False, + p=p, + loss_fun="kl_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=Cb, + init_X=Xb, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) # unique input structure X, C = ot.gromov.fgw_barycenters( - n_samples, [ys], [C1], [p1], None, 0.5, - fixed_structure=False, fixed_features=False, p=p, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - warmstartT=True, random_state=12345, log=False, verbose=False + n_samples, + [ys], + [C1], + [p1], + None, + 0.5, + fixed_structure=False, + fixed_features=False, + p=p, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + warmstartT=True, + random_state=12345, + log=False, + verbose=False, ) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb], [C1b], [p1b], [1.], 0.5, - fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - warmstartT=True, random_state=12345, log=False, verbose=False + n_samples, + [ysb], + [C1b], + [p1b], + [1.0], + 0.5, + fixed_structure=False, + fixed_features=False, + p=pb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + warmstartT=True, + random_state=12345, + log=False, + verbose=False, ) np.testing.assert_allclose(C, Cb, atol=1e-06) diff --git a/test/gromov/test_lowrank.py b/test/gromov/test_lowrank.py index befc5c835..27e0fcdb0 100644 --- a/test/gromov/test_lowrank.py +++ b/test/gromov/test_lowrank.py @@ -1,4 +1,4 @@ -""" Tests for gromov._lowrank.py """ +"""Tests for gromov._lowrank.py""" # Author: Laurène DAVID # @@ -15,8 +15,8 @@ def test__flat_product_operator(): X = np.reshape(1.0 * np.arange(2 * n), (n, d)) A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False) - A1_ = ot.gromov._flat_product_operator(A1) - A2_ = ot.gromov._flat_product_operator(A2) + A1_ = ot.gromov._lowrank._flat_product_operator(A1) + A2_ = ot.gromov._lowrank._flat_product_operator(A2) cost = ot.dist(X, X) # test value @@ -35,7 +35,9 @@ def test_lowrank_gromov_wasserstein_samples(): a = ot.unif(n_samples) b = ot.unif(n_samples) - Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False) + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False + ) P = log["lazy_plan"][:] # check constraints for P @@ -49,13 +51,29 @@ def test_lowrank_gromov_wasserstein_samples(): # check warn parameter when low rank GW algorithm doesn't converge with pytest.warns(UserWarning): ot.gromov.lowrank_gromov_wasserstein_samples( - X_s, X_t, a, b, reg=0.1, stopThr=0, numItermax=1, warn=True, warn_dykstra=False + X_s, + X_t, + a, + b, + reg=0.1, + stopThr=0, + numItermax=1, + warn=True, + warn_dykstra=False, ) # check warn parameter when Dykstra algorithm doesn't converge with pytest.warns(UserWarning): ot.gromov.lowrank_gromov_wasserstein_samples( - X_s, X_t, a, b, reg=0.1, stopThr_dykstra=0, numItermax_dykstra=1, warn=False, warn_dykstra=True + X_s, + X_t, + a, + b, + reg=0.1, + stopThr_dykstra=0, + numItermax_dykstra=1, + warn=False, + warn_dykstra=True, ) @@ -73,7 +91,9 @@ def test_lowrank_gromov_wasserstein_samples_alpha_error(alpha, rank): b = ot.unif(n_samples) with pytest.raises(ValueError): - ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False + ) @pytest.mark.parametrize(("gamma_init"), ("rescale", "theory", "other")) @@ -91,10 +111,14 @@ def test_lowrank_wasserstein_samples_gamma_init(gamma_init): if gamma_init not in ["rescale", "theory"]: with pytest.raises(NotImplementedError): - ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True + ) else: - Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples( + X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True + ) P = log["lazy_plan"][:] # check constraints for P @@ -102,7 +126,7 @@ def test_lowrank_wasserstein_samples_gamma_init(gamma_init): np.testing.assert_allclose(b, P.sum(0), atol=1e-04) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") def test_lowrank_gromov_wasserstein_samples_backends(nx): # Test low rank sinkhorn for different backends n_samples = 20 # nb samples @@ -117,7 +141,9 @@ def test_lowrank_gromov_wasserstein_samples_backends(nx): ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) - Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples(X_sb, X_tb, ab, bb, reg=0.1, log=True) + Q, R, g, log = ot.gromov.lowrank_gromov_wasserstein_samples( + X_sb, X_tb, ab, bb, reg=0.1, log=True + ) lazy_plan = log["lazy_plan"] P = lazy_plan[:] diff --git a/test/gromov/test_partial.py b/test/gromov/test_partial.py index 9a6712666..1ae4e960f 100644 --- a/test/gromov/test_partial.py +++ b/test/gromov/test_partial.py @@ -1,4 +1,4 @@ -""" Tests for gromov._partial.py """ +"""Tests for gromov._partial.py""" # Author: # Laetitia Chapel @@ -13,7 +13,6 @@ def test_raise_errors(): - n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -38,12 +37,10 @@ def test_raise_errors(): ot.gromov.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) with pytest.raises(ValueError): - ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, - log=True) + ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True) with pytest.raises(ValueError): - ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, - log=True) + ot.gromov.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True) def test_partial_gromov_wasserstein(nx): @@ -76,24 +73,47 @@ def test_partial_gromov_wasserstein(nx): C2 = ot.dist(xt, xt) C3 = ot.dist(xt2, xt2) - m = 2. / 3. + m = 2.0 / 3.0 C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy(C1, C1sub, C2, C3, p, psub, q) - G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. G0b = nx.from_numpy(G0) # check consistency across backends and stability w.r.t loss/marginals/sym list_sym = [True, None] - for i, loss_fun in enumerate(['square_loss', 'kl_loss']): - + for i, loss_fun in enumerate(["square_loss", "kl_loss"]): res, log = ot.gromov.partial_gromov_wasserstein( - C1, C3, p=p, q=None, m=m, loss_fun=loss_fun, n_dummies=1, - G0=G0, log=True, symmetric=list_sym[i], warn=True, verbose=True) + C1, + C3, + p=p, + q=None, + m=m, + loss_fun=loss_fun, + n_dummies=1, + G0=G0, + log=True, + symmetric=list_sym[i], + warn=True, + verbose=True, + ) resb, logb = ot.gromov.partial_gromov_wasserstein( - C1b, C3b, p=None, q=qb, m=m, loss_fun=loss_fun, n_dummies=1, - G0=G0b, log=True, symmetric=False, warn=True, verbose=True) + C1b, + C3b, + p=None, + q=qb, + m=m, + loss_fun=loss_fun, + n_dummies=1, + G0=G0b, + log=True, + symmetric=False, + warn=True, + verbose=True, + ) resb_ = nx.to_numpy(resb) assert np.all(res.sum(1) <= p) # cf convergence wasserstein @@ -109,36 +129,39 @@ def test_partial_gromov_wasserstein(nx): pass # tests with different number of samples across spaces - m = 2. / 3. + m = 2.0 / 3.0 res, log = ot.gromov.partial_gromov_wasserstein( - C1, C1sub, p=p, q=psub, m=m, log=True) + C1, C1sub, p=p, q=psub, m=m, log=True + ) resb, logb = ot.gromov.partial_gromov_wasserstein( - C1b, C1subb, p=pb, q=psubb, m=m, log=True) + C1b, C1subb, p=pb, q=psubb, m=m, log=True + ) resb_ = nx.to_numpy(resb) np.testing.assert_allclose(res, resb_, rtol=1e-4) assert np.all(res.sum(1) <= p) # cf convergence wasserstein assert np.all(res.sum(0) <= psub) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(res), m, atol=1e-15) + np.testing.assert_allclose(np.sum(res), m, atol=1e-15) # Edge cases - tests with m=1 set by default (coincide with gw) m = 1 - res0 = ot.gromov.partial_gromov_wasserstein( - C1, C2, p, q, m=m, log=False) + res0 = ot.gromov.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=False) res0b, log0b = ot.gromov.partial_gromov_wasserstein( - C1b, C2b, pb, qb, m=None, log=True) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss') + C1b, C2b, pb, qb, m=None, log=True + ) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, "square_loss") np.testing.assert_allclose(G, res0, rtol=1e-4) np.testing.assert_allclose(res0b, res0, rtol=1e-4) # tests for pGW2 - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: w0, log0 = ot.gromov.partial_gromov_wasserstein2( - C1, C2, p=None, q=q, m=m, loss_fun=loss_fun, log=True) + C1, C2, p=None, q=q, m=m, loss_fun=loss_fun, log=True + ) w0_val = ot.gromov.partial_gromov_wasserstein2( - C1b, C2b, p=pb, q=None, m=m, loss_fun=loss_fun, log=False) + C1b, C2b, p=pb, q=None, m=m, loss_fun=loss_fun, log=False + ) np.testing.assert_allclose(w0, w0_val, rtol=1e-4) # tests integers @@ -148,7 +171,8 @@ def test_partial_gromov_wasserstein(nx): C2b_int = nx.from_numpy(C2_int) res0b, log0b = ot.gromov.partial_gromov_wasserstein( - C1b_int, C2b_int, pb, qb, m=m, log=True) + C1b_int, C2b_int, pb, qb, m=m, log=True + ) assert nx.to_numpy(res0b).dtype == C1_int.dtype @@ -178,18 +202,19 @@ def test_partial_partial_gromov_linesearch(nx): C2 = ot.dist(xt, xt) C3 = ot.dist(xt2, xt2) - m = 2. / 3. + m = 2.0 / 3.0 C1b, C2b, C3b, pb, qb = nx.from_numpy(C1, C2, C3, p, q) - G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. G0b = nx.from_numpy(G0) # computing necessary inputs to the line-search - Gb, _ = ot.gromov.partial_gromov_wasserstein( - C1b, C2b, pb, qb, m=m, log=True) + Gb, _ = ot.gromov.partial_gromov_wasserstein(C1b, C2b, pb, qb, m=m, log=True) deltaGb = Gb - G0b - fC1, fC2, hC1, hC2 = ot.gromov._utils._transform_matrix(C1b, C2b, 'square_loss') + fC1, fC2, hC1, hC2 = ot.gromov._utils._transform_matrix(C1b, C2b, "square_loss") fC2t = fC2.T ones_p = nx.ones(p.shape[0], type_as=pb) @@ -204,10 +229,10 @@ def test_partial_partial_gromov_linesearch(nx): # perform line-search alpha, _, cost_Gb, _ = ot.gromov.solve_partial_gromov_linesearch( - G0b, deltaGb, cost_G0b, df_G0b, df_Gb, 0., 1., - alpha_min=0., alpha_max=1.) + G0b, deltaGb, cost_G0b, df_G0b, df_Gb, 0.0, 1.0, alpha_min=0.0, alpha_max=1.0 + ) - np.testing.assert_allclose(alpha, 1., rtol=1e-4) + np.testing.assert_allclose(alpha, 1.0, rtol=1e-4) @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -242,22 +267,44 @@ def test_entropic_partial_gromov_wasserstein(nx): C2 = ot.dist(xt, xt) C3 = ot.dist(xt2, xt2) - m = 2. / 3. + m = 2.0 / 3.0 C1b, C1subb, C2b, C3b, pb, psubb, qb = nx.from_numpy(C1, C1sub, C2, C3, p, psub, q) - G0 = np.outer(p, q) * m / (np.sum(p) * np.sum(q)) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. + G0 = ( + np.outer(p, q) * m / (np.sum(p) * np.sum(q)) + ) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q. G0b = nx.from_numpy(G0) # check consistency across backends and stability w.r.t loss/marginals/sym list_sym = [True, None] - for i, loss_fun in enumerate(['square_loss', 'kl_loss']): + for i, loss_fun in enumerate(["square_loss", "kl_loss"]): res, log = ot.gromov.entropic_partial_gromov_wasserstein( - C1, C3, p=p, q=None, reg=1e4, m=m, loss_fun=loss_fun, G0=None, - log=True, symmetric=list_sym[i], verbose=True) + C1, + C3, + p=p, + q=None, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=None, + log=True, + symmetric=list_sym[i], + verbose=True, + ) resb, logb = ot.gromov.entropic_partial_gromov_wasserstein( - C1b, C3b, p=None, q=qb, reg=1e4, m=m, loss_fun=loss_fun, G0=G0b, - log=True, symmetric=False, verbose=True) + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + m=m, + loss_fun=loss_fun, + G0=G0b, + log=True, + symmetric=False, + verbose=True, + ) resb_ = nx.to_numpy(resb) try: # some instability can occur with kl. to investigate further. @@ -270,37 +317,55 @@ def test_entropic_partial_gromov_wasserstein(nx): # tests with m is None res = ot.gromov.entropic_partial_gromov_wasserstein( - C1, C3, p=p, q=None, reg=1e4, G0=None, log=False, - symmetric=list_sym[i], verbose=True) + C1, + C3, + p=p, + q=None, + reg=1e4, + G0=None, + log=False, + symmetric=list_sym[i], + verbose=True, + ) resb = ot.gromov.entropic_partial_gromov_wasserstein( - C1b, C3b, p=None, q=qb, reg=1e4, G0=None, log=False, - symmetric=False, verbose=True) + C1b, + C3b, + p=None, + q=qb, + reg=1e4, + G0=None, + log=False, + symmetric=False, + verbose=True, + ) resb_ = nx.to_numpy(resb) np.testing.assert_allclose(res, resb_, rtol=1e-4) - np.testing.assert_allclose( - np.sum(res), 1., rtol=1e-4) + np.testing.assert_allclose(np.sum(res), 1.0, rtol=1e-4) # tests with different number of samples across spaces m = 0.5 res, log = ot.gromov.entropic_partial_gromov_wasserstein( - C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True) + C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True + ) resb, logb = ot.gromov.entropic_partial_gromov_wasserstein( - C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True) + C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True + ) resb_ = nx.to_numpy(resb) np.testing.assert_allclose(res, resb_, rtol=1e-4) assert np.all(res.sum(1) <= p) # cf convergence wasserstein assert np.all(res.sum(0) <= psub) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(res), m, rtol=1e-4) + np.testing.assert_allclose(np.sum(res), m, rtol=1e-4) # tests for pGW2 - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: w0, log0 = ot.gromov.entropic_partial_gromov_wasserstein2( - C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True) + C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True + ) w0_val = ot.gromov.entropic_partial_gromov_wasserstein2( - C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False) + C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False + ) np.testing.assert_allclose(w0, w0_val, rtol=1e-8) diff --git a/test/gromov/test_quantized.py b/test/gromov/test_quantized.py index a864a8a46..c3b80bb7d 100644 --- a/test/gromov/test_quantized.py +++ b/test/gromov/test_quantized.py @@ -1,4 +1,4 @@ -"""Tests for gromov._quantized.py """ +"""Tests for gromov._quantized.py""" # Author: Cédric Vincent-Cuaz # @@ -9,19 +9,18 @@ import ot -from ot.gromov._quantized import ( - networkx_import, sklearn_import) +from ot.gromov._quantized import networkx_import, sklearn_import def test_quantized_gw(nx): n_samples = 30 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - C1 = (C1 + C1.T) / 2. + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) + C1 = (C1 + C1.T) / 2.0 - C2 = rng.uniform(low=10., high=20., size=(n_samples, n_samples)) - C2 = (C2 + C2.T) / 2. + C2 = rng.uniform(low=10.0, high=20.0, size=(n_samples, n_samples)) + C2 = (C2 + C2.T) / 2.0 p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -33,11 +32,11 @@ def test_quantized_gw(nx): for npart1 in [1, n_samples + 1, 2]: log_tests = [True, False, False, True, True, False] - pairs_part_rep = [('random', 'random')] + pairs_part_rep = [("random", "random")] if networkx_import: - pairs_part_rep += [('louvain', 'random'), ('fluid', 'pagerank')] + pairs_part_rep += [("louvain", "random"), ("fluid", "pagerank")] if sklearn_import: - pairs_part_rep += [('spectral', 'random')] + pairs_part_rep += [("spectral", "random")] count_mode = 0 @@ -46,12 +45,32 @@ def test_quantized_gw(nx): count_mode += 1 res = ot.gromov.quantized_fused_gromov_wasserstein( - C1, C2, npart1, npart2, p, None, C1, None, part_method=part_method, - rep_method=rep_method, log=log_) + C1, + C2, + npart1, + npart2, + p, + None, + C1, + None, + part_method=part_method, + rep_method=rep_method, + log=log_, + ) resb = ot.gromov.quantized_fused_gromov_wasserstein( - C1b, C2b, npart1, npart2, None, qb, None, C2b, part_method=part_method, - rep_method=rep_method, log=log_) + C1b, + C2b, + npart1, + npart2, + None, + qb, + None, + C2b, + part_method=part_method, + rep_method=rep_method, + log=log_, + ) if log_: T_global, Ts_local, T, log = res @@ -64,9 +83,11 @@ def test_quantized_gw(nx): # check constraints np.testing.assert_allclose(T, Tb, atol=1e-06) np.testing.assert_allclose( - p, Tb.sum(1), atol=1e-06) # cf convergence gromov + p, Tb.sum(1), atol=1e-06 + ) # cf convergence gromov np.testing.assert_allclose( - q, Tb.sum(0), atol=1e-06) # cf convergence gromov + q, Tb.sum(0), atol=1e-06 + ) # cf convergence gromov if log_: for key in log.keys(): @@ -81,15 +102,15 @@ def test_quantized_fgw(nx): n_samples = 30 # nb samples rng = np.random.RandomState(0) - C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) - C1 = (C1 + C1.T) / 2. + C1 = rng.uniform(low=0.0, high=10, size=(n_samples, n_samples)) + C1 = (C1 + C1.T) / 2.0 - F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F1 = rng.uniform(low=0.0, high=10, size=(n_samples, 1)) - C2 = rng.uniform(low=10., high=20., size=(n_samples, n_samples)) - C2 = (C2 + C2.T) / 2. + C2 = rng.uniform(low=10.0, high=20.0, size=(n_samples, n_samples)) + C2 = (C2 + C2.T) / 2.0 - F2 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = rng.uniform(low=0.0, high=10, size=(n_samples, 1)) p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -103,13 +124,15 @@ def test_quantized_fgw(nx): pairs_part_rep = [] if networkx_import: - pairs_part_rep += [('louvain_fused', 'pagerank'), - ('louvain', 'pagerank_fused'), - ('fluid_fused', 'pagerank_fused')] + pairs_part_rep += [ + ("louvain_fused", "pagerank"), + ("louvain", "pagerank_fused"), + ("fluid_fused", "pagerank_fused"), + ] if sklearn_import: - pairs_part_rep += [('spectral_fused', 'random')] + pairs_part_rep += [("spectral_fused", "random")] - pairs_part_rep += [('random', 'random')] + pairs_part_rep += [("random", "random")] count_mode = 0 alpha = 0.5 @@ -119,12 +142,38 @@ def test_quantized_fgw(nx): count_mode += 1 res = ot.gromov.quantized_fused_gromov_wasserstein( - C1, C2, npart1, npart2, p, None, C1, None, F1, F2, alpha, - part_method, rep_method, log_) + C1, + C2, + npart1, + npart2, + p, + None, + C1, + None, + F1, + F2, + alpha, + part_method, + rep_method, + log_, + ) resb = ot.gromov.quantized_fused_gromov_wasserstein( - C1b, C2b, npart1, npart2, None, qb, None, C2b, F1b, F2b, alpha, - part_method, rep_method, log_) + C1b, + C2b, + npart1, + npart2, + None, + qb, + None, + C2b, + F1b, + F2b, + alpha, + part_method, + rep_method, + log_, + ) if log_: T_global, Ts_local, T, log = res @@ -136,10 +185,8 @@ def test_quantized_fgw(nx): Tb = nx.to_numpy(Tb) # check constraints np.testing.assert_allclose(T, Tb, atol=1e-06) - np.testing.assert_allclose( - p, Tb.sum(1), atol=1e-06) # cf convergence gromov - np.testing.assert_allclose( - q, Tb.sum(0), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose(p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose(q, Tb.sum(0), atol=1e-06) # cf convergence gromov if log_: for key in log.keys(): @@ -156,24 +203,31 @@ def test_quantized_fgw(nx): C2b_new = alpha * C2b + (1 - alpha) * DF2b part1b = ot.gromov.get_graph_partition( - C1b_new, npart1, part_method=pairs_part_rep[-1][0], random_state=0) + C1b_new, npart1, part_method=pairs_part_rep[-1][0], random_state=0 + ) part2b = ot.gromov._quantized.get_graph_partition( - C2b_new, npart2, part_method=pairs_part_rep[-1][0], random_state=0) + C2b_new, npart2, part_method=pairs_part_rep[-1][0], random_state=0 + ) rep_indices1b = ot.gromov.get_graph_representants( - C1b, part1b, rep_method=pairs_part_rep[-1][1], random_state=0) + C1b, part1b, rep_method=pairs_part_rep[-1][1], random_state=0 + ) rep_indices2b = ot.gromov.get_graph_representants( - C2b, part2b, rep_method=pairs_part_rep[-1][1], random_state=0) + C2b, part2b, rep_method=pairs_part_rep[-1][1], random_state=0 + ) CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_graph( - C1b, pb, part1b, rep_indices1b, F1b, DF1b, alpha) + C1b, pb, part1b, rep_indices1b, F1b, DF1b, alpha + ) CR2b, list_R2b, list_p2b, FR2b = ot.gromov.format_partitioned_graph( - C2b, qb, part2b, rep_indices2b, F2b, DF2b, alpha) + C2b, qb, part2b, rep_indices2b, F2b, DF2b, alpha + ) MRb = ot.dist(FR1b, FR2b) T_globalb, Ts_localb, _ = ot.gromov.quantized_fused_gromov_wasserstein_partitioned( - CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False) + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False + ) T_globalb = nx.to_numpy(T_globalb) np.testing.assert_allclose(T_global, T_globalb, atol=1e-06) @@ -183,36 +237,54 @@ def test_quantized_fgw(nx): np.testing.assert_allclose(Ts_local[key], T_localb, atol=1e-06) # tests for edge cases of the graph partitioning - for method in ['unknown_method', 'GW', 'FGW']: + for method in ["unknown_method", "GW", "FGW"]: with pytest.raises(ValueError): ot.gromov.get_graph_partition( - C1b, npart1, part_method=method, random_state=0) + C1b, npart1, part_method=method, random_state=0 + ) with pytest.raises(ValueError): ot.gromov.get_graph_partition( - C1b, npart1, part_method=method, alpha=0.5, F=None, random_state=0) + C1b, npart1, part_method=method, alpha=0.5, F=None, random_state=0 + ) # tests for edge cases of the representant selection with pytest.raises(ValueError): ot.gromov.get_graph_representants( - C1b, part1b, rep_method='unknown_method', random_state=0) + C1b, part1b, rep_method="unknown_method", random_state=0 + ) # tests for edge cases of the format_partitioned_graph function with pytest.raises(ValueError): CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_graph( - C1b, pb, part1b, rep_indices1b, F1b, None, alpha) + C1b, pb, part1b, rep_indices1b, F1b, None, alpha + ) # Tests in qFGW solvers # for non admissible values of alpha with pytest.raises(ValueError): ot.gromov.quantized_fused_gromov_wasserstein_partitioned( - CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, 0, build_OT=False) + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, 0, build_OT=False + ) # for non-consistent feature information provided with pytest.raises(ValueError): ot.gromov.quantized_fused_gromov_wasserstein( - C1, C2, npart1, npart2, p, q, None, None, F1, None, 0.5, - 'spectral_fused', 'random', log_) + C1, + C2, + npart1, + npart2, + p, + q, + None, + None, + F1, + None, + 0.5, + "spectral_fused", + "random", + log_, + ) @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -221,8 +293,8 @@ def test_quantized_gw_samples(nx): n_samples_2 = 20 # nb samples rng = np.random.RandomState(0) - X1 = rng.uniform(low=0., high=10, size=(n_samples_1, 2)) - X2 = rng.uniform(low=0., high=10, size=(n_samples_2, 4)) + X1 = rng.uniform(low=0.0, high=10, size=(n_samples_1, 2)) + X2 = rng.uniform(low=0.0, high=10, size=(n_samples_2, 4)) p = ot.unif(n_samples_1) q = ot.unif(n_samples_2) @@ -233,22 +305,24 @@ def test_quantized_gw_samples(nx): X1b, X2b, pb, qb = nx.from_numpy(X1, X2, p, q) log_tests = [True, False, True] - methods = ['random'] + methods = ["random"] if sklearn_import: - methods += ['kmeans'] + methods += ["kmeans"] count_mode = 0 - alpha = 1. + alpha = 1.0 for method in methods: log_ = log_tests[count_mode] count_mode += 1 res = ot.gromov.quantized_fused_gromov_wasserstein_samples( - X1, X2, npart1, npart2, p, None, None, None, alpha, method, log_) + X1, X2, npart1, npart2, p, None, None, None, alpha, method, log_ + ) resb = ot.gromov.quantized_fused_gromov_wasserstein_samples( - X1b, X2b, npart1, npart2, None, qb, None, None, alpha, method, log_) + X1b, X2b, npart1, npart2, None, qb, None, None, alpha, method, log_ + ) if log_: T_global, Ts_local, T, log = res @@ -260,10 +334,8 @@ def test_quantized_gw_samples(nx): Tb = nx.to_numpy(Tb) # check constraints np.testing.assert_allclose(T, Tb, atol=1e-06) - np.testing.assert_allclose( - p, Tb.sum(1), atol=1e-06) # cf convergence gromov - np.testing.assert_allclose( - q, Tb.sum(0), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose(p, Tb.sum(1), atol=1e-06) # cf convergence gromov + np.testing.assert_allclose(q, Tb.sum(0), atol=1e-06) # cf convergence gromov if log_: for key in log.keys(): @@ -276,7 +348,8 @@ def test_quantized_gw_samples(nx): # tests for edge cases of the representant selection with pytest.raises(ValueError): ot.gromov.get_partition_and_representants_samples( - X1, npart1, method='unknown_method', random_state=0) + X1, npart1, method="unknown_method", random_state=0 + ) @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -285,11 +358,11 @@ def test_quantized_fgw_samples(nx): n_samples_2 = 30 # nb samples rng = np.random.RandomState(0) - X1 = rng.uniform(low=0., high=10, size=(n_samples_1, 2)) - X2 = rng.uniform(low=0., high=10, size=(n_samples_2, 4)) + X1 = rng.uniform(low=0.0, high=10, size=(n_samples_1, 2)) + X2 = rng.uniform(low=0.0, high=10, size=(n_samples_2, 4)) - F1 = rng.uniform(low=0., high=10, size=(n_samples_1, 3)) - F2 = rng.uniform(low=0., high=10, size=(n_samples_2, 3)) + F1 = rng.uniform(low=0.0, high=10, size=(n_samples_1, 3)) + F2 = rng.uniform(low=0.0, high=10, size=(n_samples_2, 3)) p = ot.unif(n_samples_1) q = ot.unif(n_samples_2) @@ -301,8 +374,8 @@ def test_quantized_fgw_samples(nx): methods = [] if sklearn_import: - methods += ['kmeans', 'kmeans_fused'] - methods += ['random'] + methods += ["kmeans", "kmeans_fused"] + methods += ["random"] alpha = 0.5 @@ -315,10 +388,12 @@ def test_quantized_fgw_samples(nx): count_mode += 1 res = ot.gromov.quantized_fused_gromov_wasserstein_samples( - X1, X2, npart1, npart2, p, None, F1, F2, alpha, method, log_) + X1, X2, npart1, npart2, p, None, F1, F2, alpha, method, log_ + ) resb = ot.gromov.quantized_fused_gromov_wasserstein_samples( - X1b, X2b, npart1, npart2, None, qb, F1b, F2b, alpha, method, log_) + X1b, X2b, npart1, npart2, None, qb, F1b, F2b, alpha, method, log_ + ) if log_: T_global, Ts_local, T, log = res @@ -331,9 +406,11 @@ def test_quantized_fgw_samples(nx): # check constraints np.testing.assert_allclose(T, Tb, atol=1e-06) np.testing.assert_allclose( - p, Tb.sum(1), atol=1e-06) # cf convergence gromov + p, Tb.sum(1), atol=1e-06 + ) # cf convergence gromov np.testing.assert_allclose( - q, Tb.sum(0), atol=1e-06) # cf convergence gromov + q, Tb.sum(0), atol=1e-06 + ) # cf convergence gromov if log_: for key in log.keys(): @@ -345,19 +422,24 @@ def test_quantized_fgw_samples(nx): # complementary tests for utils functions part1b, rep_indices1 = ot.gromov.get_partition_and_representants_samples( - X1b, npart1, method=method, random_state=0) + X1b, npart1, method=method, random_state=0 + ) part2b, rep_indices2 = ot.gromov.get_partition_and_representants_samples( - X2b, npart2, method=method, random_state=0) + X2b, npart2, method=method, random_state=0 + ) CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_samples( - X1b, pb, part1b, rep_indices1, F1b, alpha) + X1b, pb, part1b, rep_indices1, F1b, alpha + ) CR2b, list_R2b, list_p2b, FR2b = ot.gromov.format_partitioned_samples( - X2b, qb, part2b, rep_indices2, F2b, alpha) + X2b, qb, part2b, rep_indices2, F2b, alpha + ) MRb = ot.dist(FR1b, FR2b) T_globalb, Ts_localb, _ = ot.gromov.quantized_fused_gromov_wasserstein_partitioned( - CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False) + CR1b, CR2b, list_R1b, list_R2b, list_p1b, list_p2b, MRb, alpha, build_OT=False + ) T_globalb = nx.to_numpy(T_globalb) np.testing.assert_allclose(T_global, T_globalb, atol=1e-06) @@ -369,9 +451,11 @@ def test_quantized_fgw_samples(nx): # tests for edge cases of the format_partitioned_graph function with pytest.raises(ValueError): CR1b, list_R1b, list_p1b, FR1b = ot.gromov.format_partitioned_samples( - X1b, pb, part1b, rep_indices1, None, alpha) + X1b, pb, part1b, rep_indices1, None, alpha + ) # for non-consistent feature information provided with pytest.raises(ValueError): ot.gromov.quantized_fused_gromov_wasserstein_samples( - X1, X2, npart1, npart2, p, None, None, F2, alpha, 'fused_spectral', log_) + X1, X2, npart1, npart2, p, None, None, F2, alpha, "fused_spectral", log_ + ) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 0a9e25d17..3f7668bb7 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -1,4 +1,4 @@ -""" Tests for gromov._semirelaxed.py """ +"""Tests for gromov._semirelaxed.py""" # Author: Cédric Vincent-Cuaz # @@ -10,8 +10,7 @@ import ot from ot.backend import torch -from ot.gromov._utils import ( - networkx_import, sklearn_import) +from ot.gromov._utils import networkx_import, sklearn_import def test_semirelaxed_gromov(nx): @@ -22,8 +21,7 @@ def test_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.1], - [0.1, 1.]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], [0.1, 1.0]], dtype=np.float64) pos = [0, 30, 45] @@ -33,7 +31,7 @@ def test_semirelaxed_gromov(nx): xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) pos_i_min, pos_i_max = pos[i], pos[i + 1] pos_j_min, pos_j_max = pos[j], pos[j + 1] - C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + C1[pos_i_min:pos_i_max, pos_j_min:pos_j_max] = xij p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) @@ -41,12 +39,21 @@ def test_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + C1, C2, p, loss_fun="square_loss", symmetric=None, log=True, G0=G0 + ) Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, - G0=None, alpha_min=0., alpha_max=1.) + C1b, + C2b, + None, + loss_fun="square_loss", + symmetric=False, + log=True, + G0=None, + alpha_min=0.0, + alpha_max=1.0, + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -55,19 +62,23 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + C1, C2, None, loss_fun="square_loss", symmetric=False, log=True, G0=G0 + ) srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + C1b, C2b, pb, loss_fun="square_loss", symmetric=None, log=True, G0=None + ) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G = log2["T"] + Gb = nx.to_numpy(logb2["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + list_n / ns, Gb.sum(0), atol=1e-04 + ) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srgw_dist"], logb["srgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srgw_dist"], log["srgw_dist"], atol=1e-07) # symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) @@ -75,49 +86,66 @@ def test_semirelaxed_gromov(nx): C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) init_plan_list = [ - (None, G0b), ('product', None), ("random_product", "random_product")] + (None, G0b), + ("product", None), + ("random_product", "random_product"), + ] if networkx_import: - init_plan_list += [('fluid', 'fluid'), ('fluid_soft', 'fluid_soft')] + init_plan_list += [("fluid", "fluid"), ("fluid_soft", "fluid_soft")] if sklearn_import: init_plan_list += [ - ("spectral", "spectral"), ("spectral_soft", "spectral_soft"), - ("kmeans", "kmeans"), ("kmeans_soft", "kmeans_soft")] - - for (init, init_b) in init_plan_list: + ("spectral", "spectral"), + ("spectral_soft", "spectral_soft"), + ("kmeans", "kmeans"), + ("kmeans_soft", "kmeans_soft"), + ] + for init, init_b in init_plan_list: G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=init) + C1, C2, p, loss_fun="square_loss", symmetric=None, log=True, G0=init + ) Gb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=init_b) + C1b, C2b, pb, loss_fun="square_loss", symmetric=True, log=False, G0=init_b + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov if not isinstance(init, str): - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + list_n / ns, nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov else: - if 'spectral' not in init: # issues with spectral clustering related to label switching + if ( + "spectral" not in init + ): # issues with spectral clustering related to label switching np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + C1, C2, p, loss_fun="square_loss", symmetric=True, log=True, G0=G0 + ) srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + C1b, C2b, pb, loss_fun="square_loss", symmetric=None, log=True, G0=None + ) - srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun="square_loss", symmetric=True, log=False, G0=G0 + ) - G = log2['T'] + G = log2["T"] # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srgw_dist"], log["srgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srgw_dist"], log["srgw_dist"], atol=1e-07) np.testing.assert_allclose(srgw, srgw_, atol=1e-07) @@ -140,18 +168,19 @@ def test_semirelaxed_gromov2_gradients(): C2 /= C2.max() if torch: - devices = [torch.device("cpu")] if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: # semirelaxed solvers do not support gradients over masses yet. p1 = torch.tensor(p, requires_grad=False, device=device) C11 = torch.tensor(C1, requires_grad=True, device=device) C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun) + val = ot.gromov.semirelaxed_gromov_wasserstein2( + C11, C12, p1, loss_fun=loss_fun + ) val.backward() @@ -178,12 +207,16 @@ def test_srgw_helper_backend(nx): C1 /= C1.max() C2 /= C2.max() - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True + ) # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun) + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed( + C1b, C2b, pb, loss_fun + ) ones_pb = nx.ones(pb.shape[0], type_as=pb) def f(G): @@ -198,17 +231,36 @@ def df(G): def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None) + G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0.0, 1.0, fC2t=fC2tb, nx=None + ) + # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + res, log = ot.optim.semirelaxed_cg( + pb, + qb, + 0.0, + 1.0, + f, + df, + Gb, + line_search, + log=True, + numItermax=1e4, + stopThr=1e-9, + stopThr2=1e-9, + ) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) -@pytest.mark.parametrize('loss_fun', [ - 'square_loss', 'kl_loss', - pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), -]) +@pytest.mark.parametrize( + "loss_fun", + [ + "square_loss", + "kl_loss", + pytest.param("unknown_loss", marks=pytest.mark.xfail(raises=ValueError)), + ], +) def test_gw_semirelaxed_helper_validation(loss_fun): n_samples = 20 # nb samples mu = np.array([0, 0]) @@ -228,8 +280,7 @@ def test_semirelaxed_fgw(nx): ns = 24 # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns)) - C2 = np.array([[0.7, 0.05], - [0.05, 0.9]]) + C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) pos = [0, 16, 24] @@ -239,14 +290,18 @@ def test_semirelaxed_fgw(nx): xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) pos_i_min, pos_i_max = pos[i], pos[i + 1] pos_j_min, pos_j_max = pos[j], pos[j + 1] - C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + C1[pos_i_min:pos_i_max, pos_j_min:pos_j_max] = xij F1 = np.zeros((ns, 1)) - F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) - F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) + F1[:16] = rng.normal(loc=0.0, scale=0.01, size=(16, 1)) + F1[16:] = rng.normal(loc=1.0, scale=0.01, size=(8, 1)) F2 = np.zeros((2, 1)) - F2[1, :] = 1. - M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + F2[1, :] = 1.0 + M = ( + (F1**2).dot(np.ones((1, nt))) + + np.ones((ns, 1)).dot((F2**2).T) + - 2 * F1.dot(F2.T) + ) p = ot.unif(ns) q0 = ot.unif(C2.shape[0]) @@ -255,68 +310,146 @@ def test_semirelaxed_fgw(nx): # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( - M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + M, + C1, + C2, + None, + loss_fun="square_loss", + alpha=0.5, + symmetric=None, + log=True, + G0=None, + ) Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + Mb, + C1b, + C2b, + pb, + loss_fun="square_loss", + alpha=0.5, + symmetric=False, + log=True, + G0=G0b, + ) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov # asymmetric - check consistency between srFGW and srFGW2 srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( - M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + M, + C1, + C2, + p, + loss_fun="square_loss", + alpha=0.5, + symmetric=False, + log=True, + G0=G0, + ) srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Mb, + C1b, + C2b, + None, + loss_fun="square_loss", + alpha=0.5, + symmetric=None, + log=True, + G0=None, + ) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G = log2["T"] + Gb = nx.to_numpy(logb2["T"]) np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], G.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], G.sum(0), atol=1e-04 + ) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srfgw_dist"], logb["srfgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) # symmetric structures + checking losses + inits C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) init_plan_list = [ - (None, G0b), ('product', None), ("random_product", "random_product")] + (None, G0b), + ("product", None), + ("random_product", "random_product"), + ] if networkx_import: - init_plan_list += [('fluid', 'fluid')] + init_plan_list += [("fluid", "fluid")] if sklearn_import: init_plan_list += [("kmeans", "kmeans")] - for loss_fun in ['square_loss', 'kl_loss']: - for (init, init_b) in init_plan_list: - + for loss_fun in ["square_loss", "kl_loss"]: + for init, init_b in init_plan_list: G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=init) + M, + C1, + C2, + p, + loss_fun=loss_fun, + alpha=0.5, + symmetric=None, + log=True, + G0=init, + ) Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=init_b) + Mb, + C1b, + C2b, + pb, + loss_fun=loss_fun, + alpha=0.5, + symmetric=True, + log=False, + G0=init_b, + ) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov # checking consistency with srFGW and srFGW2 solvers - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0 + ) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + Mb, + C1b, + C2b, + pb, + loss_fun=loss_fun, + alpha=0.5, + symmetric=None, + log=True, + G0=None, + ) - G2 = log2['T'] - Gb2 = nx.to_numpy(logb2['T']) + G2 = log2["T"] + Gb2 = nx.to_numpy(logb2["T"]) # check constraints np.testing.assert_allclose(G2, Gb2, atol=1e-06) np.testing.assert_allclose(G2, G, atol=1e-06) - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) np.testing.assert_allclose(srgw, srgwb, atol=1e-07) @@ -340,19 +473,20 @@ def test_semirelaxed_fgw2_gradients(): C2 /= C2.max() if torch: - devices = [torch.device("cpu")] if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: # semirelaxed solvers do not support gradients over masses yet. - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: p1 = torch.tensor(p, requires_grad=False, device=device) C11 = torch.tensor(C1, requires_grad=True, device=device) C12 = torch.tensor(C2, requires_grad=True, device=device) M1 = torch.tensor(M, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + M1, C11, C12, p1, loss_fun=loss_fun + ) val.backward() @@ -369,7 +503,9 @@ def test_semirelaxed_fgw2_gradients(): M1 = torch.tensor(M, requires_grad=True, device=device) alpha = torch.tensor(0.5, requires_grad=True, device=device) - val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha) + val = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha + ) val.backward() @@ -407,10 +543,23 @@ def test_srfgw_helper_backend(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) alpha = 0.5 - Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + "square_loss", + alpha=0.5, + armijo=False, + symmetric=True, + G0=G0b, + log=True, + ) # calls with nx=None - constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss') + constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed( + C1b, C2b, pb, loss_fun="square_loss" + ) ones_pb = nx.ones(pb.shape[0], type_as=pb) def f(G): @@ -425,9 +574,24 @@ def df(G): def line_search(cost, G, deltaG, Mi, cost_G, df_G): return ot.gromov.solve_semirelaxed_gromov_linesearch( - G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None) + G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None + ) + # feed the precomputed local optimum Gb to semirelaxed_cg - res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) + res, log = ot.optim.semirelaxed_cg( + pb, + qb, + (1 - alpha) * Mb, + alpha, + f, + df, + Gb, + line_search, + log=True, + numItermax=1e4, + stopThr=1e-9, + stopThr2=1e-9, + ) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) @@ -439,8 +603,7 @@ def test_entropic_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.1], - [0.1, 0.9]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], [0.1, 0.9]], dtype=np.float64) rng = np.random.RandomState(0) @@ -452,7 +615,7 @@ def test_entropic_semirelaxed_gromov(nx): xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) pos_i_min, pos_i_max = pos[i], pos[i + 1] pos_j_min, pos_j_max = pos[j], pos[j + 1] - C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + C1[pos_i_min:pos_i_max, pos_j_min:pos_j_max] = xij p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) @@ -460,9 +623,27 @@ def test_entropic_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) epsilon = 0.1 - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None) + for loss_fun in ["square_loss", "kl_loss"]: + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1, + C2, + p, + loss_fun=loss_fun, + epsilon=epsilon, + symmetric=None, + log=True, + G0=G0, + ) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, + C2b, + None, + loss_fun=loss_fun, + epsilon=epsilon, + symmetric=False, + log=True, + G0=None, + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -470,18 +651,38 @@ def test_entropic_semirelaxed_gromov(nx): np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1, + C2, + None, + loss_fun=loss_fun, + epsilon=epsilon, + symmetric=False, + log=True, + G0=G0, + ) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, + C2b, + pb, + loss_fun=loss_fun, + epsilon=epsilon, + symmetric=None, + log=True, + G0=None, + ) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G = log2["T"] + Gb = nx.to_numpy(logb2["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + list_n / ns, Gb.sum(0), atol=1e-04 + ) # cf convergence gromov - np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srgw_dist"], logb["srgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srgw_dist"], log["srgw_dist"], atol=1e-07) # symmetric - testing various initialization of the OT plan. @@ -491,51 +692,91 @@ def test_entropic_semirelaxed_gromov(nx): init_plan_list = [] # tests longer than with CG so we do not test all inits. if networkx_import: - init_plan_list += [('fluid', 'fluid')] + init_plan_list += [("fluid", "fluid")] if sklearn_import: init_plan_list += [("kmeans", "kmeans")] - init_plan_list += [ - ('product', None), (None, G0b)] + init_plan_list += [("product", None), (None, G0b)] - for (init, init_b) in init_plan_list: - print(f'---- init : {init} / init_b : {init_b}') + for init, init_b in init_plan_list: + print(f"---- init : {init} / init_b : {init_b}") G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, - log=True, G0=init) + C1, + C2, + p, + loss_fun="square_loss", + epsilon=epsilon, + symmetric=None, + log=True, + G0=init, + ) Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=True, - log=True, G0=init_b) + C1b, + C2b, + pb, + loss_fun="square_loss", + epsilon=epsilon, + symmetric=True, + log=True, + G0=init_b, + ) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov if not isinstance(init, str): - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + list_n / ns, nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov else: np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # comparison between srGW and srGW2 solvers srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( - C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, - log=True, G0=init) + C1, + C2, + p, + loss_fun="square_loss", + epsilon=epsilon, + symmetric=True, + log=True, + G0=init, + ) srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( - C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, - log=True, G0=init_b) + C1b, + C2b, + pb, + loss_fun="square_loss", + epsilon=epsilon, + symmetric=None, + log=True, + G0=init_b, + ) - srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) + srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1, + C2, + p, + loss_fun="square_loss", + epsilon=epsilon, + symmetric=True, + log=False, + G0=G0, + ) - G2 = log2['T'] - G2b = logb2['T'] + G2 = log2["T"] + G2b = logb2["T"] # check constraints np.testing.assert_allclose(G2, G2b, atol=1e-06) np.testing.assert_allclose(G2, G, atol=1e-06) - np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srgw_dist"], log["srgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srgw_dist"], log["srgw_dist"], atol=1e-07) np.testing.assert_allclose(srgw, srgw_, atol=1e-07) @@ -561,9 +802,8 @@ def test_entropic_semirelaxed_gromov_dtype_device(nx): C2 /= C2.max() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( @@ -584,8 +824,7 @@ def test_entropic_semirelaxed_fgw(nx): ns = 24 # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns)) - C2 = np.array([[0.7, 0.05], - [0.05, 0.9]]) + C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) pos = [0, 16, 24] @@ -595,14 +834,18 @@ def test_entropic_semirelaxed_fgw(nx): xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) pos_i_min, pos_i_max = pos[i], pos[i + 1] pos_j_min, pos_j_max = pos[j], pos[j + 1] - C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + C1[pos_i_min:pos_i_max, pos_j_min:pos_j_max] = xij F1 = np.zeros((ns, 1)) - F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) - F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) + F1[:16] = rng.normal(loc=0.0, scale=0.01, size=(16, 1)) + F1[16:] = rng.normal(loc=1.0, scale=0.01, size=(8, 1)) F2 = np.zeros((2, 1)) - F2[1, :] = 1. - M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + F2[1, :] = 1.0 + M = ( + (F1**2).dot(np.ones((1, nt))) + + np.ones((ns, 1)).dot((F2**2).T) + - 2 * F1.dot(F2.T) + ) p = ot.unif(ns) q0 = ot.unif(C2.shape[0]) @@ -611,72 +854,164 @@ def test_entropic_semirelaxed_fgw(nx): # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + M, + C1, + C2, + None, + loss_fun="square_loss", + epsilon=0.1, + alpha=0.5, + symmetric=None, + log=True, + G0=None, + ) + Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, + C1b, + C2b, + pb, + loss_fun="square_loss", + epsilon=0.1, + alpha=0.5, + symmetric=False, + log=True, + G0=G0b, + ) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + M, + C1, + C2, + p, + loss_fun="square_loss", + epsilon=0.1, + alpha=0.5, + symmetric=False, + log=True, + G0=G0, + ) srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Mb, + C1b, + C2b, + None, + loss_fun="square_loss", + epsilon=0.1, + alpha=0.5, + symmetric=None, + log=True, + G0=None, + ) - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G = log2["T"] + Gb = nx.to_numpy(logb2["T"]) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], Gb.sum(0), atol=1e-04 + ) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srfgw_dist"], logb["srfgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) # symmetric structures + checking losses + inits C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) init_plan_list = [ - (None, G0b), ('product', None), ("random_product", "random_product")] + (None, G0b), + ("product", None), + ("random_product", "random_product"), + ] if networkx_import: - init_plan_list += [('fluid', 'fluid')] + init_plan_list += [("fluid", "fluid")] if sklearn_import: init_plan_list += [("kmeans", "kmeans")] - for loss_fun in ['square_loss', 'kl_loss']: - for (init, init_b) in init_plan_list: - + for loss_fun in ["square_loss", "kl_loss"]: + for init, init_b in init_plan_list: G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, - symmetric=None, log=True, G0=init) + M, + C1, + C2, + p, + loss_fun=loss_fun, + epsilon=0.1, + alpha=0.5, + symmetric=None, + log=True, + G0=init, + ) Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( - Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, - symmetric=True, log=False, G0=init_b) + Mb, + C1b, + C2b, + pb, + loss_fun=loss_fun, + epsilon=0.1, + alpha=0.5, + symmetric=True, + log=False, + G0=init_b, + ) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose( + p, nx.sum(Gb, axis=1), atol=1e-04 + ) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02 + ) # cf convergence gromov # checking consistency with srFGW and srFGW2 solvers srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, - symmetric=True, log=True, G0=init) + M, + C1, + C2, + p, + loss_fun=loss_fun, + epsilon=0.1, + alpha=0.5, + symmetric=True, + log=True, + G0=init, + ) srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( - Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, - symmetric=None, log=True, G0=init_b) + Mb, + C1b, + C2b, + pb, + loss_fun=loss_fun, + epsilon=0.1, + alpha=0.5, + symmetric=None, + log=True, + G0=init_b, + ) - G2 = log2['T'] - Gb2 = nx.to_numpy(logb2['T']) + G2 = log2["T"] + Gb2 = nx.to_numpy(logb2["T"]) np.testing.assert_allclose(G2, Gb2, atol=1e-06) np.testing.assert_allclose(G2, G, atol=1e-06) np.testing.assert_allclose(p, Gb2.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb2.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + [2 / 3, 1 / 3], Gb2.sum(0), atol=1e-04 + ) # cf convergence gromov - np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(log2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) + np.testing.assert_allclose(logb2["srfgw_dist"], log["srfgw_dist"], atol=1e-07) np.testing.assert_allclose(srgw, srgwb, atol=1e-07) @@ -710,7 +1045,7 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) - for loss_fun in ['square_loss', 'kl_loss']: + for loss_fun in ["square_loss", "kl_loss"]: Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True ) @@ -728,8 +1063,8 @@ def test_semirelaxed_gromov_barycenter(nx): ns = 5 nt = 8 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -741,118 +1076,232 @@ def test_semirelaxed_gromov_barycenter(nx): # test on admissible stopping criterion with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" _ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, ) # test consistency of outputs across backends with 'square_loss' # using different losses # + tests on different inits - init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), - ('random', 'random')] + init_plan_list = [("fluid", "fluid"), ("kmeans", "kmeans"), ("random", "random")] - for (init, init_b) in init_plan_list: + for init, init_b in init_plan_list: + for stop_criterion in ["barycenter", "loss"]: + print("--- stop_criterion:", stop_criterion) - for stop_criterion in ['barycenter', 'loss']: - print('--- stop_criterion:', stop_criterion) - - if (init == 'fluid') and (not networkx_import): + if (init == "fluid") and (not networkx_import): with pytest.raises(ValueError): Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42, G0=init + n_samples, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0=init, ) - elif (init == 'kmeans') and (not sklearn_import): + elif (init == "kmeans") and (not sklearn_import): with pytest.raises(ValueError): Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42, G0=init + n_samples, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0=init, ) else: - Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42, G0=init + n_samples, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0=init, ) - Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0=init_b - )) + Cbb = nx.to_numpy( + ot.gromov.semirelaxed_gromov_barycenters( + n_samples, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0=init_b, + ) + ) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=5, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - warmstartT=True, random_state=42, log=True, G0=init, + n_samples, + [C1, C2], + [p1, p2], + None, + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + warmstartT=True, + random_state=42, + log=True, + G0=init, ) Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, warmstartT=True, random_state=42, log=True, G0=init_b + n_samples, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + warmstartT=True, + random_state=42, + log=True, + G0=init_b, ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_array_almost_equal( + err_["err"], nx.to_numpy(*errb_["err"]) + ) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) # test consistency across backends with larger barycenter than inputs if sklearn_import: C = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion='loss', verbose=False, - random_state=42, G0='kmeans' + ns, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion="loss", + verbose=False, + random_state=42, + G0="kmeans", ) Cb = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0='kmeans') + ns, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0="kmeans", + ) np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C C_ = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion='loss', verbose=False, - random_state=42, G0=init, init_C=C1 + ns, + [C1, C2], + None, + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion="loss", + verbose=False, + random_state=42, + G0=init, + init_C=C1, ) Cb_ = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0=init_b, init_C=C1b) + ns, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "square_loss", + max_iter=5, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + G0=init_b, + init_C=C1b, + ) np.testing.assert_allclose(C_, Cb_, atol=1e-06) # test consistency across backends with 'kl_loss' Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=5, - tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, - G0=init_b, random_state=42 + n_samples, + [C1, C2], + [p1, p2], + [0.5, 0.5], + "kl_loss", + max_iter=5, + tol=1e-3, + warmstartT=False, + stop_criterion="loss", + log=True, + G0=init_b, + random_state=42, ) Cb2b, errb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=5, - tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, - G0=init_b, random_state=42 + n_samples, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "kl_loss", + max_iter=5, + tol=1e-3, + warmstartT=False, + stop_criterion="loss", + log=True, + G0=init_b, + random_state=42, ) Cb2b = nx.to_numpy(Cb2b) try: np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) # may differ from permutation except AssertionError: - np.testing.assert_allclose(err['loss'][-1], errb['loss'][-1], atol=1e-06) + np.testing.assert_allclose(err["loss"][-1], errb["loss"][-1], atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @@ -865,31 +1314,64 @@ def test_semirelaxed_gromov_barycenter(nx): init_Cb = nx.from_numpy(init_C) Cb2_, err2_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C + n_samples, + [C1, C2], + [p1, p2], + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + verbose=False, + random_state=42, + log=True, + init_C=init_C, ) Cb2b_, err2b_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, verbose=True, random_state=42, - init_C=init_Cb, log=True + n_samples, + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + "square_loss", + max_iter=10, + tol=1e-3, + verbose=True, + random_state=42, + init_C=init_Cb, + log=True, ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_array_almost_equal(err2_["err"], nx.to_numpy(*err2b_["err"])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) # test edge cases for gw barycenters: # unique input structure Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1], None, None, 'square_loss', max_iter=1, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 + n_samples, + [C1], + None, + None, + "square_loss", + max_iter=1, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) + Cbb = nx.to_numpy( + ot.gromov.semirelaxed_gromov_barycenters( + n_samples, + [C1b], + None, + [1.0], + "square_loss", + max_iter=1, + tol=1e-3, + stop_criterion=stop_criterion, + verbose=False, + random_state=42, + ) ) - Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b], None, [1.], 'square_loss', - max_iter=1, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42 - )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) @@ -898,8 +1380,8 @@ def test_semirelaxed_fgw_barycenter(nx): ns = 10 nt = 20 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) rng = np.random.RandomState(42) ys = rng.randn(Xs.shape[0], 2) @@ -915,19 +1397,27 @@ def test_semirelaxed_fgw_barycenter(nx): ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) - lambdas = [.5, .5] + lambdas = [0.5, 0.5] Csb = [C1b, C2b] Ysb = [ysb, ytb] Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, - fixed_features=False, loss_fun='square_loss', max_iter=10, tol=1e-3, - random_state=12345, log=True + n_samples, + Ysb, + Csb, + None, + lambdas, + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + random_state=12345, + log=True, ) # test correspondance with utils function - recovered_Cb = ot.gromov.update_barycenter_structure( - logb['T'], Csb, lambdas) - recovered_Xb = ot.gromov.update_barycenter_feature( - logb['T'], Ysb, lambdas) + recovered_Cb = ot.gromov.update_barycenter_structure(logb["T"], Csb, lambdas) + recovered_Xb = ot.gromov.update_barycenter_feature(logb["T"], Ysb, lambdas) np.testing.assert_allclose(Cb, recovered_Cb) np.testing.assert_allclose(Xb, recovered_Xb) @@ -937,17 +1427,37 @@ def test_semirelaxed_fgw_barycenter(nx): init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_structure=True`and `init_C=None` Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, alpha=0.5, - fixed_structure=True, init_C=None, fixed_features=False, - loss_fun='square_loss', max_iter=10, tol=1e-3 + n_samples, + Ysb, + Csb, + ps=[p1b, p2b], + lambdas=None, + alpha=0.5, + fixed_structure=True, + init_C=None, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, ) Xb, Cb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - loss_fun='square_loss', max_iter=10, tol=1e-3 + n_samples, + [ysb, ytb], + [C1b, C2b], + ps=[p1b, p2b], + lambdas=None, + alpha=0.5, + fixed_structure=True, + init_C=init_Cb, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) @@ -957,18 +1467,44 @@ def test_semirelaxed_fgw_barycenter(nx): init_Xb = nx.from_numpy(init_X) # Tests with `fixed_structure=False` and `fixed_features=True` - with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + with pytest.raises( + ot.utils.UndefinedParameter + ): # to raise an error when `fixed_features=True`and `init_X=None` Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=None, - loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=True, + init_X=None, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_Xb, - loss_fun='square_loss', max_iter=10, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=True, + init_X=init_Xb, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + warmstartT=True, + log=True, + random_state=98765, + verbose=True, ) X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) @@ -977,20 +1513,47 @@ def test_semirelaxed_fgw_barycenter(nx): # add test with 'kl_loss' with pytest.raises(ValueError): - stop_criterion = 'unknown stop criterion' + stop_criterion = "unknown stop criterion" X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='kl_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="kl_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=C, + init_X=X, + warmstartT=True, + random_state=12345, + log=True, ) - for stop_criterion in ['barycenter', 'loss']: + for stop_criterion in ["barycenter", "loss"]: X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='kl_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, - init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="kl_loss", + max_iter=10, + tol=1e-3, + stop_criterion=stop_criterion, + init_C=C, + init_X=X, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) @@ -998,59 +1561,128 @@ def test_semirelaxed_fgw_barycenter(nx): # test correspondance with utils function recovered_C = ot.gromov.update_barycenter_structure( - log['T'], [C1, C2], lambdas, None, 'kl_loss', True) + log["T"], [C1, C2], lambdas, None, "kl_loss", True + ) np.testing.assert_allclose(C, recovered_C) # test consistency of outputs across backends with 'square_loss' # with various initialization of G0 - init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), - ('product', 'product'), ('random', 'random')] + init_plan_list = [ + ("fluid", "fluid"), + ("kmeans", "kmeans"), + ("product", "product"), + ("random", "random"), + ] - for (init, init_b) in init_plan_list: - print(f'---- init : {init} / init_b : {init_b}') + for init, init_b in init_plan_list: + print(f"---- init : {init} / init_b : {init_b}") - if (init == 'fluid') and (not networkx_import): + if (init == "fluid") and (not networkx_import): with pytest.raises(ValueError): - X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, - warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0=init, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) - elif (init == 'kmeans') and (not sklearn_import): + elif (init == "kmeans") and (not sklearn_import): with pytest.raises(ValueError): - X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, - warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0=init, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) else: X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, - warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0=init, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, - warmstartT=True, random_state=12345, log=True, verbose=True + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0=init_b, + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) np.testing.assert_allclose(X, nx.to_numpy(Xb)) np.testing.assert_allclose(C, nx.to_numpy(Cb)) # test while providing advanced T inits and init_X != None, and init_C !=None Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0='random', - warmstartT=True, random_state=12345, log=True, verbose=True, - init_C=Cb, init_X=Xb + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0="random", + warmstartT=True, + random_state=12345, + log=True, + verbose=True, + init_C=Cb, + init_X=Xb, ) np.testing.assert_allclose(Xb, Xb_) np.testing.assert_allclose(Cb, Cb_) @@ -1058,24 +1690,64 @@ def test_semirelaxed_fgw_barycenter(nx): # test consistency of backends while barycenter size not strictly inferior to sizes if sklearn_import: Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', - warmstartT=True, random_state=12345, log=True, verbose=True, - init_C=Cb, init_X=Xb + n_samples, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0="kmeans", + warmstartT=True, + random_state=12345, + log=True, + verbose=True, + init_C=Cb, + init_X=Xb, ) X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', - warmstartT=True, random_state=12345, log=True, verbose=True + ns, + [ys, yt], + [C1, C2], + [p1, p2], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0="kmeans", + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - ns, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', - warmstartT=True, random_state=12345, log=True, verbose=True + ns, + [ysb, ytb], + [C1b, C2b], + [p1b, p2b], + [0.5, 0.5], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=10, + tol=1e-3, + stop_criterion="loss", + G0="kmeans", + warmstartT=True, + random_state=12345, + log=True, + verbose=True, ) np.testing.assert_allclose(X, nx.to_numpy(Xb)) np.testing.assert_allclose(C, nx.to_numpy(Cb)) @@ -1083,16 +1755,40 @@ def test_semirelaxed_fgw_barycenter(nx): # test edge cases for semirelaxed fgw barycenters: # unique input structure X, C = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys], [C1], [p1], None, 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=2, tol=1e-3, stop_criterion=stop_criterion, - warmstartT=True, random_state=12345, log=False, verbose=False + n_samples, + [ys], + [C1], + [p1], + None, + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=2, + tol=1e-3, + stop_criterion=stop_criterion, + warmstartT=True, + random_state=12345, + log=False, + verbose=False, ) Xb, Cb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb], [C1b], [p1b], [1.], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=2, tol=1e-3, stop_criterion=stop_criterion, - warmstartT=True, random_state=12345, log=False, verbose=False + n_samples, + [ysb], + [C1b], + [p1b], + [1.0], + 0.5, + fixed_structure=False, + fixed_features=False, + loss_fun="square_loss", + max_iter=2, + tol=1e-3, + stop_criterion=stop_criterion, + warmstartT=True, + random_state=12345, + log=False, + verbose=False, ) np.testing.assert_allclose(C, Cb, atol=1e-06) diff --git a/test/gromov/test_utils.py b/test/gromov/test_utils.py index b0338c84a..107d0d2c3 100644 --- a/test/gromov/test_utils.py +++ b/test/gromov/test_utils.py @@ -1,4 +1,4 @@ -""" Tests for gromov._utils.py """ +"""Tests for gromov._utils.py""" # Author: Cédric Vincent-Cuaz # @@ -9,16 +9,15 @@ import pytest import ot -from ot.gromov._utils import ( - networkx_import, sklearn_import) +from ot.gromov._utils import networkx_import, sklearn_import def test_update_barycenter(nx): ns = 5 nt = 10 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) rng = np.random.RandomState(42) ys = rng.randn(Xs.shape[0], 2) @@ -34,25 +33,28 @@ def test_update_barycenter(nx): ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) - lambdas = [.5, .5] + lambdas = [0.5, 0.5] Csb = [C1b, C2b] Ysb = [ysb, ytb] Tb = [nx.ones((m, n_samples), type_as=C1b) / (m * n_samples) for m in [ns, nt]] - pb = nx.concatenate( - [nx.sum(elem, 0)[None, :] for elem in Tb], axis=0) + pb = nx.concatenate([nx.sum(elem, 0)[None, :] for elem in Tb], axis=0) # test edge cases for the update of the barycenter with `p != None` # and `target=False` Cb = ot.gromov.update_barycenter_structure( - [elem.T for elem in Tb], Csb, lambdas, pb, target=False) + [elem.T for elem in Tb], Csb, lambdas, pb, target=False + ) Xb = ot.gromov.update_barycenter_feature( - [elem.T for elem in Tb], Ysb, lambdas, pb, target=False) + [elem.T for elem in Tb], Ysb, lambdas, pb, target=False + ) Cbt = ot.gromov.update_barycenter_structure( - Tb, Csb, lambdas, None, target=True, check_zeros=False) + Tb, Csb, lambdas, None, target=True, check_zeros=False + ) Xbt = ot.gromov.update_barycenter_feature( - Tb, Ysb, lambdas, None, target=True, check_zeros=False) + Tb, Ysb, lambdas, None, target=True, check_zeros=False + ) np.testing.assert_allclose(Cb, Cbt) np.testing.assert_allclose(Xb, Xbt) @@ -60,18 +62,20 @@ def test_update_barycenter(nx): # test not supported metrics with pytest.raises(ValueError): Cbt = ot.gromov.update_barycenter_structure( - Tb, Csb, lambdas, None, loss_fun='unknown', target=True) + Tb, Csb, lambdas, None, loss_fun="unknown", target=True + ) with pytest.raises(ValueError): Xbt = ot.gromov.update_barycenter_feature( - Tb, Ysb, lambdas, None, loss_fun='unknown', target=True) + Tb, Ysb, lambdas, None, loss_fun="unknown", target=True + ) def test_semirelaxed_init_plan(nx): ns = 5 nt = 10 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + Xs, ys = ot.datasets.make_data_classif("3gauss", ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif("3gauss2", nt, random_state=42) rng = np.random.RandomState(42) ys = rng.randn(Xs.shape[0], 2) @@ -88,11 +92,11 @@ def test_semirelaxed_init_plan(nx): # test not supported method with pytest.raises(ValueError): - _ = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='unknown') + _ = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method="unknown") if sklearn_import: # tests consistency across backends with m > n - for method in ['kmeans', 'spectral']: + for method in ["kmeans", "spectral"]: T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) np.testing.assert_allclose(T, Tb) @@ -104,13 +108,13 @@ def test_semirelaxed_init_plan(nx): if networkx_import: # tests consistency across backends with m > n - T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') - Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') + T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method="fluid") + Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method="fluid") np.testing.assert_allclose(T, Tb) # tests consistency across backends with m = n - T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') - Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') + T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method="fluid") + Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method="fluid") np.testing.assert_allclose(T, Tb) @@ -129,7 +133,9 @@ def test_div_between_product(nx, divergence): np.testing.assert_allclose(res_nx, res, atol=1e-06) -@pytest.mark.parametrize("divergence, mass", itertools.product(["kl", "l2"], [True, False])) +@pytest.mark.parametrize( + "divergence, mass", itertools.product(["kl", "l2"], [True, False]) +) def test_div_to_product(nx, divergence, mass): ns = 5 nt = 10 @@ -140,10 +146,18 @@ def test_div_to_product(nx, divergence, mass): pi = 2 * a[:, None] * b[None, :] pi1, pi2 = nx.sum(pi, 1), nx.sum(pi, 0) - res = ot.gromov.div_to_product(pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=None) - res1 = ot.gromov.div_to_product(pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=nx) - res2 = ot.gromov.div_to_product(pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=None) - res3 = ot.gromov.div_to_product(pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=nx) + res = ot.gromov.div_to_product( + pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=None + ) + res1 = ot.gromov.div_to_product( + pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=nx + ) + res2 = ot.gromov.div_to_product( + pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=None + ) + res3 = ot.gromov.div_to_product( + pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=nx + ) np.testing.assert_allclose(res1, res, atol=1e-06) np.testing.assert_allclose(res2, res, atol=1e-06) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 8fec3e346..7ab1009af 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -23,20 +23,20 @@ def test_emd_1d_emd2_1d_with_weights(): u = rng.randn(n, 1) v = rng.randn(m, 1) - w_u = rng.uniform(0., 1., n) + w_u = rng.uniform(0.0, 1.0, n) w_u = w_u / w_u.sum() - w_v = rng.uniform(0., 1., m) + w_v = rng.uniform(0.0, 1.0, m) w_v = w_v / w_v.sum() - M = ot.dist(u, v, metric='sqeuclidean') + M = ot.dist(u, v, metric="sqeuclidean") G, log = ot.emd(w_u, w_v, M, log=True) wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric="sqeuclidean", log=True) wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric="sqeuclidean", log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric="euclidean", log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) @@ -51,10 +51,8 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_v, G.sum(0)) # check that an error is raised if the metric is not a Minkowski one - np.testing.assert_raises(ValueError, ot.emd_1d, - u, v, w_u, w_v, metric='cosine') - np.testing.assert_raises(ValueError, ot.emd2_1d, - u, v, w_u, w_v, metric='cosine') + np.testing.assert_raises(ValueError, ot.emd_1d, u, v, w_u, w_v, metric="cosine") + np.testing.assert_raises(ValueError, ot.emd2_1d, u, v, w_u, w_v, metric="cosine") def test_wasserstein_1d(nx): @@ -70,12 +68,13 @@ def test_wasserstein_1d(nx): xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) # test 1 : wasserstein_1d should be close to scipy W_1 implementation - np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), - wasserstein_distance(x, x, rho_u, rho_v)) + np.testing.assert_almost_equal( + wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v), + ) # test 2 : wasserstein_1d should be close to one when only translating the support - np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), - 1.) + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), 1.0) # test 3 : arrays test X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) @@ -121,7 +120,7 @@ def test_wasserstein_1d_device_tf(): res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) @@ -137,14 +136,14 @@ def test_emd_1d_emd2_1d(): u = rng.randn(n, 1) v = rng.randn(m, 1) - M = ot.dist(u, v, metric='sqeuclidean') + M = ot.dist(u, v, metric="sqeuclidean") G, log = ot.emd([], [], M, log=True) wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + G_1d, log = ot.emd_1d(u, v, [], [], metric="sqeuclidean", log=True) wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric="sqeuclidean", log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric="euclidean", log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) @@ -209,7 +208,7 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) @@ -224,13 +223,17 @@ def test_wasserstein_1d_circle(): n = 20 m = 30 rng = np.random.RandomState(0) - u = rng.rand(n,) - v = rng.rand(m,) - - w_u = rng.uniform(0., 1., n) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + w_u = rng.uniform(0.0, 1.0, n) w_u = w_u / w_u.sum() - w_v = rng.uniform(0., 1., m) + w_v = rng.uniform(0.0, 1.0, m) w_v = w_v / w_v.sum() M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) @@ -281,8 +284,12 @@ def test_wasserstein_1d_unif_circle(): m = 1000 rng = np.random.RandomState(0) - u = rng.rand(n,) - v = rng.rand(m,) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) # w_u = rng.uniform(0., 1., n) # w_u = w_u / w_u.sum() @@ -323,8 +330,12 @@ def test_binary_search_circle_log(): n = 20 m = 30 rng = np.random.RandomState(0) - u = rng.rand(n,) - v = rng.rand(m,) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) optimal_thetas = log["optimal_theta"] diff --git a/test/test_backend.py b/test/test_backend.py index 95ec3293f..435c6db8a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,4 +1,4 @@ -"""Tests for backend module """ +"""Tests for backend module""" # Author: Remi Flamary # Nicolas Courty @@ -16,7 +16,6 @@ def test_get_backend_list(): - lst = get_backend_list() assert len(lst) > 0 @@ -24,7 +23,6 @@ def test_get_backend_list(): def test_to_numpy(nx): - v = nx.zeros(10) M = nx.ones((10, 10)) @@ -46,12 +44,11 @@ def test_get_backend_invalid(): def test_get_backend(nx): - A = np.zeros((3, 2)) B = np.zeros((3, 1)) nx_np = get_backend(A) - assert nx_np.__name__ == 'numpy' + assert nx_np.__name__ == "numpy" A2, B2 = nx.from_numpy(A, B) @@ -79,7 +76,6 @@ class nx_subclass(nx.__type__): def test_convert_between_backends(nx): - A = np.zeros((3, 2)) B = np.zeros((3, 1)) @@ -98,7 +94,6 @@ def test_convert_between_backends(nx): def test_empty_backend(): - rnd = np.random.RandomState(0) M = rnd.randn(10, 3) v = rnd.randn(3) @@ -156,7 +151,7 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.isinf(M) with pytest.raises(NotImplementedError): - nx.einsum('ij->i', M) + nx.einsum("ij->i", M) with pytest.raises(NotImplementedError): nx.sort(M) with pytest.raises(NotImplementedError): @@ -279,7 +274,6 @@ def test_empty_backend(): def test_func_backends(nx): - rnd = np.random.RandomState(0) M = rnd.randn(10, 3) SquareM = rnd.randn(10, 10) @@ -297,8 +291,7 @@ def test_func_backends(nx): lst_tot = [] for nx in [ot.backend.NumpyBackend(), nx]: - - print('Backend: ', nx.__name__) + print("Backend: ", nx.__name__) lst_b = [] lst_name = [] @@ -319,283 +312,288 @@ def test_func_backends(nx): A = nx.set_gradients(val, v, v) lst_b.append(nx.to_numpy(A)) - lst_name.append('set_gradients') + lst_name.append("set_gradients") A = nx.detach(Mb) A, B = nx.detach(Mb, Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('detach') + lst_name.append("detach") A = nx.zeros((10, 3)) A = nx.zeros((10, 3), type_as=Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('zeros') + lst_name.append("zeros") A = nx.ones((10, 3)) A = nx.ones((10, 3), type_as=Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('ones') + lst_name.append("ones") A = nx.arange(10, 1, 2) lst_b.append(nx.to_numpy(A)) - lst_name.append('arange') + lst_name.append("arange") A = nx.full((10, 3), 3.14) A = nx.full((10, 3), 3.14, type_as=Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('full') + lst_name.append("full") A = nx.eye(10, 3) A = nx.eye(10, 3, type_as=Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('eye') + lst_name.append("eye") A = nx.sum(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('sum') + lst_name.append("sum") A = nx.sum(Mb, axis=1, keepdims=True) lst_b.append(nx.to_numpy(A)) - lst_name.append('sum(axis)') + lst_name.append("sum(axis)") A = nx.cumsum(Mb, 0) lst_b.append(nx.to_numpy(A)) - lst_name.append('cumsum(axis)') + lst_name.append("cumsum(axis)") A = nx.max(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('max') + lst_name.append("max") A = nx.max(Mb, axis=1, keepdims=True) lst_b.append(nx.to_numpy(A)) - lst_name.append('max(axis)') + lst_name.append("max(axis)") A = nx.min(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('min') + lst_name.append("min") A = nx.min(Mb, axis=1, keepdims=True) lst_b.append(nx.to_numpy(A)) - lst_name.append('min(axis)') + lst_name.append("min(axis)") A = nx.maximum(vb, 0) lst_b.append(nx.to_numpy(A)) - lst_name.append('maximum') + lst_name.append("maximum") A = nx.minimum(vb, 0) lst_b.append(nx.to_numpy(A)) - lst_name.append('minimum') + lst_name.append("minimum") A = nx.sign(vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('sign') + lst_name.append("sign") A = nx.abs(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('abs') + lst_name.append("abs") A = nx.log(A) lst_b.append(nx.to_numpy(A)) - lst_name.append('log') + lst_name.append("log") A = nx.exp(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('exp') + lst_name.append("exp") A = nx.sqrt(nx.abs(Mb)) lst_b.append(nx.to_numpy(A)) - lst_name.append('sqrt') + lst_name.append("sqrt") A = nx.power(Mb, 2) lst_b.append(nx.to_numpy(A)) - lst_name.append('power') + lst_name.append("power") A = nx.dot(vb, vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('dot(v,v)') + lst_name.append("dot(v,v)") A = nx.dot(Mb, vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('dot(M,v)') + lst_name.append("dot(M,v)") A = nx.dot(Mb, Mb.T) lst_b.append(nx.to_numpy(A)) - lst_name.append('dot(M,M)') + lst_name.append("dot(M,M)") A = nx.norm(vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('norm') + lst_name.append("norm") A = nx.norm(Mb, axis=1) lst_b.append(nx.to_numpy(A)) - lst_name.append('norm(M,axis=1)') + lst_name.append("norm(M,axis=1)") A = nx.norm(Mb, axis=1, keepdims=True) lst_b.append(nx.to_numpy(A)) - lst_name.append('norm(M,axis=1,keepdims=True)') + lst_name.append("norm(M,axis=1,keepdims=True)") A = nx.any(vb > 0) lst_b.append(nx.to_numpy(A)) - lst_name.append('any') + lst_name.append("any") A = nx.isnan(vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('isnan') + lst_name.append("isnan") A = nx.isinf(vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('isinf') + lst_name.append("isinf") - A = nx.einsum('ij->i', Mb) + A = nx.einsum("ij->i", Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('einsum(ij->i)') + lst_name.append("einsum(ij->i)") - A = nx.einsum('ij,j->i', Mb, vb) + A = nx.einsum("ij,j->i", Mb, vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('nx.einsum(ij,j->i)') + lst_name.append("nx.einsum(ij,j->i)") - A = nx.einsum('ij->i', Mb) + A = nx.einsum("ij->i", Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('nx.einsum(ij->i)') + lst_name.append("nx.einsum(ij->i)") A = nx.sort(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('sort') + lst_name.append("sort") A = nx.argsort(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('argsort') + lst_name.append("argsort") tmp = nx.sort(Mb) - A = nx.searchsorted(tmp, tmp, 'right') + A = nx.searchsorted(tmp, tmp, "right") lst_b.append(nx.to_numpy(A)) - lst_name.append('searchsorted') + lst_name.append("searchsorted") A = nx.flip(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('flip') + lst_name.append("flip") A = nx.outer(vb, vb) lst_b.append(nx.to_numpy(A)) - lst_name.append('outer') + lst_name.append("outer") A = nx.clip(vb, 0, 1) lst_b.append(nx.to_numpy(A)) - lst_name.append('clip') + lst_name.append("clip") A = nx.repeat(Mb, 0) A = nx.repeat(Mb, 2, -1) lst_b.append(nx.to_numpy(A)) - lst_name.append('repeat') + lst_name.append("repeat") A = nx.take_along_axis(vb, nx.arange(3), -1) lst_b.append(nx.to_numpy(A)) - lst_name.append('take_along_axis') + lst_name.append("take_along_axis") A = nx.concatenate((Mb, Mb), -1) lst_b.append(nx.to_numpy(A)) - lst_name.append('concatenate') + lst_name.append("concatenate") A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)]) lst_b.append(nx.to_numpy(A)) - lst_name.append('zero_pad') + lst_name.append("zero_pad") A = nx.argmax(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('argmax') + lst_name.append("argmax") A = nx.argmin(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('argmin') + lst_name.append("argmin") A = nx.mean(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('mean') + lst_name.append("mean") A = nx.median(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('median') + lst_name.append("median") A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('std') + lst_name.append("std") A = nx.linspace(0, 1, 50) A = nx.linspace(0, 1, 50, type_as=Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('linspace') + lst_name.append("linspace") X, Y = nx.meshgrid(vb, vb) lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)])) - lst_name.append('meshgrid') + lst_name.append("meshgrid") A = nx.diag(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('diag2D') + lst_name.append("diag2D") A = nx.diag(vb, 1) lst_b.append(nx.to_numpy(A)) - lst_name.append('diag1D') + lst_name.append("diag1D") A = nx.unique(nx.from_numpy(np.stack([M, M]))) lst_b.append(nx.to_numpy(A)) - lst_name.append('unique') + lst_name.append("unique") - A, A2 = nx.unique(nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True) + A, A2 = nx.unique( + nx.from_numpy(np.stack([M, M]).reshape(-1)), return_inverse=True + ) lst_b.append(nx.to_numpy(A)) - lst_name.append('unique(M,return_inverse=True)[0]') + lst_name.append("unique(M,return_inverse=True)[0]") lst_b.append(nx.to_numpy(A2)) - lst_name.append('unique(M,return_inverse=True)[1]') + lst_name.append("unique(M,return_inverse=True)[1]") A = nx.logsumexp(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('logsumexp') + lst_name.append("logsumexp") A = nx.stack([Mb, Mb]) lst_b.append(nx.to_numpy(A)) - lst_name.append('stack') + lst_name.append("stack") A = nx.reshape(Mb, (5, 3, 2)) lst_b.append(nx.to_numpy(A)) - lst_name.append('reshape') + lst_name.append("reshape") sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4)) nx.todense(Mb) lst_b.append(nx.to_numpy(nx.todense(sp_Mb))) - lst_name.append('coo_matrix') + lst_name.append("coo_matrix") - assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)' - assert nx.issparse(sp_Mb) or nx.__name__ in ("jax", "tf"), 'Assert fail on: issparse (expected True)' + assert not nx.issparse(Mb), "Assert fail on: issparse (expected False)" + assert nx.issparse(sp_Mb) or nx.__name__ in ( + "jax", + "tf", + ), "Assert fail on: issparse (expected True)" A = nx.tocsr(sp_Mb) lst_b.append(nx.to_numpy(nx.todense(A))) - lst_name.append('tocsr') + lst_name.append("tocsr") - A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.) + A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('eliminate_zeros (dense)') + lst_name.append("eliminate_zeros (dense)") A = nx.eliminate_zeros(sp_Mb) lst_b.append(nx.to_numpy(nx.todense(A))) - lst_name.append('eliminate_zeros (sparse)') + lst_name.append("eliminate_zeros (sparse)") A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('where (cond, x, y)') + lst_name.append("where (cond, x, y)") A = nx.where(nx.from_numpy(np.array([True, False]))) lst_b.append(nx.to_numpy(nx.stack(A))) - lst_name.append('where (cond)') + lst_name.append("where (cond)") A = nx.copy(Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('copy') + lst_name.append("copy") - assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)' - assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)' + assert nx.allclose(Mb, Mb), "Assert fail on: allclose (expected True)" + assert not nx.allclose(2 * Mb, Mb), "Assert fail on: allclose (expected False)" A = nx.squeeze(nx.zeros((3, 1, 4, 1))) - assert tuple(A.shape) == (3, 4), 'Assert fail on: squeeze' + assert tuple(A.shape) == (3, 4), "Assert fail on: squeeze" A = nx.bitsize(Mb) lst_b.append(float(A)) @@ -608,15 +606,15 @@ def test_func_backends(nx): A = nx.solve(SquareMb, Mb) lst_b.append(nx.to_numpy(A)) - lst_name.append('solve') + lst_name.append("solve") A = nx.trace(SquareMb) lst_b.append(nx.to_numpy(A)) - lst_name.append('trace') + lst_name.append("trace") A = nx.inv(SquareMb) lst_b.append(nx.to_numpy(A)) - lst_name.append('matrix inverse') + lst_name.append("matrix inverse") A = nx.sqrtm(SquareMb.T @ SquareMb) lst_b.append(nx.to_numpy(A)) @@ -711,12 +709,11 @@ def test_func_backends(nx): for a1, a2, name in zip(lst_np, lst_b, lst_name): np.testing.assert_allclose( - a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}' + a2, a1, atol=1e-7, err_msg=f"ASSERT FAILED ON: {name}" ) def test_random_backends(nx): - tmp_u = nx.rand() assert tmp_u < 1 @@ -745,14 +742,12 @@ def test_random_backends(nx): def test_gradients_backends(): - rnd = np.random.RandomState(0) v = rnd.randn(10) c = rnd.randn() e = rnd.randn() if torch: - nx = ot.backend.TorchBackend() v2 = torch.tensor(v, requires_grad=True) @@ -770,24 +765,26 @@ def test_gradients_backends(): if jax: nx = ot.backend.JaxBackend() with jax.checking_leaks(): + def fun(a, b, d): - val = b * nx.sum(a ** 4) + d + val = b * nx.sum(a**4) + d return nx.set_gradients(val, (a, b, d), (a, b, 2 * d)) + grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e) - np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4) + np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v**4) + e, decimal=4) np.testing.assert_allclose(grad_val[0], v, atol=1e-4) np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4) if tf: nx = ot.backend.TensorflowBackend() - w = tf.Variable(tf.random.normal((3, 2)), name='w') - b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name='b') + w = tf.Variable(tf.random.normal((3, 2)), name="w") + b = tf.Variable(tf.random.normal((2,), dtype=tf.float32), name="b") x = tf.random.normal((1, 3), dtype=tf.float32) with tf.GradientTape() as tape: y = x @ w + b - loss = tf.reduce_mean(y ** 2) + loss = tf.reduce_mean(y**2) manipulated_loss = nx.set_gradients(loss, (w, b), (w, b)) [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b]) assert nx.allclose(dl_dw, w) @@ -797,6 +794,6 @@ def fun(a, b, d): def test_get_backend_none(): a, b = np.zeros((2, 3)), None nx = get_backend(a, b) - assert str(nx) == 'numpy' + assert str(nx) == "numpy" with pytest.raises(ValueError): get_backend(None, None) diff --git a/test/test_bregman.py b/test/test_bregman.py index 1a92a1037..6c0c0e8f2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1,4 +1,4 @@ -"""Tests for module bregman on OT with bregman projections """ +"""Tests for module bregman on OT with bregman projections""" # Author: Remi Flamary # Kilian Fatras @@ -32,19 +32,23 @@ def test_sinkhorn(verbose, warn): G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn with pytest.warns(UserWarning): ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) +@pytest.mark.parametrize( + "method", + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log", + ], +) def test_convergence_warning(method): # test sinkhorn n = 100 @@ -54,24 +58,26 @@ def test_convergence_warning(method): M = ot.utils.dist0(n) with pytest.warns(UserWarning): - ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + ot.sinkhorn(a1, a2, M, 1.0, method=method, stopThr=0, numItermax=1) if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, - stopThr=0, numItermax=1, warn=True) + ot.sinkhorn2( + a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True + ) with warnings.catch_warnings(): warnings.simplefilter("error") - ot.sinkhorn2(a1, a2, M, 1, method=method, - stopThr=0, numItermax=1, warn=False) + ot.sinkhorn2( + a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False + ) def test_not_implemented_method(): # test sinkhorn w = 10 - n = w ** 2 + n = w**2 rng = np.random.RandomState(42) A_img = rng.rand(2, w, w) A_flat = A_img.reshape(n, 2) @@ -86,14 +92,13 @@ def test_not_implemented_method(): with pytest.raises(ValueError): ot.barycenter(A_flat, M_flat, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.barycenter_debiased(A_flat, M_flat, reg, - method=not_implemented) + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d(A_img, reg, - method=not_implemented) + ot.bregman.convolutional_barycenter2d(A_img, reg, method=not_implemented) with pytest.raises(ValueError): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, - method=not_implemented) + ot.bregman.convolutional_barycenter2d_debiased( + A_img, reg, method=not_implemented + ) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) @@ -119,14 +124,17 @@ def test_sinkhorn_stabilization(): reg = 1e-5 loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") - np.testing.assert_allclose( - loss1, loss2, atol=1e-06) # cf convergence sinkhorn + np.testing.assert_allclose(loss1, loss2, atol=1e-06) # cf convergence sinkhorn -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 @@ -140,14 +148,16 @@ def test_sinkhorn_multi_b(method, verbose, warn): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, - log=True) + loss0, log = ot.sinkhorn(u, b, M, 0.1, method=method, stopThr=1e-10, log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, - verbose=verbose, warn=warn) for k in range(3)] + loss = [ + ot.sinkhorn2( + u, b[:, k], M, 0.1, method=method, stopThr=1e-10, verbose=verbose, warn=warn + ) + for k in range(3) + ] # check constraints - np.testing.assert_allclose( - loss0, loss, atol=1e-4) # cf convergence sinkhorn + np.testing.assert_allclose(loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -202,7 +212,6 @@ def test_sinkhorn2_gradients(): M = ot.dist(x, y) if torch: - a1 = torch.tensor(a, requires_grad=True) b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) @@ -226,8 +235,9 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", - verbose=True, log=True) + G, log = ot.sinkhorn( + [], [], M, 1, stopThr=1e-10, method="sinkhorn_log", verbose=True, log=True + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) @@ -237,24 +247,39 @@ def test_sinkhorn_empty(): np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) - G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, - method='sinkhorn_stabilized', verbose=True, log=True) + G, log = ot.sinkhorn( + [], + [], + M, + 1, + stopThr=1e-10, + method="sinkhorn_stabilized", + verbose=True, + log=True, + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) G, log = ot.sinkhorn( - [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', - verbose=True, log=True) + [], + [], + M, + 1, + stopThr=1e-10, + method="sinkhorn_epsilon_scaling", + verbose=True, + log=True, + ) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # test empty weights greenkhorn - ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) + ot.sinkhorn([], [], M, 1, method="greenkhorn", stopThr=1e-10, log=True) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn_variants(nx): # test sinkhorn @@ -268,17 +293,18 @@ def test_sinkhorn_variants(nx): ub, M_nx = nx.from_numpy(u, M) - G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) - Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn( - ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) + G = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) + Ges = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="sinkhorn_epsilon_scaling", stopThr=1e-10) + ) + G_green = nx.to_numpy( + ot.sinkhorn(ub, ub, M_nx, 1, method="greenkhorn", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -288,14 +314,40 @@ def test_sinkhorn_variants(nx): np.testing.assert_allclose(G0, G_green, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", - "sinkhorn_epsilon_scaling", - "greenkhorn", - "sinkhorn_log"]) -@pytest.skip_arg(("nx", "method"), ("tf", "sinkhorn_epsilon_scaling"), reason="tf does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("tf", "greenkhorn"), reason="tf does not support greenkhorn", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "sinkhorn_epsilon_scaling"), reason="jax does not support sinkhorn_epsilon_scaling", getter=str) -@pytest.skip_arg(("nx", "method"), ("jax", "greenkhorn"), reason="jax does not support greenkhorn", getter=str) +@pytest.mark.parametrize( + "method", + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log", + ], +) +@pytest.skip_arg( + ("nx", "method"), + ("tf", "sinkhorn_epsilon_scaling"), + reason="tf does not support sinkhorn_epsilon_scaling", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("tf", "greenkhorn"), + reason="tf does not support greenkhorn", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("jax", "sinkhorn_epsilon_scaling"), + reason="jax does not support sinkhorn_epsilon_scaling", + getter=str, +) +@pytest.skip_arg( + ("nx", "method"), + ("jax", "greenkhorn"), + reason="jax does not support greenkhorn", + getter=str, +) def test_sinkhorn_variants_dtype_device(nx, method): n = 100 @@ -361,11 +413,11 @@ def test_sinkhorn2_variants_device_tf(method): nx.assert_same_dtype_device(Mb, lossb) # Check this only if GPU is available - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: assert nx.dtype_device(Gb)[1].startswith("GPU") -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn_variants_multi_b(nx): # test sinkhorn @@ -382,13 +434,12 @@ def test_sinkhorn_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn( - ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + G = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -396,7 +447,7 @@ def test_sinkhorn_variants_multi_b(nx): np.testing.assert_allclose(G0, Gs, atol=1e-05) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("jax") def test_sinkhorn2_variants_multi_b(nx): # test sinkhorn @@ -413,13 +464,14 @@ def test_sinkhorn2_variants_multi_b(nx): ub, bb, M_nx = nx.from_numpy(u, b, M) - G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2( - ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + G = ot.sinkhorn2(u, b, M, 1, method="sinkhorn", stopThr=1e-10) + Gl = nx.to_numpy( + ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_log", stopThr=1e-10) + ) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn", stopThr=1e-10)) + Gs = nx.to_numpy( + ot.sinkhorn2(ub, bb, M_nx, 1, method="sinkhorn_stabilized", stopThr=1e-10) + ) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -437,16 +489,23 @@ def test_sinkhorn_variants_log(): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', - stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) + G0, log0 = ot.sinkhorn(u, u, M, 1, method="sinkhorn", stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, u, M, 1, method="sinkhorn_log", stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + u, u, M, 1, method="sinkhorn_stabilized", stopThr=1e-10, log=True + ) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) + u, + u, + M, + 1, + method="sinkhorn_epsilon_scaling", + stopThr=1e-10, + log=True, + ) G_green, loggreen = ot.sinkhorn( - u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + u, u, M, 1, method="greenkhorn", stopThr=1e-10, log=True + ) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) @@ -468,21 +527,43 @@ def test_sinkhorn_variants_log_multib(verbose, warn): M = ot.dist(x, x) - G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', - stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, - verbose=verbose, warn=warn) + G0, log0 = ot.sinkhorn(u, b, M, 1, method="sinkhorn", stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn( + u, + b, + M, + 1, + method="sinkhorn_log", + stopThr=1e-10, + log=True, + verbose=verbose, + warn=warn, + ) + Gs, logs = ot.sinkhorn( + u, + b, + M, + 1, + method="sinkhorn_stabilized", + stopThr=1e-10, + log=True, + verbose=verbose, + warn=warn, + ) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -509,9 +590,11 @@ def test_barycenter(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, weights, method=method, verbose=verbose, warn=warn) + A, M, reg, weights, method=method, verbose=verbose, warn=warn + ) bary_wass, _ = ot.bregman.barycenter( - A_nx, M_nx, reg, weights_nx, method=method, log=True) + A_nx, M_nx, reg, weights_nx, method=method, log=True + ) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -522,33 +605,39 @@ def test_barycenter(nx, method, verbose, warn): def test_free_support_sinkhorn_barycenter(): measures_locations = [ - np.array([-1.]).reshape((1, 1)), # First dirac support - np.array([1.]).reshape((1, 1)) # Second dirac support + np.array([-1.0]).reshape((1, 1)), # First dirac support + np.array([1.0]).reshape((1, 1)), # Second dirac support ] measures_weights = [ - np.array([1.]), # First dirac sample weights - np.array([1.]) # Second dirac sample weights + np.array([1.0]), # First dirac sample weights + np.array([1.0]), # Second dirac sample weights ] # Barycenter initialization - X_init = np.array([-12.]).reshape((1, 1)) + X_init = np.array([-12.0]).reshape((1, 1)) # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter - bar_locations = np.array([0.]).reshape((1, 1)) + bar_locations = np.array([0.0]).reshape((1, 1)) # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization # term to 1, but this should be, in general, fine-tuned to the problem. X = ot.bregman.free_support_sinkhorn_barycenter( - measures_locations, measures_weights, X_init, reg=1) + measures_locations, measures_weights, X_init, reg=1 + ) # Verifies if calculated barycenter matches ground-truth np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product( + ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], + [True, False], + ), +) def test_barycenter_assymetric_cost(nx, method, verbose, warn): n_bins = 20 # nb bins @@ -572,9 +661,9 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): else: # wasserstein bary_wass_np = ot.bregman.barycenter( - A, M, reg, method=method, verbose=verbose, warn=warn) - bary_wass, _ = ot.bregman.barycenter( - A_nx, M_nx, reg, method=method, log=True) + A, M, reg, method=method, verbose=verbose, warn=warn + ) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -583,9 +672,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) -@pytest.mark.parametrize("method, verbose, warn", - product(["sinkhorn", "sinkhorn_log"], - [True, False], [True, False])) +@pytest.mark.parametrize( + "method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False]), +) def test_barycenter_debiased(nx, method, verbose, warn): n_bins = 100 # nb bins @@ -609,26 +699,26 @@ def test_barycenter_debiased(nx, method, verbose, warn): reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): - ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, weights, method=method) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) else: - bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, - verbose=verbose, warn=warn) + bary_wass_np = ot.bregman.barycenter_debiased( + A, M, reg, weights, method=method, verbose=verbose, warn=warn + ) bary_wass, _ = ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, weights_nx, method=method, log=True) + A_nx, M_nx, reg, weights_nx, method=method, log=True + ) bary_wass = nx.to_numpy(bary_wass) np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) - ot.bregman.barycenter_debiased( - A_nx, M_nx, reg, log=True, verbose=False) + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_convergence_warning_barycenters(method): w = 10 - n_bins = w ** 2 # nb bins + n_bins = w**2 # nb bins # Gaussian distributions a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std @@ -647,16 +737,17 @@ def test_convergence_warning_barycenters(method): weights = np.array([1 - alpha, alpha]) reg = 0.1 with pytest.warns(UserWarning): - ot.bregman.barycenter_debiased( - A, M, reg, weights, method=method, numItermax=1) + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d(A_img, reg, weights, - method=method, numItermax=1) + ot.bregman.convolutional_barycenter2d( + A_img, reg, weights, method=method, numItermax=1 + ) with pytest.warns(UserWarning): - ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, - method=method, numItermax=1) + ot.bregman.convolutional_barycenter2d_debiased( + A_img, reg, weights, method=method, numItermax=1 + ) def test_barycenter_stabilization(nx): @@ -681,15 +772,24 @@ def test_barycenter_stabilization(nx): # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter( - A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) - bar_stable = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", - stopThr=1e-8, verbose=True - )) - bar = nx.to_numpy(ot.bregman.barycenter( - A_nx, M_nx, reg, weights_b, method="sinkhorn", - stopThr=1e-8, verbose=True - )) + A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True + ) + bar_stable = nx.to_numpy( + ot.bregman.barycenter( + A_nx, + M_nx, + reg, + weights_b, + method="sinkhorn_stabilized", + stopThr=1e-8, + verbose=True, + ) + ) + bar = nx.to_numpy( + ot.bregman.barycenter( + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True + ) + ) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) @@ -743,6 +843,7 @@ def test_wasserstein_bary_2d(nx, method): ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) +@pytest.skip_backend("tf") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_dtype_device(nx, method): # Create the array of images to test @@ -874,6 +975,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) +@pytest.skip_backend("tf") @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): # Create the array of images to test @@ -1002,15 +1104,13 @@ def test_unmix(nx): # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix( - ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, - 1, alpha=0.01, log=True, verbose=True) + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) def test_empirical_sinkhorn(nx): @@ -1022,7 +1122,7 @@ def test_empirical_sinkhorn(nx): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') + M_m = ot.dist(X_s, X_t, metric="euclidean") ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) @@ -1034,32 +1134,32 @@ def test_empirical_sinkhorn(nx): sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) - G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, metric='euclidean')) + G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric="euclidean")) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) - loss_emp_sinkhorn = nx.to_numpy( - ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) + loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 + ) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 + ) # metric sqeuclidian + np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) @pytest.mark.skipif(not geomloss, reason="pytorch not installed") -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) @@ -1076,8 +1176,10 @@ def test_geomloss_solver(nx, metric): G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric=metric)) - value, log = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric, log=True) - G_geomloss = nx.to_numpy(log['lazy_plan'][:]) + value, log = ot.bregman.empirical_sinkhorn2_geomloss( + X_sb, X_tb, 1, metric=metric, log=True + ) + G_geomloss = nx.to_numpy(log["lazy_plan"][:]) print(value) @@ -1089,7 +1191,7 @@ def test_geomloss_solver(nx, metric): # check error on wrong metric with pytest.raises(ValueError): - ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric='wrong_metric') + ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric="wrong_metric") def test_lazy_empirical_sinkhorn(nx): @@ -1102,51 +1204,70 @@ def test_lazy_empirical_sinkhorn(nx): X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) - M_m = ot.dist(X_s, X_t, metric='euclidean') + M_m = ot.dist(X_s, X_t, metric="euclidean") ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) + X_sb, + X_tb, + 1, + numIterMax=numIterMax, + isLazy=True, + batchSize=(1, 3), + verbose=True, + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 1) sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) + X_sb, + X_tb, + 1, + metric="euclidean", + numIterMax=numIterMax, + isLazy=True, + batchSize=1, + ) f, g = nx.to_numpy(f), nx.to_numpy(g) G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) - G_lazy = nx.to_numpy(log['lazy_plan'][:]) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True + ) + G_lazy = nx.to_numpy(log["lazy_plan"][:]) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False + ) # check constraints np.testing.assert_allclose( - sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 + ) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05 + ) # metric sqeuclidian + np.testing.assert_allclose(sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose(sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( - sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose( - sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05 + ) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) np.testing.assert_allclose(G_log, G_lazy, atol=1e-05) @@ -1163,27 +1284,27 @@ def test_empirical_sinkhorn_divergence(nx): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy( - a, b, X_s, X_t, M, M_s, M_t) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy( - ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) + ) sinkhorn_div = nx.to_numpy( ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence( - X_s, X_t, 1, a=a, b=b) + X_s, X_t, 1, a=a, b=b + ) # check constraints + np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) np.testing.assert_allclose( - emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05) - np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + emp_sinkhorn_div, sinkhorn_div, atol=1e-05 + ) # cf conv emp sinkhorn - ot.bregman.empirical_sinkhorn_divergence( - X_sb, X_tb, 1, a=ab, b=bb, log=True) + ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) @pytest.mark.skipif(not torch, reason="No torch available") @@ -1206,7 +1327,8 @@ def test_empirical_sinkhorn_divergence_gradient(): X_tb.requires_grad = True emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence( - X_sb, X_tb, 1, a=ab, b=bb) + X_sb, X_tb, 1, a=ab, b=bb + ) emp_sinkhorn_div.backward() @@ -1235,14 +1357,12 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab, bb, M_nx = nx.from_numpy(a, b, M) - G_np, _ = ot.bregman.sinkhorn( - a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, - method="sinkhorn_stabilized", - log=True) + G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) + G, log = ot.bregman.sinkhorn( + ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True + ) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, - method="sinkhorn", log=True) + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) np.testing.assert_allclose(G_np, G2) @@ -1250,9 +1370,9 @@ def test_stabilized_vs_sinkhorn_multidim(nx): def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] - NOT_VALID_TOKENS = ['foo'] + IMPLEMENTED_METHODS = ["sinkhorn", "sinkhorn_stabilized"] + ONLY_1D_methods = ["greenkhorn", "sinkhorn_epsilon_scaling"] + NOT_VALID_TOKENS = ["foo"] # test generalized sinkhorn for unbalanced OT barycenter n = 3 rng = np.random.RandomState(42) @@ -1282,7 +1402,7 @@ def test_implemented_methods(): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.filterwarnings("ignore:Bottleneck") @@ -1301,8 +1421,9 @@ def test_screenkhorn(nx): # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn( - ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) + G_screen = nx.to_numpy( + ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True) + ) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) @@ -1338,16 +1459,21 @@ def test_sinkhorn_warmstart(): # Optimal plan with uniform warmstart pi_unif, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=None) + a, b, M, reg, method="sinkhorn", log=True, warmstart=None + ) # Optimal plan with warmstart generated from unregularized OT pi_sh, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart + ) pi_sh_log, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart + ) pi_sh_stab, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart + ) pi_sh_sc, _ = ot.bregman.sinkhorn( - a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart) + a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart + ) np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05) @@ -1371,14 +1497,17 @@ def test_empirical_sinkhorn_warmstart(): # Optimal plan with uniform warmstart f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None + ) pi_unif = np.exp(f[:, None] + g[None, :] - M / reg) # Optimal plan with warmstart generated from unregularized OT f, g, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart + ) pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg) pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart + ) np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05) np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05) @@ -1400,12 +1529,15 @@ def test_empirical_sinkhorn_divergence_warmstart(): # Optimal plan with uniform warmstart sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None + ) # Optimal plan with warmstart generated from unregularized OT sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart + ) sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence( - X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart) + X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart + ) np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05) np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05) diff --git a/test/test_coot.py b/test/test_coot.py index 853ca793c..a66a32380 100644 --- a/test/test_coot.py +++ b/test/test_coot.py @@ -1,4 +1,4 @@ -"""Tests for module COOT on OT """ +"""Tests for module COOT on OT""" # Author: Quang Huy Tran # @@ -18,8 +18,7 @@ def test_coot(nx, verbose): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) @@ -66,8 +65,7 @@ def test_entropic_coot(nx): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) @@ -78,7 +76,8 @@ def test_entropic_coot(nx): # test couplings pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) pi_sample_nx, pi_feature_nx = coot( - X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot) + X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot + ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -102,8 +101,7 @@ def test_entropic_coot(nx): # test entropic COOT distance coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) - coot_nx = nx.to_numpy( - coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)) np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08) @@ -114,8 +112,7 @@ def test_coot_with_linear_terms(nx): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) @@ -132,10 +129,10 @@ def test_coot_with_linear_terms(nx): anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples id_feature = np.eye(2, 2) / 2 - pi_sample, pi_feature = coot( - X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + pi_sample, pi_feature = coot(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) pi_sample_nx, pi_feature_nx = coot( - X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx) + X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx + ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -162,7 +159,8 @@ def test_coot_with_linear_terms(nx): coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) coot_nx = nx.to_numpy( - coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)) + coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx) + ) np.testing.assert_allclose(coot_np, 0, atol=1e-08) np.testing.assert_allclose(coot_nx, 0, atol=1e-08) @@ -173,8 +171,7 @@ def test_coot_raise_value_error(nx): mu_s = np.array([2, 4]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=43) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=43) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) @@ -216,8 +213,7 @@ def test_coot_warmstart(nx): mu_s = np.array([2, 3]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=125) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=125) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) @@ -232,34 +228,35 @@ def test_coot_warmstart(nx): init_pi_feature /= init_pi_feature / np.sum(init_pi_feature) init_pi_feature_nx = nx.from_numpy(init_pi_feature) - init_duals_sample = (rng.random(n_samples) * 2 - 1, - rng.random(n_samples) * 2 - 1) - init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]), - nx.from_numpy(init_duals_sample[1])) + init_duals_sample = (rng.random(n_samples) * 2 - 1, rng.random(n_samples) * 2 - 1) + init_duals_sample_nx = ( + nx.from_numpy(init_duals_sample[0]), + nx.from_numpy(init_duals_sample[1]), + ) - init_duals_feature = (rng.random(2) * 2 - 1, - rng.random(2) * 2 - 1) - init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]), - nx.from_numpy(init_duals_feature[1])) + init_duals_feature = (rng.random(2) * 2 - 1, rng.random(2) * 2 - 1) + init_duals_feature_nx = ( + nx.from_numpy(init_duals_feature[0]), + nx.from_numpy(init_duals_feature[1]), + ) warmstart = { "pi_sample": init_pi_sample, "pi_feature": init_pi_feature, "duals_sample": init_duals_sample, - "duals_feature": init_duals_feature + "duals_feature": init_duals_feature, } warmstart_nx = { "pi_sample": init_pi_sample_nx, "pi_feature": init_pi_feature_nx, "duals_sample": init_duals_sample_nx, - "duals_feature": init_duals_feature_nx + "duals_feature": init_duals_feature_nx, } # test couplings pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart) - pi_sample_nx, pi_feature_nx = coot( - X=xs_nx, Y=xt_nx, warmstart=warmstart_nx) + pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -298,8 +295,7 @@ def test_coot_log(nx): mu_s = np.array([-2, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=43) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=43) xt = xs[::-1].copy() xs_nx = nx.from_numpy(xs) xt_nx = nx.from_numpy(xt) diff --git a/test/test_da.py b/test/test_da.py index d3c343242..693c0dff7 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -1,4 +1,4 @@ -"""Tests for module da on Domain Adaptation """ +"""Tests for module da on Domain Adaptation""" # Author: Remi Flamary # @@ -38,8 +38,8 @@ def test_class_jax_tf(): ns = 150 nt = 200 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) @@ -51,14 +51,22 @@ def test_class_jax_tf(): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, - ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) +@pytest.mark.parametrize( + "class_to_test", + [ + ot.da.EMDTransport, + ot.da.SinkhornTransport, + ot.da.SinkhornLpl1Transport, + ot.da.SinkhornL1l2Transport, + ot.da.SinkhornL1l2Transport, + ], +) def test_log_da(nx, class_to_test): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) @@ -71,14 +79,13 @@ def test_log_da(nx, class_to_test): @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): - """test_sinkhorn_transport - """ + """test_sinkhorn_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns, random_state=42) - Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + Xs, ys = make_data_classif("3gauss", ns, random_state=42) + Xt, yt = make_data_classif("3gauss2", nt, random_state=43) # prepare semi-supervised labels yt_semi = np.copy(yt) yt_semi[np.arange(0, nt, 2)] = -1 @@ -102,15 +109,17 @@ def test_sinkhorn_lpl1_transport_class(nx): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1, random_state=44)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1, random_state=44)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -120,7 +129,7 @@ def test_sinkhorn_lpl1_transport_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1, random_state=45)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1, random_state=45)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -143,35 +152,39 @@ def test_sinkhorn_lpl1_transport_class(nx): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - assert np.all(np.isfinite(nx.to_numpy(otda_unsup.coupling_))), "unsup coupling is finite" + assert np.all( + np.isfinite(nx.to_numpy(otda_unsup.coupling_)) + ), "unsup coupling is finite" n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi) - assert np.all(np.isfinite(nx.to_numpy(otda_semi.coupling_))), "semi coupling is finite" + assert np.all( + np.isfinite(nx.to_numpy(otda_semi.coupling_)) + ), "semi coupling is finite" assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" + assert np.allclose( + n_unsup, n_semisup, atol=1e-7 + ), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = nx.sum( - otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) + mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" @pytest.skip_backend("tf") def test_sinkhorn_l1l2_transport_class(nx): - """test_sinkhorn_transport - """ + """test_sinkhorn_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns, random_state=42) - Xt, yt = make_data_classif('3gauss2', nt, random_state=43) + Xs, ys = make_data_classif("3gauss", ns, random_state=42) + Xt, yt = make_data_classif("3gauss2", nt, random_state=43) # prepare semi-supervised labels yt_semi = np.copy(yt) yt_semi[np.arange(0, nt, 2)] = -1 @@ -194,15 +207,17 @@ def test_sinkhorn_l1l2_transport_class(nx): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -222,7 +237,7 @@ def test_sinkhorn_l1l2_transport_class(nx): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -243,13 +258,17 @@ def test_sinkhorn_l1l2_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" + assert np.allclose( + n_unsup, n_semisup, atol=1e-7 + ), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9) + assert_allclose( + nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9 + ) # check everything runs well with log=True otda = ot.da.SinkhornL1l2Transport(log=True) @@ -260,14 +279,13 @@ def test_sinkhorn_l1l2_transport_class(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_transport_class(nx): - """test_sinkhorn_transport - """ + """test_sinkhorn_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) @@ -287,15 +305,17 @@ def test_sinkhorn_transport_class(nx): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -315,7 +335,7 @@ def test_sinkhorn_transport_class(nx): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -336,12 +356,13 @@ def test_sinkhorn_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" + assert np.allclose( + n_unsup, n_semisup, atol=1e-7 + ), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = nx.sum( - otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) + mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" # check everything runs well with log=True @@ -350,37 +371,40 @@ def test_sinkhorn_transport_class(nx): assert len(otda.log_.keys()) != 0 # test diffeernt transform and inverse transform - otda = ot.da.SinkhornTransport(out_of_sample_map='ferradans') + otda = ot.da.SinkhornTransport(out_of_sample_map="ferradans") transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) # test diffeernt transform - otda = ot.da.SinkhornTransport(out_of_sample_map='continuous', method='sinkhorn') + otda = ot.da.SinkhornTransport(out_of_sample_map="continuous", method="sinkhorn") transp_Xs2 = otda.fit_transform(Xs=Xs, Xt=Xt) assert_equal(transp_Xs2.shape, Xs.shape) transp_Xt2 = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt2.shape, Xt.shape) - np.testing.assert_almost_equal(nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5) - np.testing.assert_almost_equal(nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5) + np.testing.assert_almost_equal( + nx.to_numpy(transp_Xs), nx.to_numpy(transp_Xs2), decimal=5 + ) + np.testing.assert_almost_equal( + nx.to_numpy(transp_Xt), nx.to_numpy(transp_Xt2), decimal=5 + ) with pytest.raises(ValueError): - otda = ot.da.SinkhornTransport(out_of_sample_map='unknown') + otda = ot.da.SinkhornTransport(out_of_sample_map="unknown") @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_unbalanced_sinkhorn_transport_class(nx): - """test_sinkhorn_transport - """ + """test_sinkhorn_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) @@ -415,7 +439,7 @@ def test_unbalanced_sinkhorn_transport_class(nx): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -425,7 +449,7 @@ def test_unbalanced_sinkhorn_transport_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -448,7 +472,9 @@ def test_unbalanced_sinkhorn_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" + assert np.allclose( + n_unsup, n_semisup, atol=1e-7 + ), "semisupervised mode is not working" # check everything runs well with log=True otda = ot.da.SinkhornTransport(log=True) @@ -460,14 +486,13 @@ def test_unbalanced_sinkhorn_transport_class(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_emd_transport_class(nx): - """test_sinkhorn_transport - """ + """test_sinkhorn_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) @@ -488,15 +513,17 @@ def test_emd_transport_class(nx): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -516,7 +543,7 @@ def test_emd_transport_class(nx): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -544,17 +571,19 @@ def test_emd_transport_class(nx): n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different - assert np.allclose(n_unsup, n_semisup, atol=1e-7), "semisupervised mode is not working" + assert np.allclose( + n_unsup, n_semisup, atol=1e-7 + ), "semisupervised mode is not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = nx.sum( - otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) + mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] # we need to use a small tolerance here, otherwise the test breaks - assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), - rtol=1e-2, atol=1e-2) + assert_allclose( + nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-2, atol=1e-2 + ) @pytest.skip_backend("jax") @@ -562,15 +591,14 @@ def test_emd_transport_class(nx): @pytest.mark.parametrize("kernel", ["linear", "gaussian"]) @pytest.mark.parametrize("bias", ["unbiased", "biased"]) def test_mapping_transport_class(nx, kernel, bias): - """test_mapping_transport - """ + """test_mapping_transport""" ns = 20 nt = 30 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) + Xs_new, _ = make_data_classif("3gauss", ns + 1) Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new) @@ -592,9 +620,11 @@ def test_mapping_transport_class(nx, kernel, bias): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -618,8 +648,8 @@ def test_mapping_transport_class_specific_seed(nx): ns = 20 nt = 30 rng = np.random.RandomState(39) - Xs, ys = make_data_classif('3gauss', ns, random_state=rng) - Xt, yt = make_data_classif('3gauss2', nt, random_state=rng) + Xs, ys = make_data_classif("3gauss", ns, random_state=rng) + Xt, yt = make_data_classif("3gauss2", nt, random_state=rng) otda = ot.da.MappingTransport(kernel="gaussian", bias=False) otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt)) @@ -630,8 +660,8 @@ def test_linear_mapping_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xsb, Xtb = nx.from_numpy(Xs, Xt) @@ -665,8 +695,8 @@ def test_linear_gw_mapping_class(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xsb, Xtb = nx.from_numpy(Xs, Xt) @@ -690,17 +720,16 @@ def test_linear_gw_mapping_class(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_jcpot_transport_class(nx): - """test_jcpot_transport - """ + """test_jcpot_transport""" ns1 = 50 ns2 = 50 nt = 50 - Xs1, ys1 = make_data_classif('3gauss', ns1) - Xs2, ys2 = make_data_classif('3gauss', ns2) + Xs1, ys1 = make_data_classif("3gauss", ns1) + Xs2, ys2 = make_data_classif("3gauss", ns2) - Xt, yt = make_data_classif('3gauss2', nt) + Xt, yt = make_data_classif("3gauss2", nt) Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt) @@ -708,7 +737,9 @@ def test_jcpot_transport_class(nx): ys = [ys1, ys2] for log in [True, False]: - otda = ot.da.JCPOTTransport(reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=log) + otda = ot.da.JCPOTTransport( + reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=log + ) # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -727,25 +758,29 @@ def test_jcpot_transport_class(nx): for i in range(len(Xs)): # test margin constraints w.r.t. uniform target weights for each coupling matrix assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), + mu_t, + rtol=1e-3, + atol=1e-3, + ) if log: # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( nx.to_numpy( - nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) + nx.dot(otda.log_["D1"][i], nx.sum(otda.coupling_[i], axis=1)) ), nx.to_numpy(otda.proportions_), rtol=1e-3, - atol=1e-3 + atol=1e-3, ) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns1 + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -764,8 +799,7 @@ def test_jcpot_transport_class(nx): def test_jcpot_barycenter(nx): - """test_jcpot_barycenter - """ + """test_jcpot_barycenter""" ns1 = 50 ns2 = 50 @@ -773,21 +807,30 @@ def test_jcpot_barycenter(nx): sigma = 0.1 - ps1 = .2 - ps2 = .9 - pt = .4 + ps1 = 0.2 + ps2 = 0.9 + pt = 0.4 - Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) - Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) - Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + Xs1, ys1 = make_data_classif("2gauss_prop", ns1, nz=sigma, p=ps1) + Xs2, ys2 = make_data_classif("2gauss_prop", ns2, nz=sigma, p=ps2) + Xt, _ = make_data_classif("2gauss_prop", nt, nz=sigma, p=pt) Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt) Xsb = [Xs1b, Xs2b] ysb = [ys1b, ys2b] - prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean', - numItermax=10000, stopThr=1e-9, verbose=False, log=False) + prop = ot.bregman.jcpot_barycenter( + Xsb, + ysb, + Xtb, + reg=0.5, + metric="sqeuclidean", + numItermax=10000, + stopThr=1e-9, + verbose=False, + log=False, + ) np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3) @@ -796,18 +839,19 @@ def test_jcpot_barycenter(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_emd_laplace_class(nx): - """test_emd_laplace_transport - """ + """test_emd_laplace_transport""" ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) for log in [True, False]: - otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=log) + otda = ot.da.EMDLaplaceTransport( + reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=log + ) # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -823,15 +867,17 @@ def test_emd_laplace_class(nx): mu_t = unif(nt) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + Xs_new = nx.from_numpy(make_data_classif("3gauss", ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -841,7 +887,7 @@ def test_emd_laplace_class(nx): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + Xt_new = nx.from_numpy(make_data_classif("3gauss2", nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -865,7 +911,10 @@ def test_emd_laplace_class(nx): @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_nearest_brenier_potential(nx): X = nx.ones((2, 2)) - for ssnb in [ot.da.NearestBrenierPotential(log=True), ot.da.NearestBrenierPotential(log=False)]: + for ssnb in [ + ot.da.NearestBrenierPotential(log=True), + ot.da.NearestBrenierPotential(log=False), + ]: ssnb.fit(Xs=X, Xt=X) G_lu = ssnb.transform(Xs=X) # 'new' input isn't new, so should be equal to target @@ -881,23 +930,37 @@ def test_emd_laplace(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) M = ot.dist(Xs, Xt) with pytest.raises(ValueError): - ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim_param=['INVALID', 'INPUT', 2]) + ot.da.emd_laplace( + ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim_param=["INVALID", "INPUT", 2] + ) with pytest.raises(ValueError): - ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim=['INVALID', 'INPUT', 2]) + ot.da.emd_laplace( + ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim=["INVALID", "INPUT", 2] + ) # test all margin constraints with gaussian similarity and disp regularisation - coupling = ot.da.emd_laplace(ot.unif(ns, type_as=Xs), ot.unif(nt, type_as=Xs), Xs, Xt, M, sim='gauss', reg='disp') + coupling = ot.da.emd_laplace( + ot.unif(ns, type_as=Xs), + ot.unif(nt, type_as=Xs), + Xs, + Xt, + M, + sim="gauss", + reg="disp", + ) assert_allclose( - nx.to_numpy(nx.sum(coupling, axis=0)), unif(nt), rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(coupling, axis=0)), unif(nt), rtol=1e-3, atol=1e-3 + ) assert_allclose( - nx.to_numpy(nx.sum(coupling, axis=1)), unif(ns), rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(coupling, axis=1)), unif(ns), rtol=1e-3, atol=1e-3 + ) @pytest.skip_backend("jax") @@ -966,10 +1029,10 @@ def unvectorized(transp): indices_labels = [] classes = nx.unique(labels_a) for c in classes: - idxc, = nx.where(labels_a == c) + (idxc,) = nx.where(labels_a == c) indices_labels.append(idxc) W = nx.ones(M.shape, type_as=M) - for (i, c) in enumerate(classes): + for i, c in enumerate(classes): majs = nx.sum(transp[indices_labels[i]], axis=0) majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs @@ -979,7 +1042,10 @@ def vectorized(transp): labels_u, labels_idx = nx.unique(labels_a, return_inverse=True) n_labels = labels_u.shape[0] unroll_labels_idx = nx.eye(n_labels, type_as=transp)[labels_idx] - W = nx.repeat(transp.T[:, :, None], n_labels, axis=2) * unroll_labels_idx[None, :, :] + W = ( + nx.repeat(transp.T[:, :, None], n_labels, axis=2) + * unroll_labels_idx[None, :, :] + ) W = nx.sum(W, axis=1) W = p * ((W + epsilon) ** (p - 1)) W = nx.dot(W, unroll_labels_idx.T) diff --git a/test/test_dmmot.py b/test/test_dmmot.py index dcc313755..67ac985dd 100644 --- a/test/test_dmmot.py +++ b/test/test_dmmot.py @@ -1,4 +1,4 @@ -"""Tests for ot.lp.dmmot module """ +"""Tests for ot.lp.dmmot module""" # Author: Ronak Mehta # Xizheng Yu @@ -27,15 +27,17 @@ def test_dmmot_monge_1dgrid_loss(nx): primal_obj = nx.to_numpy(primal_obj) expected_primal_obj = 0.13667759626298503 - np.testing.assert_allclose(primal_obj, - expected_primal_obj, - rtol=1e-7, - err_msg="Test failed: \ - Expected different primal objective value") + np.testing.assert_allclose( + primal_obj, + expected_primal_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value", + ) # Compute loss using exact OT solver with absolute ground metric A, x = nx.to_numpy(A, x) - M = ot.utils.dist(x, metric='cityblock') # absolute ground metric + M = ot.utils.dist(x, metric="cityblock") # absolute ground metric bary, _ = ot.barycenter(A, M, 1e-2, weights=None, verbose=False, log=True) ot_obj = 0.0 for x in A.T: @@ -43,13 +45,15 @@ def test_dmmot_monge_1dgrid_loss(nx): x = np.ascontiguousarray(x) # compute loss _, log = ot.lp.emd(x, np.array(bary / np.sum(bary)), M, log=True) - ot_obj += log['cost'] + ot_obj += log["cost"] - np.testing.assert_allclose(primal_obj, - ot_obj, - rtol=1e-7, - err_msg="Test failed: \ - Expected different primal objective value") + np.testing.assert_allclose( + primal_obj, + ot_obj, + rtol=1e-7, + err_msg="Test failed: \ + Expected different primal objective value", + ) def test_dmmot_monge_1dgrid_optimize(nx): @@ -57,19 +61,22 @@ def test_dmmot_monge_1dgrid_optimize(nx): A, _ = create_test_data(nx) d = 2 niters = 10 - result = ot.lp.dmmot_monge_1dgrid_optimize(A, - niters, - lr_init=1e-3, - lr_decay=1) + result = ot.lp.dmmot_monge_1dgrid_optimize(A, niters, lr_init=1e-3, lr_decay=1) - expected_obj = np.array([[0.05553516, 0.13082618, 0.27327479, 0.54036388], - [0.04185365, 0.09570724, 0.24384705, 0.61859206]]) + expected_obj = np.array( + [ + [0.05553516, 0.13082618, 0.27327479, 0.54036388], + [0.04185365, 0.09570724, 0.24384705, 0.61859206], + ] + ) assert len(result) == d, "Test failed: Expected a list of length n" for i in range(d): - np.testing.assert_allclose(result[i], - expected_obj[i], - atol=1e-7, - rtol=1e-7, - err_msg="Test failed: \ - Expected vectors of all zeros") + np.testing.assert_allclose( + result[i], + expected_obj[i], + atol=1e-7, + rtol=1e-7, + err_msg="Test failed: \ + Expected vectors of all zeros", + ) diff --git a/test/test_dr.py b/test/test_dr.py index 3680547db..2ac026d23 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -1,4 +1,4 @@ -"""Tests for module dr on Dimensionality Reduction """ +"""Tests for module dr on Dimensionality Reduction""" # Author: Remi Flamary # Minhui Huang @@ -12,6 +12,7 @@ try: # test if autograd and pymanopt are installed import ot.dr + nogo = False except ImportError: nogo = True @@ -19,12 +20,11 @@ @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_fda(): - n_samples = 90 # nb samples in source and target datasets rng = np.random.RandomState(0) # generate gaussian dataset - xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples, random_state=rng) + xs, ys = ot.datasets.make_data_classif("gaussrot", n_samples, random_state=rng) n_features_noise = 8 @@ -41,12 +41,11 @@ def test_fda(): @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda(): - n_samples = 100 # nb samples in source and target datasets rng = np.random.RandomState(0) # generate gaussian dataset - xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples, random_state=rng) + xs, ys = ot.datasets.make_data_classif("gaussrot", n_samples, random_state=rng) n_features_noise = 8 @@ -63,12 +62,11 @@ def test_wda(): @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_low_reg(): - n_samples = 100 # nb samples in source and target datasets rng = np.random.RandomState(0) # generate gaussian dataset - xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples, random_state=rng) + xs, ys = ot.datasets.make_data_classif("gaussrot", n_samples, random_state=rng) n_features_noise = 8 @@ -76,7 +74,9 @@ def test_wda_low_reg(): p = 2 - Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') + Pwda, projwda = ot.dr.wda( + xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method="sinkhorn_log" + ) projwda(xs) @@ -85,12 +85,11 @@ def test_wda_low_reg(): @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): - n_samples = 100 # nb samples in source and target datasets rng = np.random.RandomState(0) # generate gaussian dataset - xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples, random_state=rng) + xs, ys = ot.datasets.make_data_classif("gaussrot", n_samples, random_state=rng) n_features_noise = 8 @@ -120,8 +119,8 @@ def fragmented_hypercube(n, d, dim, rng): assert dim >= 1 assert dim == int(dim) - a = (1. / n) * np.ones(n) - b = (1. / n) * np.ones(n) + a = (1.0 / n) * np.ones(n) + b = (1.0 / n) * np.ones(n) # First measure : uniform on the hypercube X = rng.uniform(-1, 1, size=(n, d)) @@ -137,17 +136,20 @@ def fragmented_hypercube(n, d, dim, rng): tau = 0.002 reg = 0.2 - pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1) + pi, U = ot.dr.projection_robust_wasserstein( + X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1 + ) U0 = rng.randn(d, k) U0, _ = np.linalg.qr(U0) - pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1) + pi, U = ot.dr.projection_robust_wasserstein( + X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1 + ) @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_ewca(): - d = 5 n_samples = 50 k = 3 @@ -164,7 +166,9 @@ def test_ewca(): assert X.shape == (n_samples, d) # compute first 3 components with BCD - pi, U = ot.dr.ewca(X, reg=0.01, method='BCD', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + pi, U = ot.dr.ewca( + X, reg=0.01, method="BCD", k=k, verbose=1, sinkhorn_method="sinkhorn_log" + ) assert pi.shape == (n_samples, n_samples) assert (pi >= 0).all() assert np.allclose(pi.sum(0), 1 / n_samples, atol=1e-3) @@ -178,7 +182,9 @@ def test_ewca(): assert np.allclose(cos, np.ones(k), atol=1e-3) # compute first 3 components with MM - pi, U = ot.dr.ewca(X, reg=0.01, method='MM', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + pi, U = ot.dr.ewca( + X, reg=0.01, method="MM", k=k, verbose=1, sinkhorn_method="sinkhorn_log" + ) assert pi.shape == (n_samples, n_samples) assert (pi >= 0).all() assert np.allclose(pi.sum(0), 1 / n_samples, atol=1e-3) @@ -192,7 +198,9 @@ def test_ewca(): assert np.allclose(cos, np.ones(k), atol=1e-3) # compute last 3 components - pi, U = ot.dr.ewca(X, reg=100000, method='MM', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + pi, U = ot.dr.ewca( + X, reg=100000, method="MM", k=k, verbose=1, sinkhorn_method="sinkhorn_log" + ) # test that U contains the last principal components U_last_eigvec = np.linalg.svd(X.T, full_matrices=False)[0][:, -k:] diff --git a/test/test_factored.py b/test/test_factored.py index 5cfc997ef..04cdc874f 100644 --- a/test/test_factored.py +++ b/test/test_factored.py @@ -1,4 +1,4 @@ -"""Tests for main module ot.weak """ +"""Tests for main module ot.weak""" # Author: Remi Flamary # @@ -28,7 +28,7 @@ def test_factored_ot(): # check constraints np.testing.assert_allclose(u, Ga.sum(1)) np.testing.assert_allclose(u, Gb.sum(0)) - np.testing.assert_allclose(1, log['lazy_plan'][:].sum()) + np.testing.assert_allclose(1, log["lazy_plan"][:].sum()) def test_factored_ot_backends(nx): diff --git a/test/test_gaussian.py b/test/test_gaussian.py index c66d5908c..eed562d15 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -18,8 +18,8 @@ def test_bures_wasserstein_mapping(nx): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) ms = np.mean(Xs, axis=0)[None, :] mt = np.mean(Xt, axis=0)[None, :] Cs = np.cov(Xs.T) @@ -27,7 +27,9 @@ def test_bures_wasserstein_mapping(nx): Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct) - A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping( + msb, mtb, Csb, Ctb, log=True + ) A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False) Xst = nx.to_numpy(nx.dot(Xsb, A) + b) @@ -45,8 +47,8 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): ns = 50 nt = 50 - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys = make_data_classif("3gauss", ns) + Xt, yt = make_data_classif("3gauss2", nt) if not bias: ms = np.mean(Xs, axis=0)[None, :] @@ -57,8 +59,12 @@ def test_empirical_bures_wasserstein_mapping(nx, bias): Xsb, Xtb = nx.from_numpy(Xs, Xt) - A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias) - A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias) + A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping( + Xsb, Xtb, log=True, bias=bias + ) + A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping( + Xsb, Xtb, log=False, bias=bias + ) Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) @@ -87,7 +93,9 @@ def test_bures_wasserstein_distance(nx): Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) - np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose( + nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2 + ) np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) @@ -101,10 +109,16 @@ def test_empirical_bures_wasserstein_distance(nx, bias): Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis] Xsb, Xtb = nx.from_numpy(Xs, Xt) - Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias) - Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias) - - np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance( + Xsb, Xtb, log=True, bias=bias + ) + Wb = ot.gaussian.empirical_bures_wasserstein_distance( + Xsb, Xtb, log=False, bias=bias + ) + + np.testing.assert_allclose( + nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2 + ) np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) @@ -116,7 +130,7 @@ def test_bures_wasserstein_barycenter(nx): m = [] C = [] for _ in range(k): - X_, y_ = make_data_classif('3gauss', n) + X_, y_ = make_data_classif("3gauss", n) m_ = np.mean(X_, axis=0)[None, :] C_ = np.cov(X_.T) X.append(X_) @@ -137,7 +151,9 @@ def test_bures_wasserstein_barycenter(nx): # Test weights argument weights = nx.ones(k) / k - mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False) + mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter( + m, C, weights=weights, log=False + ) np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2) # test with closed form for diagonal covariance matrices @@ -160,13 +176,15 @@ def test_empirical_bures_wasserstein_barycenter(nx, bias): X = [] y = [] for _ in range(k): - X_, y_ = make_data_classif('3gauss', n) + X_, y_ = make_data_classif("3gauss", n) X.append(X_) y.append(y_) X = nx.from_numpy(*X) - mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=True, bias=bias) + mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter( + X, log=True, bias=bias + ) mb, Cb = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=False, bias=bias) np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2) @@ -179,8 +197,8 @@ def test_gaussian_gromov_wasserstein_distance(nx, d_target): nt = 400 rng = np.random.RandomState(10) - Xs, ys = make_data_classif('3gauss', ns, random_state=rng) - Xt, yt = make_data_classif('3gauss2', nt, random_state=rng) + Xs, ys = make_data_classif("3gauss", ns, random_state=rng) + Xt, yt = make_data_classif("3gauss2", nt, random_state=rng) Xt = np.concatenate((Xt, rng.normal(0, 1, (nt, 8))), axis=1) Xt = Xt[:, 0:d_target].reshape((nt, d_target)) @@ -192,10 +210,14 @@ def test_gaussian_gromov_wasserstein_distance(nx, d_target): Xsb, Xtb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, Xt, ms, mt, Cs, Ct) Gb, log = ot.gaussian.gaussian_gromov_wasserstein_distance(Csb, Ctb, log=True) - Ge, log = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance(Xsb, Xtb, log=True) + Ge, log = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance( + Xsb, Xtb, log=True + ) # no log - Ge0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance(Xsb, Xtb, log=False) + Ge0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_distance( + Xsb, Xtb, log=False + ) np.testing.assert_allclose(nx.to_numpy(Gb), nx.to_numpy(Ge), rtol=1e-2, atol=1e-2) np.testing.assert_allclose(nx.to_numpy(Ge), nx.to_numpy(Ge0), rtol=1e-2, atol=1e-2) @@ -207,8 +229,8 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): nt = 400 rng = np.random.RandomState(10) - Xs, ys = make_data_classif('3gauss', ns, random_state=rng) - Xt, yt = make_data_classif('3gauss2', nt, random_state=rng) + Xs, ys = make_data_classif("3gauss", ns, random_state=rng) + Xt, yt = make_data_classif("3gauss2", nt, random_state=rng) Xt = np.concatenate((Xt, rng.normal(0, 1, (nt, 8))), axis=1) Xt = Xt[:, 0:d_target].reshape((nt, d_target)) @@ -219,11 +241,17 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): Xsb, Xtb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, Xt, ms, mt, Cs, Ct) - A, b, log = ot.gaussian.gaussian_gromov_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) - Ae, be, loge = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=True) + A, b, log = ot.gaussian.gaussian_gromov_wasserstein_mapping( + msb, mtb, Csb, Ctb, log=True + ) + Ae, be, loge = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping( + Xsb, Xtb, log=True + ) # no log + skewness - Ae0, be0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(Xsb, Xtb, log=False, sign_eigs='skewness') + Ae0, be0 = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping( + Xsb, Xtb, log=False, sign_eigs="skewness" + ) Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Cst = np.cov(Xst.T) @@ -233,7 +261,9 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target): np.testing.assert_allclose(Ct, Cst) # test the other way around (target to source) - Ai, bi, logi = ot.gaussian.gaussian_gromov_wasserstein_mapping(mtb, msb, Ctb, Csb, log=True) + Ai, bi, logi = ot.gaussian.gaussian_gromov_wasserstein_mapping( + mtb, msb, Ctb, Csb, log=True + ) Xtt = nx.to_numpy(nx.dot(Xtb, Ai) + bi) Ctt = np.cov(Xtt.T) diff --git a/test/test_gmm.py b/test/test_gmm.py index 5280b2c14..5f1a92965 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -9,7 +9,15 @@ import numpy as np import pytest from ot.utils import proj_simplex -from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures_squared, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density +from ot.gmm import ( + gaussian_pdf, + gmm_pdf, + dist_bures_squared, + gmm_ot_loss, + gmm_ot_plan, + gmm_ot_apply_map, + gmm_ot_plan_density, +) try: import torch @@ -51,7 +59,10 @@ def test_gaussian_pdf(nx): x = nx.from_numpy(rng.randn(n, n, d)) pdf = gaussian_pdf(x, m[0], C[0]) - assert pdf.shape == (n, n,) + assert pdf.shape == ( + n, + n, + ) with pytest.raises(AssertionError): gaussian_pdf(x, m[0, :-1], C[0]) @@ -68,13 +79,16 @@ def test_gmm_pdf(nx): x = nx.from_numpy(rng.randn(n, n, d)) pdf = gmm_pdf(x, m, C, w) - assert pdf.shape == (n, n,) + assert pdf.shape == ( + n, + n, + ) with pytest.raises(AssertionError): gmm_pdf(x, m[:-1], C, w) -@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("tf") # skips because of array assignment @pytest.skip_backend("jax") def test_dist_bures_squared(nx): m_s, m_t, C_s, C_t, _, _ = get_gmms(nx) @@ -93,7 +107,7 @@ def test_dist_bures_squared(nx): dist_bures_squared(m_s, m_t[1:], C_s, C_t) -@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("tf") # skips because of array assignment @pytest.skip_backend("jax") def test_gmm_ot_loss(nx): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) @@ -112,7 +126,7 @@ def test_gmm_ot_loss(nx): gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t[1:]) -@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("tf") # skips because of array assignment @pytest.skip_backend("jax") def test_gmm_ot_plan(nx): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) @@ -138,7 +152,7 @@ def test_gmm_apply_map(): rng = np.random.RandomState(seed=42) x = rng.randn(7, 3) - for method in ['bary', 'rand']: + for method in ["bary", "rand"]: gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, method=method) plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) diff --git a/test/test_gnn.py b/test/test_gnn.py index f84a435e9..150dc6b75 100644 --- a/test/test_gnn.py +++ b/test/test_gnn.py @@ -5,7 +5,6 @@ # # License: MIT License - import pytest try: # test if pytorch_geometric is installed @@ -40,12 +39,13 @@ def __init__(self, n_features, n_templates, n_template_nodes): self.n_templates = n_templates self.n_template_nodes = n_template_nodes - self.TFGW = TFGWPooling(self.n_templates, self.n_template_nodes, self.n_features) + self.TFGW = TFGWPooling( + self.n_templates, self.n_template_nodes, self.n_features + ) self.linear = Linear(self.n_templates, 1) def forward(self, x, edge_index): - x = self.TFGW(x, edge_index) x = self.linear(x) @@ -69,8 +69,8 @@ def forward(self, x, edge_index): x1 = torch.rand(n_nodes, n_features) x2 = torch.rand(n_nodes, n_features) - graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) - graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.])) + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.0])) + graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.0])) dataset = DataLoader([graph1, graph2], batch_size=1) @@ -83,7 +83,6 @@ def forward(self, x, edge_index): for i in range(n_epochs): for data in dataset: - out = model_FGW(data.x, data.edge_index) loss = criterion(out, data.y) loss.backward() @@ -115,7 +114,6 @@ def __init__(self, n_features, n_templates, n_template_nodes, pooling_layer): self.linear = Linear(self.n_templates, 1) def forward(self, x, edge_index, batch=None): - x = self.TFGW(x, edge_index, batch=batch) x = self.linear(x) @@ -132,17 +130,28 @@ def forward(self, x, edge_index, batch=None): C1 = torch.randint(0, 2, size=(n_nodes, n_nodes)) edge_index1 = torch.stack(torch.where(C1 == 1)) x1 = torch.rand(n_nodes, n_features) - graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.0])) batch1 = torch.tensor([1] * n_nodes) - batch1[:n_nodes // 2] = 0 + batch1[: n_nodes // 2] = 0 criterion = torch.nn.CrossEntropyLoss() for train_node_weights in [True, False]: for alpha in [None, 0, 0.5]: for multi_alpha in [True, False]: - model = GNN_pooling(n_features, n_templates, n_template_nodes, - pooling_layer=TFGWPooling(n_templates, n_template_nodes, n_features, alpha=alpha, multi_alpha=multi_alpha, train_node_weights=train_node_weights)) + model = GNN_pooling( + n_features, + n_templates, + n_template_nodes, + pooling_layer=TFGWPooling( + n_templates, + n_template_nodes, + n_features, + alpha=alpha, + multi_alpha=multi_alpha, + train_node_weights=train_node_weights, + ), + ) # predict out1 = model(graph1.x, graph1.edge_index) @@ -177,7 +186,6 @@ def __init__(self, n_features, n_templates, n_template_nodes, pooling_layer): self.linear = Linear(self.n_templates, 1) def forward(self, x, edge_index, batch=None): - x = self.TFGW(x, edge_index, batch=batch) x = self.linear(x) @@ -194,16 +202,24 @@ def forward(self, x, edge_index, batch=None): C1 = torch.randint(0, 2, size=(n_nodes, n_nodes)) edge_index1 = torch.stack(torch.where(C1 == 1)) x1 = torch.rand(n_nodes, n_features) - graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.0])) batch1 = torch.tensor([1] * n_nodes) - batch1[:n_nodes // 2] = 0 + batch1[: n_nodes // 2] = 0 criterion = torch.nn.CrossEntropyLoss() for train_node_weights in [True, False]: - - model = GNN_pooling(n_features, n_templates, n_template_nodes, - pooling_layer=TWPooling(n_templates, n_template_nodes, n_features, train_node_weights=train_node_weights)) + model = GNN_pooling( + n_features, + n_templates, + n_template_nodes, + pooling_layer=TWPooling( + n_templates, + n_template_nodes, + n_features, + train_node_weights=train_node_weights, + ), + ) out1 = model(graph1.x, graph1.edge_index) loss = criterion(out1, graph1.y) @@ -232,12 +248,13 @@ def __init__(self, n_features, n_templates, n_template_nodes): self.n_templates = n_templates self.n_template_nodes = n_template_nodes - self.TFGW = TWPooling(self.n_templates, self.n_template_nodes, self.n_features) + self.TFGW = TWPooling( + self.n_templates, self.n_template_nodes, self.n_features + ) self.linear = Linear(self.n_templates, 1) def forward(self, x, edge_index): - x = self.TFGW(x, edge_index) x = self.linear(x) @@ -261,8 +278,8 @@ def forward(self, x, edge_index): x1 = torch.rand(n_nodes, n_features) x2 = torch.rand(n_nodes, n_features) - graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.])) - graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.])) + graph1 = GraphData(x=x1, edge_index=edge_index1, y=torch.tensor([0.0])) + graph2 = GraphData(x=x2, edge_index=edge_index2, y=torch.tensor([1.0])) dataset = DataLoader([graph1, graph2], batch_size=1) @@ -275,7 +292,6 @@ def forward(self, x, edge_index): for i in range(n_epochs): for data in dataset: - out = model_W(data.x, data.edge_index) loss = criterion(out, data.y) loss.backward() diff --git a/test/test_helpers.py b/test/test_helpers.py index cc4c90eaf..7a605f46f 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,4 +1,4 @@ -"""Tests for helpers functions """ +"""Tests for helpers functions""" # Author: Remi Flamary # @@ -14,7 +14,6 @@ def test_helpers(): - compiler = _get_compiler() get_openmp_flag(compiler) diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 60b2d633f..4c755d3e9 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -1,4 +1,4 @@ -""" Test for low rank sinkhorn solvers """ +"""Test for low rank sinkhorn solvers""" # Author: Laurène DAVID # @@ -31,7 +31,9 @@ def test_lowrank_sinkhorn(): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(n), (n, 1)) - Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False) + Q, R, g, log = ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, log=True, rescale_cost=False + ) P = log["lazy_plan"][:] value_linear = log["value_linear"] @@ -64,8 +66,12 @@ def test_lowrank_sinkhorn_init(init): X_t = np.reshape(1.0 * np.arange(n), (n, 1)) # test ImportError if init="kmeans" and sklearn not imported - if init in ["random", "deterministic"] or ((init == "kmeans") and (sklearn_import is True)): - Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, init=init, log=True) + if init in ["random", "deterministic"] or ( + (init == "kmeans") and (sklearn_import is True) + ): + Q, R, g, log = ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, init=init, log=True + ) P = log["lazy_plan"][:] # check constraints for P @@ -88,7 +94,9 @@ def test_lowrank_sinkhorn_alpha_error(alpha, rank): X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) with pytest.raises(ValueError): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False) + ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, rank=rank, alpha=alpha, warn=False + ) @pytest.mark.parametrize(("gamma_init"), ("rescale", "theory")) @@ -101,7 +109,9 @@ def test_lowrank_sinkhorn_gamma_init(gamma_init): X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(n), (n, 1)) - Q, R, g, log = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True) + Q, R, g, log = ot.lowrank.lowrank_sinkhorn( + X_s, X_t, a, b, reg=0.1, gamma_init=gamma_init, log=True + ) P = log["lazy_plan"][:] # check constraints for P @@ -109,7 +119,7 @@ def test_lowrank_sinkhorn_gamma_init(gamma_init): np.testing.assert_allclose(b, P.sum(0), atol=1e-05) -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") def test_lowrank_sinkhorn_backends(nx): # Test low rank sinkhorn for different backends n = 100 diff --git a/test/test_mapping.py b/test/test_mapping.py index 991b2374c..d08ec031e 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -11,6 +11,7 @@ try: # test if cvxpy is installed import cvxpy # noqa: F401 + nocvxpy = False except ImportError: nocvxpy = True @@ -18,9 +19,9 @@ @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_ssnb_qcqp_constants(): - c1, c2, c3 = ot.mapping._ssnb_qcqp_constants(.5, 1) + c1, c2, c3 = ot.mapping._ssnb_qcqp_constants(0.5, 1) np.testing.assert_almost_equal(c1, 1) - np.testing.assert_almost_equal(c2, .5) + np.testing.assert_almost_equal(c2, 0.5) np.testing.assert_almost_equal(c3, 1) @@ -28,29 +29,37 @@ def test_ssnb_qcqp_constants(): def test_nearest_brenier_potential_fit(nx): X = nx.ones((2, 2)) phi, G, log = ot.mapping.nearest_brenier_potential_fit(X, X, its=3, log=True) - np.testing.assert_almost_equal(to_numpy(G), to_numpy(X)) # image of source should be close to target + np.testing.assert_almost_equal( + to_numpy(G), to_numpy(X) + ) # image of source should be close to target # test without log but with X_classes, a, b and other init method a = nx.ones(2) / 2 - ot.mapping.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), a=a, b=a, its=1, init_method='target') + ot.mapping.nearest_brenier_potential_fit( + X, X, X_classes=nx.ones(2), a=a, b=a, its=1, init_method="target" + ) @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_brenier_potential_predict_bounds(nx): X = nx.ones((2, 2)) phi, G = ot.mapping.nearest_brenier_potential_fit(X, X, its=3) - phi_lu, G_lu, log = ot.mapping.nearest_brenier_potential_predict_bounds(X, phi, G, X, log=True) + phi_lu, G_lu, log = ot.mapping.nearest_brenier_potential_predict_bounds( + X, phi, G, X, log=True + ) # 'new' input isn't new, so should be equal to target np.testing.assert_almost_equal(to_numpy(G_lu[0]), to_numpy(X)) np.testing.assert_almost_equal(to_numpy(G_lu[1]), to_numpy(X)) # test with no log but classes - ot.mapping.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) + ot.mapping.nearest_brenier_potential_predict_bounds( + X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2) + ) def test_joint_OT_mapping(): """ Complements the tests in test_da, for verbose, log and bias options """ - xs = np.array([[.1, .2], [-.1, .3]]) + xs = np.array([[0.1, 0.2], [-0.1, 0.3]]) ot.mapping.joint_OT_mapping_kernel(xs, xs, verbose=True) ot.mapping.joint_OT_mapping_linear(xs, xs, verbose=True) ot.mapping.joint_OT_mapping_kernel(xs, xs, log=True, bias=True) diff --git a/test/test_optim.py b/test/test_optim.py index cf90350d5..2d4b4d0ff 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,4 +1,4 @@ -"""Tests for module optim fro OT optimization """ +"""Tests for module optim fro OT optimization""" # Author: Remi Flamary # @@ -9,7 +9,6 @@ def test_conditional_gradient(nx): - n_bins = 100 # nb bins # bin positions x = np.arange(n_bins, dtype=np.float64) @@ -29,7 +28,7 @@ def df(G): return G def fb(G): - return 0.5 * nx.sum(G ** 2) + return 0.5 * nx.sum(G**2) ab, bb, Mb = nx.from_numpy(a, b, M) @@ -51,7 +50,7 @@ def test_conditional_gradient_itermax(nx): cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) - cov_t = np.array([[1, -.8], [-.8, 1]]) + cov_t = np.array([[1, -0.8], [-0.8, 1]]) xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) @@ -69,16 +68,18 @@ def df(G): return G def fb(G): - return 0.5 * nx.sum(G ** 2) + return 0.5 * nx.sum(G**2) ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 - G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000, - verbose=True, log=True) - Gb, log = ot.optim.cg(ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, - verbose=True, log=True) + G, log = ot.optim.cg( + a, b, M, reg, f, df, numItermaxEmd=10000, verbose=True, log=True + ) + Gb, log = ot.optim.cg( + ab, bb, Mb, reg, fb, df, numItermaxEmd=10000, verbose=True, log=True + ) Gb = nx.to_numpy(Gb) np.testing.assert_allclose(Gb, G) @@ -87,7 +88,6 @@ def fb(G): def test_generalized_conditional_gradient(nx): - n_bins = 100 # nb bins # bin positions x = np.arange(n_bins, dtype=np.float64) @@ -107,7 +107,7 @@ def df(G): return G def fb(G): - return 0.5 * nx.sum(G ** 2) + return 0.5 * nx.sum(G**2) reg1 = 1e-3 reg2 = 1e-1 @@ -133,22 +133,19 @@ def test_line_search_armijo(nx): xk = np.array([[0.25, 0.25], [0.25, 0.25]]) pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) - old_fval = -123. + old_fval = -123.0 xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) def f(x): - return 1. + return 1.0 + # Should not throw an exception and return 0. for alpha - alpha, a, b = ot.optim.line_search_armijo( - f, xkb, pkb, gfkb, old_fval - ) - alpha_np, anp, bnp = ot.optim.line_search_armijo( - f, xk, pk, gfk, old_fval - ) + alpha, a, b = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval) + alpha_np, anp, bnp = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval) assert a == anp assert b == bnp - assert alpha == 0. + assert alpha == 0.0 # check line search armijo def f(x): @@ -186,6 +183,7 @@ def grad(x): def test_line_search_armijo_dtype_device(nx): for tp in nx.__type_list__: + def f(x): return nx.sum((x - 5.0) ** 2) @@ -207,7 +205,9 @@ def grad(x): # check the case where the direction is not far enough pk = np.array([[[3.0, 3.0]]]) pkb = nx.from_numpy(pk, type_as=tp) - alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0) + alpha, _, fval = ot.optim.line_search_armijo( + f, xkb, pkb, gfkb, old_fval, alpha0=1.0 + ) alpha = nx.to_numpy(alpha) np.testing.assert_allclose(alpha, 1.0) nx.assert_same_dtype_device(old_fval, fval) diff --git a/test/test_ot.py b/test/test_ot.py index a90321d5f..da0ec746e 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -1,4 +1,4 @@ -"""Tests for main module ot """ +"""Tests for main module ot""" # Author: Remi Flamary # @@ -125,7 +125,7 @@ def test_emd_emd2_devices_tf(): nx.assert_same_dtype_device(Mb, Gb) nx.assert_same_dtype_device(Mb, w) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) @@ -147,7 +147,6 @@ def test_emd2_gradients(): M = ot.dist(x, y) if torch: - a1 = torch.tensor(a, requires_grad=True) b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) @@ -160,11 +159,15 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape - assert np.allclose(a1.grad.cpu().detach().numpy(), - log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + assert np.allclose( + a1.grad.cpu().detach().numpy(), + log["u"].cpu().detach().numpy() - log["u"].cpu().detach().numpy().mean(), + ) - assert np.allclose(b1.grad.cpu().detach().numpy(), - log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + assert np.allclose( + b1.grad.cpu().detach().numpy(), + log["v"].cpu().detach().numpy() - log["v"].cpu().detach().numpy().mean(), + ) # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) @@ -175,12 +178,15 @@ def test_emd2_gradients(): val.backward() - assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), - a2.grad.cpu().detach().numpy()) - assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), - b2.grad.cpu().detach().numpy()) - assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), - M2.grad.cpu().detach().numpy()) + assert np.allclose( + 10.0 * a1.grad.cpu().detach().numpy(), a2.grad.cpu().detach().numpy() + ) + assert np.allclose( + 10.0 * b1.grad.cpu().detach().numpy(), b2.grad.cpu().detach().numpy() + ) + assert np.allclose( + 10.0 * M1.grad.cpu().detach().numpy(), M2.grad.cpu().detach().numpy() + ) def test_emd_emd2(): @@ -264,30 +270,30 @@ def test_emd2_multi(): M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) # M/=M.max() - print('Computing {} EMD '.format(nb)) + print("Computing {} EMD ".format(nb)) # emd loss 1 proc ot.tic() emd1 = ot.emd2(a, b, M, 1) - ot.toc('1 proc : {} s') + ot.toc("1 proc : {} s") # emd loss multipro proc ot.tic() emdn = ot.emd2(a, b, M) - ot.toc('multi proc : {} s') + ot.toc("multi proc : {} s") np.testing.assert_allclose(emd1, emdn) # emd loss multipro proc with log ot.tic() emdn = ot.emd2(a, b, M, log=True, return_matrix=True) - ot.toc('multi proc : {} s') + ot.toc("multi proc : {} s") for i in range(len(emdn)): emd = emdn[i] log = emd[1] cost = emd[0] - check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost) + check_duality_gap(a, b[:, i], M, log["G"], log["u"], log["v"], cost) emdn[i] = cost emdn = np.array(emdn) @@ -304,20 +310,23 @@ def test_lp_barycenter(): # obvious barycenter between two Diracs bary0 = np.array([0, 1.0, 0]) - bary = ot.lp.barycenter(A, M, [.5, .5]) + bary = ot.lp.barycenter(A, M, [0.5, 0.5]) np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) def test_free_support_barycenter(): - measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] - measures_weights = [np.array([1.]), np.array([1.])] + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] - X_init = np.array([-12.]).reshape((1, 1)) + X_init = np.array([-12.0]).reshape((1, 1)) # obvious barycenter location between two Diracs - bar_locations = np.array([0.]).reshape((1, 1)) + bar_locations = np.array([0.0]).reshape((1, 1)) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) @@ -325,10 +334,12 @@ def test_free_support_barycenter(): def test_free_support_barycenter_backends(nx): - - measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] - measures_weights = [np.array([1.]), np.array([1.])] - X_init = np.array([-12.]).reshape((1, 1)) + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + X_init = np.array([-12.0]).reshape((1, 1)) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) @@ -342,30 +353,35 @@ def test_free_support_barycenter_backends(nx): def test_generalised_free_support_barycenter(): - X = [np.array([-1., -1.]).reshape((1, 2)), np.array([1., 1.]).reshape((1, 2))] # two 2D points bar is obviously 0 - a = [np.array([1.]), np.array([1.])] + X = [ + np.array([-1.0, -1.0]).reshape((1, 2)), + np.array([1.0, 1.0]).reshape((1, 2)), + ] # two 2D points bar is obviously 0 + a = [np.array([1.0]), np.array([1.0])] P = [np.eye(2), np.eye(2)] - Y_init = np.array([-12., 7.]).reshape((1, 2)) + Y_init = np.array([-12.0, 7.0]).reshape((1, 2)) # obvious barycenter location between two 2D Diracs - Y_true = np.array([0., .0]).reshape((1, 2)) + Y_true = np.array([0.0, 0.0]).reshape((1, 2)) # test without log and no init Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1) np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) # test with log and init - Y, _ = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init, b=np.array([1.]), log=True) + Y, _ = ot.lp.generalized_free_support_barycenter( + X, a, P, 1, Y_init=Y_init, b=np.array([1.0]), log=True + ) np.testing.assert_allclose(Y, Y_true, rtol=1e-5, atol=1e-7) def test_generalised_free_support_barycenter_backends(nx): - X = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] - a = [np.array([1.]), np.array([1.])] - P = [np.array([1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] - Y_init = np.array([-12.]).reshape((1, 1)) + X = [np.array([-1.0]).reshape((1, 1)), np.array([1.0]).reshape((1, 1))] + a = [np.array([1.0]), np.array([1.0])] + P = [np.array([1.0]).reshape((1, 1)), np.array([1.0]).reshape((1, 1))] + Y_init = np.array([-12.0]).reshape((1, 1)) Y = ot.lp.generalized_free_support_barycenter(X, a, P, 1, Y_init=Y_init) @@ -390,7 +406,7 @@ def test_lp_barycenter_cvxopt(): # obvious barycenter between two Diracs bary0 = np.array([0, 1.0, 0]) - bary = ot.lp.barycenter(A, M, [.5, .5], solver=None) + bary = ot.lp.barycenter(A, M, [0.5, 0.5], solver=None) np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7) np.testing.assert_allclose(bary.sum(), 1) @@ -413,15 +429,15 @@ def test_warnings(): b = gauss(m, m=mean2, s=10) # loss matrix - M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1.0 / 2) - print('Computing {} EMD '.format(1)) + print("Computing {} EMD ".format(1)) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - print('Computing {} EMD '.format(1)) + print("Computing {} EMD ".format(1)) ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) - #assert len(w) == 1 + # assert len(w) == 1 def test_dual_variables(): @@ -441,18 +457,18 @@ def test_dual_variables(): b = gauss(m, m=mean2, s=10) # loss matrix - M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2) + M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1.0 / 2) - print('Computing {} EMD '.format(1)) + print("Computing {} EMD ".format(1)) # emd loss 1 proc ot.tic() G, log = ot.emd(a, b, M, log=True) - ot.toc('1 proc : {} s') + ot.toc("1 proc : {} s") ot.tic() G2 = ot.emd(b, a, np.ascontiguousarray(M.T)) - ot.toc('1 proc : {} s') + ot.toc("1 proc : {} s") cost1 = (G * M).sum() # Check symmetry @@ -461,10 +477,10 @@ def test_dual_variables(): np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2)) # Check that both cost computations are equivalent - np.testing.assert_almost_equal(cost1, log['cost']) - check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) + np.testing.assert_almost_equal(cost1, log["cost"]) + check_duality_gap(a, b, M, G, log["u"], log["v"], log["cost"]) - constraint_violation = log['u'][:, None] + log['v'][None, :] - M + constraint_violation = log["u"][:, None] + log["v"][None, :] - M assert constraint_violation.max() < 1e-8 @@ -477,5 +493,6 @@ def check_duality_gap(a, b, M, G, u, v, cost): [ind1, ind2] = np.nonzero(G) # Check that reduced cost is zero on transport arcs - np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2], - np.zeros(ind1.size)) + np.testing.assert_array_almost_equal( + (M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2], np.zeros(ind1.size) + ) diff --git a/test/test_partial.py b/test/test_partial.py index 0b49b2892..464003a2e 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -1,4 +1,4 @@ -"""Tests for module partial """ +"""Tests for module partial""" # Author: # Laetitia Chapel @@ -13,7 +13,6 @@ def test_raise_errors(): - n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -53,16 +52,15 @@ def test_raise_errors(): ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, - log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, - log=True) + ot.partial.entropic_partial_gromov_wasserstein( + M, M, p, q, reg=1, m=-1, log=True + ) def test_partial_wasserstein_lagrange(): - n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -86,7 +84,6 @@ def test_partial_wasserstein_lagrange(): def test_partial_wasserstein(nx): - n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -109,7 +106,9 @@ def test_partial_wasserstein(nx): p, q, M = nx.from_numpy(p, q, M) w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) - w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) + w, log = ot.partial.entropic_partial_wasserstein( + p, q, M, reg=1, m=m, log=True, verbose=True + ) # check constraints np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) @@ -124,7 +123,7 @@ def test_partial_wasserstein(nx): w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False) - G = log0['T'] + G = log0["T"] np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) @@ -145,7 +144,9 @@ def test_partial_wasserstein(nx): # check transported mass np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04) - w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None) + w0 = ot.partial.entropic_partial_wasserstein( + empty_array, empty_array, M=M, reg=10, m=None + ) # check constraints np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) @@ -200,9 +201,11 @@ def test_entropic_partial_wasserstein_gradient(): m = 0.5 reg = 1 - _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True) + _, log = ot.partial.entropic_partial_wasserstein( + p, q, M, m=m, reg=reg, log=True + ) - log['partial_w_dist'].backward() + log["partial_w_dist"].backward() assert M.grad is not None assert p.grad is not None @@ -238,50 +241,49 @@ def test_partial_gromov_wasserstein(): C3 = ot.dist(xt2, xt2) m = 2 / 3 - res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C3, p, q, m=m, - log=True, verbose=True) + res0, log0 = ot.partial.partial_gromov_wasserstein( + C1, C3, p, q, m=m, log=True, verbose=True + ) np.testing.assert_allclose(res0, 0, atol=1e-1, rtol=1e-1) C1 = sp.spatial.distance.cdist(xs, xs) C2 = sp.spatial.distance.cdist(xt, xt) m = 1 - res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, - log=True) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss') + res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, "square_loss") np.testing.assert_allclose(G, res0, atol=1e-04) - res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True) - G = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=10) + res, log = ot.partial.entropic_partial_gromov_wasserstein( + C1, C2, p, q, 10, m=m, log=True + ) + G = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, "square_loss", epsilon=10) np.testing.assert_allclose(G, res, atol=1e-02) - w0, log0 = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m, - log=True) - w0_val = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m, - log=False) - G = log0['T'] + w0, log0 = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m, log=True) + w0_val = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m, log=False) + G = log0["T"] np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) m = 2 / 3 - res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, - log=True) - res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, - 100, m=m, - log=True) + res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) + res, log = ot.partial.entropic_partial_gromov_wasserstein( + C1, C2, p, q, 100, m=m, log=True + ) # check constraints np.testing.assert_equal( - res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + res0.sum(1) <= p, [True] * len(p) + ) # cf convergence wasserstein np.testing.assert_equal( - res0.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(res0), m, atol=1e-04) + res0.sum(0) <= q, [True] * len(q) + ) # cf convergence wasserstein + np.testing.assert_allclose(np.sum(res0), m, atol=1e-04) np.testing.assert_equal( - res.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + res.sum(1) <= p, [True] * len(p) + ) # cf convergence wasserstein np.testing.assert_equal( - res.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(res), m, atol=1e-04) + res.sum(0) <= q, [True] * len(q) + ) # cf convergence wasserstein + np.testing.assert_allclose(np.sum(res), m, atol=1e-04) diff --git a/test/test_plot.py b/test/test_plot.py index a3aade5f8..15b9d2c58 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -1,4 +1,4 @@ -"""Tests for module plot for visualization """ +"""Tests for module plot for visualization""" # Author: Remi Flamary # @@ -10,7 +10,8 @@ try: # test if matplotlib is installed import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") nogo = False except ImportError: nogo = True @@ -18,7 +19,6 @@ @pytest.mark.skipif(nogo, reason="Matplotlib not installed") def test_plot1D_mat(): - import ot import ot.plot @@ -36,15 +36,14 @@ def test_plot1D_mat(): M /= M.max() ot.plot.plot1D_mat(a, b, M) - ot.plot.plot1D_mat(a, b, M, plot_style='xy') + ot.plot.plot1D_mat(a, b, M, plot_style="xy") with pytest.raises(AssertionError): - ot.plot.plot1D_mat(a, b, M, plot_style='NotAValidStyle') + ot.plot.plot1D_mat(a, b, M, plot_style="NotAValidStyle") @pytest.mark.skipif(nogo, reason="Matplotlib not installed") def test_rescale_for_imshow_plot(): - import ot import ot.plot @@ -55,10 +54,12 @@ def test_rescale_for_imshow_plot(): y = np.linspace(a_y, b_y, n) x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n) - assert x_rescaled.shape == (n, ) - assert y_rescaled.shape == (n, ) + assert x_rescaled.shape == (n,) + assert y_rescaled.shape == (n,) - x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n, m=n, a_y=a_y + 1, b_y=b_y - 1) + x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot( + x, y, n, m=n, a_y=a_y + 1, b_y=b_y - 1 + ) assert x_rescaled.shape[0] <= n assert y_rescaled.shape[0] <= n with pytest.raises(AssertionError): @@ -67,7 +68,6 @@ def test_rescale_for_imshow_plot(): @pytest.mark.skipif(nogo, reason="Matplotlib not installed") def test_plot2D_samples_mat(): - import ot import ot.plot @@ -77,7 +77,7 @@ def test_plot2D_samples_mat(): cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4]) - cov_t = np.array([[1, -.8], [-.8, 1]]) + cov_t = np.array([[1, -0.8], [-0.8, 1]]) rng = np.random.RandomState(42) xs = ot.datasets.make_2D_samples_gauss(n_bins, mu_s, cov_s, random_state=rng) diff --git a/test/test_regpath.py b/test/test_regpath.py index 76be39caf..95d8d9374 100644 --- a/test/test_regpath.py +++ b/test/test_regpath.py @@ -9,9 +9,8 @@ def test_fully_relaxed_path(): - - n_source = 50 # nb source samples (gaussian) - n_target = 40 # nb target samples (gaussian) + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 2]]) @@ -28,8 +27,7 @@ def test_fully_relaxed_path(): M = ot.dist(xs, xt) M /= M.max() - t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, - semi_relaxed=False) + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, semi_relaxed=False) G = t.reshape((n_source, n_target)) np.testing.assert_allclose(a, G.sum(1), atol=1e-05) @@ -37,9 +35,8 @@ def test_fully_relaxed_path(): def test_semi_relaxed_path(): - - n_source = 50 # nb source samples (gaussian) - n_target = 40 # nb target samples (gaussian) + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 2]]) @@ -56,8 +53,7 @@ def test_semi_relaxed_path(): M = ot.dist(xs, xt) M /= M.max() - t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, - semi_relaxed=True) + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, semi_relaxed=True) G = t.reshape((n_source, n_target)) np.testing.assert_allclose(a, G.sum(1), atol=1e-05) diff --git a/test/test_sliced.py b/test/test_sliced.py index 0062e12a0..566a7fdf6 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -16,7 +16,7 @@ def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) - np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) + np.testing.assert_almost_equal(np.sum(projections**2, 0), 1.0) def test_sliced_same_dist(): @@ -27,7 +27,7 @@ def test_sliced_same_dist(): u = ot.utils.unif(n) res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) - np.testing.assert_almost_equal(res, 0.) + np.testing.assert_almost_equal(res, 0.0) def test_sliced_bad_shapes(): @@ -69,7 +69,7 @@ def test_sliced_different_dists(): y = rng.randn(n, 2) res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) - assert res > 0. + assert res > 0.0 def test_1d_sliced_equals_emd(): @@ -84,7 +84,7 @@ def test_1d_sliced_equals_emd(): u = ot.utils.unif(m) res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) - np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res**2, expected) def test_max_sliced_same_dist(): @@ -95,7 +95,7 @@ def test_max_sliced_same_dist(): u = ot.utils.unif(n) res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) - np.testing.assert_almost_equal(res, 0.) + np.testing.assert_almost_equal(res, 0.0) def test_max_sliced_different_dists(): @@ -107,7 +107,7 @@ def test_max_sliced_different_dists(): y = rng.randn(n, 2) res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) - assert res > 0. + assert res > 0.0 def test_sliced_same_proj(): @@ -116,16 +116,17 @@ def test_sliced_same_proj(): rng = np.random.RandomState(0) X = rng.randn(8, 2) Y = rng.randn(8, 2) - cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True) + cost1, log1 = ot.sliced_wasserstein_distance( + X, Y, seed=seed, n_projections=n_projections, log=True + ) P = get_random_projections(X.shape[1], n_projections=10, seed=seed) cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True) - assert np.allclose(log1['projections'], log2['projections']) + assert np.allclose(log1["projections"], log2["projections"]) assert np.isclose(cost1, cost2) def test_sliced_backend(nx): - n = 100 rng = np.random.RandomState(0) @@ -188,7 +189,7 @@ def test_sliced_backend_device_tf(): valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -197,7 +198,6 @@ def test_sliced_backend_device_tf(): def test_max_sliced_backend(nx): - n = 100 rng = np.random.RandomState(0) @@ -213,8 +213,12 @@ def test_max_sliced_backend(nx): val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) - val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) - val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val = ot.max_sliced_wasserstein_distance( + xb, yb, n_projections=n_projections, seed=0 + ) + val2 = ot.max_sliced_wasserstein_distance( + xb, yb, n_projections=n_projections, seed=0 + ) assert val > 0 assert val == val2 @@ -260,7 +264,7 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) - if len(tf.config.list_physical_devices('GPU')) > 0: + if len(tf.config.list_physical_devices("GPU")) > 0: # Check that everything happens on the GPU xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -275,12 +279,15 @@ def test_projections_stiefel(): x = rng.randn(100, 3) x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) - ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, - seed=rng, log=True) + ssw, log = ot.sliced_wasserstein_sphere( + x, x, n_projections=n_projs, seed=rng, log=True + ) P = log["projections"] P_T = np.transpose(P, [0, 2, 1]) - np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + np.testing.assert_almost_equal( + np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]) + ) def test_sliced_sphere_same_dist(): @@ -292,7 +299,7 @@ def test_sliced_sphere_same_dist(): u = ot.utils.unif(n) res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) - np.testing.assert_almost_equal(res, 0.) + np.testing.assert_almost_equal(res, 0.0) def test_sliced_sphere_same_proj(): @@ -308,10 +315,14 @@ def test_sliced_sphere_same_proj(): seed = 42 - cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) - cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + cost1, log1 = ot.sliced_wasserstein_sphere( + x, y, seed=seed, n_projections=n_projections, log=True + ) + cost2, log2 = ot.sliced_wasserstein_sphere( + x, y, seed=seed, n_projections=n_projections, log=True + ) - assert np.allclose(log1['projections'], log2['projections']) + assert np.allclose(log1["projections"], log2["projections"]) assert np.isclose(cost1, cost2) @@ -378,7 +389,7 @@ def test_sliced_sphere_different_dists(): y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) - assert res > 0. + assert res > 0.0 def test_1d_sliced_sphere_equals_emd(): @@ -403,7 +414,7 @@ def test_1d_sliced_sphere_equals_emd(): res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) - np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res**2, expected) np.testing.assert_almost_equal(res1, expected1, decimal=3) @@ -426,7 +437,9 @@ def test_sliced_sphere_backend_type_devices(nx): xb, yb = nx.from_numpy(x, y, type_as=tp) - valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp)) + valb = ot.sliced_wasserstein_sphere( + xb, yb, projections=nx.from_numpy(P, type_as=tp) + ) nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) diff --git a/test/test_smooth.py b/test/test_smooth.py index dbdd40541..a18fe3592 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -1,4 +1,4 @@ -"""Tests for ot.smooth model """ +"""Tests for ot.smooth model""" # Author: Remi Flamary # @@ -11,7 +11,6 @@ def test_smooth_ot_dual(): - # get data n = 100 rng = np.random.RandomState(0) @@ -22,25 +21,23 @@ def test_smooth_ot_dual(): M = ot.dist(x, x) with pytest.raises(NotImplementedError): - Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type="none") # squared l2 regularisation - Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) + Gl2, log = ot.smooth.smooth_ot_dual( + u, u, M, 1, reg_type="l2", log=True, stopThr=1e-10 + ) # check constraints - np.testing.assert_allclose( - u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn # kl regularisation - G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) + G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type="kl", stopThr=1e-10) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) @@ -48,24 +45,25 @@ def test_smooth_ot_dual(): # sparsity-constrained regularisation max_nz = 2 Gsc, log = ot.smooth.smooth_ot_dual( - u, u, M, 1, + u, + u, + M, + 1, max_nz=max_nz, log=True, - reg_type='sparsity_constrained', - stopThr=1e-10) + reg_type="sparsity_constrained", + stopThr=1e-10, + ) # check marginal constraints np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) # check sparsity constraints - np.testing.assert_array_less( - np.sum(Gsc > 0, axis=0), - np.ones(n) * max_nz + 1) + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), np.ones(n) * max_nz + 1) def test_smooth_ot_semi_dual(): - # get data n = 100 rng = np.random.RandomState(0) @@ -76,25 +74,23 @@ def test_smooth_ot_semi_dual(): M = ot.dist(x, x) with pytest.raises(NotImplementedError): - Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type="none") # squared l2 regularisation - Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) + Gl2, log = ot.smooth.smooth_ot_semi_dual( + u, u, M, 1, reg_type="l2", log=True, stopThr=1e-10 + ) # check constraints - np.testing.assert_allclose( - u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn # kl regularisation - G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10) + G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type="kl", stopThr=1e-10) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-05) # cf convergence sinkhorn - np.testing.assert_allclose( - u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) @@ -102,23 +98,24 @@ def test_smooth_ot_semi_dual(): # sparsity-constrained regularisation max_nz = 2 Gsc = ot.smooth.smooth_ot_semi_dual( - u, u, M, 1, reg_type='sparsity_constrained', - max_nz=max_nz, stopThr=1e-10) + u, u, M, 1, reg_type="sparsity_constrained", max_nz=max_nz, stopThr=1e-10 + ) # check marginal constraints np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) # check sparsity constraints - np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), - np.ones(n) * max_nz + 1) + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), np.ones(n) * max_nz + 1) def test_sparsity_constrained_gradient(): max_nz = 5 regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz) rng = np.random.RandomState(0) - X = rng.randn(10,) + X = rng.randn( + 10, + ) b = 0.5 def delta_omega_func(X): diff --git a/test/test_solvers.py b/test/test_solvers.py index a338f93a6..82a402df1 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -4,7 +4,6 @@ # # License: MIT License - import itertools import numpy as np import pytest @@ -16,44 +15,68 @@ lst_reg = [None, 1] -lst_reg_type = ['KL', 'entropy', 'L2', 'tuple'] +lst_reg_type = ["KL", "entropy", "L2", "tuple"] lst_unbalanced = [None, 0.9] -lst_unbalanced_type = ['KL', 'L2', 'TV'] +lst_unbalanced_type = ["KL", "L2", "TV"] -lst_reg_type_gromov = ['entropy'] -lst_gw_losses = ['L2', 'KL'] -lst_unbalanced_type_gromov = ['KL', 'semirelaxed', 'partial'] +lst_reg_type_gromov = ["entropy"] +lst_gw_losses = ["L2", "KL"] +lst_unbalanced_type_gromov = ["KL", "semirelaxed", "partial"] lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] lst_method_params_solve_sample = [ - {'method': '1d'}, - {'method': '1d', 'metric': 'euclidean'}, - {'method': 'gaussian'}, - {'method': 'gaussian', 'reg': 1}, - {'method': 'factored', 'rank': 10}, - {'method': 'lowrank', 'rank': 10} + {"method": "1d"}, + {"method": "1d", "metric": "euclidean"}, + {"method": "gaussian"}, + {"method": "gaussian", "reg": 1}, + {"method": "factored", "rank": 10}, + {"method": "lowrank", "rank": 10}, ] lst_parameters_solve_sample_NotImplemented = [ - {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics - {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean - {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean - {"method": 'lowrank', 'metric': 'euclidean'}, # fail lowrank on metric not euclidean - {'lazy': True}, # fail lazy for non regularized - {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced - {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized + {"method": "1d", "metric": "any other one"}, # fail 1d on weird metrics + { + "method": "gaussian", + "metric": "euclidean", + }, # fail gaussian on metric not euclidean + { + "method": "factored", + "metric": "euclidean", + }, # fail factored on metric not euclidean + { + "method": "lowrank", + "metric": "euclidean", + }, # fail lowrank on metric not euclidean + {"lazy": True}, # fail lazy for non regularized + {"lazy": True, "unbalanced": 1}, # fail lazy for non regularized unbalanced + { + "lazy": True, + "reg": 1, + "unbalanced": 1, + }, # fail lazy for unbalanced and regularized ] # set readable ids for each param -lst_method_params_solve_sample = [pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample] -lst_parameters_solve_sample_NotImplemented = [pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented] +lst_method_params_solve_sample = [ + pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample +] +lst_parameters_solve_sample_NotImplemented = [ + pytest.param(param, id=str(param)) + for param in lst_parameters_solve_sample_NotImplemented +] def assert_allclose_sol(sol1, sol2): - - lst_attr = ['value', 'value_linear', 'plan', - 'potential_a', 'potential_b', 'marginal_a', 'marginal_b'] + lst_attr = [ + "value", + "value_linear", + "plan", + "potential_a", + "potential_b", + "marginal_a", + "marginal_b", + ] nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() @@ -61,7 +84,11 @@ def assert_allclose_sol(sol1, sol2): for attr in lst_attr: if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)), equal_nan=True) + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) except NotImplementedError: pass elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: @@ -109,16 +136,15 @@ def test_solve(nx): # test not implemented unbalanced and check raise with pytest.raises(NotImplementedError): - sol0 = ot.solve(M, unbalanced=1, unbalanced_type='cryptic divergence') + sol0 = ot.solve(M, unbalanced=1, unbalanced_type="cryptic divergence") # test not implemented reg_type and check raise with pytest.raises(NotImplementedError): - sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') + sol0 = ot.solve(M, reg=1, reg_type="cryptic divergence") @pytest.mark.skipif(not torch, reason="torch no installed") def test_solve_envelope(): - n_samples_s = 10 n_samples_t = 7 n_features = 2 @@ -134,7 +160,7 @@ def test_solve_envelope(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol0 = ot.solve(M, a, b, reg=10, grad='envelope') + sol0 = ot.solve(M, a, b, reg=10, grad="envelope") sol0.value.backward() gM0 = M.grad.clone() @@ -145,7 +171,7 @@ def test_solve_envelope(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol = ot.solve(M, a, b, reg=10, grad='autodiff') + sol = ot.solve(M, a, b, reg=10, grad="autodiff") sol.value.backward() gM = M.grad.clone() @@ -158,7 +184,10 @@ def test_solve_envelope(): assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type", + itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type), +) def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): n_samples_s = 10 n_samples_t = 7 @@ -173,8 +202,8 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): M = ot.dist(x, y) try: + if reg_type == "tuple": - if reg_type == 'tuple': def f(G): return np.sum(G**2) @@ -184,10 +213,24 @@ def df(G): reg_type = (f, df) # solve unif weights - sol0 = ot.solve(M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + sol0 = ot.solve( + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + ) # solve signe weights - sol = ot.solve(M, a, b, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + sol = ot.solve( + M, + a, + b, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + ) assert_allclose_sol(sol0, sol) @@ -195,6 +238,7 @@ def df(G): ab, bb, Mb = nx.from_numpy(a, b, M) if isinstance(reg_type, tuple): + def f(G): return nx.sum(G**2) @@ -203,7 +247,15 @@ def df(G): reg_type = (f, df) - solb = ot.solve(Mb, ab, bb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) + solb = ot.solve( + Mb, + ab, + bb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + ) assert_allclose_sol(sol, solb) @@ -212,7 +264,6 @@ def df(G): def test_solve_not_implemented(nx): - n_samples_s = 10 n_samples_t = 7 n_features = 2 @@ -225,13 +276,12 @@ def test_solve_not_implemented(nx): # test not implemented and check raise with pytest.raises(NotImplementedError): - ot.solve(M, reg=1.0, reg_type='cryptic divergence') + ot.solve(M, reg=1.0, reg_type="cryptic divergence") with pytest.raises(NotImplementedError): - ot.solve(M, unbalanced=1.0, unbalanced_type='cryptic divergence') + ot.solve(M, unbalanced=1.0, unbalanced_type="cryptic divergence") def test_solve_gromov(nx): - np.random.seed(0) n_samples_s = 3 @@ -268,9 +318,18 @@ def test_solve_gromov(nx): assert_allclose_sol(sol0_fgw, solx_fgw) -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type,alpha,loss", itertools.product(lst_reg, lst_reg_type_gromov, lst_unbalanced_gromov, lst_unbalanced_type_gromov, lst_alpha, lst_gw_losses)) +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type,alpha,loss", + itertools.product( + lst_reg, + lst_reg_type_gromov, + lst_unbalanced_gromov, + lst_unbalanced_type_gromov, + lst_alpha, + lst_gw_losses, + ), +) def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha, loss): - np.random.seed(0) n_samples_s = 3 @@ -288,15 +347,50 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - - sol0 = ot.solve_gromov(Ca, Cb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW - sol0_fgw = ot.solve_gromov(Ca, Cb, M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW # solve in backend ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW - solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW solx.value_quad @@ -308,7 +402,6 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha def test_solve_gromov_not_implemented(nx): - np.random.seed(0) n_samples_s = 3 @@ -329,19 +422,19 @@ def test_solve_gromov_not_implemented(nx): # test not implemented and check raise with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, loss='weird loss') + ot.solve_gromov(Ca, Cb, loss="weird loss") with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, unbalanced=1, unbalanced_type='cryptic divergence') + ot.solve_gromov(Ca, Cb, unbalanced=1, unbalanced_type="cryptic divergence") with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, reg=1, reg_type='cryptic divergence') + ot.solve_gromov(Ca, Cb, reg=1, reg_type="cryptic divergence") # detect partial not implemented and error detect in value with pytest.raises(ValueError): - ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=1.5) + ot.solve_gromov(Ca, Cb, unbalanced_type="partial", unbalanced=1.5) with pytest.raises(NotImplementedError): - ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.5) + ot.solve_gromov(Ca, Cb, M, unbalanced_type="partial", unbalanced=0.5) with pytest.raises(ValueError): - ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) + ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type="partial", unbalanced=1.5) def test_solve_sample(nx): @@ -381,11 +474,13 @@ def test_solve_sample(nx): # test not implemented unbalanced and check raise with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + sol0 = ot.solve_sample( + X_s, X_t, unbalanced=1, unbalanced_type="cryptic divergence" + ) # test not implemented reg_type and check raise with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type="cryptic divergence") def test_solve_sample_lazy(nx): @@ -416,7 +511,7 @@ def test_solve_sample_lazy(nx): @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.skipif(not geomloss, reason="pytorch not installed") -@pytest.skip_backend('tf') +@pytest.skip_backend("tf") @pytest.skip_backend("cupy") @pytest.skip_backend("jax") @pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) @@ -437,28 +532,51 @@ def test_solve_sample_geomloss(nx, metric): sol0 = ot.solve_sample(xb, yb, ab, bb, reg=1) # solve signe weights - sol = ot.solve_sample(xb, yb, ab, bb, reg=1, method='geomloss') + sol = ot.solve_sample(xb, yb, ab, bb, reg=1, method="geomloss") assert_allclose_sol(sol0, sol) - sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=False, method='geomloss') + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=False, method="geomloss") assert_allclose_sol(sol0, sol) - sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_tensorized') - np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) - - sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_online') - np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) - - sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_multiscale') - np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) - - sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss') - np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + sol1 = ot.solve_sample( + xb, yb, ab, bb, reg=1, lazy=True, method="geomloss_tensorized" + ) + np.testing.assert_allclose( + nx.to_numpy(sol1.lazy_plan[:]), + nx.to_numpy(sol.lazy_plan[:]), + rtol=1e-5, + atol=1e-5, + ) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method="geomloss_online") + np.testing.assert_allclose( + nx.to_numpy(sol1.lazy_plan[:]), + nx.to_numpy(sol.lazy_plan[:]), + rtol=1e-5, + atol=1e-5, + ) + + sol1 = ot.solve_sample( + xb, yb, ab, bb, reg=1, lazy=True, method="geomloss_multiscale" + ) + np.testing.assert_allclose( + nx.to_numpy(sol1.lazy_plan[:]), + nx.to_numpy(sol.lazy_plan[:]), + rtol=1e-5, + atol=1e-5, + ) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method="geomloss") + np.testing.assert_allclose( + nx.to_numpy(sol1.lazy_plan[:]), + nx.to_numpy(sol.lazy_plan[:]), + rtol=1e-5, + atol=1e-5, + ) @pytest.mark.parametrize("method_params", lst_method_params_solve_sample) def test_solve_sample_methods(nx, method_params): - n_samples_s = 20 n_samples_t = 7 n_features = 2 @@ -478,13 +596,12 @@ def test_solve_sample_methods(nx, method_params): assert_allclose_sol(sol, solb) sol2 = ot.solve_sample(x, x, **method_params) - if method_params['method'] not in ['factored', 'lowrank']: + if method_params["method"] not in ["factored", "lowrank"]: np.testing.assert_allclose(sol2.value, 0) @pytest.mark.parametrize("method_params", lst_parameters_solve_sample_NotImplemented) def test_solve_sample_NotImplemented(nx, method_params): - n_samples_s = 20 n_samples_t = 7 n_features = 2 diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 2b5c0fb3e..911d52b92 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -41,14 +41,13 @@ def test_stochastic_sag(): M = ot.dist(x, x) - G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", - numItermax=numItermax) + G = ot.stochastic.solve_semi_dual_entropic( + u, u, M, reg, "sag", numItermax=numItermax + ) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-03) # cf convergence sag - np.testing.assert_allclose( - u, G.sum(0), atol=1e-03) # cf convergence sag + np.testing.assert_allclose(u, G.sum(1), atol=1e-03) # cf convergence sag + np.testing.assert_allclose(u, G.sum(0), atol=1e-03) # cf convergence sag ############################################################################# @@ -71,14 +70,13 @@ def test_stochastic_asgd(): M = ot.dist(x, x) - G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", - numItermax=numItermax, log=True) + G, log = ot.stochastic.solve_semi_dual_entropic( + u, u, M, reg, "asgd", numItermax=numItermax, log=True + ) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-02) # cf convergence asgd - np.testing.assert_allclose( - u, G.sum(0), atol=1e-02) # cf convergence asgd + np.testing.assert_allclose(u, G.sum(1), atol=1e-02) # cf convergence asgd + np.testing.assert_allclose(u, G.sum(0), atol=1e-02) # cf convergence asgd ############################################################################# @@ -100,25 +98,21 @@ def test_sag_asgd_sinkhorn(): u = ot.utils.unif(n) M = ot.dist(x, x) - G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", - numItermax=nb_iter) - G_sag = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag", - numItermax=nb_iter) + G_asgd = ot.stochastic.solve_semi_dual_entropic( + u, u, M, reg, "asgd", numItermax=nb_iter + ) + G_sag = ot.stochastic.solve_semi_dual_entropic( + u, u, M, reg, "sag", numItermax=nb_iter + ) G_sinkhorn = ot.sinkhorn(u, u, M, reg) # check constraints - np.testing.assert_allclose( - G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) - np.testing.assert_allclose( - G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02) - np.testing.assert_allclose( - G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) - np.testing.assert_allclose( - G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) - np.testing.assert_allclose( - G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag - np.testing.assert_allclose( - G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd + np.testing.assert_allclose(G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02) + np.testing.assert_allclose(G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02) + np.testing.assert_allclose(G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) + np.testing.assert_allclose(G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) + np.testing.assert_allclose(G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag + np.testing.assert_allclose(G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd ############################################################################# @@ -146,14 +140,13 @@ def test_stochastic_dual_sgd(): M = ot.dist(x, x) - G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, - numItermax=numItermax, log=True) + G, log = ot.stochastic.solve_dual_entropic( + u, u, M, reg, batch_size, numItermax=numItermax, log=True + ) # check constraints - np.testing.assert_allclose( - u, G.sum(1), atol=1e-03) # cf convergence sgd - np.testing.assert_allclose( - u, G.sum(0), atol=1e-03) # cf convergence sgd + np.testing.assert_allclose(u, G.sum(1), atol=1e-03) # cf convergence sgd + np.testing.assert_allclose(u, G.sum(0), atol=1e-03) # cf convergence sgd ############################################################################# @@ -172,25 +165,23 @@ def test_dual_sgd_sinkhorn(): batch_size = 10 rng = np.random.RandomState(0) -# Test uniform + # Test uniform x = rng.randn(n, 2) u = ot.utils.unif(n) M = ot.dist(x, x) - G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, - numItermax=nb_iter) + G_sgd = ot.stochastic.solve_dual_entropic( + u, u, M, reg, batch_size, numItermax=nb_iter + ) G_sinkhorn = ot.sinkhorn(u, u, M, reg) # check constraints - np.testing.assert_allclose( - G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) - np.testing.assert_allclose( - G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) - np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd - -# Test gaussian + np.testing.assert_allclose(G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02) + np.testing.assert_allclose(G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02) + np.testing.assert_allclose(G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd + + # Test gaussian n = 30 reg = 1 batch_size = 30 @@ -202,22 +193,19 @@ def test_dual_sgd_sinkhorn(): M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1))) M /= M.max() - G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, - numItermax=nb_iter) + G_sgd = ot.stochastic.solve_dual_entropic( + a, b, M, reg, batch_size, numItermax=nb_iter + ) G_sinkhorn = ot.sinkhorn(a, b, M, reg) # check constraints - np.testing.assert_allclose( - G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) - np.testing.assert_allclose( - G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) - np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + np.testing.assert_allclose(G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + np.testing.assert_allclose(G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + np.testing.assert_allclose(G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd def test_loss_dual_entropic(nx): - nx.seed(0) xs = nx.randn(50, 2) @@ -241,7 +229,6 @@ def metric(x, y): def test_plan_dual_entropic(nx): - nx.seed(0) xs = nx.randn(50, 2) @@ -273,7 +260,6 @@ def metric(x, y): def test_loss_dual_quadratic(nx): - nx.seed(0) xs = nx.randn(50, 2) @@ -297,7 +283,6 @@ def metric(x, y): def test_plan_dual_quadratic(nx): - nx.seed(0) xs = nx.randn(50, 2) diff --git a/test/test_ucoot.py b/test/test_ucoot.py index fcace4178..a9e3d6dd8 100644 --- a/test/test_ucoot.py +++ b/test/test_ucoot.py @@ -4,25 +4,28 @@ # # License: MIT License - import itertools import numpy as np import ot import pytest -from ot.gromov._unbalanced import unbalanced_co_optimal_transport, unbalanced_co_optimal_transport2 +from ot.gromov._unbalanced import ( + unbalanced_co_optimal_transport, + unbalanced_co_optimal_transport2, +) @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"]) +) def test_sanity(nx, unbalanced_solver, divergence): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -43,21 +46,51 @@ def test_sanity(nx, unbalanced_solver, divergence): id_feature = np.eye(2, 2) / 2 pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -69,21 +102,51 @@ def test_sanity(nx, unbalanced_solver, divergence): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -93,15 +156,19 @@ def test_sanity(nx, unbalanced_solver, divergence): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_init_plans(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -121,21 +188,51 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=(G0_samp_nx, G0_feat_nx), init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=(G0_samp_nx, G0_feat_nx), + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -145,21 +242,51 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=(G0_samp_nx, G0_feat_nx), init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=(G0_samp_nx, G0_feat_nx), + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -168,15 +295,19 @@ def test_init_plans(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_init_duals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -198,21 +329,51 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=init_duals, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=init_duals, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -222,21 +383,51 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=init_duals, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=init_duals, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -245,15 +436,19 @@ def test_init_duals(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_linear_part(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -278,21 +473,51 @@ def test_linear_part(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -302,21 +527,51 @@ def test_linear_part(nx, unbalanced_solver, divergence, eps): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -325,15 +580,19 @@ def test_linear_part(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_reg_marginals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -357,33 +616,77 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) for opt in list_options: - pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=opt, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=opt, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -393,12 +696,28 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): # test divergence ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=opt, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - method_sinkhorn="sinkhorn", log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=opt, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + method_sinkhorn="sinkhorn", + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -407,15 +726,19 @@ def test_reg_marginals(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, alpha", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, alpha", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_eps(nx, unbalanced_solver, divergence, alpha): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -440,33 +763,77 @@ def test_eps(nx, unbalanced_solver, divergence, alpha): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) for opt in list_options: - pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=opt, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=opt, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -476,12 +843,27 @@ def test_eps(nx, unbalanced_solver, divergence, alpha): # test divergence ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=opt, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=opt, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -490,15 +872,19 @@ def test_eps(nx, unbalanced_solver, divergence, alpha): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2] + ), +) def test_alpha(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -529,32 +915,77 @@ def test_alpha(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=M_samp, + M_feat=M_feat, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) for opt in list_options: pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=opt, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=opt, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -563,12 +994,27 @@ def test_alpha(nx, unbalanced_solver, divergence, eps): np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=opt, - M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=opt, + M_samp=M_samp_nx, + M_feat=M_feat_nx, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -577,15 +1023,19 @@ def test_alpha(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_log(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -603,21 +1053,51 @@ def test_log(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx, log = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=True, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=True, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -627,21 +1107,51 @@ def test_log(nx, unbalanced_solver, divergence, eps): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -650,15 +1160,19 @@ def test_log(nx, unbalanced_solver, divergence, eps): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tensorflow backend") -@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +@pytest.mark.parametrize( + "unbalanced_solver, divergence, eps", + itertools.product( + ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1] + ), +) def test_marginals(nx, unbalanced_solver, divergence, eps): n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -676,21 +1190,51 @@ def test_marginals(nx, unbalanced_solver, divergence, eps): # test couplings pi_sample, pi_feature = unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) pi_sample_nx = nx.to_numpy(pi_sample_nx) pi_feature_nx = nx.to_numpy(pi_feature_nx) @@ -700,21 +1244,51 @@ def test_marginals(nx, unbalanced_solver, divergence, eps): # test divergence ucoot = unbalanced_co_optimal_transport2( - X=xs, Y=xt, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=None, + wx_feat=None, + wy_samp=None, + wy_feat=None, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = unbalanced_co_optimal_transport2( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver=unbalanced_solver, alpha=alpha, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver=unbalanced_solver, + alpha=alpha, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) ucoot_nx = nx.to_numpy(ucoot_nx) @@ -729,8 +1303,7 @@ def test_raise_value_error(nx): mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss( - n_samples, mu_s, cov_s, random_state=4) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() px_s, px_f = ot.unif(n_samples), ot.unif(2) @@ -749,22 +1322,52 @@ def test_raise_value_error(nx): # raise error of divergence def ucoot_div(divergence): return unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence=divergence, - unbalanced_solver="mm", alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver="mm", + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) def ucoot_div_nx(divergence): return unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, - wy_feat=py_f_nx, reg_marginals=reg_m, epsilon=eps, - divergence=divergence, unbalanced_solver="mm", alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence=divergence, + unbalanced_solver="mm", + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) np.testing.assert_raises(NotImplementedError, ucoot_div, "div_not_existed") @@ -773,22 +1376,52 @@ def ucoot_div_nx(divergence): # raise error of solver def ucoot_solver(unbalanced_solver): return unbalanced_co_optimal_transport( - X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, - reg_marginals=reg_m, epsilon=eps, divergence="kl", - unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs, + Y=xt, + wx_samp=px_s, + wx_feat=px_f, + wy_samp=py_s, + wy_feat=py_f, + reg_marginals=reg_m, + epsilon=eps, + divergence="kl", + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) def ucoot_solver_nx(unbalanced_solver): return unbalanced_co_optimal_transport( - X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, - wy_feat=py_f_nx, reg_marginals=reg_m, epsilon=eps, - divergence="kl", unbalanced_solver=unbalanced_solver, alpha=0, - M_samp=None, M_feat=None, init_pi=None, init_duals=None, - max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, - log=False, verbose=False + X=xs_nx, + Y=xt_nx, + wx_samp=px_s_nx, + wx_feat=px_f_nx, + wy_samp=py_s_nx, + wy_feat=py_f_nx, + reg_marginals=reg_m, + epsilon=eps, + divergence="kl", + unbalanced_solver=unbalanced_solver, + alpha=0, + M_samp=None, + M_feat=None, + init_pi=None, + init_duals=None, + max_iter=max_iter, + tol=tol, + max_iter_ot=max_iter_ot, + tol_ot=tol_ot, + log=False, + verbose=False, ) np.testing.assert_raises(NotImplementedError, ucoot_solver, "solver_not_existed") diff --git a/test/test_utils.py b/test/test_utils.py index 82f514574..d50f29915 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,4 @@ -"""Tests for module utils for timing and parallel computation """ +"""Tests for module utils for timing and parallel computation""" # Author: Remi Flamary # @@ -63,7 +63,6 @@ def test_proj_simplex(nx): def test_projection_sparse_simplex(): - def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): r"""This is an equivalent but less efficient version of ot.utils.projection_sparse_simplex, as it uses two @@ -73,17 +72,13 @@ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): if axis == 0: # For each column of X, find top max_nz values and # their corresponding indices. This incurs a sort. - max_nz_indices = np.argpartition( - X, - kth=-max_nz, - axis=0)[-max_nz:] + max_nz_indices = np.argpartition(X, kth=-max_nz, axis=0)[-max_nz:] max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] # Project the top max_nz values onto the simplex. # This incurs a second sort. - G_nz_values = ot.smooth.projection_simplex( - max_nz_values, z=z, axis=0) + G_nz_values = ot.smooth.projection_simplex(max_nz_values, z=z, axis=0) # Put the projection of max_nz_values to their original indices # and set all other values zero. @@ -91,13 +86,11 @@ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values return G elif axis == 1: - return double_sort_projection_sparse_simplex( - X.T, max_nz, z, axis=0).T + return double_sort_projection_sparse_simplex(X.T, max_nz, z, axis=0).T else: X = X.ravel().reshape(-1, 1) - return double_sort_projection_sparse_simplex( - X, max_nz, z, axis=0).ravel() + return double_sort_projection_sparse_simplex(X, max_nz, z, axis=0).ravel() m, n = 5, 10 rng = np.random.RandomState(0) @@ -105,18 +98,14 @@ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): max_nz = 3 for axis in [0, 1, None]: - slow_sparse_proj = double_sort_projection_sparse_simplex( - X, max_nz, axis=axis) - fast_sparse_proj = ot.utils.projection_sparse_simplex( - X, max_nz, axis=axis) + slow_sparse_proj = double_sort_projection_sparse_simplex(X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex(X, max_nz, axis=axis) # check that two versions produce consistent results - np.testing.assert_allclose( - slow_sparse_proj, fast_sparse_proj) + np.testing.assert_allclose(slow_sparse_proj, fast_sparse_proj) def test_parmap(): - n = 10 def f(i): @@ -132,7 +121,6 @@ def f(i): def test_tic_toc(): - import time ot.tic() @@ -150,7 +138,6 @@ def test_tic_toc(): def test_kernel(): - n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -162,7 +149,6 @@ def test_kernel(): def test_unif(): - n = 100 u = ot.unif(n) @@ -171,7 +157,6 @@ def test_unif(): def test_unif_backend(nx): - n = 100 for tp in nx.__type_list__: @@ -183,7 +168,6 @@ def test_unif_backend(nx): def test_dist(): - n = 10 rng = np.random.RandomState(0) @@ -197,7 +181,7 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) - D4 = ot.dist(x, x, metric='minkowski', p=2) + D4 = ot.dist(x, x, metric="minkowski", p=2) assert D4[0, 1] == D4[1, 0] @@ -207,17 +191,36 @@ def test_dist(): # tests that every metric runs correctly metrics_w = [ - 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', - 'euclidean', 'hamming', 'jaccard', - 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', - 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule' + "braycurtis", + "canberra", + "chebyshev", + "cityblock", + "correlation", + "cosine", + "dice", + "euclidean", + "hamming", + "jaccard", + "matching", + "minkowski", + "rogerstanimoto", + "russellrao", + "sokalmichener", + "sokalsneath", + "sqeuclidean", + "yule", ] # those that support weights - metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version + metrics = [ + "mahalanobis", + "seuclidean", + ] # do not support weights depending on scipy's version for metric in metrics_w: print(metric) - ot.dist(x, x, metric=metric, p=3, w=rng.random((2, ))) - ot.dist(x, x, metric=metric, p=3, w=None) # check that not having any weight does not cause issues + ot.dist(x, x, metric=metric, p=3, w=rng.random((2,))) + ot.dist( + x, x, metric=metric, p=3, w=None + ) # check that not having any weight does not cause issues for metric in metrics: print(metric) ot.dist(x, x, metric=metric, p=3) @@ -228,16 +231,14 @@ def test_dist(): def test_dist_backends(nx): - n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) x1 = nx.from_numpy(x) - lst_metric = ['euclidean', 'sqeuclidean'] + lst_metric = ["euclidean", "sqeuclidean"] for metric in lst_metric: - D = ot.dist(x, x, metric=metric) D1 = ot.dist(x1, x1, metric=metric) @@ -246,16 +247,14 @@ def test_dist_backends(nx): def test_dist0(): - n = 100 - M = ot.utils.dist0(n, method='lin_square') + M = ot.utils.dist0(n, method="lin_square") # dist0 default to linear sampling with quadratic loss np.testing.assert_allclose(M[0, -1], (n - 1) * (n - 1)) def test_dots(): - n1, n2, n3, n4 = 100, 50, 200, 100 rng = np.random.RandomState(0) @@ -272,7 +271,6 @@ def test_dots(): def test_clean_zeros(): - n = 100 nz = 50 nz2 = 20 @@ -302,28 +300,27 @@ def test_cost_normalization(nx): M1 = nx.to_numpy(M0) np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C1, 'median') + M = ot.utils.cost_normalization(C1, "median") M1 = nx.to_numpy(M) np.testing.assert_allclose(np.median(M1), 1) - M = ot.utils.cost_normalization(C1, 'max') + M = ot.utils.cost_normalization(C1, "max") M1 = nx.to_numpy(M) np.testing.assert_allclose(M1.max(), 1) - M = ot.utils.cost_normalization(C1, 'log') + M = ot.utils.cost_normalization(C1, "log") M1 = nx.to_numpy(M) np.testing.assert_allclose(M1.max(), np.log(1 + C).max()) - M = ot.utils.cost_normalization(C1, 'loglog') + M = ot.utils.cost_normalization(C1, "loglog") M1 = nx.to_numpy(M) np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) with pytest.raises(ValueError): - ot.utils.cost_normalization(C1, 'error') + ot.utils.cost_normalization(C1, "error") def test_list_to_array(nx): - lst = [np.array([1, 2, 3]), np.array([4, 5, 6])] a1, a2 = ot.utils.list_to_array(*lst) @@ -335,17 +332,16 @@ def test_list_to_array(nx): def test_check_params(): - - res1 = ot.utils.check_params(first='OK', second=20) + res1 = ot.utils.check_params(first="OK", second=20) assert res1 is True - res0 = ot.utils.check_params(first='OK', second=None) + res0 = ot.utils.check_params(first="OK", second=None) assert res0 is False def test_check_random_state_error(): with pytest.raises(ValueError): - ot.utils.check_random_state('error') + ot.utils.check_random_state("error") def test_get_parameter_pair_error(): @@ -354,16 +350,15 @@ def test_get_parameter_pair_error(): def test_deprecated_func(): - - @ot.utils.deprecated('deprecated text for fun') + @ot.utils.deprecated("deprecated text for fun") def fun(): pass def fun2(): pass - @ot.utils.deprecated('deprecated text for class') - class Class(): + @ot.utils.deprecated("deprecated text for class") + class Class: pass with pytest.warns(DeprecationWarning): @@ -374,7 +369,7 @@ class Class(): print(cl) if sys.version_info < (3, 5): - print('Not tested') + print("Not tested") else: assert ot.utils._is_deprecated(fun) is True @@ -382,35 +377,31 @@ class Class(): def test_BaseEstimator(): - class Class(ot.utils.BaseEstimator): - - def __init__(self, first='spam', second='eggs'): - + def __init__(self, first="spam", second="eggs"): self.first = first self.second = second cl = Class() names = cl._get_param_names() - assert 'first' in names - assert 'second' in names + assert "first" in names + assert "second" in names params = cl.get_params() - assert 'first' in params - assert 'second' in params + assert "first" in params + assert "second" in params - params['first'] = 'spam again' + params["first"] = "spam again" cl.set_params(**params) with pytest.raises(ValueError): cl.set_params(bibi=10) - assert cl.first == 'spam again' + assert cl.first == "spam again" def test_OTResult(): - res = ot.utils.OTResult() # test print @@ -419,25 +410,27 @@ def test_OTResult(): # tets get citation print(res.citation) - lst_attributes = ['lazy_plan', - 'marginal_a', - 'marginal_b', - 'marginals', - 'plan', - 'potential_a', - 'potential_b', - 'potentials', - 'sparse_plan', - 'status', - 'value', - 'value_linear', - 'value_quad', - 'log'] + lst_attributes = [ + "lazy_plan", + "marginal_a", + "marginal_b", + "marginals", + "plan", + "potential_a", + "potential_b", + "potentials", + "sparse_plan", + "status", + "value", + "value_linear", + "value_quad", + "log", + ] for at in lst_attributes: print(at) assert getattr(res, at) is None - list_not_implemented = ['a_to_b', 'b_to_a'] + list_not_implemented = ["a_to_b", "b_to_a"] for at in list_not_implemented: print(at) with pytest.raises(NotImplementedError): @@ -455,7 +448,6 @@ def test_get_coordinate_circle(): def test_LazyTensor(nx): - n1 = 100 n2 = 200 shape = (n1, n2) @@ -496,7 +488,6 @@ def getitem(i, j, x1, x2): def test_OTResult_LazyTensor(nx): - T, a, b = get_LazyTensor(nx) res = ot.utils.OTResult(lazy_plan=T, batch_size=9, backend=nx) @@ -506,7 +497,6 @@ def test_OTResult_LazyTensor(nx): def test_LazyTensor_reduce(nx): - T, a, b = get_LazyTensor(nx) T0 = T[:] @@ -558,7 +548,6 @@ def getitem(i, j, k, a, b, c): def test_lowrank_LazyTensor(nx): - p = 5 n1 = 100 n2 = 200 @@ -602,13 +591,15 @@ def test_lowrank_LazyTensor(nx): def test_labels_to_mask_helper(nx): y = np.array([1, 0, 2, 2, 1]) - out = np.array([ - [0, 1, 0], - [1, 0, 0], - [0, 0, 1], - [0, 0, 1], - [0, 1, 0], - ]) + out = np.array( + [ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + [0, 0, 1], + [0, 1, 0], + ] + ) y = nx.from_numpy(y) masks = ot.utils.labels_to_masks(y) np.testing.assert_array_equal(out, masks) @@ -628,12 +619,12 @@ def test_label_normalization(nx): def test_proj_SDP(nx): t = np.pi / 8 U = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]]) - w = np.array([1., -1.]) + w = np.array([1.0, -1.0]) S = np.stack([U @ np.diag(w) @ U.T] * 2, axis=0) S_nx = nx.from_numpy(S) R = ot.utils.proj_SDP(S_nx) - w_expected = np.array([1., 0.]) + w_expected = np.array([1.0, 0.0]) S_expected = np.stack([U @ np.diag(w_expected) @ U.T] * 2, axis=0) assert np.allclose(nx.to_numpy(R), S_expected) diff --git a/test/test_weak.py b/test/test_weak.py index 945efb1d2..60041e5c0 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -1,4 +1,4 @@ -"""Tests for main module ot.weak """ +"""Tests for main module ot.weak""" # Author: Remi Flamary # diff --git a/test/unbalanced/test_lbfgs.py b/test/unbalanced/test_lbfgs.py index 4b33fc526..11435266a 100644 --- a/test/unbalanced/test_lbfgs.py +++ b/test/unbalanced/test_lbfgs.py @@ -6,16 +6,17 @@ # # License: MIT License - import itertools import numpy as np import ot import pytest -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) +@pytest.mark.parametrize( + "reg_div,regm_div,returnCost", + itertools.product(["kl", "l2", "entropy"], ["kl", "l2", "tv"], ["linear", "total"]), +) def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): - np.random.seed(42) xs = np.random.randn(5, 2) @@ -26,29 +27,49 @@ def test_lbfgsb_unbalanced(nx, reg_div, regm_div, returnCost): a = ot.unif(5) b = ot.unif(6) - G, log = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, 10, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, 1, 10, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) + G, log = ot.unbalanced.lbfgsb_unbalanced( + a, b, M, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False + ) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2( + a, + b, + M, + 1, + 10, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) ab, bb, Mb = nx.from_numpy(a, b, M) - Gb, log = ot.unbalanced.lbfgsb_unbalanced(ab, bb, Mb, 1, 10, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - loss0, log = ot.unbalanced.lbfgsb_unbalanced2(ab, bb, Mb, 1, 10, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) + Gb, log = ot.unbalanced.lbfgsb_unbalanced( + ab, bb, Mb, 1, 10, reg_div=reg_div, regm_div=regm_div, log=True, verbose=False + ) + loss0, log = ot.unbalanced.lbfgsb_unbalanced2( + ab, + bb, + Mb, + 1, + 10, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) np.testing.assert_allclose(G, nx.to_numpy(Gb)) np.testing.assert_allclose(loss, nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) +@pytest.mark.parametrize( + "reg_div,regm_div,returnCost", + itertools.product(["kl", "l2", "entropy"], ["kl", "l2", "tv"], ["linear", "total"]), +) def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCost): - np.random.seed(42) xs = np.random.randn(5, 2) @@ -68,34 +89,73 @@ def test_lbfgsb_unbalanced_relaxation_parameters(nx, reg_div, regm_div, returnCo np1_reg_m = reg_m * np.ones(1) np2_reg_m = reg_m * np.ones(2) - list_options = [np1_reg_m, np2_reg_m, full_tuple_reg_m, - tuple_reg_m, full_list_reg_m, list_reg_m] - - G = ot.unbalanced.lbfgsb_unbalanced(a, b, M, 1, reg_m=reg_m, - reg_div=reg_div, regm_div=regm_div, - log=False, verbose=False) - loss = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, 1, - reg_m=reg_m, reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=False, verbose=False) + list_options = [ + np1_reg_m, + np2_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + G = ot.unbalanced.lbfgsb_unbalanced( + a, + b, + M, + 1, + reg_m=reg_m, + reg_div=reg_div, + regm_div=regm_div, + log=False, + verbose=False, + ) + loss = ot.unbalanced.lbfgsb_unbalanced2( + a, + b, + M, + 1, + reg_m=reg_m, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=False, + verbose=False, + ) for opt in list_options: G0 = ot.unbalanced.lbfgsb_unbalanced( - a, b, M, 1, reg_m=opt, reg_div=reg_div, - regm_div=regm_div, log=False, verbose=False + a, + b, + M, + 1, + reg_m=opt, + reg_div=reg_div, + regm_div=regm_div, + log=False, + verbose=False, ) loss0 = ot.unbalanced.lbfgsb_unbalanced2( - a, b, M, 1, reg_m=opt, reg_div=reg_div, - regm_div=regm_div, returnCost=returnCost, - log=False, verbose=False + a, + b, + M, + 1, + reg_m=opt, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=False, + verbose=False, ) np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) +@pytest.mark.parametrize( + "reg_div,regm_div,returnCost", + itertools.product(["kl", "l2", "entropy"], ["kl", "l2", "tv"], ["linear", "total"]), +) def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): - np.random.seed(42) xs = np.random.randn(5, 2) @@ -107,28 +167,68 @@ def test_lbfgsb_reference_measure(nx, reg_div, regm_div, returnCost): a, b, M = nx.from_numpy(a, b, M) c = a[:, None] * b[None, :] - G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=None, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, c=None, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) - - G0, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, c=c, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - - loss0, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, c=c, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) + G, _ = ot.unbalanced.lbfgsb_unbalanced( + a, + b, + M, + reg=1, + reg_m=10, + c=None, + reg_div=reg_div, + regm_div=regm_div, + log=True, + verbose=False, + ) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2( + a, + b, + M, + reg=1, + reg_m=10, + c=None, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) + + G0, _ = ot.unbalanced.lbfgsb_unbalanced( + a, + b, + M, + reg=1, + reg_m=10, + c=c, + reg_div=reg_div, + regm_div=regm_div, + log=True, + verbose=False, + ) + + loss0, _ = ot.unbalanced.lbfgsb_unbalanced2( + a, + b, + M, + reg=1, + reg_m=10, + c=c, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) -@pytest.mark.parametrize("reg_div,regm_div,returnCost", itertools.product(['kl', 'l2', 'entropy'], ['kl', 'l2', 'tv'], ['linear', 'total'])) +@pytest.mark.parametrize( + "reg_div,regm_div,returnCost", + itertools.product(["kl", "l2", "entropy"], ["kl", "l2", "tv"], ["linear", "total"]), +) def test_lbfgsb_marginals(nx, reg_div, regm_div, returnCost): - np.random.seed(42) xs = np.random.randn(5, 2) @@ -139,30 +239,63 @@ def test_lbfgsb_marginals(nx, reg_div, regm_div, returnCost): a, b, M = nx.from_numpy(a, b, M) - G, _ = ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=1, reg_m=10, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - loss, _ = ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) + G, _ = ot.unbalanced.lbfgsb_unbalanced( + a, + b, + M, + reg=1, + reg_m=10, + reg_div=reg_div, + regm_div=regm_div, + log=True, + verbose=False, + ) + loss, _ = ot.unbalanced.lbfgsb_unbalanced2( + a, + b, + M, + reg=1, + reg_m=10, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) a_empty, b_empty = np.array([]), np.array([]) a_empty, b_empty = nx.from_numpy(a_empty, b_empty) - G0, _ = ot.unbalanced.lbfgsb_unbalanced(a_empty, b_empty, M, reg=1, reg_m=10, - reg_div=reg_div, regm_div=regm_div, - log=True, verbose=False) - - loss0, _ = ot.unbalanced.lbfgsb_unbalanced2(a_empty, b_empty, M, reg=1, reg_m=10, - reg_div=reg_div, regm_div=regm_div, - returnCost=returnCost, log=True, verbose=False) + G0, _ = ot.unbalanced.lbfgsb_unbalanced( + a_empty, + b_empty, + M, + reg=1, + reg_m=10, + reg_div=reg_div, + regm_div=regm_div, + log=True, + verbose=False, + ) + + loss0, _ = ot.unbalanced.lbfgsb_unbalanced2( + a_empty, + b_empty, + M, + reg=1, + reg_m=10, + reg_div=reg_div, + regm_div=regm_div, + returnCost=returnCost, + log=True, + verbose=False, + ) np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-06) np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-06) def test_lbfgsb_wrong_divergence(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -186,7 +319,6 @@ def lbfgsb2_div(div): def test_lbfgsb_wrong_marginal_divergence(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -210,7 +342,6 @@ def lbfgsb2_div(div): def test_lbfgsb_wrong_returnCost(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -224,7 +355,8 @@ def test_lbfgsb_wrong_returnCost(nx): a, b, M = nx.from_numpy(a_np, b_np, M) def lbfgsb2(returnCost): - return ot.unbalanced.lbfgsb_unbalanced2(a, b, M, reg=1, reg_m=10, - returnCost=returnCost, verbose=True) + return ot.unbalanced.lbfgsb_unbalanced2( + a, b, M, reg=1, reg_m=10, returnCost=returnCost, verbose=True + ) np.testing.assert_raises(ValueError, lbfgsb2, "invalid_returnCost") diff --git a/test/unbalanced/test_mm.py b/test/unbalanced/test_mm.py index ea9f00869..cd8b303aa 100644 --- a/test/unbalanced/test_mm.py +++ b/test/unbalanced/test_mm.py @@ -6,7 +6,6 @@ # # License: MIT License - import numpy as np import ot import pytest @@ -27,9 +26,12 @@ def test_mm_convergence(nx, div): reg_m = 100 a, b, M = nx.from_numpy(a_np, b_np, M) - G, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div=div, - verbose=False, log=True) - _, log = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div=div, verbose=True, log=True) + G, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, div=div, verbose=False, log=True + ) + _, log = ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m, div=div, verbose=True, log=True + ) linear_cost = nx.to_numpy(log["cost"]) # check if the marginals come close to the true ones when large reg @@ -77,23 +79,32 @@ def test_mm_relaxation_parameters(nx, div): nx1_reg_m = reg_m * nx.ones(1) nx2_reg_m = reg_m * nx.ones(2) - list_options = [nx1_reg_m, nx2_reg_m, full_tuple_reg_m, - tuple_reg_m, full_list_reg_m, list_reg_m] - - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=False, log=True) + list_options = [ + nx1_reg_m, + nx2_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + G0, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, reg=reg, div=div, verbose=False, log=True + ) loss_0 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=True) + ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=reg_m, reg=reg, div=div, verbose=True + ) ) for opt in list_options: - G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=opt, - reg=reg, div=div, - verbose=False, log=True) + G1, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=opt, reg=reg, div=div, verbose=False, log=True + ) loss_1 = nx.to_numpy( - ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=opt, - reg=reg, div=div, verbose=True) + ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=opt, reg=reg, div=div, verbose=True + ) ) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) @@ -118,17 +129,20 @@ def test_mm_reference_measure(nx, div): reg = 1e-2 reg_m = 100 - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=False, log=True) - loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=True) + G0, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, c=None, reg=reg, div=div, verbose=False, log=True + ) + loss_0 = ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=reg_m, c=None, reg=reg, div=div, verbose=True + ) loss_0 = nx.to_numpy(loss_0) - G1, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=c, - reg=reg, div=div, - verbose=False, log=True) - loss_1 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=c, - reg=reg, div=div, verbose=True) + G1, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, c=c, reg=reg, div=div, verbose=False, log=True + ) + loss_1 = ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=reg_m, c=c, reg=reg, div=div, verbose=True + ) loss_1 = nx.to_numpy(loss_1) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) @@ -152,20 +166,23 @@ def test_mm_marginals(nx, div): reg = 1e-2 reg_m = 100 - G0, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=False, log=True) - loss_0 = ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, c=None, reg=reg, - div=div, verbose=True) + G0, _ = ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, c=None, reg=reg, div=div, verbose=False, log=True + ) + loss_0 = ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=reg_m, c=None, reg=reg, div=div, verbose=True + ) loss_0 = nx.to_numpy(loss_0) a_empty, b_empty = np.array([]), np.array([]) a_empty, b_empty = nx.from_numpy(a_empty, b_empty) - G1, _ = ot.unbalanced.mm_unbalanced(a_empty, b_empty, M, reg_m=reg_m, - reg=reg, div=div, - verbose=False, log=True) - loss_1 = ot.unbalanced.mm_unbalanced2(a_empty, b_empty, M, reg_m=reg_m, - reg=reg, div=div, verbose=True) + G1, _ = ot.unbalanced.mm_unbalanced( + a_empty, b_empty, M, reg_m=reg_m, reg=reg, div=div, verbose=False, log=True + ) + loss_1 = ot.unbalanced.mm_unbalanced2( + a_empty, b_empty, M, reg_m=reg_m, reg=reg, div=div, verbose=True + ) loss_1 = nx.to_numpy(loss_1) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) @@ -173,7 +190,6 @@ def test_mm_marginals(nx, div): def test_mm_wrong_divergence(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -190,19 +206,20 @@ def test_mm_wrong_divergence(nx): reg_m = 100 def mm_div(div): - return ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=False, log=True) + return ot.unbalanced.mm_unbalanced( + a, b, M, reg_m=reg_m, reg=reg, div=div, verbose=False, log=True + ) def mm2_div(div): - return ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=reg_m, reg=reg, - div=div, verbose=True) + return ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=reg_m, reg=reg, div=div, verbose=True + ) np.testing.assert_raises(ValueError, mm_div, "div_not_existed") np.testing.assert_raises(ValueError, mm2_div, "div_not_existed") def test_mm_wrong_returnCost(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -216,7 +233,8 @@ def test_mm_wrong_returnCost(nx): a, b, M = nx.from_numpy(a_np, b_np, M) def mm2(returnCost): - return ot.unbalanced.mm_unbalanced2(a, b, M, reg_m=100, reg=1e-2, - returnCost=returnCost, verbose=True) + return ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m=100, reg=1e-2, returnCost=returnCost, verbose=True + ) np.testing.assert_raises(ValueError, mm2, "invalid_returnCost") diff --git a/test/unbalanced/test_sinkhorn.py b/test/unbalanced/test_sinkhorn.py index 8c84ebd65..be7694309 100644 --- a/test/unbalanced/test_sinkhorn.py +++ b/test/unbalanced/test_sinkhorn.py @@ -6,7 +6,6 @@ # # License: MIT License - import itertools import numpy as np import ot @@ -14,7 +13,18 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) +@pytest.mark.parametrize( + "method,reg_type", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + ["kl", "entropy"], + ), +) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -28,17 +38,32 @@ def test_unbalanced_convergence(nx, method, reg_type): M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 G, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + log=True, + verbose=True, + ) + loss = nx.to_numpy( + ot.unbalanced.sinkhorn_unbalanced2( + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + verbose=True, + ) ) - loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, verbose=True - )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) @@ -55,15 +80,28 @@ def test_unbalanced_convergence(nx, method, reg_type): u_final = fi * (loga - logKv) np.testing.assert_allclose( - nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05 + ) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) +@pytest.mark.parametrize( + "method,reg_type", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + ["kl", "entropy"], + ), +) def test_unbalanced_marginals(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -75,15 +113,20 @@ def test_unbalanced_marginals(nx, method, reg_type): M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 G0, log0 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, log=True ) loss0 = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, ) # check in case no histogram is provided or histogram is None @@ -91,23 +134,41 @@ def test_unbalanced_marginals(nx, method, reg_type): a_empty, b_empty = nx.from_numpy(a_empty, b_empty) G_empty, log_empty = ot.unbalanced.sinkhorn_unbalanced( - a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, log=True + a_empty, + b_empty, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + log=True, ) loss_empty = ot.unbalanced.sinkhorn_unbalanced2( - a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type + a_empty, b_empty, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type ) np.testing.assert_allclose( - nx.to_numpy(log_empty["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + nx.to_numpy(log_empty["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(log_empty["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + nx.to_numpy(log_empty["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05 + ) np.testing.assert_allclose(nx.to_numpy(G_empty), nx.to_numpy(G0), atol=1e-05) np.testing.assert_allclose(nx.to_numpy(loss_empty), nx.to_numpy(loss0), atol=1e-5) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) +@pytest.mark.parametrize( + "method,reg_type", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + ["kl", "entropy"], + ), +) def test_unbalanced_warmstart(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -119,48 +180,97 @@ def test_unbalanced_warmstart(nx, method, reg_type): M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 G0, log0 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=None, + log=True, + verbose=True, ) loss0 = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=None, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=None, + verbose=True, ) dim_a, dim_b = M.shape warmstart = (nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)) G, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=warmstart, + log=True, + verbose=True, ) loss = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=warmstart, + verbose=True, ) _, log_emd = ot.lp.emd(a, b, M, log=True) warmstart1 = (log_emd["u"], log_emd["v"]) G1, log1 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart1, log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=warmstart1, + log=True, + verbose=True, ) loss1 = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, warmstart=warmstart1, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + warmstart=warmstart1, + verbose=True, ) np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05) + nx.to_numpy(log0["logu"]), nx.to_numpy(log1["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05) + nx.to_numpy(log0["logv"]), nx.to_numpy(log1["logv"]), atol=1e-05 + ) np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) np.testing.assert_allclose(nx.to_numpy(G0), nx.to_numpy(G1), atol=1e-05) @@ -169,7 +279,18 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) +@pytest.mark.parametrize( + "method,reg_type", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + ["kl", "entropy"], + ), +) def test_unbalanced_reference_measure(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -181,12 +302,19 @@ def test_unbalanced_reference_measure(nx, method, reg_type): M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 G0, log0 = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=None, log=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + c=None, + log=True, ) loss0 = ot.unbalanced.sinkhorn_unbalanced2( a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, c=None @@ -198,23 +326,42 @@ def test_unbalanced_reference_measure(nx, method, reg_type): c = nx.ones(M.shape, type_as=M) G, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=c, log=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + reg_type=reg_type, + c=c, + log=True, ) loss = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - reg_type=reg_type, c=c + a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type=reg_type, c=c ) np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05) + nx.to_numpy(log["logu"]), nx.to_numpy(log0["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05) + nx.to_numpy(log["logv"]), nx.to_numpy(log0["logv"]), atol=1e-05 + ) np.testing.assert_allclose(nx.to_numpy(G), nx.to_numpy(G0), atol=1e-05) np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) -@pytest.mark.parametrize("method, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], [True, False])) +@pytest.mark.parametrize( + "method, log", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + [True, False], + ), +) def test_sinkhorn_unbalanced2(nx, method, log): n = 100 rng = np.random.RandomState(42) @@ -227,24 +374,43 @@ def test_sinkhorn_unbalanced2(nx, method, log): M = ot.dist(x, x) a, b, M = nx.from_numpy(a, b, M) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 - loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, log=False, verbose=True - )) + loss = nx.to_numpy( + ot.unbalanced.sinkhorn_unbalanced2( + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method=method, + c=None, + log=False, + verbose=True, + ) + ) res = ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, reg=epsilon, reg_m=reg_m, method=method, - c=None, log=log, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, c=None, log=log, verbose=True ) loss0 = res[0] if log else res np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss0), atol=1e-5) -@pytest.mark.parametrize("method,reg_m", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], [1, float("inf")])) +@pytest.mark.parametrize( + "method,reg_m", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + [1, float("inf")], + ), +) def test_unbalanced_relaxation_parameters(nx, method, reg_m): # test generalized sinkhorn for unbalanced OT n = 100 @@ -257,7 +423,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): b = rng.rand(n, 2) M = ot.dist(x, x) - epsilon = 1. + epsilon = 1.0 a, b, M = nx.from_numpy(a, b, M) @@ -266,29 +432,45 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): full_tuple_reg_m = (reg_m, reg_m) tuple_reg_m, list_reg_m = (reg_m), [reg_m] nx_reg_m = reg_m * nx.ones(1) - list_options = [nx_reg_m, full_tuple_reg_m, - tuple_reg_m, full_list_reg_m, list_reg_m] + list_options = [ + nx_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] loss, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True + a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True ) for opt in list_options: loss_opt, log_opt = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=opt, - method=method, log=True, verbose=True + a, b, M, reg=epsilon, reg_m=opt, method=method, log=True, verbose=True ) np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05) - np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05) + nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) - - -@pytest.mark.parametrize("method, reg_m1, reg_m2", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], [1, float("inf")], [1, float("inf")])) + nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05 + ) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + +@pytest.mark.parametrize( + "method, reg_m1, reg_m2", + itertools.product( + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], + [1, float("inf")], + [1, float("inf")], + ), +) def test_unbalanced_relaxation_parameters_pair(nx, method, reg_m1, reg_m2): # test generalized sinkhorn for unbalanced OT n = 100 @@ -301,7 +483,7 @@ def test_unbalanced_relaxation_parameters_pair(nx, method, reg_m1, reg_m2): b = rng.rand(n, 2) M = ot.dist(x, x) - epsilon = 1. + epsilon = 1.0 a, b, M = nx.from_numpy(a, b, M) @@ -311,25 +493,39 @@ def test_unbalanced_relaxation_parameters_pair(nx, method, reg_m1, reg_m2): list_options = [full_tuple_reg_m, full_list_reg_m] loss, log = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=(reg_m1, reg_m2), - method=method, log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=(reg_m1, reg_m2), + method=method, + log=True, + verbose=True, ) for opt in list_options: loss_opt, log_opt = ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=opt, - method=method, log=True, verbose=True + a, b, M, reg=epsilon, reg_m=opt, method=method, log=True, verbose=True ) np.testing.assert_allclose( - nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05) - np.testing.assert_allclose( - nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05) + nx.to_numpy(log["logu"]), nx.to_numpy(log_opt["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) - - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"]) + nx.to_numpy(log["logv"]), nx.to_numpy(log_opt["logv"]), atol=1e-05 + ) + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + +@pytest.mark.parametrize( + "method", + [ + "sinkhorn", + "sinkhorn_stabilized", + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ], +) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -342,23 +538,21 @@ def test_unbalanced_multiple_inputs(nx, method): b = rng.rand(n, 2) M = ot.dist(x, x) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 a, b, M = nx.from_numpy(a, b, M) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, method=method, - log=True, verbose=True) + G, log = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) logb = nx.log(b + 1e-16) loga = nx.log(a + 1e-16)[:, None] - logKtu = nx.logsumexp( - log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 - ) + logKtu = nx.logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0) logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) @@ -369,24 +563,25 @@ def test_unbalanced_multiple_inputs(nx, method): print("logv shape = {}".format(log["logv"].shape)) np.testing.assert_allclose( - nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05 + ) # reg_type="entropy" as multiple inputs does not work for KL yet - losses = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - reg_m=reg_m, method=method, - reg_type="entropy") + losses = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, reg_type="entropy" + ) - loss1 = ot.unbalanced.sinkhorn_unbalanced2(a, b[:, 0], M, reg=epsilon, - reg_m=reg_m, method=method, - reg_type="entropy") - loss2 = ot.unbalanced.sinkhorn_unbalanced2(a, b[:, 1], M, reg=epsilon, - reg_m=reg_m, method=method, - reg_type="entropy") + loss1 = ot.unbalanced.sinkhorn_unbalanced2( + a, b[:, 0], M, reg=epsilon, reg_m=reg_m, method=method, reg_type="entropy" + ) + loss2 = ot.unbalanced.sinkhorn_unbalanced2( + a, b[:, 1], M, reg=epsilon, reg_m=reg_m, method=method, reg_type="entropy" + ) - np.testing.assert_allclose( - nx.to_numpy(losses), np.array([loss1, loss2]), atol=1e-4) + np.testing.assert_allclose(nx.to_numpy(losses), np.array([loss1, loss2]), atol=1e-4) def test_stabilized_vs_sinkhorn(nx): @@ -404,13 +599,20 @@ def test_stabilized_vs_sinkhorn(nx): M = ot.utils.dist0(n) M /= np.median(M) epsilon = 1 - reg_m = 1. + reg_m = 1.0 stopThr = 1e-12 ab, bb, Mb = nx.from_numpy(a, b, M) G, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True, stopThr=stopThr, + ab, + bb, + Mb, + epsilon, + reg_m, + method="sinkhorn_stabilized", + log=True, + stopThr=stopThr, ) G2, _ = ot.unbalanced.sinkhorn_unbalanced2( ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr @@ -419,7 +621,14 @@ def test_stabilized_vs_sinkhorn(nx): a, b, M, epsilon, reg_m, method="sinkhorn", log=True, stopThr=stopThr ) G3, _ = ot.unbalanced.sinkhorn_unbalanced2( - ab, bb, Mb, epsilon, reg_m, method="sinkhorn_translation_invariant", log=True, stopThr=stopThr + ab, + bb, + Mb, + epsilon, + reg_m, + method="sinkhorn_translation_invariant", + log=True, + stopThr=stopThr, ) G = nx.to_numpy(G) @@ -432,7 +641,6 @@ def test_stabilized_vs_sinkhorn(nx): def test_sinkhorn_wrong_returnCost(nx): - n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) @@ -445,15 +653,19 @@ def test_sinkhorn_wrong_returnCost(nx): M = M / M.max() a, b, M = nx.from_numpy(a_np, b_np, M) epsilon = 1 - reg_m = 1. + reg_m = 1.0 def sinkhorn2(returnCost): - return ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, returnCost=returnCost) + return ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, returnCost=returnCost + ) np.testing.assert_raises(ValueError, sinkhorn2, "invalid_returnCost") -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) +@pytest.mark.parametrize( + "method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"] +) def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -465,8 +677,8 @@ def test_unbalanced_barycenter(nx, method): # make dists unbalanced A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 A, M = nx.from_numpy(A, M) @@ -477,17 +689,17 @@ def test_unbalanced_barycenter(nx, method): fi = reg_m / (reg_m + epsilon) logA = nx.log(A + 1e-16) logq = nx.log(q + 1e-16)[:, None] - logKtu = nx.logsumexp( - log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 - ) + logKtu = nx.logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0) logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05 + ) np.testing.assert_allclose( - nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05 + ) def test_barycenter_stabilized_vs_sinkhorn(nx): @@ -507,8 +719,14 @@ def test_barycenter_stabilized_vs_sinkhorn(nx): Ab, Mb = nx.from_numpy(A, M) qstable, _ = barycenter_unbalanced( - Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, - method="sinkhorn_stabilized", verbose=True + Ab, + Mb, + reg=epsilon, + reg_m=reg_m, + log=True, + tau=100, + method="sinkhorn_stabilized", + verbose=True, ) q, _ = barycenter_unbalanced( Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True @@ -522,7 +740,6 @@ def test_barycenter_stabilized_vs_sinkhorn(nx): def test_wrong_method(nx): - n = 10 rng = np.random.RandomState(42) @@ -533,26 +750,35 @@ def test_wrong_method(nx): b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 a, b, M = nx.from_numpy(a, b, M) with pytest.raises(ValueError): ot.unbalanced.sinkhorn_unbalanced( - a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', - log=True, verbose=True + a, + b, + M, + reg=epsilon, + reg_m=reg_m, + method="badmethod", + log=True, + verbose=True, ) with pytest.raises(ValueError): ot.unbalanced.sinkhorn_unbalanced2( - a, b, M, epsilon, reg_m, method='badmethod', verbose=True + a, b, M, epsilon, reg_m, method="badmethod", verbose=True ) def test_implemented_methods(nx): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling', 'sinkhorn_translation_invariant'] - NOT_VALID_TOKENS = ['foo'] + IMPLEMENTED_METHODS = ["sinkhorn", "sinkhorn_stabilized"] + TO_BE_IMPLEMENTED_METHODS = [ + "sinkhorn_reg_scaling", + "sinkhorn_translation_invariant", + ] + NOT_VALID_TOKENS = ["foo"] # test generalized sinkhorn for unbalanced OT barycenter n = 3 rng = np.random.RandomState(42) @@ -564,31 +790,22 @@ def test_implemented_methods(nx): b = ot.utils.unif(n) * 1.5 A = rng.rand(n, 2) M = ot.dist(x, x) - epsilon = 1. - reg_m = 1. + epsilon = 1.0 + reg_m = 1.0 a, b, M, A = nx.from_numpy(a, b, M, A) for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, - method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method) - barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method) - with pytest.warns(UserWarning, match='not implemented'): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + with pytest.warns(UserWarning, match="not implemented"): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, - method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method) - barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method) + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, - method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method) - barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method) + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method)