Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flake8 D10{1,2,3} errors #269

Merged
merged 10 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,11 @@ ignore =
E111, E114
# Missing blank line before section
D411
# TODO(michalk8): fix D10{1,2,3}
# D101 Missing docstring in public class
D101
# Missing docstring in public method
D102
# Missing docstring in public function
D103
exclude = .git,__pycache__,build,docs/_build,dist
# C409: Unnecessary <dict/list/tuple> call - rewrite as a literal.
per-file-ignores =
tests/*: D,C408
*/__init__.py: F401
examples/*: D101, D102, D103
docs/*: D101, D102
src/ott/types.py: D102
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# vscode
.vscode/

# generated documentation
docs/html
**/_autosummary
Expand Down
85 changes: 32 additions & 53 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)

def tree_flatten(self):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])


@jax.tree_util.register_pytree_node_class
class SqPNorm(TICost):
Expand All @@ -154,7 +162,7 @@ def __init__(self, p: float):
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.5 * jnp.linalg.norm(z, self.p) ** 2

def h_legendre(self, z: jnp.ndarray) -> float:
Expand All @@ -164,14 +172,6 @@ def h_legendre(self, z: jnp.ndarray) -> float:
"""
return 0.5 * jnp.linalg.norm(z, self.q) ** 2

def tree_flatten(self):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])


@jax.tree_util.register_pytree_node_class
class PNormP(TICost):
Expand All @@ -188,21 +188,13 @@ def __init__(self, p: float):
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.linalg.norm(z, self.p) ** self.p / self.p

def h_legendre(self, z: jnp.ndarray) -> float:
def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`"
return jnp.linalg.norm(z, self.q) ** self.q / self.q

def tree_flatten(self):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])


@jax.tree_util.register_pytree_node_class
class Euclidean(CostFn):
Expand Down Expand Up @@ -231,10 +223,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2. * jnp.vdot(x, y)

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sum(z ** 2)

def h_legendre(self, z: jnp.ndarray) -> float:
def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.25 * jnp.sum(z ** 2)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -294,6 +286,14 @@ def h_legendre(self, z: jnp.ndarray) -> float:
q = jax.lax.stop_gradient(self.prox_reg(z))
return jnp.sum(q * z) - self.h(q)

def tree_flatten(self): #noqa: D102
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.gamma,)

@classmethod
def tree_unflatten(cls, aux_data, children): #noqa: D102
del children
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
class ElasticL1(RegTICost):
Expand All @@ -312,20 +312,12 @@ def __init__(self, gamma: float = 1.0):
assert gamma >= 0, "Gamma must be non-negative."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
return self.gamma * jnp.linalg.norm(z, ord=1)

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma)

def tree_flatten(self):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.gamma,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
class ElasticSTVS(RegTICost):
Expand All @@ -348,22 +340,14 @@ def __init__(self, gamma: float = 1.0):
assert gamma > 0, "Gamma must be positive."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
u = jnp.arcsinh(jnp.abs(z) / (2 * self.gamma))
out = u - 0.5 * jnp.exp(-2.0 * u)
return (self.gamma ** 2) * jnp.sum(out + 0.5) # make positive

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z

def tree_flatten(self):
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
return (), (self.gamma,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
class ElasticSqKOverlap(RegTICost):
Expand All @@ -388,7 +372,7 @@ def __init__(self, k: int, gamma: float = 1.0):
self.k = k
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
# Prop 2.1 in :cite:`argyriou:12`
k = self.k
top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values
Expand All @@ -409,7 +393,7 @@ def reg(self, z: jnp.ndarray) -> float:

return 0.5 * self.gamma * (s + (r + 1) * cesaro[r] ** 2)

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102

@functools.partial(jax.vmap, in_axes=[0, None, None])
def find_indices(r: int, l: jnp.ndarray,
Expand Down Expand Up @@ -454,14 +438,9 @@ def inner(r: int, l: int,
# change sign and reorder
return sgn * q[jnp.argsort(z_ixs.astype(float))]

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.k, self.gamma)

@classmethod
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
def tree_unflatten(cls, aux_data, children):
del children
return cls(*aux_data)


@jax.tree_util.register_pytree_node_class
class Bures(CostFn):
Expand Down Expand Up @@ -606,11 +585,11 @@ def _padder(cls, dim: int) -> jnp.ndarray:
)
return padding[jnp.newaxis, :]

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self._dimension, self._sqrtm_kw)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0], **aux_data[1])

Expand Down Expand Up @@ -718,11 +697,11 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
(sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan
)

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
dim, sigma, gamma, kwargs = aux_data
return cls(dim, sigma=sigma, gamma=gamma, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,20 @@ def at(self, iteration: Optional[int] = 1) -> float:
return multiple * self.target

def done(self, eps: float) -> bool:
"""Return whether the scheduler is done at a given value."""
return eps == self.target

def done_at(self, iteration: Optional[int]) -> bool:
"""Return whether the scheduler is done at a given iteration."""
return self.done(self.at(iteration))

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (
self._target_init, self._scale_epsilon, self._init, self._decay
), None

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)

Expand Down
4 changes: 2 additions & 2 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
assert mask.shape == (size,)
return mask

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._relative_epsilon, self._scale_epsilon, self._src_mask,
Expand All @@ -910,7 +910,7 @@ def tree_flatten(self):
}

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
*args, kwargs = children
return cls(*args, **kwargs, **aux_data)

Expand Down
26 changes: 18 additions & 8 deletions src/ott/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def apply_kernel(
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` on positive scaling vector.

Args:
scaling: Scaling to apply the kernel to.
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.

Returns:
Kernel applied to ``scaling``.
"""

