Skip to content

Commit

Permalink
Fix small TODOs (#311)
Browse files Browse the repository at this point in the history
* Fix small TODOs

* Update ``Geometry`` docs

* Update MANIFEST.in

* Fix linter

* Fix minor issues

* Fix not passing bias in `LRCGeometry.__add__`

* Update `LRCGeometry` add test

* Fix `LRCGeometry` subsetting bug

* Fix subsetting tests

* [ci skip] Fix typo
  • Loading branch information
michalk8 authored Feb 21, 2023
1 parent a4c76aa commit 5155dc8
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 50 deletions.
2 changes: 0 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
prune docs
prune examples
prune images
prune tests
prune .github
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ The call to `solver(prob)` above works out the optimal transport solution. The `
more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost
functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/).

![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png)
![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/couplings.png)

## Citation
If you have found this work useful, please consider citing this reference:
Expand Down
File renamed without changes
9 changes: 5 additions & 4 deletions src/ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class Epsilon:
Args:
target: the epsilon regularizer that is targeted.
If ``None``, use :math:`0.05`.
scale_epsilon: if passed, used to multiply the regularizer, to rescale it.
If ``None``, use :math:`1`.
init: initial value when using epsilon scheduling, understood as multiple
of target value. if passed, ``int * decay ** iteration`` will be used
to rescale target.
Expand All @@ -58,9 +60,8 @@ def __init__(
def target(self) -> float:
"""Return the final regularizer value of scheduler."""
target = 5e-2 if self._target_init is None else self._target_init
if self._scale_epsilon is None:
return target
return target * self._scale_epsilon
scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon
return scale * target

def at(self, iteration: Optional[int] = 1) -> float:
"""Return (intermediate) regularizer value at a given iteration."""
Expand All @@ -81,7 +82,7 @@ def done_at(self, iteration: Optional[int]) -> bool:
return self.done(self.at(iteration))

def set(self, **kwargs: Any) -> "Epsilon":
"""TODO."""
"""Return a copy of self, with potential overwrites."""
kwargs = {
"target": self._target_init,
"scale_epsilon": self._scale_epsilon,
Expand Down
50 changes: 22 additions & 28 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,15 @@ class Geometry:
basic operations to be run with the Sinkhorn algorithm.
Args:
cost_matrix: jnp.ndarray<float>[num_a, num_b]: a cost matrix storing n x m
costs.
kernel_matrix: jnp.ndarray<float>[num_a, num_b]: a kernel matrix storing n
x m kernel values.
epsilon: a regularization parameter. TODO(michalk8): update the docstring
If a :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler is passed,
other parameters below are ignored in practice. If the
parameter is a float, then this is understood to be the regularization
that is needed, unless ``relative_epsilon`` below is ``True``, in which
case ``epsilon`` is understood as a normalized quantity, to be scaled by
the :attr:`mean_cost_matrix`.
cost_matrix: Cost matrix of shape ``[n, m]``.
kernel_matrix: Kernel matrix of shape ``[n, m]``.
epsilon: Regularization parameter. If ``scale_epsilon = None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None`` and
``epsilon = None`` in :class:`~ott.geometry.epsilon_scheduler.Epsilon`
is used, ``scale_epsilon`` the is :attr:`mean_cost_matrix`. If
``epsilon = None``, use :math:`0.05`.
relative_epsilon: whether epsilon is passed relative to scale of problem,
here understood the value of :attr:`mean_cost_matrix`.
here understood the value of the :attr:`mean_cost_matrix`.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
Expand All @@ -68,10 +64,11 @@ class Geometry:
:attr:`cost_matrix`, see :attr:`tgt_mask`.
Note:
When defining a ``Geometry`` through a ``cost_matrix``, it is important to
select an ``epsilon`` regularization parameter that is meaningful. That
parameter can be provided by the user, or assigned a default value through
a simple rule, using the :attr:`mean_cost_matrix`.
When defining a :class:`~ott.geometry.geometry.Geometry` through a
``cost_matrix``, it is important to select an ``epsilon`` regularization
parameter that is meaningful. That parameter can be provided by the user,
or assigned a default value through a simple rule,
using the :attr:`mean_cost_matrix`.
"""

def __init__(
Expand Down Expand Up @@ -674,19 +671,14 @@ def to_LRCGeometry(
i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)

# force `batch_size=None` since `cost_matrix` would be `None`
ci_star = self.subset(
i_star, None, batch_size=None
).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(
None, j_star, batch_size=None
).cost_matrix.ravel() ** 2 # (n,)
ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None, batch_size=None).cost_matrix
s = self.subset(row_ixs, None).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(s ** 2, axis=0) # (m,)
Expand All @@ -707,9 +699,7 @@ def to_LRCGeometry(
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(
None, col_ixs, batch_size=None
).cost_matrix * inv_scale
A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)
Expand Down Expand Up @@ -754,7 +744,11 @@ def subset_fn(
return arr

return self._mask_subset_helper(
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True
src_ixs,
tgt_ixs,
fn=subset_fn,
propagate_mask=True,
**kwargs,
)

def mask(
Expand Down
1 change: 1 addition & 0 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def to_LRCGeometry(
cost_2=cost_2,
scale_factor=scale,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
Expand Down
14 changes: 10 additions & 4 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,15 @@ def _mask_subset_helper(
)

def __add__(self, other: 'LRCGeometry') -> 'LRCGeometry':
assert isinstance(other, LRCGeometry), type(other)
return type(self)(
if not isinstance(other, LRCGeometry):
return NotImplemented
return LRCGeometry(
cost_1=jnp.concatenate((self.cost_1, other.cost_1), axis=1),
cost_2=jnp.concatenate((self.cost_2, other.cost_2), axis=1),
bias=self._bias + other._bias,
# already included in `cost_{1,2}`
scale_factor=1.0,
scale_cost=1.0,
)

@property
Expand All @@ -317,22 +322,23 @@ def tree_flatten(self): # noqa: D102
self._cost_2,
self._src_mask,
self._tgt_mask,
self._epsilon_init,
self._bias,
self._scale_factor,
# TODO(michalk8): eps
), {
'scale_cost': self._scale_cost,
'batch_size': self.batch_size
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
c1, c2, src_mask, tgt_mask, bias, scale_factor = children
c1, c2, src_mask, tgt_mask, epsilon, bias, scale_factor = children
return cls(
c1,
c2,
bias=bias,
scale_factor=scale_factor,
epsilon=epsilon,
src_mask=src_mask,
tgt_mask=tgt_mask,
**aux_data
Expand Down
6 changes: 3 additions & 3 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,17 +563,17 @@ def prepare_divergences(
)

def tree_flatten(self): # noqa: D102
return ([
return (
self.x,
self.y,
self._src_mask,
self._tgt_mask,
self._epsilon_init,
self.cost_fn,
], {
), {
'batch_size': self._batch_size,
'scale_cost': self._scale_cost
})
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
Expand Down
22 changes: 15 additions & 7 deletions tests/geometry/low_rank_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Low-Rank Geometry."""
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Tuple

import pytest

Expand Down Expand Up @@ -93,7 +93,12 @@ def test_apply_squared(self, rng: jnp.ndarray):
geom2.apply_cost(mat, axis=axis), out_lr, rtol=5e-4
)

def test_add_lr_geoms(self, rng: jnp.ndarray):
@pytest.mark.parametrize("bias", [(0, 0), (4, 5)])
@pytest.mark.parametrize("scale_factor", [(1, 1), (2, 3)])
def test_add_lr_geoms(
self, rng: jnp.ndarray, bias: Tuple[float, float],
scale_factor: Tuple[float, float]
):
"""Test application of cost to vec or matrix."""
n, m, r, q = 17, 11, 7, 2
keys = jax.random.split(rng, 5)
Expand All @@ -102,12 +107,15 @@ def test_add_lr_geoms(self, rng: jnp.ndarray):
d1 = jax.random.normal(keys[0], (n, q))
d2 = jax.random.normal(keys[1], (m, q))

c = jnp.matmul(c1, c2.T)
d = jnp.matmul(d1, d2.T)
geom = geometry.Geometry(c + d)
s1, s2 = scale_factor
b1, b2 = bias

c = jnp.matmul(c1, c2.T) * s1
d = jnp.matmul(d1, d2.T) * s2
geom = geometry.Geometry(c + d + b1 + b2)

geom_lr_c = low_rank.LRCGeometry(c1, c2)
geom_lr_d = low_rank.LRCGeometry(d1, d2)
geom_lr_c = low_rank.LRCGeometry(c1, c2, scale_factor=s1, bias=b1)
geom_lr_d = low_rank.LRCGeometry(d1, d2, scale_factor=s2, bias=b2)
geom_lr = geom_lr_c + geom_lr_d

for dim, axis in ((m, 1), (n, 0)):
Expand Down
5 changes: 4 additions & 1 deletion tests/geometry/subsetting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def test_mask(
tgt_ixs, int
) else len(tgt_ixs)

geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size)
if clazz is geometry.Geometry:
geom_sub = geom.subset(src_ixs, tgt_ixs)
else:
geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size)

assert type(geom_sub) == type(geom)
np.testing.assert_array_equal(geom_sub.shape, (n, m))
Expand Down

0 comments on commit 5155dc8

Please sign in to comment.