diff --git a/MANIFEST.in b/MANIFEST.in index 652d6fd3b..1964149df 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,3 @@ prune docs -prune examples -prune images prune tests prune .github diff --git a/README.md b/README.md index f6e20495a..b73d011ef 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ The call to `solver(prob)` above works out the optimal transport solution. The ` more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost functions, objectives, and solvers, as detailed in the [full documentation](https://ott-jax.readthedocs.io/en/latest/). -![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/images/couplings.png) +![obtained coupling](https://raw.githubusercontent.com/ott-jax/ott/main/docs/_static/images/couplings.png) ## Citation If you have found this work useful, please consider citing this reference: diff --git a/images/couplings.png b/docs/_static/images/couplings.png similarity index 100% rename from images/couplings.png rename to docs/_static/images/couplings.png diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 3a71936cc..3dabed2ff 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -35,7 +35,9 @@ class Epsilon: Args: target: the epsilon regularizer that is targeted. + If ``None``, use :math:`0.05`. scale_epsilon: if passed, used to multiply the regularizer, to rescale it. + If ``None``, use :math:`1`. init: initial value when using epsilon scheduling, understood as multiple of target value. if passed, ``int * decay ** iteration`` will be used to rescale target. @@ -58,9 +60,8 @@ def __init__( def target(self) -> float: """Return the final regularizer value of scheduler.""" target = 5e-2 if self._target_init is None else self._target_init - if self._scale_epsilon is None: - return target - return target * self._scale_epsilon + scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon + return scale * target def at(self, iteration: Optional[int] = 1) -> float: """Return (intermediate) regularizer value at a given iteration.""" @@ -81,7 +82,7 @@ def done_at(self, iteration: Optional[int]) -> bool: return self.done(self.at(iteration)) def set(self, **kwargs: Any) -> "Epsilon": - """TODO.""" + """Return a copy of self, with potential overwrites.""" kwargs = { "target": self._target_init, "scale_epsilon": self._scale_epsilon, diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index f757d8ef5..04452854e 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -45,19 +45,15 @@ class Geometry: basic operations to be run with the Sinkhorn algorithm. Args: - cost_matrix: jnp.ndarray[num_a, num_b]: a cost matrix storing n x m - costs. - kernel_matrix: jnp.ndarray[num_a, num_b]: a kernel matrix storing n - x m kernel values. - epsilon: a regularization parameter. TODO(michalk8): update the docstring - If a :class:`~ott.geometry.epsilon_scheduler.Epsilon` scheduler is passed, - other parameters below are ignored in practice. If the - parameter is a float, then this is understood to be the regularization - that is needed, unless ``relative_epsilon`` below is ``True``, in which - case ``epsilon`` is understood as a normalized quantity, to be scaled by - the :attr:`mean_cost_matrix`. + cost_matrix: Cost matrix of shape ``[n, m]``. + kernel_matrix: Kernel matrix of shape ``[n, m]``. + epsilon: Regularization parameter. If ``scale_epsilon = None`` and either + ``relative_epsilon = True`` or ``relative_epsilon = None`` and + ``epsilon = None`` in :class:`~ott.geometry.epsilon_scheduler.Epsilon` + is used, ``scale_epsilon`` the is :attr:`mean_cost_matrix`. If + ``epsilon = None``, use :math:`0.05`. relative_epsilon: whether epsilon is passed relative to scale of problem, - here understood the value of :attr:`mean_cost_matrix`. + here understood the value of the :attr:`mean_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. @@ -68,10 +64,11 @@ class Geometry: :attr:`cost_matrix`, see :attr:`tgt_mask`. Note: - When defining a ``Geometry`` through a ``cost_matrix``, it is important to - select an ``epsilon`` regularization parameter that is meaningful. That - parameter can be provided by the user, or assigned a default value through - a simple rule, using the :attr:`mean_cost_matrix`. + When defining a :class:`~ott.geometry.geometry.Geometry` through a + ``cost_matrix``, it is important to select an ``epsilon`` regularization + parameter that is meaningful. That parameter can be provided by the user, + or assigned a default value through a simple rule, + using the :attr:`mean_cost_matrix`. """ def __init__( @@ -674,19 +671,14 @@ def to_LRCGeometry( i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n) j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m) - # force `batch_size=None` since `cost_matrix` would be `None` - ci_star = self.subset( - i_star, None, batch_size=None - ).cost_matrix.ravel() ** 2 # (m,) - cj_star = self.subset( - None, j_star, batch_size=None - ).cost_matrix.ravel() ** 2 # (n,) + ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,) + cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,) p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,) p_row /= jnp.sum(p_row) row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row) # (n_subset, m) - s = self.subset(row_ixs, None, batch_size=None).cost_matrix + s = self.subset(row_ixs, None).cost_matrix s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None]) p_col = jnp.sum(s ** 2, axis=0) # (m,) @@ -707,9 +699,7 @@ def to_LRCGeometry( col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,) # (n, n_subset) - A_trans = self.subset( - None, col_ixs, batch_size=None - ).cost_matrix * inv_scale + A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k) M = jnp.linalg.inv(B.T @ B) # (k, k) V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k) @@ -754,7 +744,11 @@ def subset_fn( return arr return self._mask_subset_helper( - src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True + src_ixs, + tgt_ixs, + fn=subset_fn, + propagate_mask=True, + **kwargs, ) def mask( diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 8608ea325..f9449d731 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -409,6 +409,7 @@ def to_LRCGeometry( cost_2=cost_2, scale_factor=scale, epsilon=self._epsilon_init, + relative_epsilon=self._relative_epsilon, scale_cost=self._scale_cost, src_mask=self.src_mask, tgt_mask=self.tgt_mask, diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index a6392f047..17580b43a 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -301,10 +301,15 @@ def _mask_subset_helper( ) def __add__(self, other: 'LRCGeometry') -> 'LRCGeometry': - assert isinstance(other, LRCGeometry), type(other) - return type(self)( + if not isinstance(other, LRCGeometry): + return NotImplemented + return LRCGeometry( cost_1=jnp.concatenate((self.cost_1, other.cost_1), axis=1), cost_2=jnp.concatenate((self.cost_2, other.cost_2), axis=1), + bias=self._bias + other._bias, + # already included in `cost_{1,2}` + scale_factor=1.0, + scale_cost=1.0, ) @property @@ -317,9 +322,9 @@ def tree_flatten(self): # noqa: D102 self._cost_2, self._src_mask, self._tgt_mask, + self._epsilon_init, self._bias, self._scale_factor, - # TODO(michalk8): eps ), { 'scale_cost': self._scale_cost, 'batch_size': self.batch_size @@ -327,12 +332,13 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - c1, c2, src_mask, tgt_mask, bias, scale_factor = children + c1, c2, src_mask, tgt_mask, epsilon, bias, scale_factor = children return cls( c1, c2, bias=bias, scale_factor=scale_factor, + epsilon=epsilon, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index dbb0ae3ca..22dad3830 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -563,17 +563,17 @@ def prepare_divergences( ) def tree_flatten(self): # noqa: D102 - return ([ + return ( self.x, self.y, self._src_mask, self._tgt_mask, self._epsilon_init, self.cost_fn, - ], { + ), { 'batch_size': self._batch_size, 'scale_cost': self._scale_cost - }) + } @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index ed8ad4b84..b53abf9a0 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test Low-Rank Geometry.""" -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Tuple import pytest @@ -93,7 +93,12 @@ def test_apply_squared(self, rng: jnp.ndarray): geom2.apply_cost(mat, axis=axis), out_lr, rtol=5e-4 ) - def test_add_lr_geoms(self, rng: jnp.ndarray): + @pytest.mark.parametrize("bias", [(0, 0), (4, 5)]) + @pytest.mark.parametrize("scale_factor", [(1, 1), (2, 3)]) + def test_add_lr_geoms( + self, rng: jnp.ndarray, bias: Tuple[float, float], + scale_factor: Tuple[float, float] + ): """Test application of cost to vec or matrix.""" n, m, r, q = 17, 11, 7, 2 keys = jax.random.split(rng, 5) @@ -102,12 +107,15 @@ def test_add_lr_geoms(self, rng: jnp.ndarray): d1 = jax.random.normal(keys[0], (n, q)) d2 = jax.random.normal(keys[1], (m, q)) - c = jnp.matmul(c1, c2.T) - d = jnp.matmul(d1, d2.T) - geom = geometry.Geometry(c + d) + s1, s2 = scale_factor + b1, b2 = bias + + c = jnp.matmul(c1, c2.T) * s1 + d = jnp.matmul(d1, d2.T) * s2 + geom = geometry.Geometry(c + d + b1 + b2) - geom_lr_c = low_rank.LRCGeometry(c1, c2) - geom_lr_d = low_rank.LRCGeometry(d1, d2) + geom_lr_c = low_rank.LRCGeometry(c1, c2, scale_factor=s1, bias=b1) + geom_lr_d = low_rank.LRCGeometry(d1, d2, scale_factor=s2, bias=b2) geom_lr = geom_lr_c + geom_lr_d for dim, axis in ((m, 1), (n, 0)): diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 8c4023e1a..bc32f643e 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -74,7 +74,10 @@ def test_mask( tgt_ixs, int ) else len(tgt_ixs) - geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size) + if clazz is geometry.Geometry: + geom_sub = geom.subset(src_ixs, tgt_ixs) + else: + geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size) assert type(geom_sub) == type(geom) np.testing.assert_array_equal(geom_sub.shape, (n, m))