Skip to content

Commit

Permalink
Remove source/target masks (#603)
Browse files Browse the repository at this point in the history
* Remove `{src,tgt}_mask`

* Remove `nan{mean,max,median}`

* Reintroduce subset function

* Bump macOS runners

* Fwd-port `jnp.isscalar`
  • Loading branch information
michalk8 authored Dec 3, 2024
1 parent 04e5abb commit 690b1ae
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 664 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ jobs:
os: [ubuntu-latest]
include:
- python-version: '3.9'
os: macos-13
- python-version: '3.10'
os: macos-14
- python-version: '3.10'
os: macos-15

steps:
- uses: actions/checkout@v4
Expand Down
237 changes: 43 additions & 194 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import numpy as np

from ott import utils
from ott.geometry import epsilon_scheduler as eps_scheduler
Expand Down Expand Up @@ -65,10 +64,6 @@ class Geometry:
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can
be given to rescale the cost such that ``cost_matrix /= scale_cost``.
src_mask: Mask specifying valid rows when computing some statistics of
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
:attr:`cost_matrix`, see :attr:`tgt_mask`.
Note:
When defining a :class:`~ott.geometry.geometry.Geometry` through a
Expand All @@ -86,18 +81,13 @@ def __init__(
relative_epsilon: Optional[Literal["mean", "std"]] = None,
scale_cost: Union[float, Literal["mean", "max_cost", "median",
"std"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix
self._epsilon_init = epsilon
self._relative_epsilon = relative_epsilon
self._scale_cost = scale_cost

self._src_mask = src_mask
self._tgt_mask = tgt_mask

@property
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
Expand All @@ -117,14 +107,14 @@ def cost_matrix(self) -> jnp.ndarray:
@property
def median_cost_matrix(self) -> float:
"""Median of the :attr:`cost_matrix`."""
geom = self._masked_geom(mask_value=jnp.nan)
return jnp.nanmedian(geom.cost_matrix) # will fail for online PC
return jnp.median(self.cost_matrix)

@property
def mean_cost_matrix(self) -> float:
"""Mean of the :attr:`cost_matrix`."""
tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
return jnp.sum(tmp * self._m_normed_ones)
n, m = self.shape
tmp = self.apply_cost(jnp.full((n,), fill_value=1.0 / n))
return jnp.sum((1.0 / m) * tmp)

@property
def std_cost_matrix(self) -> float:
Expand All @@ -139,8 +129,9 @@ def std_cost_matrix(self) -> float:
to output :math:`\sigma`.
"""
tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze()
tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2)
n, m = self.shape
tmp = self.apply_square_cost(jnp.full((n,), fill_value=1.0 / n))
tmp = jnp.sum((1.0 / m) * tmp) - (self.mean_cost_matrix ** 2)
return jnp.sqrt(jax.nn.relu(tmp))

@property
Expand All @@ -158,7 +149,6 @@ def epsilon_scheduler(self) -> eps_scheduler.Epsilon:
"""Epsilon scheduler."""
if isinstance(self._epsilon_init, eps_scheduler.Epsilon):
return self._epsilon_init

# no relative epsilon
if self._relative_epsilon is None:
if self._epsilon_init is not None:
Expand Down Expand Up @@ -217,23 +207,21 @@ def is_online(self) -> bool:
@property
def is_symmetric(self) -> bool:
"""Whether geometry cost/kernel is a symmetric matrix."""
n, m = self.shape
mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
return (
mat.shape[0] == mat.shape[1] and jnp.all(mat == mat.T)
) if mat is not None else False
return (n == m) and jnp.all(mat == mat.T)

@property
def inv_scale_cost(self) -> float:
def inv_scale_cost(self) -> jnp.ndarray:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost, (int, float, np.number, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == "max_cost":
return 1.0 / jnp.nanmax(self._cost_matrix)
return 1.0 / jnp.max(self._cost_matrix)
if self._scale_cost == "mean":
return 1.0 / jnp.nanmean(self._cost_matrix)
return 1.0 / jnp.mean(self._cost_matrix)
if self._scale_cost == "median":
return 1.0 / jnp.nanmedian(self._cost_matrix)
return 1.0 / jnp.median(self._cost_matrix)
if utils.is_scalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry":
Expand Down Expand Up @@ -692,14 +680,14 @@ def to_LRCGeometry(
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)

ci_star = self.subset([i_star], None).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(None, [j_star]).cost_matrix.ravel() ** 2 # (n,)
ci_star = self.subset(row_ixs=i_star).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(col_ixs=j_star).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None).cost_matrix
s = self.subset(row_ixs=row_ixs).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(s ** 2, axis=0) # (m,)
Expand All @@ -720,7 +708,7 @@ def to_LRCGeometry(
col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale
A_trans = self.subset(col_ixs=col_ixs).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)
Expand All @@ -737,187 +725,48 @@ def to_LRCGeometry(
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "Geometry":
"""Subset rows or columns of a geometry.
Args:
src_ixs: Row indices. If ``None``, use all rows.
tgt_ixs: Column indices. If ``None``, use all columns.
kwargs: Keyword arguments to override the initialization.
Returns:
The modified geometry.
"""

def subset_fn(
arr: Optional[jnp.ndarray],
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return None
if src_ixs is not None:
arr = arr[src_ixs, ...]
if tgt_ixs is not None:
arr = arr[:, tgt_ixs]
return arr # noqa: RET504

return self._mask_subset_helper(
src_ixs,
tgt_ixs,
fn=subset_fn,
propagate_mask=True,
**kwargs,
)

def mask(
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.0,
row_ixs: Optional[jnp.ndarray] = None,
col_ixs: Optional[jnp.ndarray] = None
) -> "Geometry":
"""Mask rows or columns of a geometry.
The mask is used only when computing some statistics of the
:attr:`cost_matrix`.
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""Subset rows or columns of a geometry.
Args:
src_mask: Row mask. Can be specified either as a boolean array of shape
``[num_a,]`` or as an array of indices. If ``None``, no mask is applied.
tgt_mask: Column mask. Can be specified either as a boolean array of shape
``[num_b,]`` or as an array of indices. If ``None``, no mask is applied.
mask_value: Value to use for masking.
row_ixs: Row indices. If :obj:`None`, use all rows.
col_ixs: Column indices. If :obj:`None`, use all columns.
Returns:
The masked geometry.
The subsetted geometry.
"""

def mask_fn(
arr: Optional[jnp.ndarray],
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return arr
assert arr.ndim == 2, arr.ndim
if src_mask is not None:
arr = jnp.where(src_mask[:, None], arr, mask_value)
if tgt_mask is not None:
arr = jnp.where(tgt_mask[None, :], arr, mask_value)
return arr # noqa: RET504

src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
return self._mask_subset_helper(
src_mask, tgt_mask, fn=mask_fn, propagate_mask=False
)

def _mask_subset_helper(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
*,
fn: Callable[
[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]],
Optional[jnp.ndarray]],
propagate_mask: bool,
**kwargs: Any,
) -> "Geometry":
(cost, kernel, eps, src_mask, tgt_mask), aux_data = self.tree_flatten()
cost = fn(cost, src_ixs, tgt_ixs)
kernel = fn(kernel, src_ixs, tgt_ixs)
if propagate_mask:
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
src_mask = fn(src_mask, src_ixs, None)
tgt_mask = fn(tgt_mask, tgt_ixs, None)

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(
aux_data, [cost, kernel, eps, src_mask, tgt_mask]
)

@property
def src_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics.
Specifically, it is used when computing:
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._src_mask, self.shape[0])

@property
def tgt_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics.
Specifically, it is used when computing:
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._tgt_mask, self.shape[1])
(cost, kernel, *rest), aux_data = self.tree_flatten()
row_ixs = row_ixs if row_ixs is None else jnp.atleast_1d(row_ixs)
col_ixs = col_ixs if col_ixs is None else jnp.atleast_1d(col_ixs)
if cost is not None:
cost = cost if row_ixs is None else cost[row_ixs]
cost = cost if col_ixs is None else cost[:, col_ixs]
if kernel is not None:
kernel = kernel if row_ixs is None else kernel[row_ixs]
kernel = kernel if col_ixs is None else kernel[:, col_ixs]
return type(self).tree_unflatten(aux_data, (cost, kernel, *rest))

@property
def dtype(self) -> jnp.dtype:
"""The data type."""
return (
self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
).dtype

def _masked_geom(self, mask_value: float = 0.0) -> "Geometry":
"""Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`."""
src_mask, tgt_mask = self.src_mask, self.tgt_mask
if src_mask is None and tgt_mask is None:
return self
return self.mask(src_mask, tgt_mask, mask_value=mask_value)

@property
def _n_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_a,]``."""
mask = self.src_mask
arr = jnp.ones(self.shape[0]) if mask is None else mask
return arr / jnp.sum(arr)

@property
def _m_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_b,]``."""
mask = self.tgt_mask
arr = jnp.ones(self.shape[1]) if mask is None else mask
return arr / jnp.sum(arr)

@staticmethod
def _normalize_mask(mask: Optional[jnp.ndarray],
size: int) -> Optional[jnp.ndarray]:
"""Convert array of indices to a boolean mask."""
if mask is None:
return None
if not jnp.issubdtype(mask, (bool, jnp.bool_)):
mask = jnp.isin(jnp.arange(size), mask)
assert mask.shape == (size,)
return mask
if self._cost_matrix is not None:
return self._cost_matrix.dtype
return self._kernel_matrix.dtype

def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._src_mask, self._tgt_mask
self._cost_matrix,
self._kernel_matrix,
self._epsilon_init,
), {
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon,
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
cost, kernel, eps, src_mask, tgt_mask = children
return cls(
cost, kernel, eps, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)
cost, kernel, epsilon = children
return cls(cost, kernel_matrix=kernel, epsilon=epsilon, **aux_data)
17 changes: 0 additions & 17 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,6 @@ def transport_from_scalings(
"cloud geometry instead."
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray]
) -> NoReturn:
"""Not implemented."""
raise NotImplementedError("Subsetting is not implemented for grids.")

def mask(
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.0,
) -> NoReturn:
"""Not implemented."""
raise NotImplementedError("Masking is not implemented for grids.")

@property
def cost_matrix(self) -> jnp.ndarray:
"""Not implemented."""
Expand Down Expand Up @@ -425,6 +410,4 @@ def to_LRCGeometry(
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
)
Loading

0 comments on commit 690b1ae

Please sign in to comment.