def conf_fn(
iteration: int, solver_lap: Tuple[decomposition.CholeskySolver,
Expand Down Expand Up @@ -145,7 +155,7 @@ def body_fn(
state=state,
)[1]

def apply_transport_from_scalings(
def apply_transport_from_scalings( # noqa: D102
self,
u: jnp.ndarray,
v: jnp.ndarray,
Expand All @@ -171,15 +181,15 @@ def body_fn(carry: None, vec: jnp.ndarray) -> jnp.ndarray:
return res

@property
def kernel_matrix(self) -> jnp.ndarray:
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
n, _ = self.shape
kernel = self.apply_kernel(jnp.eye(n))
# force symmetry because of numerical imprecisions
# happens when `numerical_scheme='backward_euler'` and small `t`
return (kernel + kernel.T) * .5

@property
def cost_matrix(self) -> jnp.ndarray:
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
return -self.t * mu.safe_log(self.kernel_matrix)

@property
Expand Down Expand Up @@ -274,7 +284,7 @@ def solver(self) -> decomposition.CholeskySolver:
return self._solver

@property
def shape(self) -> Tuple[int, int]:
def shape(self) -> Tuple[int, int]: # noqa: D102
arr = self._graph if self._graph is not None else self._lap
return arr.shape

Expand All @@ -295,12 +305,12 @@ def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]:
return (self._graph + self._graph.T) if self.directed else self._graph

@property
def is_symmetric(self) -> bool:
def is_symmetric(self) -> bool: # noqa: D102
# there may be some numerical imprecisions, but it should be symmetric
return True

@property
def dtype(self) -> jnp.dtype:
def dtype(self) -> jnp.dtype: # noqa: D102
return self._graph.dtype

# TODO(michalk8): in future, use mixins for lse/kernel mode
Expand Down Expand Up @@ -330,7 +340,7 @@ def marginal_from_potentials(
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self._graph, self._lap, self.solver], {
"t": self._t,
"n_steps": self.n_steps,
Expand All @@ -342,7 +352,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
}

@classmethod
def tree_unflatten(
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Graph":
graph, laplacian, solver = children
Expand Down
12 changes: 6 additions & 6 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def median_cost_matrix(self) -> NoReturn:
raise NotImplementedError('Median cost not implemented for grids.')

@property
def can_LRC(self) -> bool:
def can_LRC(self) -> bool: # noqa: D102
return True

@property
def shape(self) -> Tuple[int, int]:
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.num_a, self.num_a

@property
def is_symmetric(self) -> bool:
def is_symmetric(self) -> bool: # noqa: D102
return True

# Reimplemented functions to be used in regularized OT
Expand Down Expand Up @@ -341,14 +341,14 @@ def prepare_divergences(
return tuple(sep_grid for _ in range(size))

@property
def dtype(self) -> jnp.dtype:
def dtype(self) -> jnp.dtype: # noqa: D102
return self.x[0].dtype

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (self.x, self.cost_fns, self._epsilon), self.kwargs

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(
x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data
)
Expand Down
Loading