Skip to content

Commit

Permalink
[MRG] fix doc+example lowrank sinkhorn (PythonOT#601)
Browse files Browse the repository at this point in the history
* fix doc+example lowrank sinkhorn

* fix autosummary for lowrank doc

* update release

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
cedricvincentcuaz and rflamary authored Jan 18, 2024
1 parent 64c8374 commit c84ef33
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 22 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
- Fix doc and example for lowrank sinkhorn (PR #601)

## 0.9.2
*December 2023*
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ API and modules
gaussian
gnn
gromov
lowrank
lp
mapping
optim
Expand Down
19 changes: 7 additions & 12 deletions examples/others/plot_lowrank_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,35 @@
#%%

# Plot sinkhorn vs low rank sinkhorn
pl.figure(1, figsize=(10, 4))
pl.figure(1, figsize=(10, 8))

pl.subplot(1, 3, 1)
pl.subplot(2, 3, 1)
pl.imshow(list_P_Sin[0], interpolation='nearest')
pl.axis('off')
pl.title('Sinkhorn (reg=0.05)')

pl.subplot(1, 3, 2)
pl.subplot(2, 3, 2)
pl.imshow(list_P_Sin[1], interpolation='nearest')
pl.axis('off')
pl.title('Sinkhorn (reg=0.005)')

pl.subplot(1, 3, 3)
pl.subplot(2, 3, 3)
pl.imshow(list_P_Sin[2], interpolation='nearest')
pl.axis('off')
pl.title('Sinkhorn (reg=0.001)')
pl.show()


#%%

pl.figure(2, figsize=(10, 4))

pl.subplot(1, 3, 1)
pl.subplot(2, 3, 4)
pl.imshow(list_P_LR[0], interpolation='nearest')
pl.axis('off')
pl.title('Low rank (rank=3)')

pl.subplot(1, 3, 2)
pl.subplot(2, 3, 5)
pl.imshow(list_P_LR[1], interpolation='nearest')
pl.axis('off')
pl.title('Low rank (rank=10)')

pl.subplot(1, 3, 3)
pl.subplot(2, 3, 6)
pl.imshow(list_P_LR[2], interpolation='nearest')
pl.axis('off')
pl.title('Low rank (rank=50)')
Expand Down
5 changes: 3 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
, :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`.
, :py:mod:`ot.unbalanced`, :py:mod:`ot.mapping` .
The following sub-modules are not imported due to additional dependencies:
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- :any:`ot.plot` : depends on :code:`matplotlib`
Expand Down Expand Up @@ -71,4 +71,5 @@
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn']
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
'lowrank_sinkhorn']
17 changes: 9 additions & 8 deletions ot/lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,18 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, re
The function solves the following optimization problem:
.. math::
\mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle -
\mathrm{reg} \cdot H((Q,R,g))
\mathop{\inf_{(\mathbf{Q},\mathbf{R},\mathbf{g}) \in \mathcal{C}(\mathbf{a},\mathbf{b},r)}} \langle \mathbf{C}, \mathbf{Q}\mathrm{diag}(1/\mathbf{g})\mathbf{R}^\top \rangle -
\mathrm{reg} \cdot H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))
where :
- :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
- :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan
- :math: `g` is the weight vector for the low-rank decomposition of the OT plan
- :math:`\mathbf{C}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))` is the values of the three respective entropies evaluated for each term.
- :math:`\mathbf{Q}` and :math:`\mathbf{R}` are the low-rank matrix decomposition of the OT plan
- :math:`\mathbf{g}` is the weight vector for the low-rank decomposition of the OT plan
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
- :math: `r` is the rank of the OT plan
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem
- :math:`r` is the rank of the OT plan
- :math:`\mathcal{C}(\mathbf{a}, \mathbf{b}, r)` are the low-rank couplings of the OT problem
Parameters
Expand Down

0 comments on commit c84ef33

Please sign in to comment.