diff --git a/.flake8 b/.flake8 index d5e2e12b0..363a33365 100644 --- a/.flake8 +++ b/.flake8 @@ -41,15 +41,11 @@ ignore = E111, E114 # Missing blank line before section D411 - # TODO(michalk8): fix D10{1,2,3} - # D101 Missing docstring in public class - D101 - # Missing docstring in public method - D102 - # Missing docstring in public function - D103 exclude = .git,__pycache__,build,docs/_build,dist # C409: Unnecessary call - rewrite as a literal. per-file-ignores = tests/*: D,C408 */__init__.py: F401 + examples/*: D101, D102, D103 + docs/*: D101, D102 + src/ott/types.py: D102 diff --git a/.gitignore b/.gitignore index d8f1864bd..397ee11b6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# vscode +.vscode/ + # generated documentation docs/html **/_autosummary diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index c0aa7ff92..5fb31241b 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -154,7 +154,7 @@ def __init__(self, p: float): self.p = p self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * jnp.linalg.norm(z, self.p) ** 2 def h_legendre(self, z: jnp.ndarray) -> float: @@ -164,11 +164,11 @@ def h_legendre(self, z: jnp.ndarray) -> float: """ return 0.5 * jnp.linalg.norm(z, self.q) ** 2 - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.p,) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(aux_data[0]) @@ -188,18 +188,18 @@ def __init__(self, p: float): self.p = p self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.linalg.norm(z, self.p) ** self.p / self.p - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`" return jnp.linalg.norm(z, self.q) ** self.q / self.q - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.p,) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(aux_data[0]) @@ -231,10 +231,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sum(z ** 2) - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.25 * jnp.sum(z ** 2) def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @@ -312,17 +312,17 @@ def __init__(self, gamma: float = 1.0): assert gamma >= 0, "Gamma must be non-negative." self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 return self.gamma * jnp.linalg.norm(z, ord=1) - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.gamma,) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(*aux_data) @@ -348,19 +348,19 @@ def __init__(self, gamma: float = 1.0): assert gamma > 0, "Gamma must be positive." self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 u = jnp.arcsinh(jnp.abs(z) / (2 * self.gamma)) out = u - 0.5 * jnp.exp(-2.0 * u) return (self.gamma ** 2) * jnp.sum(out + 0.5) # make positive - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.gamma,) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(*aux_data) @@ -388,7 +388,7 @@ def __init__(self, k: int, gamma: float = 1.0): self.k = k self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 # Prop 2.1 in :cite:`argyriou:12` k = self.k top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values @@ -409,7 +409,7 @@ def reg(self, z: jnp.ndarray) -> float: return 0.5 * self.gamma * (s + (r + 1) * cesaro[r] ** 2) - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 @functools.partial(jax.vmap, in_axes=[0, None, None]) def find_indices(r: int, l: jnp.ndarray, @@ -454,11 +454,11 @@ def inner(r: int, l: int, # change sign and reorder return sgn * q[jnp.argsort(z_ixs.astype(float))] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.k, self.gamma) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(*aux_data) @@ -606,11 +606,11 @@ def _padder(cls, dim: int) -> jnp.ndarray: ) return padding[jnp.newaxis, :] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self._dimension, self._sqrtm_kw) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(aux_data[0], **aux_data[1]) @@ -718,11 +718,11 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: (sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children dim, sigma, gamma, kwargs = aux_data return cls(dim, sigma=sigma, gamma=gamma, **kwargs) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 0a3e2b686..3a3273d0c 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -71,18 +71,20 @@ def at(self, iteration: Optional[int] = 1) -> float: return multiple * self.target def done(self, eps: float) -> bool: + """Return whether the scheduler is done at a given value.""" return eps == self.target def done_at(self, iteration: Optional[int]) -> bool: + """Return whether the scheduler is done at a given iteration.""" return self.done(self.at(iteration)) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._target_init, self._scale_epsilon, self._init, self._decay ), None @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del aux_data return cls(*children) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 1de4962e5..4fef6dcae 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -900,7 +900,7 @@ def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]], assert mask.shape == (size,) return mask - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._cost_matrix, self._kernel_matrix, self._epsilon_init, self._relative_epsilon, self._scale_epsilon, self._src_mask, @@ -910,7 +910,7 @@ def tree_flatten(self): } @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 *args, kwargs = children return cls(*args, **kwargs, **aux_data) diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index d79a6c7e4..5ce929afc 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -97,6 +97,16 @@ def apply_kernel( eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: + r"""Apply :attr:`kernel_matrix` on positive scaling vector. + + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. + + Returns: + Kernel applied to ``scaling``. + """ def conf_fn( iteration: int, solver_lap: Tuple[decomposition.CholeskySolver, @@ -158,7 +168,7 @@ def body_fn( state=state, )[1] - def apply_transport_from_scalings( + def apply_transport_from_scalings( # noqa: D102 self, u: jnp.ndarray, v: jnp.ndarray, @@ -184,7 +194,7 @@ def body_fn(carry: None, vec: jnp.ndarray) -> jnp.ndarray: return res @property - def kernel_matrix(self) -> jnp.ndarray: + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # force symmetry because of numerical imprecisions @@ -192,7 +202,7 @@ def kernel_matrix(self) -> jnp.ndarray: return (kernel + kernel.T) * .5 @property - def cost_matrix(self) -> jnp.ndarray: + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 return -self.t * mu.safe_log(self.kernel_matrix) @property @@ -287,7 +297,7 @@ def solver(self) -> decomposition.CholeskySolver: return self._solver @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 arr = self._graph if self._graph is not None else self._lap return arr.shape @@ -308,12 +318,12 @@ def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]: return (self._graph + self._graph.T) if self.directed else self._graph @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 # there may be some numerical imprecisions, but it should be symmetric return True @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self._graph.dtype # TODO(michalk8): in future, use mixins for lse/kernel mode @@ -343,7 +353,7 @@ def marginal_from_potentials( """Not implemented.""" raise ValueError("Not implemented.") - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self._graph, self._lap, self.solver], { "t": self._t, "n_steps": self.n_steps, @@ -355,7 +365,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: } @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Graph": graph, laplacian, solver = children diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 553949fe4..394381d04 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -132,15 +132,15 @@ def median_cost_matrix(self) -> NoReturn: raise NotImplementedError('Median cost not implemented for grids.') @property - def can_LRC(self) -> bool: + def can_LRC(self) -> bool: # noqa: D102 return True @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 return self.num_a, self.num_a @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return True # Reimplemented functions to be used in regularized OT @@ -341,14 +341,14 @@ def prepare_divergences( return tuple(sep_grid for _ in range(size)) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self.x[0].dtype - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (self.x, self.cost_fns, self._epsilon), self.kwargs @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls( x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data ) diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 453b1cdf8..f55e9a775 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -91,7 +91,7 @@ def bias(self) -> float: return self._bias * self.inv_scale_cost @property - def cost_rank(self) -> int: + def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property @@ -100,18 +100,18 @@ def cost_matrix(self) -> jnp.ndarray: return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 return self._cost_1.shape[0], self._cost_2.shape[0] @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return ( self._cost_1.shape[0] == self._cost_2.shape[0] and jnp.all(self._cost_1 == self._cost_2) ) @property - def inv_scale_cost(self) -> float: + def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost @@ -239,10 +239,10 @@ def to_LRCGeometry( return self @property - def can_LRC(self): + def can_LRC(self): # noqa: D102 return True - def subset( + def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "LRCGeometry": @@ -257,7 +257,7 @@ def subset_fn( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs ) - def mask( + def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], @@ -311,10 +311,10 @@ def __add__(self, other: 'LRCGeometry') -> 'LRCGeometry': ) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self._cost_1.dtype - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._cost_1, self._cost_2, self._src_mask, self._tgt_mask, self._kwargs ), { @@ -325,7 +325,7 @@ def tree_flatten(self): } @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 c1, c2, src_mask, tgt_mask, kwargs = children return cls( c1, c2, src_mask=src_mask, tgt_mask=tgt_mask, **kwargs, **aux_data diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index b4e7e6d9d..82e4e9523 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -99,7 +99,7 @@ def _norm_y(self) -> Union[float, jnp.ndarray]: return 0. @property - def can_LRC(self): + def can_LRC(self): # noqa: D102 return self.is_squared_euclidean and self._check_LRC_dim @property @@ -108,20 +108,20 @@ def _check_LRC_dim(self): return n * m > (n + m) * d @property - def cost_matrix(self) -> Optional[jnp.ndarray]: + def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None cost_matrix = self._compute_cost_matrix() return cost_matrix * self.inv_scale_cost @property - def kernel_matrix(self) -> Optional[jnp.ndarray]: + def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None return jnp.exp(-self.cost_matrix / self.epsilon) @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 # in the process of flattening/unflattening in vmap, `__init__` # can be called with dummy objects # we optionally access `shape` in order to get the batch size @@ -130,13 +130,13 @@ def shape(self) -> Tuple[int, int]: return self.x.shape[0], self.y.shape[0] @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return self.y is None or ( jnp.all(self.x.shape == self.y.shape) and jnp.all(self.x == self.y) ) @property - def is_squared_euclidean(self) -> bool: + def is_squared_euclidean(self) -> bool: # noqa: D102 return isinstance(self.cost_fn, costs.SqEuclidean) @property @@ -147,11 +147,11 @@ def is_online(self) -> bool: # TODO(michalk8): when refactoring, consider PC as a subclass of LR? @property - def cost_rank(self) -> int: + def cost_rank(self) -> int: # noqa: D102 return self.x.shape[1] @property - def inv_scale_cost(self) -> float: + def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost @@ -201,7 +201,7 @@ def _compute_cost_matrix(self) -> jnp.ndarray: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] return cost_matrix - def apply_lse_kernel( + def apply_lse_kernel( # noqa: D102 self, f: jnp.ndarray, g: jnp.ndarray, @@ -288,7 +288,7 @@ def finalize(i: int): return eps * h_res - jnp.where(jnp.isfinite(v), v, 0), h_sign - def apply_kernel( + def apply_kernel( # noqa: D102 self, scaling: jnp.ndarray, eps: Optional[float] = None, @@ -315,7 +315,7 @@ def apply_kernel( self.cost_fn, self.inv_scale_cost ) - def transport_from_potentials( + def transport_from_potentials( # noqa: D102 self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: if not self.is_online: @@ -329,7 +329,7 @@ def transport_from_potentials( self.cost_fn, self.inv_scale_cost ) - def transport_from_scalings( + def transport_from_scalings( # noqa: D102 self, u: jnp.ndarray, v: jnp.ndarray ) -> jnp.ndarray: if not self.is_online: @@ -569,7 +569,7 @@ def prepare_divergences( for ((x, y), (x_mask, y_mask)) in zip(couples, masks) ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ([self.x, self.y, self._src_mask, self._tgt_mask, self.cost_fn], { 'epsilon': self._epsilon_init, 'relative_epsilon': self._relative_epsilon, @@ -579,7 +579,7 @@ def tree_flatten(self): }) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 x, y, src_mask, tgt_mask, cost_fn = children return cls( x, y, cost_fn=cost_fn, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data @@ -649,7 +649,7 @@ def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry: **self._kwargs ) - def subset( + def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "PointCloud": @@ -664,7 +664,7 @@ def subset_fn( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs ) - def mask( + def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], @@ -710,7 +710,7 @@ def _mask_subset_helper( ) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self.x.dtype @property diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index 2a5479713..186956bfc 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -198,7 +198,7 @@ def __init__( self.max_iter = max_iter self.vectorized_update = vectorized_update - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([], { 'tolerance': self.tolerance, 'max_iter': self.max_iter, diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index f0ca847bd..378a0ce55 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -237,7 +237,7 @@ class RandomInitializer(LRInitializer): kwargs: Additional keyword arguments. """ - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -250,7 +250,7 @@ def init_q( init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank))) return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True)) - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -263,7 +263,7 @@ def init_r( init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank))) return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True)) - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -311,7 +311,7 @@ def _compute_factor( return ((lambda_1 * x[:, None] @ g1.reshape(1, -1)) + ((1 - lambda_1) * y[:, None] @ g2.reshape(1, -1))) - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -322,7 +322,7 @@ def init_q( del key, kwargs return self._compute_factor(ot_prob, init_g, which="q") - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -333,7 +333,7 @@ def init_r( del key, kwargs return self._compute_factor(ot_prob, init_g, which="r") - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -422,7 +422,7 @@ def _compute_factor( solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs) return solver(prob).matrix - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -434,7 +434,7 @@ def init_q( ot_prob, key, init_g=init_g, which="q", **kwargs ) - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -446,7 +446,7 @@ def init_r( ot_prob, key, init_g=init_g, which="r", **kwargs ) - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -455,7 +455,7 @@ def init_g( del key, kwargs return jnp.ones((self.rank,)) / self.rank - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["sinkhorn_kwargs"] = self._sinkhorn_kwargs aux_data["min_iterations"] = self._min_iter diff --git a/src/ott/initializers/nn/initializers.py b/src/ott/initializers/nn/initializers.py index 95d7d3928..e8d32fa67 100644 --- a/src/ott/initializers/nn/initializers.py +++ b/src/ott/initializers/nn/initializers.py @@ -135,7 +135,7 @@ def update( """ return self.update_impl(state, a, b) - def init_dual_a( + def init_dual_a( # noqa: D102 self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: # Detect if the problem is batched. @@ -199,7 +199,7 @@ def _compute_f(self, a, b, params): """ return self.meta_model.apply({'params': params}, a, b) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.geom, self.meta_model, self.opt], { 'rng': self.rng, 'state': self.state diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index c14f1121b..1d4614d9c 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -202,6 +202,6 @@ def rank(self) -> int: """Rank of the transport matrix factorization.""" return self._linear_lr_initializer.rank - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() return children + [self._linear_lr_initializer], aux_data diff --git a/src/ott/math/decomposition.py b/src/ott/math/decomposition.py index b9e82595f..1099ffe11 100644 --- a/src/ott/math/decomposition.py +++ b/src/ott/math/decomposition.py @@ -140,7 +140,7 @@ def _decompose(self, A: T) -> Optional[T]: def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray: return jsp.linalg.solve_triangular(L, b, lower=self._lower) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["lower"] = self._lower return children, aux_data @@ -218,7 +218,7 @@ def clear_factor_cache(cls) -> None: def __hash__(self) -> int: return object.__hash__(self) if self._key is None else self._key - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() return children, { **aux_data, "beta": self._beta, diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 0c49460c9..8044d1c16 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -236,7 +236,7 @@ def sqrtm_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) -def sqrtm_only( +def sqrtm_only( # noqa: D103 x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, @@ -282,7 +282,7 @@ def sqrtm_only_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) -def inv_sqrtm_only( +def inv_sqrtm_only( # noqa: D103 x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, diff --git a/src/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py index 92694ff90..776c9b18e 100644 --- a/src/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -87,5 +87,5 @@ def diag_jacobian_of_marginal_fit( ) -def rho(epsilon: float, tau: float) -> float: +def rho(epsilon: float, tau: float) -> float: # noqa: D103 return (epsilon * tau) / (1. - tau) diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 27e900a09..4dc59a0df 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -30,7 +30,11 @@ Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] -def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: +def safe_log( # noqa: D103 + x: jnp.ndarray, + *, + eps: Optional[float] = None +) -> jnp.ndarray: if eps is None: eps = jnp.finfo(x.dtype).tiny return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) @@ -56,7 +60,9 @@ def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4)) -def logsumexp(mat, axis=None, keepdims=False, b=None, return_sign=False): +def logsumexp( # noqa: D103 + mat, axis=None, keepdims=False, b=None, return_sign=False +): return jax.scipy.special.logsumexp( mat, axis=axis, keepdims=keepdims, b=b, return_sign=return_sign ) @@ -111,4 +117,14 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): def barycentric_projection( matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" ) -> jnp.ndarray: + """Compute the barycentric projection of a matrix. + + Args: + matrix: a matrix of shape (n, m) + y: a vector of shape (m,) + cost_fn: a CostFn instance. + + Returns: + a vector of shape (n,) containing the barycentric projection of matrix. + """ return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index c8f76c016..4eac8a922 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -165,7 +165,7 @@ def weights(self) -> jnp.ndarray: def _is_segmented(self) -> bool: return self._y.ndim == 3 - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self._y, self._b, self._weights], { 'cost_fn': self.cost_fn, 'epsilon': self.epsilon, @@ -174,7 +174,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "BarycenterProblem": y, b, weights = children diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 4b1bb9709..60baf731d 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -107,14 +107,14 @@ def get_transport_functions( ) return marginal_a, marginal_b, app_transport - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.geom, self._a, self._b], { 'tau_a': self.tau_a, 'tau_b': self.tau_b }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "LinearProblem": return cls(*children, **aux_data) diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 4eaf82769..45e4d3339 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -160,7 +160,7 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: ) return jax.vmap(jax.grad(self.cost_fn.h_legendre)) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], { "f": self._f, "g": self._g, @@ -169,7 +169,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: } @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "DualPotentials": return cls(*children, **aux_data) @@ -347,11 +347,11 @@ def __init__( self._g_yy = g_yy @property - def f(self) -> Potential_t: + def f(self) -> Potential_t: # noqa: D102 return self._potential_fn(kind="f") @property - def g(self) -> Potential_t: + def g(self) -> Potential_t: # noqa: D102 return self._potential_fn(kind="g") def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t: @@ -414,5 +414,5 @@ def epsilon(self) -> float: """Entropy regularizer.""" return self._prob.geom.epsilon - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {} diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index 7c17297f7..a39ea3782 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -90,7 +90,11 @@ def __iter__(self) -> Iterator[jnp.array]: return self.create_sample_generators() def create_sample_generators(self) -> Iterator[jnp.array]: - # create generator which randomly picks center and adds noise + """Random sample generator from Gaussian mixture. + + Returns: + A generator of samples from the Gaussian mixture. + """ key = self.init_key while True: k1, k2, key = jax.random.split(key, 3) diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index 1eac1d330..ed3ad65c1 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -269,7 +269,7 @@ def segmented_y_fused(self) -> Optional[jnp.ndarray]: return y_fused @property - def ndim(self) -> Optional[int]: + def ndim(self) -> Optional[int]: # noqa: D102 return None if self._y_as_costs else self._y.shape[-1] @property @@ -292,7 +292,7 @@ def gw_loss(self) -> quadratic_costs.GWLoss: f"Loss `{self._loss_name}` is not yet implemented." ) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 (y, b, weights), aux = super().tree_flatten() if self._y_as_costs: children = [None, b, weights, y] @@ -304,7 +304,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return children + [self._y_fused], aux @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GWBarycenterProblem": y, b, weights, costs, y_fused = children diff --git a/src/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py index 8de1b398d..2460dfbf7 100644 --- a/src/ott/problems/quadratic/quadratic_costs.py +++ b/src/ott/problems/quadratic/quadratic_costs.py @@ -18,7 +18,7 @@ class GWLoss(NamedTuple): h2: Loss -def make_square_loss() -> GWLoss: +def make_square_loss() -> GWLoss: # noqa: D103 f1 = Loss(lambda x: x ** 2, is_linear=False) f2 = Loss(lambda y: y ** 2, is_linear=False) h1 = Loss(lambda x: x, is_linear=True) @@ -26,7 +26,7 @@ def make_square_loss() -> GWLoss: return GWLoss(f1, f2, h1, h2) -def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: +def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: # noqa: D103 f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) f2 = Loss(lambda y: y, is_linear=True) h1 = Loss(lambda x: x, is_linear=True) diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 5369ce9fd..aa7192ca5 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -459,7 +459,7 @@ def is_balanced(self) -> bool: return ((not self.gw_unbalanced_correction) or (self.tau_a == 1.0 and self.tau_b == 1.0)) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ([self.geom_xx, self.geom_yy, self.geom_xy, self._a, self._b], { 'tau_a': self.tau_a, 'tau_b': self.tau_b, @@ -472,7 +472,7 @@ def tree_flatten(self): }) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 geoms, (a, b) = children[:3], children[3:] return cls(*geoms, a=a, b=b, **aux_data) diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 594c130ee..93f137874 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -116,6 +116,7 @@ def init_maps( def update_history( self, state: 'sinkhorn.SinkhornState', pb, lse_mode: bool ) -> 'sinkhorn.SinkhornState': + """Update history of mapped dual variables.""" f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) return state.set(old_mapped_fus=mapped) @@ -154,7 +155,7 @@ def lehmann(self, state: 'sinkhorn.SinkhornState') -> float: power = 1.0 / self.inner_iterations return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) - def __call__( + def __call__( # noqa: D102 self, weight: float, value: jnp.ndarray, diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 4fa79b8d0..b899457e4 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -55,6 +55,17 @@ def update( self, iteration: int, bar_prob: barycenter_problem.BarycenterProblem, linear_ot_solver: Any, store_errors: bool ) -> 'BarycenterState': + """Update the state of the solver. + + Args: + iteration: the current iteration of the outer loop. + bar_prob: the barycenter problem. + linear_ot_solver: the linear OT solver to use. + store_errors: whether to store the errors of the inner loop. + + Returns: + The updated state. + """ seg_y, seg_b = bar_prob.segmented_y_b @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) @@ -120,7 +131,7 @@ def solve_linear_ot( class WassersteinBarycenter(was_solver.WassersteinSolver): """A Continuous Wasserstein barycenter solver, built on generic template.""" - def __call__( + def __call__( # noqa: D102 self, bar_prob: barycenter_problem.BarycenterProblem, bar_size: int = 100, @@ -181,7 +192,9 @@ def init_state( -jnp.ones((num_iter,)), -jnp.ones((num_iter,)), errors, x, a ) - def output_from_state(self, state: BarycenterState) -> BarycenterState: + def output_from_state( # noqa: D102 + self, state: BarycenterState + ) -> BarycenterState: # TODO(michalk8): create an output variable to match rest of the framework return state diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 3304fd42d..3014a364a 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -26,7 +26,7 @@ __all__ = ["SinkhornBarycenterOutput", "discrete_barycenter"] -class SinkhornBarycenterOutput(NamedTuple): +class SinkhornBarycenterOutput(NamedTuple): # noqa: D101 f: jnp.ndarray g: jnp.ndarray histogram: jnp.ndarray diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 19417f3f9..fa2184f41 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -312,7 +312,7 @@ def set(self, **kwargs: Any) -> 'SinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) - def set_cost( + def set_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool ) -> 'SinkhornOutput': @@ -358,27 +358,27 @@ def transport_cost_at_geom( return jnp.sum(self.matrix * other_geom.cost_matrix) @property - def linear(self) -> bool: + def linear(self) -> bool: # noqa: D102 return isinstance(self.ot_prob, linear_problem.LinearProblem) @property - def geom(self) -> geometry.Geometry: + def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: + def linear_output(self) -> bool: # noqa: D102 return True @property - def converged(self) -> bool: + def converged(self) -> bool: # noqa: D102 if self.errors is None: return False return jnp.logical_and( @@ -387,13 +387,13 @@ def converged(self) -> bool: # TODO(michalk8): this should be always present @property - def n_iters(self) -> int: + def n_iters(self) -> int: # noqa: D102 if self.errors is None: return -1 return jnp.sum(self.errors > -1) @property - def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 u = self.ot_prob.geom.scaling_from_potential(self.f) v = self.ot_prob.geom.scaling_from_potential(self.g) return u, v @@ -417,7 +417,7 @@ def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: self.f, self.g, inputs, axis=axis ) - def marginal(self, axis: int) -> jnp.ndarray: + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis) def cost_at_geom(self, other_geom: geometry.Geometry) -> float: @@ -1006,7 +1006,7 @@ def norm_error(self) -> Tuple[int, ...]: return self._norm_error, # TODO(michalk8): in the future, enforce this (+ in GW) via abstract method - def create_initializer(self) -> init_lib.SinkhornInitializer: + def create_initializer(self) -> init_lib.SinkhornInitializer: # noqa: D102 if isinstance(self.initializer, init_lib.SinkhornInitializer): return self.initializer if self.initializer == "default": @@ -1019,14 +1019,14 @@ def create_initializer(self) -> init_lib.SinkhornInitializer: f"Initializer `{self.initializer}` is not yet implemented." ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 aux = vars(self).copy() aux['norm_error'] = aux.pop('_norm_error') aux.pop('threshold') return [self.threshold], aux @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(**aux_data, threshold=children[0]) diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 566f6f61a..c0b824d20 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -71,6 +71,19 @@ def compute_reg_ot_cost( ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: + """Compute the regularized OT cost. + + Args: + q: first factor of solution + r: second factor of solution + g: weights of solution + ot_prob: linear problem + use_danskin: if True, use Danskin's theorem :cite:`danskin:67,bertsekas:71` + to avoid computing the gradient of the cost function. + + Returns: + regularized OT cost + """ q = jax.lax.stop_gradient(q) if use_danskin else q r = jax.lax.stop_gradient(r) if use_danskin else r g = jax.lax.stop_gradient(g) if use_danskin else g @@ -86,9 +99,9 @@ def solution_error( Since only balanced case is available for LR, this is marginal deviation. Args: - q: first factor of solution - r: second factor of solution - ot_prob: linear problem + q: first factor of solution. + r: second factor of solution. + ot_prob: linear problem. norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector products. @@ -130,7 +143,7 @@ def set(self, **kwargs: Any) -> 'LRSinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) - def set_cost( + def set_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, @@ -139,7 +152,7 @@ def set_cost( del lse_mode return self.set(reg_ot_cost=self.compute_reg_ot_cost(ot_prob, use_danskin)) - def compute_reg_ot_cost( + def compute_reg_ot_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, use_danskin: bool = False, @@ -147,27 +160,27 @@ def compute_reg_ot_cost( return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) @property - def linear(self) -> bool: + def linear(self) -> bool: # noqa: D102 return isinstance(self.ot_prob, linear_problem.LinearProblem) @property - def geom(self) -> geometry.Geometry: + def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: + def linear_output(self) -> bool: # noqa: D102 return True @property - def converged(self) -> bool: + def converged(self) -> bool: # noqa: D102 return jnp.logical_and( jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs)) ) @@ -183,7 +196,7 @@ def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jnp.ndarray: + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -369,6 +382,7 @@ def dykstra_update( inner_iter: int = 10, max_iter: int = 10000 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank n, m = ot_prob.geom.shape @@ -539,7 +553,7 @@ def one_iteration( ) @property - def norm_error(self) -> Tuple[int]: + def norm_error(self) -> Tuple[int]: # noqa: D102 return self._norm_error, @property diff --git a/src/ott/solvers/nn/conjugate_solvers.py b/src/ott/solvers/nn/conjugate_solvers.py index e8dfde5f0..adc0db2a6 100644 --- a/src/ott/solvers/nn/conjugate_solvers.py +++ b/src/ott/solvers/nn/conjugate_solvers.py @@ -90,7 +90,7 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): decrease_factor: float = 0.66 ls_method: Literal['wolf', 'strong-wolfe'] = 'strong-wolfe' - def solve( + def solve( # noqa: D102 self, f: Callable[[jnp.ndarray], jnp.ndarray], y: jnp.ndarray, diff --git a/src/ott/solvers/nn/models.py b/src/ott/solvers/nn/models.py index 35c45325d..92f7d3d48 100644 --- a/src/ott/solvers/nn/models.py +++ b/src/ott/solvers/nn/models.py @@ -160,10 +160,10 @@ class ICNN(ModelBase): gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None @property - def is_potential(self) -> bool: + def is_potential(self) -> bool: # noqa: D102 return True - def setup(self) -> None: + def setup(self) -> None: # noqa: D102 self.num_hidden = len(self.dim_hidden) if self.pos_weights: @@ -279,7 +279,7 @@ def _compute_identity_map(input_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: return A, b @nn.compact - def __call__(self, x: jnp.ndarray) -> float: + def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 z = self.act_fn(self.w_xs[0](x)) for i in range(self.num_hidden): z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) @@ -322,7 +322,7 @@ class MLP(ModelBase): act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index 605fd18c3..5f96b2895 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -222,7 +222,7 @@ def setup( self.valid_step_g = self.get_step_fn(train=False, to_optimize="g") self.train_fn = self.train_neuraldual_alternating - def __call__( + def __call__( # noqa: D102 self, trainloader_source: Iterable[jnp.ndarray], trainloader_target: Iterable[jnp.ndarray], diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 52da5a255..c6475a1b7 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -338,7 +338,7 @@ def warm_start(self) -> bool: """Whether to initialize (low-rank) Sinkhorn using previous solutions.""" return self.is_low_rank if self._warm_start is None else self._warm_start - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["warm_start"] = self._warm_start aux_data["quad_initializer"] = self.quad_initializer diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 0b8fc1866..3900f6c93 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -204,6 +204,7 @@ def update_state( problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: + """Solve the (fused) Gromov-Wasserstein barycenter problem.""" def solve_gw( state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray, @@ -252,12 +253,12 @@ def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: # will be refactored in the future to create an output return state - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux = super().tree_flatten() return children + [self._quad_solver], aux @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GromovWassersteinBarycenter": epsilon, _, threshold, quad_solver = children diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index 0771bb198..9226a667d 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -83,7 +83,7 @@ def is_low_rank(self) -> bool: """Whether the solver is low-rank.""" return self.rank > 0 - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.epsilon, self.linear_ot_solver, self.threshold], { "min_iterations": self.min_iterations, "max_iterations": self.max_iterations, @@ -94,7 +94,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "WassersteinSolver": epsilon, linear_ot_solver, threshold = children diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index 07fd2ab75..5b40bfb59 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -101,23 +101,29 @@ def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> 'Gaussian': @property def loc(self) -> jnp.ndarray: + """Mean of the Gaussian.""" return self._loc @property def scale(self) -> scale_tril.ScaleTriL: + """Scale of the Gaussian.""" return self._scale @property def n_dimensions(self) -> int: + """Dimensionality of the Gaussian.""" return self.loc.shape[-1] def covariance(self) -> jnp.ndarray: + """Covariance of the Gaussian.""" return self.scale.covariance() def to_z(self, x: jnp.ndarray) -> jnp.ndarray: + """Transform x to z = (x - loc) / scale.""" return self.scale.centered_to_z(x_centered=x - self.loc) def from_z(self, z: jnp.ndarray) -> jnp.ndarray: + """Transform z to x = loc + scale * z.""" return self.scale.z_to_centered(z=z) + self.loc def log_prob( @@ -197,13 +203,13 @@ def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: dest_scale=dest.scale, points=points - self.loc[None] ) + dest.loc[None] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.loc, self.scale) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __hash__(self): diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index e038ed361..09b4d9a35 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -141,26 +141,32 @@ def from_points_and_assignment_probs( @property def dtype(self): + """Dtype of the GMM parameters.""" return self.loc.dtype @property def n_dimensions(self): + """Number of dimensions of the GMM parameters.""" return self._loc.shape[-1] @property def n_components(self): + """Number of components of the GMM parameters.""" return self._loc.shape[-2] @property def loc(self) -> jnp.ndarray: + """Location parameters of the GMM.""" return self._loc @property def scale_params(self) -> jnp.ndarray: + """Scale parameters of the GMM.""" return self._scale_params @property def cholesky(self) -> jnp.ndarray: + """Cholesky decomposition of the GMM covariance matrices.""" size = self.n_dimensions def _get_cholesky(scale_params): @@ -170,6 +176,7 @@ def _get_cholesky(scale_params): @property def covariance(self) -> jnp.ndarray: + """Covariance matrices of the GMM.""" size = self.n_dimensions def _get_covariance(scale_params): @@ -179,13 +186,16 @@ def _get_covariance(scale_params): @property def component_weight_ob(self) -> probabilities.Probabilities: + """Component weight object.""" return self._component_weight_ob @property def component_weights(self) -> jnp.ndarray: + """Component weights probabilities.""" return self._component_weight_ob.probs() def log_component_weights(self) -> jnp.ndarray: + """Log component weights probabilities.""" return self._component_weight_ob.log_probs() def _get_normal( @@ -197,13 +207,13 @@ def _get_normal( ) def get_component(self, index: int) -> gaussian.Gaussian: - """Get the specified GMM component.""" + """Specified GMM component.""" return self._get_normal( loc=self.loc[index], scale_params=self.scale_params[index] ) def components(self) -> List[gaussian.Gaussian]: - """Get a list of all GMM components.""" + """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: @@ -288,19 +298,19 @@ def get_log_component_posterior(self, x: jnp.ndarray) -> jnp.ndarray: log_prob_unnorm, axis=-1, keepdims=True ) - def has_nans(self) -> bool: + def has_nans(self) -> bool: # noqa: D102 for leaf in jax.tree_util.tree_leaves(self): if jnp.any(~jnp.isfinite(leaf)): return True return False - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.loc, self.scale_params, self.component_weight_ob) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index fbaa0732c..1cc8f8c63 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -85,31 +85,31 @@ def __init__( self._lock_gmm1 = lock_gmm1 @property - def dtype(self): + def dtype(self): # noqa: D102 return self.gmm0.dtype @property - def gmm0(self): + def gmm0(self): # noqa: D102 return self._gmm0 @property - def gmm1(self): + def gmm1(self): # noqa: D102 return self._gmm1 @property - def epsilon(self): + def epsilon(self): # noqa: D102 return self._epsilon @property - def tau(self): + def tau(self): # noqa: D102 return self._tau @property - def rho(self): + def rho(self): # noqa: D102 return self.epsilon * self.tau / (1. - self.tau) @property - def lock_gmm1(self): + def lock_gmm1(self): # noqa: D102 return self._lock_gmm1 def get_bures_geometry(self) -> pointcloud.PointCloud: diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index a35dddae1..376483756 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -57,36 +57,40 @@ def from_probs(cls, probs: jnp.ndarray) -> 'Probabilities': return cls(params=log_probs_normalized) @property - def params(self): + def params(self): # noqa: D102 return self._params @property - def dtype(self): + def dtype(self): # noqa: D102 return self._params.dtype def unnormalized_log_probs(self) -> jnp.ndarray: + """Get the unnormalized log probabilities.""" return jnp.concatenate([self._params, jnp.zeros((1,), dtype=self.dtype)], axis=-1) def log_probs(self) -> jnp.ndarray: + """Get the log probabilities.""" return jax.nn.log_softmax(self.unnormalized_log_probs()) def probs(self) -> jnp.ndarray: + """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: + """Sample from the distribution.""" return jax.random.categorical( key=key, logits=self.unnormalized_log_probs(), shape=(size,) ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.params,) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index 6b2e98af2..c88379b9d 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -196,13 +196,13 @@ def transport( m = self.gaussian_map(dest_scale) return (m @ points.T).T - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.params,) aux_data = {'size': self.size} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 45a1282eb..efcb2fe01 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -27,7 +27,7 @@ ] -class SinkhornDivergenceOutput(NamedTuple): +class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 divergence: float potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] diff --git a/src/ott/utils.py b/src/ott/utils.py index 36e6d4379..d0fec58f6 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -32,7 +32,7 @@ def register_pytree_node(cls: type) -> type: return cls -def deprecate( +def deprecate( # noqa: D103 *, version: Optional[str] = None, alt: Optional[str] = None, @@ -55,6 +55,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: def is_jax_array(obj: Any) -> bool: + """Check if an object is a Jax array.""" if hasattr(jax, "Array"): # https://jax.readthedocs.io/en/latest/jax_array_migration.html return isinstance(obj, (jax.Array, jnp.DeviceArray))