From 183716a082b50cfc41ba38751f55779a786a181a Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 10:07:20 +0100 Subject: [PATCH 1/9] fix D101 and B028 (stack level for warnings) introduced in #219 --- .flake8 | 4 ++-- src/ott/solvers/linear/discrete_barycenter.py | 2 +- src/ott/solvers/nn/neuraldual.py | 7 ++++--- src/ott/tools/sinkhorn_divergence.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.flake8 b/.flake8 index d5e2e12b0..d252de535 100644 --- a/.flake8 +++ b/.flake8 @@ -42,8 +42,6 @@ ignore = # Missing blank line before section D411 # TODO(michalk8): fix D10{1,2,3} - # D101 Missing docstring in public class - D101 # Missing docstring in public method D102 # Missing docstring in public function @@ -53,3 +51,5 @@ exclude = .git,__pycache__,build,docs/_build,dist per-file-ignores = tests/*: D,C408 */__init__.py: F401 + examples/*: D101 + docs/*: D101 diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index a181d6d12..66f7a8267 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -26,7 +26,7 @@ __all__ = ["SinkhornBarycenterOutput", "discrete_barycenter"] -class SinkhornBarycenterOutput(NamedTuple): +class SinkhornBarycenterOutput(NamedTuple): # noqa: D101 f: jnp.ndarray g: jnp.ndarray histogram: jnp.ndarray diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index 1b8bc0058..ac9c24639 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -175,13 +175,13 @@ def setup( if isinstance( neural_f, models.ICNN ) and neural_f.pos_weights is not self.pos_weights: - warnings.warn(warn_str) + warnings.warn(warn_str, stacklevel=2) neural_f.pos_weights = self.pos_weights if isinstance( neural_g, models.ICNN ) and neural_g.pos_weights is not self.pos_weights: - warnings.warn(warn_str) + warnings.warn(warn_str, stacklevel=2) neural_g.pos_weights = self.pos_weights self.state_f = neural_f.create_train_state( @@ -206,7 +206,8 @@ def setup( else: if self.parallel_updates: warnings.warn( - 'parallel_updates set to True but disabling it because num_inner_iters>1' + 'parallel_updates set to True but disabling it because num_inner_iters>1', + stacklevel=2 ) if self.back_and_forth: raise NotImplementedError( diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index 9d4da7f93..cd1cd475e 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -27,7 +27,7 @@ ] -class SinkhornDivergenceOutput(NamedTuple): +class SinkhornDivergenceOutput(NamedTuple): # noqa: D101 divergence: float potentials: Tuple[List[jnp.ndarray], List[jnp.ndarray], List[jnp.ndarray]] geoms: Tuple[geometry.Geometry, geometry.Geometry, geometry.Geometry] From 7946c2a738b2c41508a4fb4de2d503a88e678833 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 10:08:00 +0100 Subject: [PATCH 2/9] add vscode to gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d8f1864bd..397ee11b6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# vscode +.vscode/ + # generated documentation docs/html **/_autosummary From aa45bde0a9e855251c9029711c18abdd676660ec Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 10:59:25 +0100 Subject: [PATCH 3/9] fix D102 on costs.py and try to simplify inheritance methods --- .flake8 | 6 +-- src/ott/geometry/costs.py | 92 +++++++++++++++++---------------------- 2 files changed, 41 insertions(+), 57 deletions(-) diff --git a/.flake8 b/.flake8 index d252de535..2922411b9 100644 --- a/.flake8 +++ b/.flake8 @@ -42,8 +42,6 @@ ignore = # Missing blank line before section D411 # TODO(michalk8): fix D10{1,2,3} - # Missing docstring in public method - D102 # Missing docstring in public function D103 exclude = .git,__pycache__,build,docs/_build,dist @@ -51,5 +49,5 @@ exclude = .git,__pycache__,build,docs/_build,dist per-file-ignores = tests/*: D,C408 */__init__.py: F401 - examples/*: D101 - docs/*: D101 + examples/*: D101, D102 + docs/*: D101, D102 diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 09838a4ef..cd6d56c18 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -126,6 +126,7 @@ class TICost(CostFn): strictly convex, as well as provide the Legendre transform of :math:`h`, whose gradient is necessarily the inverse of the gradient of :math:`h`. """ + p = 1.0 @abc.abstractmethod def h(self, z: jnp.ndarray) -> float: @@ -139,6 +140,16 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) + def tree_flatten(self): + """Tree flatten.""" + return (), (self.p,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Tree unflatten.""" + del children + return cls(aux_data[0]) + @jax.tree_util.register_pytree_node_class class SqPNorm(TICost): @@ -154,7 +165,7 @@ def __init__(self, p: float): self.p = p self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * jnp.linalg.norm(z, self.p) ** 2 def h_legendre(self, z: jnp.ndarray) -> float: @@ -164,14 +175,6 @@ def h_legendre(self, z: jnp.ndarray) -> float: """ return 0.5 * jnp.linalg.norm(z, self.q) ** 2 - def tree_flatten(self): - return (), (self.p,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(aux_data[0]) - @jax.tree_util.register_pytree_node_class class PNormP(TICost): @@ -188,21 +191,13 @@ def __init__(self, p: float): self.p = p self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.linalg.norm(z, self.p) ** self.p / self.p - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`" return jnp.linalg.norm(z, self.q) ** self.q / self.q - def tree_flatten(self): - return (), (self.p,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(aux_data[0]) - @jax.tree_util.register_pytree_node_class class Euclidean(CostFn): @@ -231,10 +226,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute minus twice the dot-product between vectors.""" return -2. * jnp.vdot(x, y) - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sum(z ** 2) - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.25 * jnp.sum(z ** 2) def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: @@ -279,6 +274,8 @@ class RegTICost(TICost, abc.ABC): where :func:`reg` is the regularization function. """ + gamma = 0 + @abc.abstractmethod def reg(self, z: jnp.ndarray) -> float: """Regularization function.""" @@ -294,6 +291,16 @@ def h_legendre(self, z: jnp.ndarray) -> float: q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) + def tree_flatten(self): + """Flatten tree.""" + return (), (self.gamma,) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Unflatten tree.""" + del children + return cls(*aux_data) + @jax.tree_util.register_pytree_node_class class ElasticL1(RegTICost): @@ -312,20 +319,12 @@ def __init__(self, gamma: float = 1.0): assert gamma >= 0, "Gamma must be non-negative." self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 return self.gamma * jnp.linalg.norm(z, ord=1) - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma) - def tree_flatten(self): - return (), (self.gamma,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(*aux_data) - @jax.tree_util.register_pytree_node_class class ElasticSTVS(RegTICost): @@ -348,22 +347,14 @@ def __init__(self, gamma: float = 1.0): assert gamma > 0, "Gamma must be positive." self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 u = jnp.arcsinh(jnp.abs(z) / (2 * self.gamma)) out = u - 0.5 * jnp.exp(-2.0 * u) return (self.gamma ** 2) * jnp.sum(out + 0.5) # make positive - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z - def tree_flatten(self): - return (), (self.gamma,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(*aux_data) - @jax.tree_util.register_pytree_node_class class ElasticSqKOverlap(RegTICost): @@ -388,7 +379,7 @@ def __init__(self, k: int, gamma: float = 1.0): self.k = k self.gamma = gamma - def reg(self, z: jnp.ndarray) -> float: + def reg(self, z: jnp.ndarray) -> float: # noqa: D102 # Prop 2.1 in :cite:`argyriou:12` k = self.k top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values @@ -409,7 +400,7 @@ def reg(self, z: jnp.ndarray) -> float: return 0.5 * self.gamma * (s + (r + 1) * cesaro[r] ** 2) - def prox_reg(self, z: jnp.ndarray) -> float: + def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 @functools.partial(jax.vmap, in_axes=[0, None, None]) def find_indices(r: int, l: jnp.ndarray, @@ -454,14 +445,9 @@ def inner(r: int, l: int, # change sign and reorder return sgn * q[jnp.argsort(z_ixs.astype(float))] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self.k, self.gamma) - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(*aux_data) - @jax.tree_util.register_pytree_node_class class Bures(CostFn): @@ -606,11 +592,11 @@ def _padder(cls, dim: int) -> jnp.ndarray: ) return padding[jnp.newaxis, :] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self._dimension, self._sqrtm_kw) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children return cls(aux_data[0], **aux_data[1]) @@ -718,11 +704,11 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: (sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del children dim, sigma, gamma, kwargs = aux_data return cls(dim, sigma=sigma, gamma=gamma, **kwargs) From 987b31b148fc0b8ee97d911bd0288de6cc35f9a4 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 14:13:26 +0100 Subject: [PATCH 4/9] try to fix D102 everywhere --- .flake8 | 1 + src/ott/geometry/costs.py | 8 +--- src/ott/geometry/epsilon_scheduler.py | 6 ++- src/ott/geometry/geometry.py | 4 +- src/ott/geometry/graph.py | 26 ++++++++---- src/ott/geometry/grid.py | 12 +++--- src/ott/geometry/low_rank.py | 20 +++++----- src/ott/geometry/pointcloud.py | 34 ++++++++-------- src/ott/initializers/linear/initializers.py | 2 +- .../initializers/linear/initializers_lr.py | 20 +++++----- src/ott/initializers/nn/initializers.py | 4 +- .../initializers/quadratic/initializers.py | 2 +- src/ott/math/decomposition.py | 4 +- src/ott/problems/linear/barycenter_problem.py | 4 +- src/ott/problems/linear/linear_problem.py | 4 +- src/ott/problems/linear/potentials.py | 10 ++--- src/ott/problems/nn/dataset.py | 6 ++- src/ott/problems/quadratic/gw_barycenter.py | 6 +-- .../problems/quadratic/quadratic_problem.py | 4 +- src/ott/solvers/linear/acceleration.py | 3 +- .../solvers/linear/continuous_barycenter.py | 17 +++++++- src/ott/solvers/linear/sinkhorn.py | 26 ++++++------ src/ott/solvers/linear/sinkhorn_lr.py | 40 +++++++++++++------ src/ott/solvers/nn/conjugate_solvers.py | 2 +- src/ott/solvers/nn/models.py | 8 ++-- src/ott/solvers/nn/neuraldual.py | 2 +- .../solvers/quadratic/gromov_wasserstein.py | 2 +- src/ott/solvers/quadratic/gw_barycenter.py | 5 ++- src/ott/solvers/was_solver.py | 4 +- src/ott/tools/gaussian_mixture/gaussian.py | 10 ++++- .../gaussian_mixture/gaussian_mixture.py | 16 ++++++-- .../gaussian_mixture/gaussian_mixture_pair.py | 14 +++---- .../tools/gaussian_mixture/probabilities.py | 12 ++++-- src/ott/tools/gaussian_mixture/scale_tril.py | 4 +- 34 files changed, 202 insertions(+), 140 deletions(-) diff --git a/.flake8 b/.flake8 index 2922411b9..6d5729d78 100644 --- a/.flake8 +++ b/.flake8 @@ -51,3 +51,4 @@ per-file-ignores = */__init__.py: F401 examples/*: D101, D102 docs/*: D101, D102 + src/ott/types.py: D102 diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index cd6d56c18..9fc4a87af 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -141,12 +141,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: return self.h(x - y) def tree_flatten(self): - """Tree flatten.""" return (), (self.p,) @classmethod def tree_unflatten(cls, aux_data, children): - """Tree unflatten.""" del children return cls(aux_data[0]) @@ -291,13 +289,11 @@ def h_legendre(self, z: jnp.ndarray) -> float: q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) - def tree_flatten(self): - """Flatten tree.""" + def tree_flatten(self): #noqa: D102 return (), (self.gamma,) @classmethod - def tree_unflatten(cls, aux_data, children): - """Unflatten tree.""" + def tree_unflatten(cls, aux_data, children): #noqa: D102 del children return cls(*aux_data) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index e1dd3ad8f..aa2cc713a 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -71,18 +71,20 @@ def at(self, iteration: Optional[int] = 1) -> float: return multiple * self.target def done(self, eps: float) -> bool: + """Return whether the scheduler is done at a given value.""" return eps == self.target def done_at(self, iteration: Optional[int]) -> bool: + """Return whether the scheduler is done at a given iteration.""" return self.done(self.at(iteration)) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._target_init, self._scale_epsilon, self._init, self._decay ), None @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del aux_data return cls(*children) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index b588e3aba..fce21a57c 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -900,7 +900,7 @@ def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]], assert mask.shape == (size,) return mask - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._cost_matrix, self._kernel_matrix, self._epsilon_init, self._relative_epsilon, self._scale_epsilon, self._src_mask, @@ -910,7 +910,7 @@ def tree_flatten(self): } @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 *args, kwargs = children return cls(*args, **kwargs, **aux_data) diff --git a/src/ott/geometry/graph.py b/src/ott/geometry/graph.py index 5fe7b144f..4b3657859 100644 --- a/src/ott/geometry/graph.py +++ b/src/ott/geometry/graph.py @@ -84,6 +84,16 @@ def apply_kernel( eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: + r"""Apply :attr:`kernel_matrix` on positive scaling vector. + + Args: + scaling: Scaling to apply the kernel to. + eps: passed for consistency, not used yet. + axis: passed for consistency, not used yet. + + Returns: + Kernel applied to ``scaling``. + """ def conf_fn( iteration: int, solver_lap: Tuple[decomposition.CholeskySolver, @@ -145,7 +155,7 @@ def body_fn( state=state, )[1] - def apply_transport_from_scalings( + def apply_transport_from_scalings( # noqa: D102 self, u: jnp.ndarray, v: jnp.ndarray, @@ -171,7 +181,7 @@ def body_fn(carry: None, vec: jnp.ndarray) -> jnp.ndarray: return res @property - def kernel_matrix(self) -> jnp.ndarray: + def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 n, _ = self.shape kernel = self.apply_kernel(jnp.eye(n)) # force symmetry because of numerical imprecisions @@ -179,7 +189,7 @@ def kernel_matrix(self) -> jnp.ndarray: return (kernel + kernel.T) * .5 @property - def cost_matrix(self) -> jnp.ndarray: + def cost_matrix(self) -> jnp.ndarray: # noqa: D102 return -self.t * mu.safe_log(self.kernel_matrix) @property @@ -274,7 +284,7 @@ def solver(self) -> decomposition.CholeskySolver: return self._solver @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 arr = self._graph if self._graph is not None else self._lap return arr.shape @@ -295,12 +305,12 @@ def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]: return (self._graph + self._graph.T) if self.directed else self._graph @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 # there may be some numerical imprecisions, but it should be symmetric return True @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self._graph.dtype # TODO(michalk8): in future, use mixins for lse/kernel mode @@ -330,7 +340,7 @@ def marginal_from_potentials( """Not implemented.""" raise ValueError("Not implemented.") - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self._graph, self._lap, self.solver], { "t": self._t, "n_steps": self.n_steps, @@ -342,7 +352,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: } @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "Graph": graph, laplacian, solver = children diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 3bb4a04fd..753f7e299 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -132,15 +132,15 @@ def median_cost_matrix(self) -> NoReturn: raise NotImplementedError('Median cost not implemented for grids.') @property - def can_LRC(self) -> bool: + def can_LRC(self) -> bool: # noqa: D102 return True @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 return self.num_a, self.num_a @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return True # Reimplemented functions to be used in regularized OT @@ -341,14 +341,14 @@ def prepare_divergences( return tuple(sep_grid for _ in range(size)) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self.x[0].dtype - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (self.x, self.cost_fns, self._epsilon), self.kwargs @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls( x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data ) diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 9da672fe3..1e84aea27 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -91,7 +91,7 @@ def bias(self) -> float: return self._bias * self.inv_scale_cost @property - def cost_rank(self) -> int: + def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property @@ -100,18 +100,18 @@ def cost_matrix(self) -> jnp.ndarray: return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 return self._cost_1.shape[0], self._cost_2.shape[0] @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return ( self._cost_1.shape[0] == self._cost_2.shape[0] and jnp.all(self._cost_1 == self._cost_2) ) @property - def inv_scale_cost(self) -> float: + def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost @@ -239,10 +239,10 @@ def to_LRCGeometry( return self @property - def can_LRC(self): + def can_LRC(self): # noqa: D102 return True - def subset( + def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "LRCGeometry": @@ -257,7 +257,7 @@ def subset_fn( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs ) - def mask( + def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], @@ -311,10 +311,10 @@ def __add__(self, other: 'LRCGeometry') -> 'LRCGeometry': ) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self._cost_1.dtype - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ( self._cost_1, self._cost_2, self._src_mask, self._tgt_mask, self._kwargs ), { @@ -325,7 +325,7 @@ def tree_flatten(self): } @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 c1, c2, src_mask, tgt_mask, kwargs = children return cls( c1, c2, src_mask=src_mask, tgt_mask=tgt_mask, **kwargs, **aux_data diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index c2abe6bc9..6bc3b24c6 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -99,7 +99,7 @@ def _norm_y(self) -> Union[float, jnp.ndarray]: return 0. @property - def can_LRC(self): + def can_LRC(self): # noqa: D102 return self.is_squared_euclidean and self._check_LRC_dim @property @@ -108,20 +108,20 @@ def _check_LRC_dim(self): return n * m > (n + m) * d @property - def cost_matrix(self) -> Optional[jnp.ndarray]: + def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None cost_matrix = self._compute_cost_matrix() return cost_matrix * self.inv_scale_cost @property - def kernel_matrix(self) -> Optional[jnp.ndarray]: + def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102 if self.is_online: return None return jnp.exp(-self.cost_matrix / self.epsilon) @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> Tuple[int, int]: # noqa: D102 # in the process of flattening/unflattening in vmap, `__init__` # can be called with dummy objects # we optionally access `shape` in order to get the batch size @@ -130,13 +130,13 @@ def shape(self) -> Tuple[int, int]: return self.x.shape[0], self.y.shape[0] @property - def is_symmetric(self) -> bool: + def is_symmetric(self) -> bool: # noqa: D102 return self.y is None or ( jnp.all(self.x.shape == self.y.shape) and jnp.all(self.x == self.y) ) @property - def is_squared_euclidean(self) -> bool: + def is_squared_euclidean(self) -> bool: # noqa: D102 return isinstance(self.cost_fn, costs.SqEuclidean) @property @@ -147,11 +147,11 @@ def is_online(self) -> bool: # TODO(michalk8): when refactoring, consider PC as a subclass of LR? @property - def cost_rank(self) -> int: + def cost_rank(self) -> int: # noqa: D102 return self.x.shape[1] @property - def inv_scale_cost(self) -> float: + def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost @@ -201,7 +201,7 @@ def _compute_cost_matrix(self) -> jnp.ndarray: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] return cost_matrix - def apply_lse_kernel( + def apply_lse_kernel( # noqa: D102 self, f: jnp.ndarray, g: jnp.ndarray, @@ -288,7 +288,7 @@ def finalize(i: int): return eps * h_res - jnp.where(jnp.isfinite(v), v, 0), h_sign - def apply_kernel( + def apply_kernel( # noqa: D102 self, scaling: jnp.ndarray, eps: Optional[float] = None, @@ -315,7 +315,7 @@ def apply_kernel( self.cost_fn, self.inv_scale_cost ) - def transport_from_potentials( + def transport_from_potentials( # noqa: D102 self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: if not self.is_online: @@ -329,7 +329,7 @@ def transport_from_potentials( self.cost_fn, self.inv_scale_cost ) - def transport_from_scalings( + def transport_from_scalings( # noqa: D102 self, u: jnp.ndarray, v: jnp.ndarray ) -> jnp.ndarray: if not self.is_online: @@ -569,7 +569,7 @@ def prepare_divergences( for ((x, y), (x_mask, y_mask)) in zip(couples, masks) ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ([self.x, self.y, self._src_mask, self._tgt_mask, self.cost_fn], { 'epsilon': self._epsilon_init, 'relative_epsilon': self._relative_epsilon, @@ -579,7 +579,7 @@ def tree_flatten(self): }) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 x, y, src_mask, tgt_mask, cost_fn = children return cls( x, y, cost_fn=cost_fn, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data @@ -649,7 +649,7 @@ def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry: **self._kwargs ) - def subset( + def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "PointCloud": @@ -664,7 +664,7 @@ def subset_fn( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs ) - def mask( + def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], @@ -710,7 +710,7 @@ def _mask_subset_helper( ) @property - def dtype(self) -> jnp.dtype: + def dtype(self) -> jnp.dtype: # noqa: D102 return self.x.dtype @property diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index 8e7124461..e80079c95 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -198,7 +198,7 @@ def __init__( self.max_iter = max_iter self.vectorized_update = vectorized_update - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([], { 'tolerance': self.tolerance, 'max_iter': self.max_iter, diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 3d1bdf885..f13d9c3ca 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -224,7 +224,7 @@ class RandomInitializer(LRInitializer): kwargs: Additional keyword arguments. """ - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -237,7 +237,7 @@ def init_q( init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank))) return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True)) - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -250,7 +250,7 @@ def init_r( init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank))) return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True)) - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -298,7 +298,7 @@ def _compute_factor( return ((lambda_1 * x[:, None] @ g1.reshape(1, -1)) + ((1 - lambda_1) * y[:, None] @ g2.reshape(1, -1))) - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -309,7 +309,7 @@ def init_q( del key, kwargs return self._compute_factor(ot_prob, init_g, which="q") - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -320,7 +320,7 @@ def init_r( del key, kwargs return self._compute_factor(ot_prob, init_g, which="r") - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -409,7 +409,7 @@ def _compute_factor( solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs) return solver(prob).matrix - def init_q( + def init_q( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -421,7 +421,7 @@ def init_q( ot_prob, key, init_g=init_g, which="q", **kwargs ) - def init_r( + def init_r( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -433,7 +433,7 @@ def init_r( ot_prob, key, init_g=init_g, which="r", **kwargs ) - def init_g( + def init_g( # noqa: D102 self, ot_prob: Problem_t, key: jnp.ndarray, @@ -442,7 +442,7 @@ def init_g( del key, kwargs return jnp.ones((self.rank,)) / self.rank - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["sinkhorn_kwargs"] = self._sinkhorn_kwargs aux_data["min_iterations"] = self._min_iter diff --git a/src/ott/initializers/nn/initializers.py b/src/ott/initializers/nn/initializers.py index f87dce7d0..3e040addf 100644 --- a/src/ott/initializers/nn/initializers.py +++ b/src/ott/initializers/nn/initializers.py @@ -122,7 +122,7 @@ def update( """ return self.update_impl(state, a, b) - def init_dual_a( + def init_dual_a( # noqa: D102 self, ot_prob: 'linear_problem.LinearProblem', lse_mode: bool ) -> jnp.ndarray: # Detect if the problem is batched. @@ -186,7 +186,7 @@ def _compute_f(self, a, b, params): """ return self.meta_model.apply({'params': params}, a, b) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.geom, self.meta_model, self.opt], { 'rng': self.rng, 'state': self.state diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index 70a50af4b..cd7f361d0 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -189,6 +189,6 @@ def rank(self) -> int: """Rank of the transport matrix factorization.""" return self._linear_lr_initializer.rank - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() return children + [self._linear_lr_initializer], aux_data diff --git a/src/ott/math/decomposition.py b/src/ott/math/decomposition.py index e05888e81..593727afb 100644 --- a/src/ott/math/decomposition.py +++ b/src/ott/math/decomposition.py @@ -127,7 +127,7 @@ def _decompose(self, A: T) -> Optional[T]: def _solve(self, L: Optional[T], b: jnp.ndarray) -> jnp.ndarray: return jsp.linalg.solve_triangular(L, b, lower=self._lower) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["lower"] = self._lower return children, aux_data @@ -205,7 +205,7 @@ def clear_factor_cache(cls) -> None: def __hash__(self) -> int: return object.__hash__(self) if self._key is None else self._key - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() return children, { **aux_data, "beta": self._beta, diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index bcfbc30cd..e995a8e64 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -165,7 +165,7 @@ def weights(self) -> jnp.ndarray: def _is_segmented(self) -> bool: return self._y.ndim == 3 - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self._y, self._b, self._weights], { 'cost_fn': self.cost_fn, 'epsilon': self.epsilon, @@ -174,7 +174,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "BarycenterProblem": y, b, weights = children diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 9ff72aa35..3b1acfeb9 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -107,14 +107,14 @@ def get_transport_functions( ) return marginal_a, marginal_b, app_transport - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.geom, self._a, self._b], { 'tau_a': self.tau_a, 'tau_b': self.tau_b }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "LinearProblem": return cls(*children, **aux_data) diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index b12a41a90..a6b65025d 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -147,7 +147,7 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]: ) return jax.vmap(jax.grad(self.cost_fn.h_legendre)) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], { "f": self._f, "g": self._g, @@ -156,7 +156,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: } @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "DualPotentials": return cls(*children, **aux_data) @@ -334,11 +334,11 @@ def __init__( self._g_yy = g_yy @property - def f(self) -> Potential_t: + def f(self) -> Potential_t: # noqa: D102 return self._potential_fn(kind="f") @property - def g(self) -> Potential_t: + def g(self) -> Potential_t: # noqa: D102 return self._potential_fn(kind="g") def _potential_fn(self, *, kind: Literal["f", "g"]) -> Potential_t: @@ -401,5 +401,5 @@ def epsilon(self) -> float: """Entropy regularizer.""" return self._prob.geom.epsilon - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self._f, self._g, self._prob, self._f_xx, self._g_yy], {} diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index 8810ae701..ec89eb1d0 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -88,7 +88,11 @@ def __iter__(self) -> Iterator[jnp.array]: return self.create_sample_generators() def create_sample_generators(self) -> Iterator[jnp.array]: - # create generator which randomly picks center and adds noise + """Creates a generator of samples from the Gaussian mixture. + + Returns: + A generator of samples from the Gaussian mixture. + """ key = self.init_key while True: k1, k2, key = jax.random.split(key, 3) diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index 1eac1d330..ed3ad65c1 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -269,7 +269,7 @@ def segmented_y_fused(self) -> Optional[jnp.ndarray]: return y_fused @property - def ndim(self) -> Optional[int]: + def ndim(self) -> Optional[int]: # noqa: D102 return None if self._y_as_costs else self._y.shape[-1] @property @@ -292,7 +292,7 @@ def gw_loss(self) -> quadratic_costs.GWLoss: f"Loss `{self._loss_name}` is not yet implemented." ) - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 (y, b, weights), aux = super().tree_flatten() if self._y_as_costs: children = [None, b, weights, y] @@ -304,7 +304,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: return children + [self._y_fused], aux @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GWBarycenterProblem": y, b, weights, costs, y_fused = children diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index face28f37..8510310bb 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -459,7 +459,7 @@ def is_balanced(self) -> bool: return ((not self.gw_unbalanced_correction) or (self.tau_a == 1.0 and self.tau_b == 1.0)) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return ([self.geom_xx, self.geom_yy, self.geom_xy, self._a, self._b], { 'tau_a': self.tau_a, 'tau_b': self.tau_b, @@ -472,7 +472,7 @@ def tree_flatten(self): }) @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 geoms, (a, b) = children[:3], children[3:] return cls(*geoms, a=a, b=b, **aux_data) diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 802af6947..76a72b83d 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -103,6 +103,7 @@ def init_maps( def update_history( self, state: 'sinkhorn.SinkhornState', pb, lse_mode: bool ) -> 'sinkhorn.SinkhornState': + """Update history of mapped dual variables.""" f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) return state.set(old_mapped_fus=mapped) @@ -141,7 +142,7 @@ def lehmann(self, state: 'sinkhorn.SinkhornState') -> float: power = 1.0 / self.inner_iterations return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) - def __call__( + def __call__( # noqa: D102 self, weight: float, value: jnp.ndarray, diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 556192ad5..26891b4a3 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -55,6 +55,17 @@ def update( self, iteration: int, bar_prob: barycenter_problem.BarycenterProblem, linear_ot_solver: Any, store_errors: bool ) -> 'BarycenterState': + """Update the state of the solver. + + Args: + iteration: the current iteration of the outer loop. + bar_prob: the barycenter problem. + linear_ot_solver: the linear OT solver to use. + store_errors: whether to store the errors of the inner loop. + + Returns: + The updated state. + """ seg_y, seg_b = bar_prob.segmented_y_b @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) @@ -120,7 +131,7 @@ def solve_linear_ot( class WassersteinBarycenter(was_solver.WassersteinSolver): """A Continuous Wasserstein barycenter solver, built on generic template.""" - def __call__( + def __call__( # noqa: D102 self, bar_prob: barycenter_problem.BarycenterProblem, bar_size: int = 100, @@ -181,7 +192,9 @@ def init_state( -jnp.ones((num_iter,)), -jnp.ones((num_iter,)), errors, x, a ) - def output_from_state(self, state: BarycenterState) -> BarycenterState: + def output_from_state( # noqa: D102 + self, state: BarycenterState + ) -> BarycenterState: # TODO(michalk8): create an output variable to match rest of the framework return state diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index f91e3cc4b..901bcc95d 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -291,7 +291,7 @@ def set(self, **kwargs: Any) -> 'SinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) - def set_cost( + def set_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool ) -> 'SinkhornOutput': @@ -337,27 +337,27 @@ def transport_cost_at_geom( return jnp.sum(self.matrix * other_geom.cost_matrix) @property - def linear(self) -> bool: + def linear(self) -> bool: # noqa: D102 return isinstance(self.ot_prob, linear_problem.LinearProblem) @property - def geom(self) -> geometry.Geometry: + def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: + def linear_output(self) -> bool: # noqa: D102 return True @property - def converged(self) -> bool: + def converged(self) -> bool: # noqa: D102 if self.errors is None: return False return jnp.logical_and( @@ -366,13 +366,13 @@ def converged(self) -> bool: # TODO(michalk8): this should be always present @property - def n_iters(self) -> int: + def n_iters(self) -> int: # noqa: D102 if self.errors is None: return -1 return jnp.sum(self.errors > -1) @property - def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102 u = self.ot_prob.geom.scaling_from_potential(self.f) v = self.ot_prob.geom.scaling_from_potential(self.g) return u, v @@ -396,7 +396,7 @@ def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: self.f, self.g, inputs, axis=axis ) - def marginal(self, axis: int) -> jnp.ndarray: + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis) def cost_at_geom(self, other_geom: geometry.Geometry) -> float: @@ -985,7 +985,7 @@ def norm_error(self) -> Tuple[int, ...]: return self._norm_error, # TODO(michalk8): in the future, enforce this (+ in GW) via abstract method - def create_initializer(self) -> init_lib.SinkhornInitializer: + def create_initializer(self) -> init_lib.SinkhornInitializer: # noqa: D102 if isinstance(self.initializer, init_lib.SinkhornInitializer): return self.initializer if self.initializer == "default": @@ -998,14 +998,14 @@ def create_initializer(self) -> init_lib.SinkhornInitializer: f"Initializer `{self.initializer}` is not yet implemented." ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 aux = vars(self).copy() aux['norm_error'] = aux.pop('_norm_error') aux.pop('threshold') return [self.threshold], aux @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(**aux_data, threshold=children[0]) diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 22e542b06..87e51423c 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -71,6 +71,19 @@ def compute_reg_ot_cost( ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: + """Compute the regularized OT cost. + + Args: + q: first factor of solution + r: second factor of solution + g: weights of solution + ot_prob: linear problem + use_danskin: if True, use Danskin's trick to avoid computing the gradient of + the cost function. + + Returns: + regularized OT cost + """ q = jax.lax.stop_gradient(q) if use_danskin else q r = jax.lax.stop_gradient(r) if use_danskin else r g = jax.lax.stop_gradient(g) if use_danskin else g @@ -86,9 +99,9 @@ def solution_error( Since only balanced case is available for LR, this is marginal deviation. Args: - q: first factor of solution - r: second factor of solution - ot_prob: linear problem + q: first factor of solution. + r: second factor of solution. + ot_prob: linear problem. norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector products. @@ -130,7 +143,7 @@ def set(self, **kwargs: Any) -> 'LRSinkhornOutput': """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) - def set_cost( + def set_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, @@ -139,7 +152,7 @@ def set_cost( del lse_mode return self.set(reg_ot_cost=self.compute_reg_ot_cost(ot_prob, use_danskin)) - def compute_reg_ot_cost( + def compute_reg_ot_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, use_danskin: bool = False, @@ -147,27 +160,27 @@ def compute_reg_ot_cost( return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) @property - def linear(self) -> bool: + def linear(self) -> bool: # noqa: D102 return isinstance(self.ot_prob, linear_problem.LinearProblem) @property - def geom(self) -> geometry.Geometry: + def geom(self) -> geometry.Geometry: # noqa: D102 return self.ot_prob.geom @property - def a(self) -> jnp.ndarray: + def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property - def b(self) -> jnp.ndarray: + def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property - def linear_output(self) -> bool: + def linear_output(self) -> bool: # noqa: D102 return True @property - def converged(self) -> bool: + def converged(self) -> bool: # noqa: D102 return jnp.logical_and( jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs)) ) @@ -183,7 +196,7 @@ def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T - def marginal(self, axis: int) -> jnp.ndarray: + def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis) @@ -369,6 +382,7 @@ def dykstra_update( inner_iter: int = 10, max_iter: int = 10000 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank n, m = ot_prob.geom.shape @@ -539,7 +553,7 @@ def one_iteration( ) @property - def norm_error(self) -> Tuple[int]: + def norm_error(self) -> Tuple[int]: # noqa: D102 return self._norm_error, @property diff --git a/src/ott/solvers/nn/conjugate_solvers.py b/src/ott/solvers/nn/conjugate_solvers.py index 6eb4d5f57..44d7a3d68 100644 --- a/src/ott/solvers/nn/conjugate_solvers.py +++ b/src/ott/solvers/nn/conjugate_solvers.py @@ -88,7 +88,7 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): decrease_factor: float = 0.66 ls_method: Literal['wolf', 'strong-wolfe'] = 'strong-wolfe' - def solve( + def solve( # noqa: D102 self, f: Callable[[jnp.ndarray], jnp.ndarray], y: jnp.ndarray, diff --git a/src/ott/solvers/nn/models.py b/src/ott/solvers/nn/models.py index e96c70163..a69772ee9 100644 --- a/src/ott/solvers/nn/models.py +++ b/src/ott/solvers/nn/models.py @@ -158,10 +158,10 @@ class ICNN(ModelBase): gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None @property - def is_potential(self) -> bool: + def is_potential(self) -> bool: # noqa: D102 return True - def setup(self) -> None: + def setup(self) -> None: # noqa: D102 self.num_hidden = len(self.dim_hidden) if self.pos_weights: @@ -277,7 +277,7 @@ def _compute_identity_map(input_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: return A, b @nn.compact - def __call__(self, x: jnp.ndarray) -> float: + def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 z = self.act_fn(self.w_xs[0](x)) for i in range(self.num_hidden): z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x)) @@ -320,7 +320,7 @@ class MLP(ModelBase): act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 squeeze = x.ndim == 1 if squeeze: x = jnp.expand_dims(x, 0) diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index 5e0fb5cd1..4630bfc88 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -220,7 +220,7 @@ def setup( self.valid_step_g = self.get_step_fn(train=False, to_optimize="g") self.train_fn = self.train_neuraldual_alternating - def __call__( + def __call__( # noqa: D102 self, trainloader_source: Iterable[jnp.ndarray], trainloader_target: Iterable[jnp.ndarray], diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index b105c6c81..c171f32ad 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -338,7 +338,7 @@ def warm_start(self) -> bool: """Whether to initialize (low-rank) Sinkhorn using previous solutions.""" return self.is_low_rank if self._warm_start is None else self._warm_start - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["warm_start"] = self._warm_start aux_data["quad_initializer"] = self.quad_initializer diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index d382d0ebc..5c769a994 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -191,6 +191,7 @@ def update_state( problem: gw_barycenter.GWBarycenterProblem, store_errors: bool = True, ) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]: + """Solve the (fused) Gromov-Wasserstein barycenter problem.""" def solve_gw( state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray, @@ -239,12 +240,12 @@ def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState: # will be refactored in the future to create an output return state - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux = super().tree_flatten() return children + [self._quad_solver], aux @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "GromovWassersteinBarycenter": epsilon, _, threshold, quad_solver = children diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index e392f7bd5..c3ef3a07c 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -83,7 +83,7 @@ def is_low_rank(self) -> bool: """Whether the solver is low-rank.""" return self.rank > 0 - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.epsilon, self.linear_ot_solver, self.threshold], { "min_iterations": self.min_iterations, "max_iterations": self.max_iterations, @@ -94,7 +94,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: }) @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "WassersteinSolver": epsilon, linear_ot_solver, threshold = children diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index 10f74d423..6dfca0bc8 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -101,23 +101,29 @@ def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> 'Gaussian': @property def loc(self) -> jnp.ndarray: + """Mean of the Gaussian.""" return self._loc @property def scale(self) -> scale_tril.ScaleTriL: + """Scale of the Gaussian.""" return self._scale @property def n_dimensions(self) -> int: + """Dimensionality of the Gaussian.""" return self.loc.shape[-1] def covariance(self) -> jnp.ndarray: + """Covariance of the Gaussian.""" return self.scale.covariance() def to_z(self, x: jnp.ndarray) -> jnp.ndarray: + """Transform x to z = (x - loc) / scale.""" return self.scale.centered_to_z(x_centered=x - self.loc) def from_z(self, z: jnp.ndarray) -> jnp.ndarray: + """Transform z to x = loc + scale * z.""" return self.scale.z_to_centered(z=z) + self.loc def log_prob( @@ -197,13 +203,13 @@ def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: dest_scale=dest.scale, points=points - self.loc[None] ) + dest.loc[None] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.loc, self.scale) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __hash__(self): diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index fc692f943..f39cb87f9 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -141,26 +141,32 @@ def from_points_and_assignment_probs( @property def dtype(self): + """Get the dtype of the GMM parameters.""" return self.loc.dtype @property def n_dimensions(self): + """Get the number of dimensions of the GMM parameters.""" return self._loc.shape[-1] @property def n_components(self): + """Get the number of components of the GMM parameters.""" return self._loc.shape[-2] @property def loc(self) -> jnp.ndarray: + """Get the location parameters of the GMM.""" return self._loc @property def scale_params(self) -> jnp.ndarray: + """Get the scale parameters of the GMM.""" return self._scale_params @property def cholesky(self) -> jnp.ndarray: + """Get the Cholesky decomposition of the GMM covariance matrices.""" size = self.n_dimensions def _get_cholesky(scale_params): @@ -170,6 +176,7 @@ def _get_cholesky(scale_params): @property def covariance(self) -> jnp.ndarray: + """Get the covariance matrices of the GMM.""" size = self.n_dimensions def _get_covariance(scale_params): @@ -179,13 +186,16 @@ def _get_covariance(scale_params): @property def component_weight_ob(self) -> probabilities.Probabilities: + """Get the component weight object.""" return self._component_weight_ob @property def component_weights(self) -> jnp.ndarray: + """Get the component weights probabilities.""" return self._component_weight_ob.probs() def log_component_weights(self) -> jnp.ndarray: + """Get the log component weights probabilities.""" return self._component_weight_ob.log_probs() def _get_normal( @@ -288,19 +298,19 @@ def get_log_component_posterior(self, x: jnp.ndarray) -> jnp.ndarray: log_prob_unnorm, axis=-1, keepdims=True ) - def has_nans(self) -> bool: + def has_nans(self) -> bool: # noqa: D102 for leaf in jax.tree_util.tree_leaves(self): if jnp.any(~jnp.isfinite(leaf)): return True return False - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.loc, self.scale_params, self.component_weight_ob) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index e1ec650f4..75b875924 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -85,31 +85,31 @@ def __init__( self._lock_gmm1 = lock_gmm1 @property - def dtype(self): + def dtype(self): # noqa: D102 return self.gmm0.dtype @property - def gmm0(self): + def gmm0(self): # noqa: D102 return self._gmm0 @property - def gmm1(self): + def gmm1(self): # noqa: D102 return self._gmm1 @property - def epsilon(self): + def epsilon(self): # noqa: D102 return self._epsilon @property - def tau(self): + def tau(self): # noqa: D102 return self._tau @property - def rho(self): + def rho(self): # noqa: D102 return self.epsilon * self.tau / (1. - self.tau) @property - def lock_gmm1(self): + def lock_gmm1(self): # noqa: D102 return self._lock_gmm1 def get_bures_geometry(self) -> pointcloud.PointCloud: diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index 13369024e..30c8e4f59 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -57,36 +57,40 @@ def from_probs(cls, probs: jnp.ndarray) -> 'Probabilities': return cls(params=log_probs_normalized) @property - def params(self): + def params(self): # noqa: D102 return self._params @property - def dtype(self): + def dtype(self): # noqa: D102 return self._params.dtype def unnormalized_log_probs(self) -> jnp.ndarray: + """Get the unnormalized log probabilities.""" return jnp.concatenate([self._params, jnp.zeros((1,), dtype=self.dtype)], axis=-1) def log_probs(self) -> jnp.ndarray: + """Get the log probabilities.""" return jax.nn.log_softmax(self.unnormalized_log_probs()) def probs(self) -> jnp.ndarray: + """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: + """Sample from the distribution.""" return jax.random.categorical( key=key, logits=self.unnormalized_log_probs(), shape=(size,) ) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.params,) aux_data = {} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index 52ed9a90b..697c3e045 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -196,13 +196,13 @@ def transport( m = self.gaussian_map(dest_scale) return (m @ points.T).T - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 children = (self.params,) aux_data = {'size': self.size} return children, aux_data @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) def __repr__(self): From 7d2e5f85d01a315c50550a4d39aa52f1cb5a490c Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 14:17:16 +0100 Subject: [PATCH 5/9] fix D103 --- .flake8 | 3 +-- src/ott/math/matrix_square_root.py | 4 ++-- src/ott/math/unbalanced_functions.py | 2 +- src/ott/math/utils.py | 20 +++++++++++++++++-- src/ott/problems/quadratic/quadratic_costs.py | 4 ++-- src/ott/utils.py | 3 ++- 6 files changed, 26 insertions(+), 10 deletions(-) diff --git a/.flake8 b/.flake8 index 6d5729d78..c1f22492e 100644 --- a/.flake8 +++ b/.flake8 @@ -43,12 +43,11 @@ ignore = D411 # TODO(michalk8): fix D10{1,2,3} # Missing docstring in public function - D103 exclude = .git,__pycache__,build,docs/_build,dist # C409: Unnecessary call - rewrite as a literal. per-file-ignores = tests/*: D,C408 */__init__.py: F401 - examples/*: D101, D102 + examples/*: D101, D102, D103 docs/*: D101, D102 src/ott/types.py: D102 diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 1db246d9f..3bc9c83aa 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -236,7 +236,7 @@ def sqrtm_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) -def sqrtm_only( +def sqrtm_only( # noqa: D103 x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, @@ -282,7 +282,7 @@ def sqrtm_only_bwd( @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5)) -def inv_sqrtm_only( +def inv_sqrtm_only( # noqa: D103 x: jnp.ndarray, threshold: float = 1e-6, min_iterations: int = 0, diff --git a/src/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py index 4a6aacff5..e55f6f81b 100644 --- a/src/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -87,5 +87,5 @@ def diag_jacobian_of_marginal_fit( ) -def rho(epsilon: float, tau: float) -> float: +def rho(epsilon: float, tau: float) -> float: # noqa: D103 return (epsilon * tau) / (1. - tau) diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index fef5ae667..865d09103 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -17,7 +17,11 @@ Sparse_t = Union[jesp.CSR, jesp.CSC, jesp.COO, jesp.BCOO] -def safe_log(x: jnp.ndarray, *, eps: Optional[float] = None) -> jnp.ndarray: +def safe_log( # noqa: D103 + x: jnp.ndarray, + *, + eps: Optional[float] = None +) -> jnp.ndarray: if eps is None: eps = jnp.finfo(x.dtype).tiny return jnp.where(x > 0., jnp.log(x), jnp.log(eps)) @@ -43,7 +47,9 @@ def sparse_scale(c: float, mat: Sparse_t) -> Sparse_t: @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4)) -def logsumexp(mat, axis=None, keepdims=False, b=None, return_sign=False): +def logsumexp( # noqa: D103 + mat, axis=None, keepdims=False, b=None, return_sign=False +): return jax.scipy.special.logsumexp( mat, axis=axis, keepdims=keepdims, b=b, return_sign=return_sign ) @@ -98,4 +104,14 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): def barycentric_projection( matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn" ) -> jnp.ndarray: + """Compute the barycentric projection of a matrix. + + Args: + matrix: a matrix of shape (n, m) + y: a vector of shape (m,) + cost_fn: a CostFn instance. + + Returns: + a vector of shape (n,) containing the barycentric projection of matrix. + """ return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) diff --git a/src/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py index 8de1b398d..2460dfbf7 100644 --- a/src/ott/problems/quadratic/quadratic_costs.py +++ b/src/ott/problems/quadratic/quadratic_costs.py @@ -18,7 +18,7 @@ class GWLoss(NamedTuple): h2: Loss -def make_square_loss() -> GWLoss: +def make_square_loss() -> GWLoss: # noqa: D103 f1 = Loss(lambda x: x ** 2, is_linear=False) f2 = Loss(lambda y: y ** 2, is_linear=False) h1 = Loss(lambda x: x, is_linear=True) @@ -26,7 +26,7 @@ def make_square_loss() -> GWLoss: return GWLoss(f1, f2, h1, h2) -def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: +def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss: # noqa: D103 f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False) f2 = Loss(lambda y: y, is_linear=True) h1 = Loss(lambda x: x, is_linear=True) diff --git a/src/ott/utils.py b/src/ott/utils.py index aac51b701..cc0b73f18 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -32,7 +32,7 @@ def register_pytree_node(cls: type) -> type: return cls -def deprecate( +def deprecate( # noqa: D103 *, version: Optional[str] = None, alt: Optional[str] = None, @@ -55,6 +55,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: def is_jax_array(obj: Any) -> bool: + """Check if an object is a Jax array.""" if hasattr(jax, "Array"): # https://jax.readthedocs.io/en/latest/jax_array_migration.html return isinstance(obj, (jax.Array, jnp.DeviceArray)) From f16a08d5510aececb68bb4b1925ac969b9c15576 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 14:18:23 +0100 Subject: [PATCH 6/9] remove comment from .flake8 --- .flake8 | 2 -- 1 file changed, 2 deletions(-) diff --git a/.flake8 b/.flake8 index c1f22492e..363a33365 100644 --- a/.flake8 +++ b/.flake8 @@ -41,8 +41,6 @@ ignore = E111, E114 # Missing blank line before section D411 - # TODO(michalk8): fix D10{1,2,3} - # Missing docstring in public function exclude = .git,__pycache__,build,docs/_build,dist # C409: Unnecessary call - rewrite as a literal. per-file-ignores = From f6c590dbe1e7c8427120a2a12a21fb208f682a49 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 15:40:55 +0100 Subject: [PATCH 7/9] address comments --- src/ott/problems/nn/dataset.py | 2 +- src/ott/solvers/linear/sinkhorn_lr.py | 4 ++-- .../gaussian_mixture/gaussian_mixture.py | 24 +++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index ec89eb1d0..a6208bcfa 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -88,7 +88,7 @@ def __iter__(self) -> Iterator[jnp.array]: return self.create_sample_generators() def create_sample_generators(self) -> Iterator[jnp.array]: - """Creates a generator of samples from the Gaussian mixture. + """Random sample generator from Gaussian mixture. Returns: A generator of samples from the Gaussian mixture. diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 87e51423c..a4976b5af 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -78,8 +78,8 @@ def compute_reg_ot_cost( r: second factor of solution g: weights of solution ot_prob: linear problem - use_danskin: if True, use Danskin's trick to avoid computing the gradient of - the cost function. + use_danskin: if True, use Danskin's theorem :cite:`danskin:67,bertsekas:71` + to avoid computing the gradient of the cost function. Returns: regularized OT cost diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index f39cb87f9..aeb3f47b8 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -141,32 +141,32 @@ def from_points_and_assignment_probs( @property def dtype(self): - """Get the dtype of the GMM parameters.""" + """Dtype of the GMM parameters.""" return self.loc.dtype @property def n_dimensions(self): - """Get the number of dimensions of the GMM parameters.""" + """Number of dimensions of the GMM parameters.""" return self._loc.shape[-1] @property def n_components(self): - """Get the number of components of the GMM parameters.""" + """Number of components of the GMM parameters.""" return self._loc.shape[-2] @property def loc(self) -> jnp.ndarray: - """Get the location parameters of the GMM.""" + """Location parameters of the GMM.""" return self._loc @property def scale_params(self) -> jnp.ndarray: - """Get the scale parameters of the GMM.""" + """Scale parameters of the GMM.""" return self._scale_params @property def cholesky(self) -> jnp.ndarray: - """Get the Cholesky decomposition of the GMM covariance matrices.""" + """Cholesky decomposition of the GMM covariance matrices.""" size = self.n_dimensions def _get_cholesky(scale_params): @@ -176,7 +176,7 @@ def _get_cholesky(scale_params): @property def covariance(self) -> jnp.ndarray: - """Get the covariance matrices of the GMM.""" + """Covariance matrices of the GMM.""" size = self.n_dimensions def _get_covariance(scale_params): @@ -186,16 +186,16 @@ def _get_covariance(scale_params): @property def component_weight_ob(self) -> probabilities.Probabilities: - """Get the component weight object.""" + """Component weight object.""" return self._component_weight_ob @property def component_weights(self) -> jnp.ndarray: - """Get the component weights probabilities.""" + """Component weights probabilities.""" return self._component_weight_ob.probs() def log_component_weights(self) -> jnp.ndarray: - """Get the log component weights probabilities.""" + """Log component weights probabilities.""" return self._component_weight_ob.log_probs() def _get_normal( @@ -207,13 +207,13 @@ def _get_normal( ) def get_component(self, index: int) -> gaussian.Gaussian: - """Get the specified GMM component.""" + """Specified GMM component.""" return self._get_normal( loc=self.loc[index], scale_params=self.scale_params[index] ) def components(self) -> List[gaussian.Gaussian]: - """Get a list of all GMM components.""" + """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: From 6be84472408faa8b42c3a3fab07599f193ddcf2b Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 15:42:30 +0100 Subject: [PATCH 8/9] more comments --- src/ott/geometry/costs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 9fc4a87af..38e23fc85 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -126,7 +126,6 @@ class TICost(CostFn): strictly convex, as well as provide the Legendre transform of :math:`h`, whose gradient is necessarily the inverse of the gradient of :math:`h`. """ - p = 1.0 @abc.abstractmethod def h(self, z: jnp.ndarray) -> float: @@ -272,8 +271,6 @@ class RegTICost(TICost, abc.ABC): where :func:`reg` is the regularization function. """ - gamma = 0 - @abc.abstractmethod def reg(self, z: jnp.ndarray) -> float: """Regularization function.""" From b5a9b047f94e36f7091db40f38edb6f22510bcb9 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 14 Feb 2023 16:00:04 +0100 Subject: [PATCH 9/9] re-introduce tree-flatten in costs --- src/ott/geometry/costs.py | 53 +++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 38e23fc85..9a6ab89fb 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -139,14 +139,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute cost as evaluation of :func:`h` on :math:`x-y`.""" return self.h(x - y) - def tree_flatten(self): - return (), (self.p,) - - @classmethod - def tree_unflatten(cls, aux_data, children): - del children - return cls(aux_data[0]) - @jax.tree_util.register_pytree_node_class class SqPNorm(TICost): @@ -172,6 +164,14 @@ def h_legendre(self, z: jnp.ndarray) -> float: """ return 0.5 * jnp.linalg.norm(z, self.q) ** 2 + def tree_flatten(self): # noqa: D102 + return (), (self.p,) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(aux_data[0]) + @jax.tree_util.register_pytree_node_class class PNormP(TICost): @@ -195,6 +195,14 @@ def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`" return jnp.linalg.norm(z, self.q) ** self.q / self.q + def tree_flatten(self): # noqa: D102 + return (), (self.p,) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(aux_data[0]) + @jax.tree_util.register_pytree_node_class class Euclidean(CostFn): @@ -286,14 +294,6 @@ def h_legendre(self, z: jnp.ndarray) -> float: q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) - def tree_flatten(self): #noqa: D102 - return (), (self.gamma,) - - @classmethod - def tree_unflatten(cls, aux_data, children): #noqa: D102 - del children - return cls(*aux_data) - @jax.tree_util.register_pytree_node_class class ElasticL1(RegTICost): @@ -318,6 +318,14 @@ def reg(self, z: jnp.ndarray) -> float: # noqa: D102 def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma) + def tree_flatten(self): # noqa: D102 + return (), (self.gamma,) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(*aux_data) + @jax.tree_util.register_pytree_node_class class ElasticSTVS(RegTICost): @@ -348,6 +356,14 @@ def reg(self, z: jnp.ndarray) -> float: # noqa: D102 def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102 return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z + def tree_flatten(self): # noqa: D102 + return (), (self.gamma,) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(*aux_data) + @jax.tree_util.register_pytree_node_class class ElasticSqKOverlap(RegTICost): @@ -441,6 +457,11 @@ def inner(r: int, l: int, def tree_flatten(self): # noqa: D102 return (), (self.k, self.gamma) + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(*aux_data) + @jax.tree_util.register_pytree_node_class class Bures(CostFn):