diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 93424e7cb..14e3d1fb8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,9 +77,9 @@ jobs: os: [ubuntu-latest] include: - python-version: '3.9' - os: macos-13 - - python-version: '3.10' os: macos-14 + - python-version: '3.10' + os: macos-15 steps: - uses: actions/checkout@v4 diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 96b62f798..d7e9c3dcc 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -21,7 +21,6 @@ import jax.numpy as jnp import jax.scipy as jsp import jax.tree_util as jtu -import numpy as np from ott import utils from ott.geometry import epsilon_scheduler as eps_scheduler @@ -65,10 +64,6 @@ class Geometry: scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. - src_mask: Mask specifying valid rows when computing some statistics of - :attr:`cost_matrix`, see :attr:`src_mask`. - tgt_mask: Mask specifying valid columns when computing some statistics of - :attr:`cost_matrix`, see :attr:`tgt_mask`. Note: When defining a :class:`~ott.geometry.geometry.Geometry` through a @@ -86,8 +81,6 @@ def __init__( relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median", "std"]] = 1.0, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None, ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix @@ -95,9 +88,6 @@ def __init__( self._relative_epsilon = relative_epsilon self._scale_cost = scale_cost - self._src_mask = src_mask - self._tgt_mask = tgt_mask - @property def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" @@ -117,14 +107,14 @@ def cost_matrix(self) -> jnp.ndarray: @property def median_cost_matrix(self) -> float: """Median of the :attr:`cost_matrix`.""" - geom = self._masked_geom(mask_value=jnp.nan) - return jnp.nanmedian(geom.cost_matrix) # will fail for online PC + return jnp.median(self.cost_matrix) @property def mean_cost_matrix(self) -> float: """Mean of the :attr:`cost_matrix`.""" - tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze() - return jnp.sum(tmp * self._m_normed_ones) + n, m = self.shape + tmp = self.apply_cost(jnp.full((n,), fill_value=1.0 / n)) + return jnp.sum((1.0 / m) * tmp) @property def std_cost_matrix(self) -> float: @@ -139,8 +129,9 @@ def std_cost_matrix(self) -> float: to output :math:`\sigma`. """ - tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze() - tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2) + n, m = self.shape + tmp = self.apply_square_cost(jnp.full((n,), fill_value=1.0 / n)) + tmp = jnp.sum((1.0 / m) * tmp) - (self.mean_cost_matrix ** 2) return jnp.sqrt(jax.nn.relu(tmp)) @property @@ -158,7 +149,6 @@ def epsilon_scheduler(self) -> eps_scheduler.Epsilon: """Epsilon scheduler.""" if isinstance(self._epsilon_init, eps_scheduler.Epsilon): return self._epsilon_init - # no relative epsilon if self._relative_epsilon is None: if self._epsilon_init is not None: @@ -217,23 +207,21 @@ def is_online(self) -> bool: @property def is_symmetric(self) -> bool: """Whether geometry cost/kernel is a symmetric matrix.""" + n, m = self.shape mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix - return ( - mat.shape[0] == mat.shape[1] and jnp.all(mat == mat.T) - ) if mat is not None else False + return (n == m) and jnp.all(mat == mat.T) @property - def inv_scale_cost(self) -> float: + def inv_scale_cost(self) -> jnp.ndarray: """Compute and return inverse of scaling factor for cost matrix.""" - if isinstance(self._scale_cost, (int, float, np.number, jax.Array)): - return 1.0 / self._scale_cost - self = self._masked_geom(mask_value=jnp.nan) if self._scale_cost == "max_cost": - return 1.0 / jnp.nanmax(self._cost_matrix) + return 1.0 / jnp.max(self._cost_matrix) if self._scale_cost == "mean": - return 1.0 / jnp.nanmean(self._cost_matrix) + return 1.0 / jnp.mean(self._cost_matrix) if self._scale_cost == "median": - return 1.0 / jnp.nanmedian(self._cost_matrix) + return 1.0 / jnp.median(self._cost_matrix) + if utils.is_scalar(self._scale_cost): + return 1.0 / self._scale_cost raise ValueError(f"Scaling {self._scale_cost} not implemented.") def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry": @@ -692,14 +680,14 @@ def to_LRCGeometry( i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n) j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m) - ci_star = self.subset([i_star], None).cost_matrix.ravel() ** 2 # (m,) - cj_star = self.subset(None, [j_star]).cost_matrix.ravel() ** 2 # (n,) + ci_star = self.subset(row_ixs=i_star).cost_matrix.ravel() ** 2 # (m,) + cj_star = self.subset(col_ixs=j_star).cost_matrix.ravel() ** 2 # (n,) p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,) p_row /= jnp.sum(p_row) row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row) # (n_subset, m) - s = self.subset(row_ixs, None).cost_matrix + s = self.subset(row_ixs=row_ixs).cost_matrix s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None]) p_col = jnp.sum(s ** 2, axis=0) # (m,) @@ -720,7 +708,7 @@ def to_LRCGeometry( col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,) # (n, n_subset) - A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale + A_trans = self.subset(col_ixs=col_ixs).cost_matrix * inv_scale B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k) M = jnp.linalg.inv(B.T @ B) # (k, k) V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k) @@ -737,179 +725,42 @@ def to_LRCGeometry( ) def subset( - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], - **kwargs: Any - ) -> "Geometry": - """Subset rows or columns of a geometry. - - Args: - src_ixs: Row indices. If ``None``, use all rows. - tgt_ixs: Column indices. If ``None``, use all columns. - kwargs: Keyword arguments to override the initialization. - - Returns: - The modified geometry. - """ - - def subset_fn( - arr: Optional[jnp.ndarray], - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: - if arr is None: - return None - if src_ixs is not None: - arr = arr[src_ixs, ...] - if tgt_ixs is not None: - arr = arr[:, tgt_ixs] - return arr # noqa: RET504 - - return self._mask_subset_helper( - src_ixs, - tgt_ixs, - fn=subset_fn, - propagate_mask=True, - **kwargs, - ) - - def mask( self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - mask_value: float = 0.0, + row_ixs: Optional[jnp.ndarray] = None, + col_ixs: Optional[jnp.ndarray] = None ) -> "Geometry": - """Mask rows or columns of a geometry. - - The mask is used only when computing some statistics of the - :attr:`cost_matrix`. - - - :attr:`mean_cost_matrix` - - :attr:`median_cost_matrix` - - :attr:`inv_scale_cost` + """Subset rows or columns of a geometry. Args: - src_mask: Row mask. Can be specified either as a boolean array of shape - ``[num_a,]`` or as an array of indices. If ``None``, no mask is applied. - tgt_mask: Column mask. Can be specified either as a boolean array of shape - ``[num_b,]`` or as an array of indices. If ``None``, no mask is applied. - mask_value: Value to use for masking. + row_ixs: Row indices. If :obj:`None`, use all rows. + col_ixs: Column indices. If :obj:`None`, use all columns. Returns: - The masked geometry. + The subsetted geometry. """ - - def mask_fn( - arr: Optional[jnp.ndarray], - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: - if arr is None: - return arr - assert arr.ndim == 2, arr.ndim - if src_mask is not None: - arr = jnp.where(src_mask[:, None], arr, mask_value) - if tgt_mask is not None: - arr = jnp.where(tgt_mask[None, :], arr, mask_value) - return arr # noqa: RET504 - - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - return self._mask_subset_helper( - src_mask, tgt_mask, fn=mask_fn, propagate_mask=False - ) - - def _mask_subset_helper( - self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], - *, - fn: Callable[ - [Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], - propagate_mask: bool, - **kwargs: Any, - ) -> "Geometry": - (cost, kernel, eps, src_mask, tgt_mask), aux_data = self.tree_flatten() - cost = fn(cost, src_ixs, tgt_ixs) - kernel = fn(kernel, src_ixs, tgt_ixs) - if propagate_mask: - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - src_mask = fn(src_mask, src_ixs, None) - tgt_mask = fn(tgt_mask, tgt_ixs, None) - - aux_data = {**aux_data, **kwargs} - return type(self).tree_unflatten( - aux_data, [cost, kernel, eps, src_mask, tgt_mask] - ) - - @property - def src_mask(self) -> Optional[jnp.ndarray]: - """Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics. - - Specifically, it is used when computing: - - - :attr:`mean_cost_matrix` - - :attr:`median_cost_matrix` - - :attr:`inv_scale_cost` - """ - return self._normalize_mask(self._src_mask, self.shape[0]) - - @property - def tgt_mask(self) -> Optional[jnp.ndarray]: - """Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics. - - Specifically, it is used when computing: - - - :attr:`mean_cost_matrix` - - :attr:`median_cost_matrix` - - :attr:`inv_scale_cost` - """ - return self._normalize_mask(self._tgt_mask, self.shape[1]) + (cost, kernel, *rest), aux_data = self.tree_flatten() + row_ixs = row_ixs if row_ixs is None else jnp.atleast_1d(row_ixs) + col_ixs = col_ixs if col_ixs is None else jnp.atleast_1d(col_ixs) + if cost is not None: + cost = cost if row_ixs is None else cost[row_ixs] + cost = cost if col_ixs is None else cost[:, col_ixs] + if kernel is not None: + kernel = kernel if row_ixs is None else kernel[row_ixs] + kernel = kernel if col_ixs is None else kernel[:, col_ixs] + return type(self).tree_unflatten(aux_data, (cost, kernel, *rest)) @property def dtype(self) -> jnp.dtype: """The data type.""" - return ( - self._kernel_matrix if self._cost_matrix is None else self._cost_matrix - ).dtype - - def _masked_geom(self, mask_value: float = 0.0) -> "Geometry": - """Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`.""" - src_mask, tgt_mask = self.src_mask, self.tgt_mask - if src_mask is None and tgt_mask is None: - return self - return self.mask(src_mask, tgt_mask, mask_value=mask_value) - - @property - def _n_normed_ones(self) -> jnp.ndarray: - """Normalized array of shape ``[num_a,]``.""" - mask = self.src_mask - arr = jnp.ones(self.shape[0]) if mask is None else mask - return arr / jnp.sum(arr) - - @property - def _m_normed_ones(self) -> jnp.ndarray: - """Normalized array of shape ``[num_b,]``.""" - mask = self.tgt_mask - arr = jnp.ones(self.shape[1]) if mask is None else mask - return arr / jnp.sum(arr) - - @staticmethod - def _normalize_mask(mask: Optional[jnp.ndarray], - size: int) -> Optional[jnp.ndarray]: - """Convert array of indices to a boolean mask.""" - if mask is None: - return None - if not jnp.issubdtype(mask, (bool, jnp.bool_)): - mask = jnp.isin(jnp.arange(size), mask) - assert mask.shape == (size,) - return mask + if self._cost_matrix is not None: + return self._cost_matrix.dtype + return self._kernel_matrix.dtype def tree_flatten(self): # noqa: D102 return ( - self._cost_matrix, self._kernel_matrix, self._epsilon_init, - self._src_mask, self._tgt_mask + self._cost_matrix, + self._kernel_matrix, + self._epsilon_init, ), { "scale_cost": self._scale_cost, "relative_epsilon": self._relative_epsilon, @@ -917,7 +768,5 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - cost, kernel, eps, src_mask, tgt_mask = children - return cls( - cost, kernel, eps, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data - ) + cost, kernel, epsilon = children + return cls(cost, kernel_matrix=kernel, epsilon=epsilon, **aux_data) diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 1755b72f6..97392f160 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -315,21 +315,6 @@ def transport_from_scalings( "cloud geometry instead." ) - def subset( - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray] - ) -> NoReturn: - """Not implemented.""" - raise NotImplementedError("Subsetting is not implemented for grids.") - - def mask( - self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - mask_value: float = 0.0, - ) -> NoReturn: - """Not implemented.""" - raise NotImplementedError("Masking is not implemented for grids.") - @property def cost_matrix(self) -> jnp.ndarray: """Not implemented.""" @@ -425,6 +410,4 @@ def to_LRCGeometry( epsilon=self._epsilon_init, relative_epsilon=self._relative_epsilon, scale_cost=self._scale_cost, - src_mask=self.src_mask, - tgt_mask=self.tgt_mask, ) diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index ec0d7d7b9..f75b3ecc8 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -94,27 +94,26 @@ def shape(self) -> Tuple[int, int]: # noqa: D102 @property def is_symmetric(self) -> bool: # noqa: D102 - return ( - self._cost_1.shape[0] == self._cost_2.shape[0] and - jnp.all(self._cost_1 == self._cost_2) - ) + n, m = self.shape + return (n == m) and jnp.all(self._cost_1 == self._cost_2) @property - def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jax.Array)): - return 1.0 / self._scale_cost - self = self._masked_geom() + def inv_scale_cost(self) -> jnp.ndarray: # noqa: D102 if self._scale_cost == "max_bound": x_norm = self._cost_1[:, 0].max() y_norm = self._cost_2[:, 1].max() max_bound = x_norm + y_norm + 2.0 * jnp.sqrt(x_norm * y_norm) return 1.0 / (max_bound + self._bias) if self._scale_cost == "mean": - a, b = self._n_normed_ones, self._m_normed_ones + n, m = self.shape + a = jnp.full((n,), fill_value=1.0 / n) + b = jnp.full((m,), fill_value=1.0 / m) mean = jnp.linalg.multi_dot([a, self._cost_1, self._cost_2.T, b]) return 1.0 / (mean + self._bias) if self._scale_cost == "max_cost": return 1.0 / self._max_cost_matrix + if utils.is_scalar(self._scale_cost): + return 1.0 / self._scale_cost raise ValueError(f"Scaling {self._scale_cost} not implemented.") def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: @@ -193,66 +192,6 @@ def to_LRCGeometry( def can_LRC(self): # noqa: D102 return True - def subset( # noqa: D102 - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], - **kwargs: Any - ) -> "LRCGeometry": - - def subset_fn( - arr: Optional[jnp.ndarray], - ixs: Optional[jnp.ndarray], - ) -> jnp.ndarray: - return arr if arr is None or ixs is None else arr[ixs, ...] - - return self._mask_subset_helper( - src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs - ) - - def mask( # noqa: D102 - self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - mask_value: float = 0.0, - ) -> "LRCGeometry": - - def mask_fn( - arr: Optional[jnp.ndarray], - mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: - if arr is None or mask is None: - return arr - return jnp.where(mask[:, None], arr, mask_value) - - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - return self._mask_subset_helper( - src_mask, tgt_mask, fn=mask_fn, propagate_mask=False - ) - - def _mask_subset_helper( - self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], - *, - fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], - propagate_mask: bool, - **kwargs: Any, - ) -> "LRCGeometry": - (c1, c2, src_mask, tgt_mask, *children), aux_data = self.tree_flatten() - c1 = fn(c1, src_ixs) - c2 = fn(c2, tgt_ixs) - if propagate_mask: - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - src_mask = fn(src_mask, src_ixs) - tgt_mask = fn(tgt_mask, tgt_ixs) - - aux_data = {**aux_data, **kwargs} - return type(self).tree_unflatten( - aux_data, [c1, c2, src_mask, tgt_mask] + children - ) - def __add__(self, other: "LRCGeometry") -> "LRCGeometry": if not isinstance(other, LRCGeometry): return NotImplemented @@ -273,8 +212,6 @@ def tree_flatten(self): # noqa: D102 return ( self._cost_1, self._cost_2, - self._src_mask, - self._tgt_mask, self._epsilon_init, self._bias, self._scale_factor, @@ -285,15 +222,13 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - c1, c2, src_mask, tgt_mask, epsilon, bias, scale_factor = children + c1, c2, epsilon, bias, scale_factor = children return cls( c1, c2, bias=bias, scale_factor=scale_factor, epsilon=epsilon, - src_mask=src_mask, - tgt_mask=tgt_mask, **aux_data ) diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index fafc39b5b..f99073246 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Callable, Literal, Optional, Tuple, Union -import jax import jax.numpy as jnp import jax.tree_util as jtu @@ -203,7 +202,9 @@ def compute_max(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: return jnp.max(jnp.abs(cost)) if summary == "mean": - a, b = self._n_normed_ones, self._m_normed_ones + n, m = self.shape + a = jnp.full((n,), fill_value=1.0 / n) + b = jnp.full((m,), fill_value=1.0 / m) return jnp.sum(self._apply_cost_to_vec(a, scale_cost=1.0) * b) if summary == "max_cost": @@ -226,28 +227,18 @@ def prepare_divergences( x: jnp.ndarray, y: jnp.ndarray, static_b: bool = False, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None, **kwargs: Any ) -> Tuple["PointCloud", ...]: """Instantiate the geometries used for a divergence computation.""" couples = [(x, y), (x, x)] - masks = [(src_mask, tgt_mask), (src_mask, src_mask)] if not static_b: couples += [(y, y)] - masks += [(tgt_mask, tgt_mask)] - - return tuple( - cls(x, y, src_mask=x_mask, tgt_mask=y_mask, **kwargs) - for ((x, y), (x_mask, y_mask)) in zip(couples, masks) - ) + return tuple(cls(x, y, **kwargs) for (x, y) in couples) def tree_flatten(self): # noqa: D102 return ( self.x, self.y, - self._src_mask, - self._tgt_mask, self._epsilon_init, self.cost_fn, ), { @@ -258,16 +249,8 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - x, y, src_mask, tgt_mask, epsilon, cost_fn = children - return cls( - x, - y, - cost_fn=cost_fn, - src_mask=src_mask, - tgt_mask=tgt_mask, - epsilon=epsilon, - **aux_data - ) + x, y, epsilon, cost_fn = children + return cls(x, y, cost_fn=cost_fn, epsilon=epsilon, **aux_data) def _cosine_to_sqeucl(self) -> "PointCloud": assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn) @@ -328,68 +311,6 @@ def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry: epsilon=self._epsilon_init, relative_epsilon=self._relative_epsilon, scale_cost=self._scale_cost, - src_mask=self.src_mask, - tgt_mask=self.tgt_mask, - ) - - def subset( # noqa: D102 - self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], - **kwargs: Any - ) -> "PointCloud": - - def subset_fn( - arr: Optional[jnp.ndarray], - ixs: Optional[jnp.ndarray], - ) -> jnp.ndarray: - return arr if arr is None or ixs is None else arr[ixs, ...] - - return self._mask_subset_helper( - src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs - ) - - def mask( # noqa: D102 - self, - src_mask: Optional[jnp.ndarray], - tgt_mask: Optional[jnp.ndarray], - mask_value: float = 0.0, - ) -> "PointCloud": - - def mask_fn( - arr: Optional[jnp.ndarray], - mask: Optional[jnp.ndarray], - ) -> Optional[jnp.ndarray]: - if arr is None or mask is None: - return arr - return jnp.where(mask[:, None], arr, mask_value) - - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - return self._mask_subset_helper( - src_mask, tgt_mask, fn=mask_fn, propagate_mask=False - ) - - def _mask_subset_helper( - self, - src_ixs: Optional[jnp.ndarray], - tgt_ixs: Optional[jnp.ndarray], - *, - fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], - Optional[jnp.ndarray]], - propagate_mask: bool, - **kwargs: Any, - ) -> "PointCloud": - (x, y, src_mask, tgt_mask, *children), aux_data = self.tree_flatten() - x = fn(x, src_ixs) - y = fn(y, tgt_ixs) - if propagate_mask: - src_mask = self._normalize_mask(src_mask, self.shape[0]) - tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) - src_mask = fn(src_mask, src_ixs) - tgt_mask = fn(tgt_mask, tgt_ixs) - aux_data = {**aux_data, **kwargs} - - return type(self).tree_unflatten( - aux_data, [x, y, src_mask, tgt_mask] + children ) @property @@ -401,10 +322,7 @@ def _unscaled_cost_matrix(self) -> jnp.ndarray: return self.cost_fn.all_pairs(self.x, self.y) @property - def inv_scale_cost(self) -> float: # noqa: D102 - if isinstance(self._scale_cost, (int, float, jax.Array)): - return 1.0 / self._scale_cost - self = self._masked_geom() + def inv_scale_cost(self) -> jnp.ndarray: # noqa: D102 if self._scale_cost == "max_cost": if self.is_online: return 1.0 / self._compute_summary_online(self._scale_cost) @@ -412,23 +330,21 @@ def inv_scale_cost(self) -> float: # noqa: D102 if self._scale_cost == "mean": if self.is_online: return 1.0 / self._compute_summary_online(self._scale_cost) - geom = self._masked_geom(mask_value=jnp.nan)._unscaled_cost_matrix - return 1.0 / jnp.nanmean(geom) + return 1.0 / jnp.mean(self._unscaled_cost_matrix) if self._scale_cost == "median": if not self.is_online: - geom = self._masked_geom(mask_value=jnp.nan) - return 1.0 / jnp.nanmedian(geom._unscaled_cost_matrix) + return 1.0 / jnp.median(self._unscaled_cost_matrix) raise NotImplementedError( "Using the median as scaling factor for " "the cost matrix with the online mode is not implemented." ) - if not hasattr(self.cost_fn, "norm"): - raise ValueError("Cost function has no norm method.") - norm_x = self.cost_fn.norm(self.x) - norm_y = self.cost_fn.norm(self.y) if self._scale_cost == "max_norm": + norm_x = self.cost_fn.norm(self.x) + norm_y = self.cost_fn.norm(self.y) return 1.0 / jnp.maximum(norm_x.max(), norm_y.max()) if self._scale_cost == "max_bound": + norm_x = self.cost_fn.norm(self.x) + norm_y = self.cost_fn.norm(self.y) if self.is_squared_euclidean: x_argmax = jnp.argmax(norm_x) y_argmax = jnp.argmax(norm_y) @@ -442,8 +358,22 @@ def inv_scale_cost(self) -> float: # noqa: D102 "the cost matrix when the cost is not squared euclidean " "is not implemented." ) + if utils.is_scalar(self._scale_cost): + return 1.0 / self._scale_cost raise ValueError(f"Scaling {self._scale_cost} not implemented.") + def subset( # noqa: D102 + self, + row_ixs: Optional[jnp.ndarray] = None, + col_ixs: Optional[jnp.ndarray] = None, + ) -> "PointCloud": + (x, y, *rest), aux_data = self.tree_flatten() + if row_ixs is not None: + x = x[jnp.atleast_1d(row_ixs)] + if col_ixs is not None: + y = y[jnp.atleast_1d(col_ixs)] + return type(self).tree_unflatten(aux_data, (x, y, *rest)) + @property def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 return jnp.exp(-self.cost_matrix / self.epsilon) @@ -458,9 +388,8 @@ def dtype(self) -> jnp.dtype: # noqa: D102 @property def is_symmetric(self) -> bool: # noqa: D102 - return self.y is None or ( - jnp.all(self.x.shape == self.y.shape) and jnp.all(self.x == self.y) - ) + n, m = self.shape + return self.y is None or ((n == m) and jnp.all(self.x == self.y)) @property def is_squared_euclidean(self) -> bool: # noqa: D102 diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index 7498710be..9b38482a0 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -121,7 +121,7 @@ def project( transport: jnp.ndarray, fn: Optional[quadratic_costs.Loss], ) -> jnp.ndarray: - geom = self._create_y_geometry(y, mask=b > 0.0) + geom = self._create_y_geometry(y) fn, lin = (None, True) if fn is None else (fn.func, fn.is_linear) tmp = geom.apply_cost( @@ -182,12 +182,9 @@ def update_features(self, transports: jnp.ndarray, def _create_bary_geometry( self, cost_matrix: jnp.ndarray, - mask: Optional[jnp.ndarray] = None ) -> geometry.Geometry: return geometry.Geometry( cost_matrix=cost_matrix, - src_mask=mask, - tgt_mask=mask, epsilon=self.epsilon, scale_cost=self.scale_cost ) @@ -195,7 +192,6 @@ def _create_bary_geometry( def _create_y_geometry( self, y: jnp.ndarray, - mask: Optional[jnp.ndarray] = None ) -> geometry.Geometry: if self._y_as_costs: assert y.shape[0] == y.shape[1], y.shape @@ -203,24 +199,18 @@ def _create_y_geometry( y, epsilon=self.epsilon, scale_cost=self.scale_cost, - src_mask=mask, - tgt_mask=mask ) return pointcloud.PointCloud( y, epsilon=self.epsilon, scale_cost=self.scale_cost, cost_fn=self.cost_fn, - src_mask=mask, - tgt_mask=mask ) def _create_fused_geometry( self, x: jnp.ndarray, y: jnp.ndarray, - src_mask: Optional[jnp.ndarray] = None, - tgt_mask: Optional[jnp.ndarray] = None ) -> pointcloud.PointCloud: return pointcloud.PointCloud( x, @@ -228,8 +218,6 @@ def _create_fused_geometry( cost_fn=self.cost_fn, epsilon=self.epsilon, scale_cost=self.scale_cost, - src_mask=src_mask, - tgt_mask=tgt_mask ) def _create_problem( @@ -239,18 +227,12 @@ def _create_problem( b: jnp.ndarray, f: Optional[jnp.ndarray] = None ) -> quadratic_problem.QuadraticProblem: - # TODO(michalk8): in future, mask in the problem for convenience? - bary_mask = state.a > 0.0 - y_mask = b > 0.0 - - geom_xx = self._create_bary_geometry(state.cost, mask=bary_mask) - geom_yy = self._create_y_geometry(y, mask=y_mask) + geom_xx = self._create_bary_geometry(state.cost) + geom_yy = self._create_y_geometry(y) if self.is_fused: assert f is not None assert state.x.shape[1] == f.shape[1] - geom_xy = self._create_fused_geometry( - state.x, f, src_mask=bary_mask, tgt_mask=y_mask - ) + geom_xy = self._create_fused_geometry(state.x, f) else: geom_xy = None diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 5467c98ee..5c2c08098 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -73,12 +73,7 @@ def solve_linear_ot( a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): geom = pointcloud.PointCloud( - x, - y, - src_mask=a > 0.0, - tgt_mask=b > 0.0, - cost_fn=bar_prob.cost_fn, - epsilon=bar_prob.epsilon + x, y, cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ) prob = linear_problem.LinearProblem(geom, a=a, b=b) out = linear_solver(prob) diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index b8469e81c..e2759617d 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -275,9 +275,7 @@ def init_transports( rng1, rng2 = jax.random.split(rng, 2) x = jax.random.normal(rng1, shape=(len(a), 2)) y = jax.random.normal(rng2, shape=(len(b), 2)) - geom = pointcloud.PointCloud( - x, y, epsilon=epsilon, src_mask=a > 0, tgt_mask=b > 0 - ) + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) problem = linear_problem.LinearProblem(geom, a=a, b=b) return solver(problem).matrix diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index 2a12f6e3a..c29867955 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -111,9 +111,9 @@ def _from_state( def _random_init( geom: pointcloud.PointCloud, k: int, rng: jax.Array ) -> jnp.ndarray: - ixs = jnp.arange(geom.shape[0]) - ixs = jax.random.choice(rng, ixs, shape=(k,), replace=False) - return geom.subset(ixs, None).x + n, _ = geom.shape + ixs = jax.random.choice(rng, jnp.arange(n), shape=(k,), replace=False) + return geom.x[ixs] def _k_means_plus_plus( @@ -127,7 +127,7 @@ def init_fn(geom: pointcloud.PointCloud, rng: jax.Array) -> KPPState: rng, next_rng = jax.random.split(rng, 2) ix = jax.random.choice(rng, jnp.arange(geom.shape[0]), shape=()) centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix]) - dists = geom.subset([ix], None).cost_matrix[0] + dists = geom.subset(ix).cost_matrix.squeeze(0) # (m,) return KPPState(rng=next_rng, centroids=centroids, centroid_dists=dists) def body_fn( @@ -143,7 +143,7 @@ def body_fn( ixs = jax.random.choice( rng, ixs, shape=(n_local_trials,), p=probs, replace=True ) - geom = geom.subset(ixs, None) + geom = geom.subset(row_ixs=ixs) candidate_dists = jnp.minimum(geom.cost_matrix, state.centroid_dists) best_ix = jnp.argmin(candidate_dists.sum(1)) diff --git a/src/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py index 6fe8e669f..e7b8b597e 100644 --- a/src/ott/tools/segment_sinkhorn.py +++ b/src/ott/tools/segment_sinkhorn.py @@ -108,16 +108,11 @@ def eval_fn( padded_y: jnp.ndarray, padded_weight_x: jnp.ndarray, padded_weight_y: jnp.ndarray, - ) -> float: - mask_x = padded_weight_x > 0.0 - mask_y = padded_weight_y > 0.0 - + ) -> jnp.ndarray: geom = pointcloud.PointCloud( padded_x, padded_y, cost_fn=cost_fn, - src_mask=mask_x, - tgt_mask=mask_y, **kwargs, ) prob = linear_problem.LinearProblem( diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index c379677f4..16c5b596a 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -429,8 +429,6 @@ def eval_fn( padded_weight_x: jnp.ndarray, padded_weight_y: jnp.ndarray, ) -> float: - mask_x = padded_weight_x > 0.0 - mask_y = padded_weight_y > 0.0 div, _ = sinkhorn_divergence( pointcloud.PointCloud, padded_x, @@ -442,9 +440,7 @@ def eval_fn( share_epsilon=share_epsilon, symmetric_sinkhorn=symmetric_sinkhorn, cost_fn=cost_fn, - src_mask=mask_x, - tgt_mask=mask_y, - **kwargs + **kwargs, ) return div diff --git a/src/ott/utils.py b/src/ott/utils.py index 42c8132a4..5a6b242d3 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -49,6 +49,7 @@ "default_progress_fn", "tqdm_progress_fn", "batched_vmap", + "is_scalar", ] IOStatus = Tuple[np.ndarray, np.ndarray, np.ndarray, NamedTuple] @@ -422,3 +423,13 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: batched_fun = _apply_scan(vmapped_fun, in_axes=in_axes) return wrapper + + +# TODO(michalk8): remove when `jax>=0.4.31` +def is_scalar(x: Any) -> bool: # noqa: D103 + if ( + isinstance(x, (np.ndarray, jax.Array)) or hasattr(x, "__jax_array__") or + np.isscalar(x) + ): + return jnp.asarray(x).ndim == 0 + return False diff --git a/tests/geometry/lr_cost_test.py b/tests/geometry/lr_cost_test.py index f17355cbe..79f14a5ac 100644 --- a/tests/geometry/lr_cost_test.py +++ b/tests/geometry/lr_cost_test.py @@ -255,15 +255,10 @@ def test_point_cloud_to_lr( y, cost_fn=costs.SqPNorm(p=2.1), batch_size=batch_size, - scale_cost=scale_cost + scale_cost=scale_cost, ) - if geom.batch_size is not None: - # because `self.assert_upper_bound` tries to instantiate the matrix - geom = geom.subset(None, None, batch_size=None) - geom_lr = geom.to_LRCGeometry(rank=rank, tol=tol) - np.testing.assert_array_equal(geom.shape, geom_lr.shape) assert geom_lr.cost_rank == rank self.assert_upper_bound(geom, geom_lr, rank=rank, tol=tol) diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py deleted file mode 100644 index 3a934cb48..000000000 --- a/tests/geometry/subsetting_test.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Sequence, Tuple, Type, Union - -import pytest - -import jax -import jax.numpy as jnp -import numpy as np - -from ott.geometry import geometry, low_rank, pointcloud - -Geom_t = Union[pointcloud.PointCloud, geometry.Geometry, low_rank.LRCGeometry] - - -@pytest.fixture() -def pc_masked( - rng: jax.Array -) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: - n, m = 20, 30 - rng1, rng2 = jax.random.split(rng, 2) - # x = jnp.full((n,), fill_value=1.) - # y = jnp.full((m,), fill_value=2.) - x = jax.random.normal(rng1, shape=(n, 3)) - y = jax.random.normal(rng1, shape=(m, 3)) - src_mask = jnp.asarray([0, 1, 2]) - tgt_mask = jnp.asarray([3, 5, 6]) - - pc = pointcloud.PointCloud(x, y, src_mask=src_mask, tgt_mask=tgt_mask) - masked = pointcloud.PointCloud(x[src_mask], y[tgt_mask]) - return pc, masked - - -@pytest.fixture(params=["geometry", "point_cloud", "low_rank"]) -def geom_masked(request, pc_masked) -> Tuple[Geom_t, pointcloud.PointCloud]: - pc, masked = pc_masked - if request.param == "point_cloud": - geom = pc - elif request.param == "geometry": - geom = geometry.Geometry( - cost_matrix=pc.cost_matrix, src_mask=pc.src_mask, tgt_mask=pc.tgt_mask - ) - elif request.param == "low_rank": - geom = pc.to_LRCGeometry() - else: - raise NotImplementedError(request.param) - return geom, masked - - -@pytest.mark.fast() -class TestMaskPointCloud: - - @pytest.mark.parametrize("tgt_ixs", [[1], jnp.arange(5)]) - @pytest.mark.parametrize("src_ixs", [None, (3, 3)]) - @pytest.mark.parametrize( - "clazz", [geometry.Geometry, pointcloud.PointCloud, low_rank.LRCGeometry] - ) - def test_mask( - self, rng: jax.Array, clazz: Type[geometry.Geometry], - src_ixs: Optional[Union[int, Sequence[int]]], - tgt_ixs: Optional[Union[int, Sequence[int]]] - ): - rng1, rng2 = jax.random.split(rng, 2) - new_batch_size = 7 - x = jax.random.normal(rng1, shape=(10, 3)) - y = jax.random.normal(rng2, shape=(20, 3)) - - if clazz is geometry.Geometry: - geom = clazz(cost_matrix=x @ y.T, scale_cost="mean") - elif clazz is pointcloud.PointCloud: - geom = clazz(x, y, scale_cost="max_cost", batch_size=5) - else: - geom = clazz(x, y, scale_cost="max_cost") - n = geom.shape[0] if src_ixs is None else len(src_ixs) - m = geom.shape[1] if tgt_ixs is None else len(tgt_ixs) - - if clazz is pointcloud.PointCloud: - geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size) - else: - geom_sub = geom.subset(src_ixs, tgt_ixs) - - assert type(geom_sub) is type(geom) - np.testing.assert_array_equal(geom_sub.shape, (n, m)) - assert geom_sub._scale_cost == geom._scale_cost - if clazz is pointcloud.PointCloud: - # test overriding some argument - assert geom_sub._batch_size == new_batch_size - - @pytest.mark.parametrize( - "scale_cost", ["mean", "max_cost", "median", "max_norm", "max_bound"] - ) - def test_mask_inverse_scaling( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], scale_cost: str - ): - geom, masked = geom_masked - geom = geom.set_scale_cost(scale_cost) - masked = masked.set_scale_cost(scale_cost) - - try: - actual = geom.inv_scale_cost - desired = masked.inv_scale_cost - except ValueError as e: - if "not implemented" not in str(e): - raise - pytest.mark.xfail(str(e)) - else: - np.testing.assert_allclose(actual, desired, rtol=1e-6, atol=1e-6) - geom_subset = geom.subset(geom.src_mask, geom.tgt_mask) - np.testing.assert_allclose( - geom_subset.cost_matrix, masked.cost_matrix, rtol=1e-6, atol=1e-6 - ) - - @pytest.mark.parametrize("stat", ["mean", "median"]) - def test_masked_summary( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], stat: str - ): - geom, masked = geom_masked - if stat == "mean": - np.testing.assert_allclose( - geom.mean_cost_matrix, masked.mean_cost_matrix, rtol=1e-6, atol=1e-6 - ) - else: - np.testing.assert_allclose( - geom.median_cost_matrix, - masked.median_cost_matrix, - rtol=1e-6, - atol=1e-6, - ) - - def test_mask_permutation( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array - ): - rng1, rng2 = jax.random.split(rng) - geom, _ = geom_masked - n, m = geom.shape - - # nullify the mask - geom._src_mask = None - geom._tgt_mask = None - assert geom._masked_geom() is geom - children, aux_data = geom.tree_flatten() - gt_geom = type(geom).tree_unflatten(aux_data, children) - - geom._src_mask = jax.random.permutation(rng1, jnp.arange(n)) - geom._tgt_mask = jax.random.permutation(rng2, jnp.arange(m)) - - np.testing.assert_allclose(geom.mean_cost_matrix, gt_geom.mean_cost_matrix) - np.testing.assert_allclose( - geom.median_cost_matrix, gt_geom.median_cost_matrix - ) - - def test_boolean_mask( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jax.Array - ): - rng1, rng2 = jax.random.split(rng) - p = jnp.array([0.5, 0.5]) - geom, _ = geom_masked - n, m = geom.shape - - src_mask = jax.random.choice(rng1, jnp.array([False, True]), (n,), p=p) - tgt_mask = jax.random.choice(rng1, jnp.array([False, True]), (m,), p=p) - geom._src_mask = src_mask - geom._tgt_mask = tgt_mask - gt_cost = geom.cost_matrix[src_mask, :][:, tgt_mask] - - np.testing.assert_allclose( - geom.mean_cost_matrix, jnp.mean(gt_cost), rtol=1e-6, atol=1e-6 - ) - np.testing.assert_allclose( - geom.median_cost_matrix, jnp.median(gt_cost), rtol=1e-6, atol=1e-6 - ) - - def test_subset_mask( - self, - geom_masked: Tuple[Geom_t, pointcloud.PointCloud], - ): - geom, masked = geom_masked - assert masked.shape < geom.shape - geom = geom.subset(geom.src_mask, geom.tgt_mask) - - assert geom.shape == masked.shape - assert geom.src_mask.shape == (geom.shape[0],) - assert geom.tgt_mask.shape == (geom.shape[1],) - - np.testing.assert_allclose( - geom.mean_cost_matrix, masked.mean_cost_matrix, rtol=1e-6, atol=1e-6 - ) - np.testing.assert_allclose( - geom.median_cost_matrix, - masked.median_cost_matrix, - rtol=1e-6, - atol=1e-6 - ) - np.testing.assert_allclose( - geom.cost_matrix, masked.cost_matrix, rtol=1e-6, atol=1e-6 - ) - - def test_mask_as_nonunique_indices( - self, - geom_masked: Tuple[Geom_t, pointcloud.PointCloud], - ): - geom, _ = geom_masked - n, m = geom.shape - src_ixs, tgt_ixs = [0, 2], [3, 1] - geom._src_mask = jnp.asarray(src_ixs * 11) # numbers chosen arbitrarily - geom._tgt_mask = jnp.asarray(tgt_ixs * 13) - - np.testing.assert_array_equal( - geom.src_mask, jnp.isin(jnp.arange(n), jnp.asarray(src_ixs)) - ) - np.testing.assert_array_equal( - geom.tgt_mask, jnp.isin(jnp.arange(m), jnp.asarray(tgt_ixs)) - )