diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 0ae69efde..286c1ec66 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -63,7 +63,6 @@ class Geometry: scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. - If `True`, use 'mean'. src_mask: Mask specifying valid rows when computing some statistics of :attr:`cost_matrix`, see :attr:`src_mask`. tgt_mask: Mask specifying valid columns when computing some statistics of @@ -83,8 +82,8 @@ def __init__( kernel_matrix: Optional[jnp.ndarray] = None, epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, relative_epsilon: Optional[bool] = None, - scale_cost: Union[bool, int, float, Literal["mean", "max_cost", - "median"]] = 1.0, + scale_cost: Union[int, float, Literal["mean", "max_cost", + "median"]] = 1.0, src_mask: Optional[jnp.ndarray] = None, tgt_mask: Optional[jnp.ndarray] = None, ): @@ -97,7 +96,7 @@ def __init__( ) else epsilon_scheduler.Epsilon(epsilon) self._relative_epsilon = relative_epsilon - self._scale_cost = "mean" if scale_cost is True else scale_cost + self._scale_cost = scale_cost self._src_mask = src_mask self._tgt_mask = tgt_mask @@ -212,11 +211,11 @@ def inv_scale_cost(self) -> float: return 1.0 / jnp.nanmedian(self._cost_matrix) raise ValueError(f"Scaling {self._scale_cost} not implemented.") - def set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry": + def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry": """Modify how to rescale of the :attr:`cost_matrix`.""" # case when `geom` doesn't have `scale_cost` or doesn't need to be modified # `False` retains the original scale - if scale_cost is False or scale_cost == self._scale_cost: + if scale_cost == self._scale_cost: return self children, aux_data = self.tree_flatten() aux_data["scale_cost"] = scale_cost diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index c28c31420..2ab5aba27 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -42,7 +42,7 @@ class LRCGeometry(geometry.Geometry): scale_cost: option to rescale the cost matrix. Implemented scalings are 'max_bound', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that - ``cost_matrix /= scale_cost``. If `True`, use 'mean'. + ``cost_matrix /= scale_cost``. batch_size: optional size of the batch to compute online (without instantiating the matrix) the scale factor ``scale_cost`` of the :attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch @@ -57,8 +57,8 @@ def __init__( cost_2: jnp.ndarray, bias: float = 0.0, scale_factor: float = 1.0, - scale_cost: Union[bool, int, float, Literal["mean", "max_bound", - "max_cost"]] = 1.0, + scale_cost: Union[int, float, Literal["mean", "max_bound", + "max_cost"]] = 1.0, batch_size: Optional[int] = None, **kwargs: Any, ): @@ -67,7 +67,7 @@ def __init__( self._cost_2 = cost_2 self._bias = bias self._scale_factor = scale_factor - self._scale_cost = "mean" if scale_cost is True else scale_cost + self._scale_cost = scale_cost self.batch_size = batch_size @property diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index cd01cdc1d..5fdc94474 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -50,7 +50,7 @@ class PointCloud(geometry.Geometry): scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'. Alternatively, a float factor can be given to rescale the cost such - that ``cost_matrix /= scale_cost``. If `True`, use 'mean'. + that ``cost_matrix /= scale_cost``. kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ @@ -60,9 +60,8 @@ def __init__( y: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, batch_size: Optional[int] = None, - scale_cost: Union[bool, int, float, - Literal["mean", "max_norm", "max_bound", "max_cost", - "median"]] = 1.0, + scale_cost: Union[int, float, Literal["mean", "max_norm", "max_bound", + "max_cost", "median"]] = 1.0, **kwargs: Any ): super().__init__(**kwargs) @@ -74,7 +73,7 @@ def __init__( if batch_size is not None: assert batch_size > 0, f"`batch_size={batch_size}` must be positive." self._batch_size = batch_size - self._scale_cost = "mean" if scale_cost is True else scale_cost + self._scale_cost = scale_cost @property def _norm_x(self) -> Union[float, jnp.ndarray]: diff --git a/src/ott/neural/losses.py b/src/ott/neural/losses.py index e13d41c2a..f6136bf07 100644 --- a/src/ott/neural/losses.py +++ b/src/ott/neural/losses.py @@ -29,8 +29,7 @@ def monge_gap( cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, - scale_cost: Union[bool, int, float, Literal["mean", "max_cost", - "median"]] = 1.0, + scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any ) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: @@ -66,7 +65,6 @@ def monge_gap( scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. - If `True`, use 'mean'. return_output: boolean to also return the :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. kwargs: holds the kwargs to instantiate the or @@ -96,8 +94,7 @@ def monge_gap_from_samples( cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, relative_epsilon: Optional[bool] = None, - scale_cost: Union[bool, int, float, Literal["mean", "max_cost", - "median"]] = 1.0, + scale_cost: Union[int, float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any ) -> Union[float, Tuple[float, sinkhorn.SinkhornOutput]]: @@ -126,7 +123,6 @@ def monge_gap_from_samples( scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. - If `True`, use 'mean'. return_output: boolean to also return the :class:`~ott.solvers.linear.sinkhorn.SinkhornOutput`. kwargs: holds the kwargs to instantiate the or diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 5deb4558c..25f1ae821 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -51,15 +51,10 @@ class QuadraticProblem: reduces to a plain Gromov-Wasserstein problem :cite:`peyre:16`. fused_penalty: Multiplier of the linear term in fused Gromov-Wasserstein, i.e. ``problem = purely quadratic + fused_penalty * linear problem``. - scale_cost: option to rescale the cost matrices: - - - if :obj:`True`, use the default for each geometry. - - if :obj:`False`, keep the original scaling in geometries. - - if :class:`str`, use a specific method available in - :class:`~ott.geometry.geometry.Geometry` or - :class:`~ott.geometry.pointcloud.PointCloud`. - - if :obj:`None`, do not scale the cost matrices. - + scale_cost: How to rescale the cost matrices. If a :class:`str`, + use specific options available in :class:`~ott.geometry.geometry.Geometry` + or :class:`~ott.geometry.pointcloud.PointCloud`. If :obj:`None`, keep + the original scaling. a: The first marginal. If :obj:`None`, it will be uniform. b: The second marginal. If :obj:`None`, it will be uniform. loss: Gromov-Wasserstein loss function, see @@ -90,7 +85,7 @@ def __init__( geom_yy: geometry.Geometry, geom_xy: Optional[geometry.Geometry] = None, fused_penalty: float = 1.0, - scale_cost: Optional[Union[bool, float, str]] = False, + scale_cost: Optional[Union[float, str]] = None, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", @@ -100,11 +95,15 @@ def __init__( ranks: Union[int, Tuple[int, ...]] = -1, tolerances: Union[float, Tuple[float, ...]] = 1e-2, ): - self._geom_xx = geom_xx.set_scale_cost(scale_cost) - self._geom_yy = geom_yy.set_scale_cost(scale_cost) - self._geom_xy = ( - None if geom_xy is None else geom_xy.set_scale_cost(scale_cost) - ) + if scale_cost is not None: + geom_xx = geom_xx.set_scale_cost(scale_cost) + geom_yy = geom_yy.set_scale_cost(scale_cost) + if geom_xy is not None: + geom_xy = geom_xy.set_scale_cost(scale_cost) + + self._geom_xx = geom_xx + self._geom_yy = geom_yy + self._geom_xy = geom_xy self.fused_penalty = fused_penalty self.scale_cost = scale_cost self._a = a diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 9f9cde158..3283ae845 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -296,7 +296,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]): geom_y_scaled, geom_xy_scaled, fused_penalty=fused_penalty, - scale_cost=False + scale_cost=None, ) prob_scale = quadratic_problem.QuadraticProblem( geom_x,