Skip to content

Commit

Permalink
Remove batch_size from LRCGeometry
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Oct 11, 2024
1 parent 47462d2 commit 05630a8
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[Union[bool, Literal["mean", "std"]]] = None,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median",
"std"]] = 1.0,
scale_cost: Union[float, Literal["mean", "max_cost", "median",
"std"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
Expand Down
18 changes: 3 additions & 15 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ class LRCGeometry(geometry.Geometry):
'max_bound', 'mean' and 'max_cost'. Alternatively, a float
factor can be given to rescale the cost such that
``cost_matrix /= scale_cost``.
batch_size: optional size of the batch to compute online (without
instantiating the matrix) the scale factor ``scale_cost`` of the
:attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch
size is set to `1024` or to the largest number of samples between
:attr:`cost_1` and :attr:`cost_2` if smaller than `1024`.
kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

Expand All @@ -57,9 +52,7 @@ def __init__(
cost_2: jnp.ndarray,
bias: float = 0.0,
scale_factor: float = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_bound",
"max_cost"]] = 1.0,
batch_size: Optional[int] = None,
scale_cost: Union[float, Literal["mean", "max_bound", "max_cost"]] = 1.0,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -68,8 +61,6 @@ def __init__(
self._bias = bias
self._scale_factor = scale_factor
self._scale_cost = scale_cost
# TODO(michalk8): remove?
self.batch_size = batch_size

@property
def cost_1(self) -> jnp.ndarray:
Expand Down Expand Up @@ -116,7 +107,7 @@ def inv_scale_cost(self) -> float: # noqa: D102
if self._scale_cost == "max_bound":
x_norm = self._cost_1[:, 0].max()
y_norm = self._cost_2[:, 1].max()
max_bound = x_norm + y_norm + 2 * jnp.sqrt(x_norm * y_norm)
max_bound = x_norm + y_norm + 2.0 * jnp.sqrt(x_norm * y_norm)
return 1.0 / (max_bound + self._bias)
if self._scale_cost == "mean":
a, b = self._n_normed_ones, self._m_normed_ones
Expand Down Expand Up @@ -183,9 +174,7 @@ def _apply_cost_to_vec_fast(
@property
def _max_cost_matrix(self) -> jnp.ndarray:
fn = utils.batched_vmap(
lambda c1, c2: jnp.max(c1 @ c2.T),
batch_size=self.batch_size or 1024,
in_axes=(0, None)
lambda c1, c2: jnp.max(c1 @ c2.T), batch_size=1024, in_axes=(0, None)
)
return jnp.max(fn(self._cost_1, self._cost_2)) + self._bias

Expand Down Expand Up @@ -291,7 +280,6 @@ def tree_flatten(self): # noqa: D102
self._scale_factor,
), {
"scale_cost": self._scale_cost,
"batch_size": self.batch_size
}

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
batch_size: Optional[int] = None,
scale_cost: Union[int, float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any
scale_cost: Union[float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.x = x
Expand Down
4 changes: 2 additions & 2 deletions src/ott/neural/methods/monge_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def monge_gap(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down Expand Up @@ -112,7 +112,7 @@ def monge_gap_from_samples(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down
2 changes: 1 addition & 1 deletion src/ott/problems/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
y_fused: Optional[jnp.ndarray] = None,
fused_penalty: float = 1.0,
gw_loss: Literal["sqeucl", "kl"] = "sqeucl",
scale_cost: Union[int, float, Literal["mean", "max_cost"]] = 1.0,
scale_cost: Union[float, Literal["mean", "max_cost"]] = 1.0,
**kwargs: Any,
):
assert y is None or costs is None, "Cannot specify both `y` and `costs`."
Expand Down
12 changes: 0 additions & 12 deletions tests/geometry/scaling_cost_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,6 @@ def apply_sinkhorn(cost1, cost2, scale_cost):
if scale == "max_cost":
np.testing.assert_allclose(1.0, geom.cost_matrix.max(), rtol=1e-4)

@pytest.mark.parametrize("batch_size", [5, 12])
def test_max_scale_cost_low_rank_with_batch(self, batch_size: int):
"""Test max_cost options for low rank with batch_size fixed."""

geom0 = low_rank.LRCGeometry(
self.cost1, self.cost2, scale_cost="max_cost", batch_size=batch_size
)

np.testing.assert_allclose(
geom0.inv_scale_cost, 1.0 / jnp.max(self.cost_lr), rtol=1e-4
)

def test_max_scale_cost_low_rank_large_array(self):
"""Test max_cost options for large matrices."""

Expand Down
10 changes: 6 additions & 4 deletions tests/geometry/subsetting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,17 @@ def test_mask(

if clazz is geometry.Geometry:
geom = clazz(cost_matrix=x @ y.T, scale_cost="mean")
else:
elif clazz is pointcloud.PointCloud:
geom = clazz(x, y, scale_cost="max_cost", batch_size=5)
else:
geom = clazz(x, y, scale_cost="max_cost")
n = geom.shape[0] if src_ixs is None else len(src_ixs)
m = geom.shape[1] if tgt_ixs is None else len(tgt_ixs)

if clazz is geometry.Geometry:
geom_sub = geom.subset(src_ixs, tgt_ixs)
else:
if clazz is pointcloud.PointCloud:
geom_sub = geom.subset(src_ixs, tgt_ixs, batch_size=new_batch_size)
else:
geom_sub = geom.subset(src_ixs, tgt_ixs)

assert type(geom_sub) == type(geom)
np.testing.assert_array_equal(geom_sub.shape, (n, m))
Expand Down

0 comments on commit 05630a8

Please sign in to comment.