Skip to content

Commit

Permalink
Remove bool option for scale_cost (#492)
Browse files Browse the repository at this point in the history
* Remove `bool` option for `scale_cost`

* Update quadratic problem

* Fix GW test
  • Loading branch information
michalk8 authored Feb 16, 2024
1 parent 8519182 commit 8a2dbef
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 37 deletions.
11 changes: 5 additions & 6 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class Geometry:
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
src_mask: Mask specifying valid rows when computing some statistics of
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
Expand All @@ -83,8 +82,8 @@ def __init__(
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
Expand All @@ -97,7 +96,7 @@ def __init__(
) else epsilon_scheduler.Epsilon(epsilon)
self._relative_epsilon = relative_epsilon

self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost

self._src_mask = src_mask
self._tgt_mask = tgt_mask
Expand Down Expand Up @@ -212,11 +211,11 @@ def inv_scale_cost(self) -> float:
return 1.0 / jnp.nanmedian(self._cost_matrix)
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

def set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":
def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry":
"""Modify how to rescale of the :attr:`cost_matrix`."""
# case when `geom` doesn't have `scale_cost` or doesn't need to be modified
# `False` retains the original scale
if scale_cost is False or scale_cost == self._scale_cost:
if scale_cost == self._scale_cost:
return self
children, aux_data = self.tree_flatten()
aux_data["scale_cost"] = scale_cost
Expand Down
8 changes: 4 additions & 4 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LRCGeometry(geometry.Geometry):
scale_cost: option to rescale the cost matrix. Implemented scalings are
'max_bound', 'mean' and 'max_cost'. Alternatively, a float
factor can be given to rescale the cost such that
``cost_matrix /= scale_cost``. If `True`, use 'mean'.
``cost_matrix /= scale_cost``.
batch_size: optional size of the batch to compute online (without
instantiating the matrix) the scale factor ``scale_cost`` of the
:attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch
Expand All @@ -57,8 +57,8 @@ def __init__(
cost_2: jnp.ndarray,
bias: float = 0.0,
scale_factor: float = 1.0,
scale_cost: Union[bool, int, float, Literal["mean", "max_bound",
"max_cost"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_bound",
"max_cost"]] = 1.0,
batch_size: Optional[int] = None,
**kwargs: Any,
):
Expand All @@ -67,7 +67,7 @@ def __init__(
self._cost_2 = cost_2
self._bias = bias
self._scale_factor = scale_factor
self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost
self.batch_size = batch_size

@property
Expand Down
9 changes: 4 additions & 5 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PointCloud(geometry.Geometry):
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'.
Alternatively, a float factor can be given to rescale the cost such
that ``cost_matrix /= scale_cost``. If `True`, use 'mean'.
that ``cost_matrix /= scale_cost``.
kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

Expand All @@ -60,9 +60,8 @@ def __init__(
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
batch_size: Optional[int] = None,
scale_cost: Union[bool, int, float,
Literal["mean", "max_norm", "max_bound", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any
):
super().__init__(**kwargs)
Expand All @@ -74,7 +73,7 @@ def __init__(
if batch_size is not None:
assert batch_size > 0, f"`batch_size={batch_size}` must be positive."
self._batch_size = batch_size
self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost

@property
def _norm_x(self) -> Union[float, jnp.ndarray]:
Expand Down
8 changes: 2 additions & 6 deletions src/ott/neural/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def monge_gap(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down Expand Up @@ -66,7 +65,6 @@ def monge_gap(
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
Expand Down Expand Up @@ -96,8 +94,7 @@ def monge_gap_from_samples(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down Expand Up @@ -126,7 +123,6 @@ def monge_gap_from_samples(
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
Expand Down
29 changes: 14 additions & 15 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,10 @@ class QuadraticProblem:
reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. ``problem = purely quadratic + fused_penalty * linear problem``.
scale_cost: option to rescale the cost matrices:
- if :obj:`True`, use the default for each geometry.
- if :obj:`False`, keep the original scaling in geometries.
- if :class:`str`, use a specific method available in
:class:`~ott.geometry.geometry.Geometry` or
:class:`~ott.geometry.pointcloud.PointCloud`.
- if :obj:`None`, do not scale the cost matrices.
scale_cost: How to rescale the cost matrices. If a :class:`str`,
use specific options available in :class:`~ott.geometry.geometry.Geometry`
or :class:`~ott.geometry.pointcloud.PointCloud`. If :obj:`None`, keep
the original scaling.
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
loss: Gromov-Wasserstein loss function, see
Expand Down Expand Up @@ -90,7 +85,7 @@ def __init__(
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[bool, float, str]] = False,
scale_cost: Optional[Union[float, str]] = None,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
Expand All @@ -100,11 +95,15 @@ def __init__(
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
):
self._geom_xx = geom_xx.set_scale_cost(scale_cost)
self._geom_yy = geom_yy.set_scale_cost(scale_cost)
self._geom_xy = (
None if geom_xy is None else geom_xy.set_scale_cost(scale_cost)
)
if scale_cost is not None:
geom_xx = geom_xx.set_scale_cost(scale_cost)
geom_yy = geom_yy.set_scale_cost(scale_cost)
if geom_xy is not None:
geom_xy = geom_xy.set_scale_cost(scale_cost)

self._geom_xx = geom_xx
self._geom_yy = geom_yy
self._geom_xy = geom_xy
self.fused_penalty = fused_penalty
self.scale_cost = scale_cost
self._a = a
Expand Down
2 changes: 1 addition & 1 deletion tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]):
geom_y_scaled,
geom_xy_scaled,
fused_penalty=fused_penalty,
scale_cost=False
scale_cost=None,
)
prob_scale = quadratic_problem.QuadraticProblem(
geom_x,
Expand Down

0 comments on commit 8a2dbef

Please sign in to comment.