Skip to content

Commit

Permalink
Fix flake8 D10{1,2,3} errors (ott-jax#269)
Browse files Browse the repository at this point in the history
* fix D101 and B028 (stack level for warnings) introduced in ott-jax#219

* add vscode to gitignore

* fix D102 on costs.py and try to simplify inheritance methods

* try to fix D102 everywhere

* fix D103

* remove comment from .flake8

* address comments

* more comments

* re-introduce tree-flatten in costs
  • Loading branch information
giovp authored and pierreablin committed Feb 14, 2023
1 parent c57de59 commit 8a70047
Show file tree
Hide file tree
Showing 42 changed files with 259 additions and 178 deletions.
10 changes: 3 additions & 7 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,11 @@ ignore =
E111, E114
# Missing blank line before section
D411
# TODO(michalk8): fix D10{1,2,3}
# D101 Missing docstring in public class
D101
# Missing docstring in public method
D102
# Missing docstring in public function
D103
exclude = .git,__pycache__,build,docs/_build,dist
# C409: Unnecessary <dict/list/tuple> call - rewrite as a literal.
per-file-ignores =
tests/*: D,C408
*/__init__.py: F401
examples/*: D101, D102, D103
docs/*: D101, D102
src/ott/types.py: D102
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# vscode
.vscode/

# generated documentation
docs/html
**/_autosummary
Expand Down
50 changes: 25 additions & 25 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(self, p: float):
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.5 * jnp.linalg.norm(z, self.p) ** 2

def h_legendre(self, z: jnp.ndarray) -> float:
Expand All @@ -164,11 +164,11 @@ def h_legendre(self, z: jnp.ndarray) -> float:
"""
return 0.5 * jnp.linalg.norm(z, self.q) ** 2

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0])

Expand All @@ -188,18 +188,18 @@ def __init__(self, p: float):
self.p = p
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else jnp.inf

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.linalg.norm(z, self.p) ** self.p / self.p

def h_legendre(self, z: jnp.ndarray) -> float:
def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
assert self.q < jnp.inf, "Legendre transform not defined for `p=1.0`"
return jnp.linalg.norm(z, self.q) ** self.q / self.q

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0])

Expand Down Expand Up @@ -231,10 +231,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2. * jnp.vdot(x, y)

def h(self, z: jnp.ndarray) -> float:
def h(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sum(z ** 2)

def h_legendre(self, z: jnp.ndarray) -> float:
def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102
return 0.25 * jnp.sum(z ** 2)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -312,17 +312,17 @@ def __init__(self, gamma: float = 1.0):
assert gamma >= 0, "Gamma must be non-negative."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
return self.gamma * jnp.linalg.norm(z, ord=1)

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jnp.sign(z) * jax.nn.relu(jnp.abs(z) - self.gamma)

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.gamma,)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)

Expand All @@ -348,19 +348,19 @@ def __init__(self, gamma: float = 1.0):
assert gamma > 0, "Gamma must be positive."
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
u = jnp.arcsinh(jnp.abs(z) / (2 * self.gamma))
out = u - 0.5 * jnp.exp(-2.0 * u)
return (self.gamma ** 2) * jnp.sum(out + 0.5) # make positive

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102
return jax.nn.relu(1 - (self.gamma / (jnp.abs(z) + 1e-12)) ** 2) * z

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.gamma,)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)

Expand Down Expand Up @@ -388,7 +388,7 @@ def __init__(self, k: int, gamma: float = 1.0):
self.k = k
self.gamma = gamma

def reg(self, z: jnp.ndarray) -> float:
def reg(self, z: jnp.ndarray) -> float: # noqa: D102
# Prop 2.1 in :cite:`argyriou:12`
k = self.k
top_w = jax.lax.top_k(jnp.abs(z), k)[0] # Fetch largest k values
Expand All @@ -409,7 +409,7 @@ def reg(self, z: jnp.ndarray) -> float:

return 0.5 * self.gamma * (s + (r + 1) * cesaro[r] ** 2)

def prox_reg(self, z: jnp.ndarray) -> float:
def prox_reg(self, z: jnp.ndarray) -> float: # noqa: D102

@functools.partial(jax.vmap, in_axes=[0, None, None])
def find_indices(r: int, l: jnp.ndarray,
Expand Down Expand Up @@ -454,11 +454,11 @@ def inner(r: int, l: int,
# change sign and reorder
return sgn * q[jnp.argsort(z_ixs.astype(float))]

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self.k, self.gamma)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(*aux_data)

Expand Down Expand Up @@ -606,11 +606,11 @@ def _padder(cls, dim: int) -> jnp.ndarray:
)
return padding[jnp.newaxis, :]

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self._dimension, self._sqrtm_kw)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
return cls(aux_data[0], **aux_data[1])

Expand Down Expand Up @@ -718,11 +718,11 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
(sig2 + gam) * jnp.exp(log_m_pi), lambda: jnp.nan
)

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (), (self._dimension, self._sigma, self._gamma, self._sqrtm_kw)

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del children
dim, sigma, gamma, kwargs = aux_data
return cls(dim, sigma=sigma, gamma=gamma, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions src/ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,20 @@ def at(self, iteration: Optional[int] = 1) -> float:
return multiple * self.target

def done(self, eps: float) -> bool:
"""Return whether the scheduler is done at a given value."""
return eps == self.target

def done_at(self, iteration: Optional[int]) -> bool:
"""Return whether the scheduler is done at a given iteration."""
return self.done(self.at(iteration))

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (
self._target_init, self._scale_epsilon, self._init, self._decay
), None

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)

Expand Down
4 changes: 2 additions & 2 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
assert mask.shape == (size,)
return mask

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._relative_epsilon, self._scale_epsilon, self._src_mask,
Expand All @@ -910,7 +910,7 @@ def tree_flatten(self):
}

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
*args, kwargs = children
return cls(*args, **kwargs, **aux_data)

Expand Down
26 changes: 18 additions & 8 deletions src/ott/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def apply_kernel(
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` on positive scaling vector.
Args:
scaling: Scaling to apply the kernel to.
eps: passed for consistency, not used yet.
axis: passed for consistency, not used yet.
Returns:
Kernel applied to ``scaling``.
"""

def conf_fn(
iteration: int, solver_lap: Tuple[decomposition.CholeskySolver,
Expand Down Expand Up @@ -158,7 +168,7 @@ def body_fn(
state=state,
)[1]

def apply_transport_from_scalings(
def apply_transport_from_scalings( # noqa: D102
self,
u: jnp.ndarray,
v: jnp.ndarray,
Expand All @@ -184,15 +194,15 @@ def body_fn(carry: None, vec: jnp.ndarray) -> jnp.ndarray:
return res

@property
def kernel_matrix(self) -> jnp.ndarray:
def kernel_matrix(self) -> jnp.ndarray: # noqa: D102
n, _ = self.shape
kernel = self.apply_kernel(jnp.eye(n))
# force symmetry because of numerical imprecisions
# happens when `numerical_scheme='backward_euler'` and small `t`
return (kernel + kernel.T) * .5

@property
def cost_matrix(self) -> jnp.ndarray:
def cost_matrix(self) -> jnp.ndarray: # noqa: D102
return -self.t * mu.safe_log(self.kernel_matrix)

@property
Expand Down Expand Up @@ -287,7 +297,7 @@ def solver(self) -> decomposition.CholeskySolver:
return self._solver

@property
def shape(self) -> Tuple[int, int]:
def shape(self) -> Tuple[int, int]: # noqa: D102
arr = self._graph if self._graph is not None else self._lap
return arr.shape

Expand All @@ -308,12 +318,12 @@ def graph(self) -> Optional[Union[jnp.ndarray, jesp.BCOO]]:
return (self._graph + self._graph.T) if self.directed else self._graph

@property
def is_symmetric(self) -> bool:
def is_symmetric(self) -> bool: # noqa: D102
# there may be some numerical imprecisions, but it should be symmetric
return True

@property
def dtype(self) -> jnp.dtype:
def dtype(self) -> jnp.dtype: # noqa: D102
return self._graph.dtype

# TODO(michalk8): in future, use mixins for lse/kernel mode
Expand Down Expand Up @@ -343,7 +353,7 @@ def marginal_from_potentials(
"""Not implemented."""
raise ValueError("Not implemented.")

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self._graph, self._lap, self.solver], {
"t": self._t,
"n_steps": self.n_steps,
Expand All @@ -355,7 +365,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
}

@classmethod
def tree_unflatten(
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "Graph":
graph, laplacian, solver = children
Expand Down
12 changes: 6 additions & 6 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def median_cost_matrix(self) -> NoReturn:
raise NotImplementedError('Median cost not implemented for grids.')

@property
def can_LRC(self) -> bool:
def can_LRC(self) -> bool: # noqa: D102
return True

@property
def shape(self) -> Tuple[int, int]:
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.num_a, self.num_a

@property
def is_symmetric(self) -> bool:
def is_symmetric(self) -> bool: # noqa: D102
return True

# Reimplemented functions to be used in regularized OT
Expand Down Expand Up @@ -341,14 +341,14 @@ def prepare_divergences(
return tuple(sep_grid for _ in range(size))

@property
def dtype(self) -> jnp.dtype:
def dtype(self) -> jnp.dtype: # noqa: D102
return self.x[0].dtype

def tree_flatten(self):
def tree_flatten(self): # noqa: D102
return (self.x, self.cost_fns, self._epsilon), self.kwargs

@classmethod
def tree_unflatten(cls, aux_data, children):
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(
x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data
)
Expand Down
Loading

0 comments on commit 8a70047

Please sign in to comment.