Skip to content

Commit

Permalink
Fwd-port jnp.isscalar
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Dec 2, 2024
1 parent 41ebe25 commit 4a89427
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def inv_scale_cost(self) -> jnp.ndarray:
return 1.0 / jnp.mean(self._cost_matrix)
if self._scale_cost == "median":
return 1.0 / jnp.median(self._cost_matrix)
if jnp.isscalar(self._scale_cost):
if utils.is_scalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

Expand Down
2 changes: 1 addition & 1 deletion src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def inv_scale_cost(self) -> jnp.ndarray: # noqa: D102
return 1.0 / (mean + self._bias)
if self._scale_cost == "max_cost":
return 1.0 / self._max_cost_matrix
if jnp.isscalar(self._scale_cost):
if utils.is_scalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

Expand Down
2 changes: 1 addition & 1 deletion src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def inv_scale_cost(self) -> jnp.ndarray: # noqa: D102
"the cost matrix when the cost is not squared euclidean "
"is not implemented."
)
if jnp.isscalar(self._scale_cost):
if utils.is_scalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

Expand Down
11 changes: 11 additions & 0 deletions src/ott/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"default_progress_fn",
"tqdm_progress_fn",
"batched_vmap",
"is_scalar",
]

IOStatus = Tuple[np.ndarray, np.ndarray, np.ndarray, NamedTuple]
Expand Down Expand Up @@ -422,3 +423,13 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
batched_fun = _apply_scan(vmapped_fun, in_axes=in_axes)

return wrapper


# TODO(michalk8): remove when `jax>=0.4.31`
def is_scalar(x: Any) -> bool: # noqa: D103
if (
isinstance(x, (np.ndarray, jax.Array)) or hasattr(x, "__jax_array__") or
np.isscalar(x)
):
return jnp.asarray(x).ndim == 0
return False

0 comments on commit 4a89427

Please sign in to comment.