Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove bool option for scale_cost #492

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class Geometry:
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
src_mask: Mask specifying valid rows when computing some statistics of
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
Expand All @@ -83,8 +82,8 @@ def __init__(
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
Expand All @@ -97,7 +96,7 @@ def __init__(
) else epsilon_scheduler.Epsilon(epsilon)
self._relative_epsilon = relative_epsilon

self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost

self._src_mask = src_mask
self._tgt_mask = tgt_mask
Expand Down Expand Up @@ -212,11 +211,11 @@ def inv_scale_cost(self) -> float:
return 1.0 / jnp.nanmedian(self._cost_matrix)
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

def set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":
def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry":
"""Modify how to rescale of the :attr:`cost_matrix`."""
# case when `geom` doesn't have `scale_cost` or doesn't need to be modified
# `False` retains the original scale
if scale_cost is False or scale_cost == self._scale_cost:
if scale_cost == self._scale_cost:
return self
children, aux_data = self.tree_flatten()
aux_data["scale_cost"] = scale_cost
Expand Down
8 changes: 4 additions & 4 deletions src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LRCGeometry(geometry.Geometry):
scale_cost: option to rescale the cost matrix. Implemented scalings are
'max_bound', 'mean' and 'max_cost'. Alternatively, a float
factor can be given to rescale the cost such that
``cost_matrix /= scale_cost``. If `True`, use 'mean'.
``cost_matrix /= scale_cost``.
batch_size: optional size of the batch to compute online (without
instantiating the matrix) the scale factor ``scale_cost`` of the
:attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch
Expand All @@ -57,8 +57,8 @@ def __init__(
cost_2: jnp.ndarray,
bias: float = 0.0,
scale_factor: float = 1.0,
scale_cost: Union[bool, int, float, Literal["mean", "max_bound",
"max_cost"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_bound",
"max_cost"]] = 1.0,
batch_size: Optional[int] = None,
**kwargs: Any,
):
Expand All @@ -67,7 +67,7 @@ def __init__(
self._cost_2 = cost_2
self._bias = bias
self._scale_factor = scale_factor
self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost
self.batch_size = batch_size

@property
Expand Down
9 changes: 4 additions & 5 deletions src/ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PointCloud(geometry.Geometry):
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'.
Alternatively, a float factor can be given to rescale the cost such
that ``cost_matrix /= scale_cost``. If `True`, use 'mean'.
that ``cost_matrix /= scale_cost``.
kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""

Expand All @@ -60,9 +60,8 @@ def __init__(
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
batch_size: Optional[int] = None,
scale_cost: Union[bool, int, float,
Literal["mean", "max_norm", "max_bound", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any
):
super().__init__(**kwargs)
Expand All @@ -74,7 +73,7 @@ def __init__(
if batch_size is not None:
assert batch_size > 0, f"`batch_size={batch_size}` must be positive."
self._batch_size = batch_size
self._scale_cost = "mean" if scale_cost is True else scale_cost
self._scale_cost = scale_cost

@property
def _norm_x(self) -> Union[float, jnp.ndarray]:
Expand Down
8 changes: 2 additions & 6 deletions src/ott/neural/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def monge_gap(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down Expand Up @@ -66,7 +65,6 @@ def monge_gap(
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
Expand Down Expand Up @@ -96,8 +94,7 @@ def monge_gap_from_samples(
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0,
return_output: bool = False,
**kwargs: Any
) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]:
Expand Down Expand Up @@ -126,7 +123,6 @@ def monge_gap_from_samples(
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
return_output: boolean to also return the
:class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`.
kwargs: holds the kwargs to instantiate the or
Expand Down
29 changes: 14 additions & 15 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,10 @@ class QuadraticProblem:
reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`.
fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein,
i.e. ``problem = purely quadratic + fused_penalty * linear problem``.
scale_cost: option to rescale the cost matrices:

- if :obj:`True`, use the default for each geometry.
- if :obj:`False`, keep the original scaling in geometries.
- if :class:`str`, use a specific method available in
:class:`~ott.geometry.geometry.Geometry` or
:class:`~ott.geometry.pointcloud.PointCloud`.
- if :obj:`None`, do not scale the cost matrices.

scale_cost: How to rescale the cost matrices. If a :class:`str`,
use specific options available in :class:`~ott.geometry.geometry.Geometry`
or :class:`~ott.geometry.pointcloud.PointCloud`. If :obj:`None`, keep
the original scaling.
a: The first marginal. If :obj:`None`, it will be uniform.
b: The second marginal. If :obj:`None`, it will be uniform.
loss: Gromov-Wasserstein loss function, see
Expand Down Expand Up @@ -90,7 +85,7 @@ def __init__(
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[bool, float, str]] = False,
scale_cost: Optional[Union[float, str]] = None,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl",
Expand All @@ -100,11 +95,15 @@ def __init__(
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
):
self._geom_xx = geom_xx.set_scale_cost(scale_cost)
self._geom_yy = geom_yy.set_scale_cost(scale_cost)
self._geom_xy = (
None if geom_xy is None else geom_xy.set_scale_cost(scale_cost)
)
if scale_cost is not None:
geom_xx = geom_xx.set_scale_cost(scale_cost)
geom_yy = geom_yy.set_scale_cost(scale_cost)
if geom_xy is not None:
geom_xy = geom_xy.set_scale_cost(scale_cost)

self._geom_xx = geom_xx
self._geom_yy = geom_yy
self._geom_xy = geom_xy
self.fused_penalty = fused_penalty
self.scale_cost = scale_cost
self._a = a
Expand Down
2 changes: 1 addition & 1 deletion tests/solvers/quadratic/fgw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]):
geom_y_scaled,
geom_xy_scaled,
fused_penalty=fused_penalty,
scale_cost=False
scale_cost=None,
)
prob_scale = quadratic_problem.QuadraticProblem(
geom_x,
Expand Down
Loading