diff --git a/ott/.DS_Store b/ott/.DS_Store new file mode 100644 index 000000000..e84ffab0c Binary files /dev/null and b/ott/.DS_Store differ diff --git a/ott/core/__init__.py b/ott/core/__init__.py index 8b66471d9..045795659 100644 --- a/ott/core/__init__.py +++ b/ott/core/__init__.py @@ -18,14 +18,17 @@ # pytype: disable=import-error # kwargs-checking from . import anderson from . import dataclasses +from . import problems +from . import quad_problems +from . import bar_problems from . import discrete_barycenter +from . import continuous_barycenter from . import gromov_wasserstein from . import implicit_differentiation from . import momentum -from . import problems from . import sinkhorn from . import sinkhorn_lr -from . import neuraldual +#from . import neuraldual from .implicit_differentiation import ImplicitDiff from .problems import LinearProblem from .sinkhorn import Sinkhorn diff --git a/ott/core/bar_problems.py b/ott/core/bar_problems.py new file mode 100644 index 000000000..f932f14ef --- /dev/null +++ b/ott/core/bar_problems.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes defining OT problem(s) (objective function + utilities).""" + +from typing import Optional, Tuple +import jax +import jax.numpy as jnp +from ott.geometry import geometry +from ott.geometry import costs +from ott.core import segment + + +@jax.tree_util.register_pytree_node_class +class BarycenterProblem: + """Holds the definition of a linear regularized OT problem and some tools.""" + + def __init__(self, + y: Optional[jnp.ndarray] = None, + b: Optional[jnp.ndarray] = None, + weights: Optional[jnp.ndarray] = None, + cost_fn: Optional[costs.CostFn] = None, + epsilon: Optional[jnp.ndarray] = None, + debiased: bool = False, + segment_ids: Optional[jnp.ndarray] = None, + num_segments: Optional[jnp.ndarray] = None, + indices_are_sorted: Optional[bool] = None, + num_per_segment: Optional[jnp.ndarray] = None, + max_measure_size: Optional[int] = None): + """Initializes a discrete BarycenterProblem + + Args: + y: a matrix merging the points of all measures. + b: a vector containing the weights (within each masure) of all the points + weights: weights of the barycenter problem (size num_segments) + cost_fn: cost function used. + epsilon: epsilon regularization used to solve reg-OT problems. + debiased: whether the problem is debiased, in the sense that + the regularized transportation cost of barycenter to itself will + be considered when computing gradient. Note that if the debiased option + is used, the barycenter size (used in call function) needs to be smaller + than the max_measure_size parameter below, for parallelization to + operate efficiently. + segment_ids: describe for each point to which measure it belongs. + num_segments: total number of measures + indices_are_sorted: flag indicating indices in segment_ids are sorted. + num_per_segment: number of points in each segment, if contiguous. + max_measure_size: max number of points in each segment (for efficient jit) + """ + self._y = y + self._b = b + self._weights = weights + self.cost_fn = costs.Euclidean() if cost_fn is None else cost_fn + self.epsilon = epsilon + self.debiased = debiased + self._segment_ids = segment_ids + self._num_segments = num_segments + self._indices_are_sorted = indices_are_sorted + self._num_per_segment = num_per_segment + self._max_measure_size = max_measure_size + + def tree_flatten(self): + return ([self._y, self._b, self._weights], + { + 'cost_fn' : self.cost_fn, + 'epsilon' : self.epsilon, + 'debiased': self.debiased, + 'segment_ids' : self._segment_ids, + 'num_segments' : self._num_segments, + 'indices_are_sorted' : self._indices_are_sorted, + 'num_per_segment' : self._num_per_segment, + 'max_measure_size' : self._max_measure_size}) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + + @property + def segmented_y_b(self): + if self._y is None or (self._y.ndim == 3 and self._b.ndim == 2): + return self.add_slice_for_debiased(self._y, self._b) + else: + segmented_y, segmented_b, _ = segment.segment_point_cloud( + self._y, self._b, self._segment_ids, self._num_segments, + self._indices_are_sorted, self._num_per_segment, + self.max_measure_size) + return self.add_slice_for_debiased(segmented_y, segmented_b) + + def add_slice_for_debiased(self, y, b): + if y is None or b is None: + return y, b + if self.debiased: + n, dim = y.shape[1], y.shape[2] + y = jnp.concatenate((y, jnp.zeros((1, n, dim))), axis=0) + b = jnp.concatenate((b, jnp.zeros((1, n,))), axis=0) + return y, b + + @property + def flattened_y(self): + if self._y is not None and self._y.ndim == 3: + return self._y.reshape((-1,self._y.shape[-1])) + else: + return self._y + + @property + def flattened_b(self): + if self._b is not None and self._b.ndim == 2: + return self._b.ravel() + else: + return self._b + + @property + def max_measure_size(self): + if self._max_measure_size is not None: + return self._max_measure_size + if self._y is not None and self._y.ndim == 3: + return self._y.shape[1] + else: + if self._num_per_segment is None: + if num_segments is None: + num_segments = jnp.max(self._segment_ids) + 1 + if indices_are_sorted is None: + indices_are_sorted = False + num_per_segment = jax.ops.segment_sum( + jnp.ones_like(self._segment_ids), self._segment_ids, + num_segments=num_segments, indices_are_sorted=indices_are_sorted) + return jnp.max(num_per_segment) + else: + return jnp.max(self._num_per_segment) + + @property + def num_segments(self): + if self._y is None: + return 0 + if self._y.ndim == 3: + if self._b is not None: + assert self._y.shape[0] == self._b.shape[0] + return self._y.shape[0] + else: + _ , _, num_segments = segment.segment_point_cloud( + self._y, self._b, self._segment_ids, self._num_segments, + self._indices_are_sorted, self._num_per_segment, + self.max_measure_size) + return num_segments + + + @property + def weights(self): + if self._weights is None: + weights = jnp.ones((self.num_segments,)) / self.num_segments + else: + assert self.weights.shape[0] == self.num_segments + assert jnp.isclose(jnp.sum(self.weights), 1.0) + weights = self.weights + if self.debiased: + weights = jnp.concatenate((weights, jnp.array([-0.5]))) + return weights \ No newline at end of file diff --git a/ott/core/continuous_barycenter.py b/ott/core/continuous_barycenter.py new file mode 100644 index 000000000..24b32f1f8 --- /dev/null +++ b/ott/core/continuous_barycenter.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2022 Apple. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A Jax version of the W barycenter algorithm (Cuturi Doucet 2014).""" +import functools +from typing import Any, Dict, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from ott.core import fixed_point_loop +from ott.core import problems +from ott.core import bar_problems +from ott.core import was_solver +from ott.geometry import epsilon_scheduler +from ott.geometry import geometry +from ott.geometry import costs +from ott.geometry import low_rank +from ott.geometry import pointcloud + + +class BarycenterOutput(NamedTuple): + """Holds the output of a Wasserstein Barycenter solver. + + The goal is to approximate the W barycenter of a set of N measures using + a discrete measure described by k locations x. To do so the OT between + each of the input N measures to the barycenter is recomputed and x_bar + adjusted following that result. + + Attributes: + costs: Holds the sequence of weighted sum of N regularized W costs seen + (possibly debiased) through the outer loop of the solver. + linear_convergence: Holds the sequence of bool convergence flags of the + inner N Sinkhorn iterations. + convergence: Bool convergence flag for the outer Barycenter iterations. + errors: Holds sequence of matrices of N x max_iterations errors of the + N Sinkhorn algorithms run at each inner iteration. + x : barycenter locations, k x dimension + a : weights of the barycenter + transports: final N transport objects mapping barycenter to input measures. + reg_gw_cost: Total regularized optimal transport cost upon convergence + """ + costs: Optional[jnp.ndarray] = None + linear_convergence: Optional[jnp.ndarray] = None + convergence: bool = False + errors: Optional[jnp.ndarray] = None + x = None + a = None + transports = None + reg_gw_cost = None + + def set(self, **kwargs) -> 'BarycenterOutput': + """Returns a copy of self, possibly with overwrites.""" + return self._replace(**kwargs) + + +class BarycenterState(NamedTuple): + """Holds the state of the Wasserstein barycenter solver. + + Attributes: + costs: Holds the sequence of regularized GW costs seen through the outer + loop of the solver. + linear_convergence: Holds the sequence of bool convergence flags of the + inner Sinkhorn iterations. + errors: Holds sequence of vectors of errors of the Sinkhorn algorithm + at each iteration. + linear_states: State used to solve and store solutions to the OT problems + from the barycenter to the measures. + x: barycenter points + a: barycenter weights + """ + costs: Optional[jnp.ndarray] = None + linear_convergence: Optional[jnp.ndarray] = None + errors: Optional[jnp.ndarray] = None + x: Optional[jnp.ndarray] = None + a: Optional[jnp.ndarray] = None + + def set(self, **kwargs) -> 'BarycenterState': + """Returns a copy of self, possibly with overwrites.""" + return self._replace(**kwargs) + + def update(self, + iteration: int, + bar_prob: bar_problems.BarycenterProblem, + linear_ot_solver: Any, + store_errors: bool): + segmented_y, segmented_b = bar_prob.segmented_y_b + + @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) + def solve_linear_ot(a, x, b, y): + out = linear_ot_solver( + problems.LinearProblem(pointcloud.PointCloud( + x, y, cost_fn = bar_prob.cost_fn, epsilon= bar_prob.epsilon), + a, b)) + return (out.reg_ot_cost, out.converged, out.matrix, + out.errors if store_errors else None) + + if bar_prob.debiased: + # Check max size (used to pad) is bigger than barycenter size + n, dim = self.x.shape + max_size = bar_prob.max_measure_size + segmented_y = segmented_y.at[-1,:n,:].set(self.x) + segmented_b = segmented_b.at[-1,:n].set(self.a) + + reg_ot_costs, convergeds, matrices, errors = solve_linear_ot( + self.a, self.x, segmented_b, segmented_y) + + cost = jnp.sum(reg_ot_costs * bar_prob.weights) + updated_costs = self.costs.at[iteration].set(cost) + converged = jnp.all(convergeds) + linear_convergence = self.linear_convergence.at[iteration].set(converged) + + if store_errors and self.errors is not None: + errors = self.errors.at[iteration, :, :].set(errors) + else: + errors = None + + divide_a = jnp.where(self.a > 0, 1.0 / self.a, 1.0) + convex_weights = matrices * divide_a[None, :, None] + x_new = jnp.sum( + barycentric_projection(convex_weights, segmented_y, bar_prob.cost_fn) + * bar_prob.weights[:, None, None], axis=0) + return self.set(costs=updated_costs, + linear_convergence=linear_convergence, + errors=errors, + x=x_new) + +@functools.partial(jax.vmap, in_axes=[0, 0, None]) +def barycentric_projection(matrix, y, cost_fn): + return jax.vmap(cost_fn.barycenter, in_axes=[0, None])(matrix, y) + +@jax.tree_util.register_pytree_node_class +class WassersteinBarycenter(was_solver.WassersteinSolver): + """A Continuous Wasserstein barycenter solver, built on generic template.""" + + def __call__( + self, + bar_prob: bar_problems.BarycenterProblem, + bar_size: int = 100, + x_init: jnp.ndarray = None, + rng: int = 0 + ) -> BarycenterState: + bar_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations + out = bar_fn(self, bar_size, bar_prob, x_init, rng) + iteration = jnp.sum(out.costs != 0) + convergence = jnp.logical_not(self.not_converged(out, iteration)) + return out + + def init_state(self, bar_prob, bar_size, x_init, rng + ) -> BarycenterState: + """Initializes the state of the Wasserstein barycenter iterations.""" + if x_init is not None: + assert bar_size == x_init.shape[0] + x = x_init + else: + # sample randomly points in the support of the y measures + indices_subset = jax.random.choice(jax.random.PRNGKey(rng), + a=bar_prob.flattened_y.shape[0], + shape=(bar_size,), + replace=False, + p=bar_prob.flattened_b) + x = bar_prob.flattened_y[indices_subset,:] + + # TODO(cuturi) expand to non-uniform weights for barycenter. + a = jnp.ones((bar_size,))/ bar_size + num_iter = self.max_iterations + if self.store_inner_errors: + errors = -jnp.ones( + (self.num_iter, bar_prob.num_segments, + self.linear_ot_solver.outer_iterations)) + else: + errors = None + return BarycenterState(-jnp.ones((num_iter,)), -jnp.ones((num_iter,)), + errors, x, a) + + def output_from_state(self, state): + return state + +def iterations(solver: WassersteinBarycenter, + bar_size, bar_prob, x_init, rng) -> WassersteinBarycenter: + """A jittable Wasserstein barycenter outer loop.""" + def cond_fn(iteration, constants, state): + solver, _ = constants + return solver.not_converged(state, iteration) + + def body_fn(iteration, constants, state, compute_error): + del compute_error # Always assumed True + solver, bar_prob = constants + return state.update( + iteration, + bar_prob, + solver.linear_ot_solver, + solver.store_inner_errors) + + state = fixed_point_loop.fixpoint_iter( + cond_fn=cond_fn, + body_fn=body_fn, + min_iterations=solver.min_iterations, + max_iterations=solver.max_iterations, + inner_iterations=1, + constants=(solver, bar_prob), + state=solver.init_state(bar_prob, bar_size, x_init, rng)) + + return solver.output_from_state(state) \ No newline at end of file diff --git a/ott/core/gromov_wasserstein.py b/ott/core/gromov_wasserstein.py index c57077201..6faa77f6e 100644 --- a/ott/core/gromov_wasserstein.py +++ b/ott/core/gromov_wasserstein.py @@ -25,6 +25,9 @@ from ott.core import quad_problems from ott.core import sinkhorn from ott.core import sinkhorn_lr +from ott.core import was_solver + +from ott.core.was_solver import WassersteinSolver from ott.geometry import epsilon_scheduler from ott.geometry import geometry from ott.geometry import low_rank @@ -118,84 +121,11 @@ def update(self, iteration: int, linear_sol, linear_pb, store_errors: bool, @jax.tree_util.register_pytree_node_class -class GromovWasserstein: - """A Gromov Wasserstein solver.""" - - def __init__(self, - epsilon: Optional[float] = None, - rank: int = -1, - linear_ot_solver: Any = None, - min_iterations: int = 5, - max_iterations: int = 50, - threshold: float = 1e-3, - jit: bool = True, - store_inner_errors: bool = False, - **kwargs): - default_epsilon = 1.0 - # Set epsilon value to default if needed, but keep track of whether None was - # passed to handle the case where a linear_ot_solver is passed or not. - self.epsilon = epsilon if epsilon is not None else default_epsilon - self.rank = rank - self.linear_ot_solver = linear_ot_solver - if self.linear_ot_solver is None: - # Detect if user requests low-rank solver. In that case the - # default_epsilon makes little sense, since it was designed for GW. - if self.is_low_rank: - if epsilon is None: - # Use default entropic regularization in LRSinkhorn if None was passed - self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( - rank=self.rank, **kwargs) - else: - # If epsilon is passed, use it to replace the default LRSinkhorn value - self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( - rank=self.rank, - epsilon=self.epsilon, **kwargs) - else: - # When using Entropic GW, epsilon is not handled inside Sinkhorn, - # but rather added back to the Geometry object reinstantiated - # when linearizing the problem. Therefore no need to pass it to solver. - self.linear_ot_solver = sinkhorn.Sinkhorn(**kwargs) - - self.min_iterations = min_iterations - self.max_iterations = max_iterations - self.threshold = threshold - self.jit = jit - self.store_inner_errors = store_inner_errors - self._kwargs = kwargs - - @property - def is_low_rank(self): - return self.rank > 0 - - def tree_flatten(self): - return ([self.epsilon, self.rank, - self.linear_ot_solver, self.threshold], - dict( - min_iterations=self.min_iterations, - max_iterations=self.max_iterations, - jit=self.jit, - store_inner_errors=self.store_inner_errors, - **self._kwargs)) - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls( - epsilon=children[0], - rank=children[1], - linear_ot_solver=children[2], - threshold=children[3], - **aux_data) - - def not_converged(self, state, iteration): - costs, i, tol = state.costs, iteration, self.threshold - return jnp.logical_or( - i <= 2, - jnp.logical_and( - jnp.isfinite(costs[i - 1]), - jnp.logical_not(jnp.isclose(costs[i - 2], costs[i - 1], rtol=tol)))) +class GromovWasserstein(was_solver.WassersteinSolver): + """A Gromov Wasserstein solver, built on generic template.""" def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput: - # Consider converting problem first is using low-rank solver + # Consider converting problem first if using low-rank solver if self.is_low_rank: convert = ( isinstance(prob.geom_xx, pointcloud.PointCloud) and diff --git a/ott/core/segment.py b/ott/core/segment.py new file mode 100644 index 000000000..61c1eb640 --- /dev/null +++ b/ott/core/segment.py @@ -0,0 +1,70 @@ +"""Prepare point clouds for parallel computations.""" + + +import functools +from typing import Any, Dict, Mapping, Optional, Type + +import jax +from jax import numpy as jnp + +def segment_point_cloud( + x: jnp.ndarray, + a: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + num_segments: Optional[int] = None, + indices_are_sorted: Optional[bool] = None, + num_per_segment: Optional[jnp.ndarray] = None, + max_measure_size: Optional[int] = None + ) -> jnp.ndarray: + """ Segment and pad as needed the entries of a point cloud. + There are two interfaces: either use `segment_ids`, and optionally + `num_segments` and `indices_are_sorted`, to describe for each + data point in the matrix to which segment each point corresponds to, + OR use `num_per_segment`, which describes contiguous segments. + + If using the first interface, `num_segments` is required for JIT compilation. + Assumes range(0, `num_segments`) are the segment ids. + + In both cases, jitting requires defining a max_measure_size, the + upper bound on the maximal size of measures, which will be used for padding. + """ + num, dim = x.shape + use_segment_ids = segment_ids is not None + if use_segment_ids: + if num_segments is None: + num_segments = jnp.max(segment_ids) + 1 + if indices_are_sorted is None: + indices_are_sorted = False + + num_per_segment = jax.ops.segment_sum( + jnp.ones_like(segment_ids), + segment_ids, + num_segments=num_segments, + indices_are_sorted=indices_are_sorted) + else: + assert num_per_segment is not None + assert num_segments is None or num_segments == num_per_segment.shape[0] + num_segments = num_per_segment.shape[0] + segment_ids = jnp.arange(num_segments).repeat( + num_per_segment, total_repeat_length=num) + + if a is None: + a = (1 / num_per_segment).repeat(num_per_segment) + + if max_measure_size is None: + max_measure_size = jnp.max(num_per_segment) + + segmented_a = [] + segmented_x = [] + x = jnp.concatenate((x, jnp.zeros((1, dim)))) + a = jnp.concatenate((a, jnp.zeros((1, )))) + for i in range(num_segments): + idx = jnp.where(segment_ids == i, jnp.arange(num), num+1) + idx = jax.lax.dynamic_slice(jnp.sort(idx), (0,), (max_measure_size,)) + z = a.at[idx].get() + segmented_a.append(z) + z = x.at[idx].get() + segmented_x.append(z) + segmented_a = jnp.stack(segmented_a) + segmented_x = jnp.stack(segmented_x) + return segmented_x, segmented_a, num_segments \ No newline at end of file diff --git a/ott/core/sinkhorn_lr.py b/ott/core/sinkhorn_lr.py index dfc67d0c9..fcce08d82 100644 --- a/ott/core/sinkhorn_lr.py +++ b/ott/core/sinkhorn_lr.py @@ -303,8 +303,8 @@ def dysktra_update(self, c_q, c_r, h, ot_prob, state, iteration, # shortcuts for problem's definition. r = self.rank n, m = ot_prob.geom.shape - a, b = ot_prob.a, ot_prob.b - + loga, logb = jnp.log(ot_prob.a), jnp.log(ot_prob.b) + h_old = h g1_old, g2_old = jnp.zeros(r), jnp.zeros(r) f1, f2 = jnp.zeros(n), jnp.zeros(m) @@ -313,7 +313,7 @@ def dysktra_update(self, c_q, c_r, h, ot_prob, state, iteration, w_q, w_r = jnp.zeros(r), jnp.zeros(r) err = jnp.inf state_inner = f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err - constants = c_q, c_r, a, b + constants = c_q, c_r, loga, logb def cond_fn(iteration, constants, state_inner): del iteration, constants @@ -326,11 +326,15 @@ def _softm(f, g, c, axis): def body_fn(iteration, constants, state_inner, compute_error): f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner - c_q, c_r, a, b = constants + c_q, c_r, loga, logb = constants # First Projection - f1 = (jnp.log(a) - _softm(f1, g1_old, c_q, 1)) / self.gamma + f1 - f2 = (jnp.log(b) - _softm(f2, g2_old, c_r, 1)) / self.gamma + f2 + f1 = jnp.where( + jnp.isfinite(loga), + (loga - _softm(f1, g1_old, c_q, 1)) / self.gamma + f1, loga) + f2 = jnp.where( + jnp.isfinite(logb), + (logb - _softm(f2, g2_old, c_r, 1)) / self.gamma + f2, logb) h = h_old + w_gi h = jnp.maximum(jnp.log(min_entry_value) / self.gamma, h) diff --git a/ott/core/was_solver.py b/ott/core/was_solver.py new file mode 100644 index 000000000..147a71ea2 --- /dev/null +++ b/ott/core/was_solver.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A Jax version of the regularised GW Solver (Peyre et al. 2016).""" +import functools +from typing import Any, Dict, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from ott.core import sinkhorn +from ott.core import sinkhorn_lr + +@jax.tree_util.register_pytree_node_class +class WassersteinSolver: + """A generic solver for problems that use a linear reg-OT pb in inner loop.""" + def __init__(self, + epsilon: Optional[float] = None, + rank: int = -1, + linear_ot_solver: Any = None, + min_iterations: int = 5, + max_iterations: int = 50, + threshold: float = 1e-3, + jit: bool = True, + store_inner_errors: bool = False, + **kwargs): + default_epsilon = 1.0 + # Set epsilon value to default if needed, but keep track of whether None was + # passed to handle the case where a linear_ot_solver is passed or not. + self.epsilon = epsilon if epsilon is not None else default_epsilon + self.rank = rank + self.linear_ot_solver = linear_ot_solver + if self.linear_ot_solver is None: + # Detect if user requests low-rank solver. In that case the + # default_epsilon makes little sense, since it was designed for GW. + if self.is_low_rank: + if epsilon is None: + # Use default entropic regularization in LRSinkhorn if None was passed + self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( + rank=self.rank, jit=False, **kwargs) + else: + # If epsilon is passed, use it to replace the default LRSinkhorn value + self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( + rank=self.rank, + epsilon=self.epsilon, **kwargs) + else: + # When using Entropic GW, epsilon is not handled inside Sinkhorn, + # but rather added back to the Geometry object reinstantiated + # when linearizing the problem. Therefore no need to pass it to solver. + self.linear_ot_solver = sinkhorn.Sinkhorn(**kwargs) + + self.min_iterations = min_iterations + self.max_iterations = max_iterations + self.threshold = threshold + self.jit = jit + self.store_inner_errors = store_inner_errors + self._kwargs = kwargs + + @property + def is_low_rank(self): + return self.rank > 0 + + def tree_flatten(self): + return ([self.epsilon, self.rank, + self.linear_ot_solver, self.threshold], + dict( + min_iterations=self.min_iterations, + max_iterations=self.max_iterations, + jit=self.jit, + store_inner_errors=self.store_inner_errors, + **self._kwargs)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + epsilon=children[0], + rank=children[1], + linear_ot_solver=children[2], + threshold=children[3], + **aux_data) + + def not_converged(self, state, iteration): + costs, i, tol = state.costs, iteration, self.threshold + return jnp.logical_or( + i <= 2, + jnp.logical_and( + jnp.isfinite(costs[i - 1]), + jnp.logical_not(jnp.isclose(costs[i - 2], costs[i - 1], rtol=tol)))) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index 1269ed21d..f0bfcd0f8 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -41,6 +41,9 @@ class CostFn(abc.ABC): def pairwise(self, x, y): pass + def barycenter(self, weights, xs): + pass + def __call__(self, x, y): return self.pairwise(x, y) + ( 0 if self.norm is None else self.norm(x) + self.norm(y)) # pylint: disable=not-callable @@ -85,6 +88,10 @@ def norm(self, x): def pairwise(self, x, y): return -2 * jnp.vdot(x, y) + + def barycenter(self, weights, xs): + return jnp.average(xs, weights=weights, axis=0) + @jax.tree_util.register_pytree_node_class diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 51c55cbaf..f2c43ab28 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -468,6 +468,11 @@ def finalize(i: int): raise ValueError( f'Scaling method {summary} does not exist for online mode.') + def barycenter(self, weights): + """Compute barycenter of points in self.x using weights, valid for p=2.0 """ + assert self.power == 2.0 + return self.cost_fn.barycenter(self.x, weights) + @classmethod def prepare_divergences(cls, *args, static_b: bool = False, **kwargs): """Instantiates the geometries used for a divergence computation.""" diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index 6d12db341..9842583c3 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -21,6 +21,7 @@ import jax from jax import numpy as jnp from ott.core import sinkhorn +from ott.core import segment from ott.geometry import geometry from ott.geometry import pointcloud @@ -155,17 +156,13 @@ def segment_sinkhorn_divergence( share_epsilon: bool = True, **kwargs) -> jnp.ndarray: """Computes Sinkhorn divergence between subsets of data with pointcloud. - - There are two interfaces: either use `segment_ids_x`, `segment_ids_y`, and - optionally `num_segments` and `indices_are_sorted`, OR use `num_per_segment_x` - and `num_per_segment_y`. If using the first interface, `num_segments` is - required for JIT compilation. Assumes range(0, `num_segments`) are the segment - ids. The second interface assumes `x` and `y` are segmented contiguously. + + The second interface assumes `x` and `y` are segmented contiguously. In all cases, both `x` and `y` should contain the same number of segments. Each segment will be separately run through the sinkhorn divergence using array padding. - + Args: x: Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array. @@ -203,62 +200,24 @@ def segment_sinkhorn_divergence( Returns: An array of sinkhorn divergence values for each segment. """ - use_segment_ids = segment_ids_x is not None - if use_segment_ids: assert segment_ids_y is not None else: assert num_per_segment_x is not None assert num_per_segment_y is not None - if use_segment_ids: - if num_segments is None: - num_segments = jnp.max(segment_ids_x) + 1 - assert num_segments == jnp.max(segment_ids_y) + 1 - - if indices_are_sorted is None: - indices_are_sorted = False - - num_per_segment_x = jax.ops.segment_sum( - jnp.ones_like(segment_ids_x), - segment_ids_x, - num_segments=num_segments, - indices_are_sorted=indices_are_sorted) - num_per_segment_y = jax.ops.segment_sum( - jnp.ones_like(segment_ids_y), - segment_ids_y, - num_segments=num_segments, - indices_are_sorted=indices_are_sorted) - else: - assert num_segments is None - assert num_per_segment_x is not None - num_segments = num_per_segment_x.shape[0] - segment_ids_x = jnp.arange(num_segments).repeat(num_per_segment_x) - segment_ids_y = jnp.arange(num_segments).repeat(num_per_segment_y) - - num_x = x.shape[0] - num_y = y.shape[0] - - if weights_x is None: - weights_x = (1 / num_per_segment_x).repeat( - num_per_segment_x, total_repeat_length=num_x) - - if weights_y is None: - weights_y = (1 / num_per_segment_y).repeat( - num_per_segment_y, total_repeat_length=num_y) - - segmented_x = jnp.stack( - [x * (segment_ids_x == i)[:, None] for i in range(num_segments)]) - - segmented_y = jnp.stack( - [y * (segment_ids_y == i)[:, None] for i in range(num_segments)]) - - segmented_weights_x = jnp.stack( - [weights_x * (segment_ids_x == i) for i in range(num_segments)]) - - segmented_weights_y = jnp.stack( - [weights_y * (segment_ids_y == i) for i in range(num_segments)]) + segmented_x, segmented_weights_x, num_segments_x = segment.segment_point_cloud( + x, weights_x, + segment_ids_x, num_segments, indices_are_sorted, + num_per_segment_x) + + segmented_y, segmented_weights_y, num_segments_y = segment.segment_point_cloud( + y, weights_y, + segment_ids_y, num_segments, indices_are_sorted, + num_per_segment_y) + + assert num_segments_x == num_segments_y def single_segment_sink_div(padded_x, padded_y, padded_weight_x, padded_weight_y): diff --git a/ott/version.py b/ott/version.py index fe025177d..5a0d2ebc3 100644 --- a/ott/version.py +++ b/ott/version.py @@ -15,4 +15,4 @@ """Current ott version.""" -__version__ = "0.2.3" +__version__ = "0.2.4" diff --git a/tests/core/continuous_barycenter_test.py b/tests/core/continuous_barycenter_test.py new file mode 100644 index 000000000..4dd0b62da --- /dev/null +++ b/tests/core/continuous_barycenter_test.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright 2022 Apple +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Continuous barycenters.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from ott.geometry import pointcloud +from ott.core import bar_problems +from ott.core import continuous_barycenter + +class Barycenter(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(0) + self._dim = 4 + self._num_points = 113 + + @parameterized.product( + rank=[-1, 6], + epsilon=[1e-1, 1e-2], + debiased=[True, False], + jit=[True, False], + init_random=[True, False]) + def test_euclidean_barycenter(self, rank, epsilon, debiased, jit, init_random): + print('Rank: ', rank, 'Epsilon: ', epsilon, 'Debiased', debiased) + rngs = jax.random.split(self.rng, 20) + # Sample 2 point clouds, each of size 113, the first around [0,1]^4, + # Second around [2,3]^4. + y1 = jax.random.uniform(rngs[0], (self._num_points, self._dim)) + y2 = jax.random.uniform(rngs[1], (self._num_points, self._dim)) + 2 + # Merge them + y = jnp.concatenate((y1, y2)) + + # Define segments + num_per_segment = jnp.array([33, 29, 24, 27, 27, 31, 30, 25]) + # Set weights for each segment that sum to 1. + b = [] + for i in range(num_per_segment.shape[0]): + c = jax.random.uniform(rngs[i], (num_per_segment[i],)) + b.append(c / jnp.sum(c)) + b = jnp.concatenate(b, axis=0) + print(b.shape) + # Set a barycenter problem with 8 measures, of irregular sizes. + + bar_prob = bar_problems.BarycenterProblem( + y, b, + num_per_segment=num_per_segment, + num_segments=num_per_segment.shape[0], + max_measure_size=jnp.max(num_per_segment)+3, # +3 set with no purpose. + debiased=debiased) + + # Define solver + threshold = 1e-3 + solver = continuous_barycenter.WassersteinBarycenter( + epsilon=epsilon, + rank=rank, + threshold = threshold, jit=jit) + + # Run it, requesting a barycenter of size 31, with or without initializing + # to 0s (when init_zero is False, initialization is taken randomly in + # points constituting the y's). + bar_size=31 + if init_random: + # choose points randomly in entire support. + x_init= 3 * jax.random.uniform(rngs[-1], (bar_size, self._dim)) + out = solver( + bar_prob, bar_size=bar_size, x_init=x_init) + else: + out = solver(bar_prob, bar_size=bar_size) + + costs = out.costs + costs = costs[costs > -1] + # Check shape + self.assertTrue(out.x.shape==(bar_size,self._dim)) + # Check converged + self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold)) + + # Check barycenter has points roughly in [1,2]^4. + # (Note sampled points where either in [0,1]^4 or [2,3]^4) + self.assertTrue(jnp.all(out.x.ravel()<2.3)) + self.assertTrue(jnp.all(out.x.ravel()>.7)) + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/tests/core/sinkhorn_lr_test.py b/tests/core/sinkhorn_lr_test.py index e0104672e..21e840d8e 100644 --- a/tests/core/sinkhorn_lr_test.py +++ b/tests/core/sinkhorn_lr_test.py @@ -14,7 +14,7 @@ # limitations under the License. # Lint as: python3 -"""Tests for the Policy.""" +"""Tests Sinkhorn Low-Rank solver with various initializations.""" from absl.testing import absltest from absl.testing import parameterized import jax @@ -28,9 +28,9 @@ class SinkhornLRTest(parameterized.TestCase): def setUp(self): super().setUp() self.rng = jax.random.PRNGKey(0) - self.dim = 2 - self.n = 19 - self.m = 17 + self.dim = 4 + self.n = 29 + self.m = 27 self.rng, *rngs = jax.random.split(self.rng, 5) self.x = jax.random.uniform(rngs[0], (self.n, self.dim)) self.y = jax.random.uniform(rngs[1], (self.m, self.dim)) @@ -38,71 +38,72 @@ def setUp(self): b = jax.random.uniform(rngs[3], (self.m,)) # # adding zero weights to test proper handling - # a = a.at[0].set(0) - # b = b.at[3].set(0) + a = a.at[0].set(0) + b = b.at[3].set(0) self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - @parameterized.parameters([True], [False]) - def test_euclidean_point_cloud(self, use_lrcgeom): - """Two point clouds, tested with various parameters.""" - init_type_arr = ["rank_2", "random"] - for init_type in init_type_arr: - threshold = 1e-9 - gamma = 100 - geom = pointcloud.PointCloud(self.x, self.y) - if use_lrcgeom: - geom = geom.to_LRCGeometry() - ot_prob = problems.LinearProblem(geom, self.a, self.b) - solver = sinkhorn_lr.LRSinkhorn( - threshold=threshold, - gamma=gamma, - rank=2, - epsilon=0.0, - init_type=init_type, - ) - costs = solver(ot_prob).costs - self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold)) - cost_1 = costs[costs > -1][-1] + @parameterized.product( + use_lrcgeom=[True, False], + init_type= ["rank_2", "random"]) + def test_euclidean_point_cloud(self, use_lrcgeom, init_type): + """Two point clouds, tested with 2 different initializations.""" + threshold = 1e-6 + geom = pointcloud.PointCloud(self.x, self.y) + # This test to check LR can work both with LRCGeometries and regular ones + if use_lrcgeom: + geom = geom.to_LRCGeometry() + ot_prob = problems.LinearProblem(geom, self.a, self.b) - solver = sinkhorn_lr.LRSinkhorn( - threshold=threshold, - gamma=gamma, - rank=10, - epsilon=0.0, - init_type=init_type, - ) - out = solver(ot_prob) - costs = out.costs - cost_2 = costs[costs > -1][-1] - self.assertGreater(cost_1, cost_2) + # Start with a low rank parameter + solver = sinkhorn_lr.LRSinkhorn( + threshold=threshold, + rank=10, + epsilon=0.0, + init_type=init_type, + ) + solved = solver(ot_prob) + costs = solved.costs + costs= costs[ costs > -1] + + # Check convergence + self.assertTrue(solved.converged) + self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold)) + + # Store cost value. + cost_1 = costs[-1] - other_geom = pointcloud.PointCloud(self.x, self.y + 0.3) - cost_other = out.cost_at_geom(other_geom) - self.assertGreater(cost_other, 0.0) + # Try with higher rank + solver = sinkhorn_lr.LRSinkhorn( + threshold=threshold, + rank=14, + epsilon=0.0, + init_type=init_type, + ) + out = solver(ot_prob) + costs = out.costs + cost_2 = costs[costs > -1][-1] + # Ensure solution with more rank budget has lower cost (not guaranteed) + self.assertGreater(cost_1, cost_2) - solver = sinkhorn_lr.LRSinkhorn( - threshold=threshold, - gamma=gamma, - rank=14, - epsilon=1e-1, - init_type=init_type, - ) - out = solver(ot_prob) - costs = out.costs - cost_3 = costs[costs > -1][-1] + # Ensure cost can still be computed on different geometry. + other_geom = pointcloud.PointCloud(self.x, self.y + 0.3) + cost_other = out.cost_at_geom(other_geom) + self.assertGreater(cost_other, 0.0) - solver = sinkhorn_lr.LRSinkhorn( - threshold=threshold, - gamma=gamma, - rank=14, - epsilon=1e-3, - init_type=init_type, - ) - out = solver(ot_prob) - costs = out.costs - cost_4 = costs[costs > -1][-1] - self.assertGreater(cost_3, cost_4) + # Ensure cost is higher when using high entropy. + # (Note that for small entropy regularizers, this can be the opposite + # due to non-convexity of problem and benefit of adding regularizer. + solver = sinkhorn_lr.LRSinkhorn( + threshold=threshold, + rank=14, + epsilon=1e-1, + init_type=init_type, + ) + out = solver(ot_prob) + costs = out.costs + cost_3 = costs[costs > -1][-1] + self.assertGreater(cost_3, cost_2) -if __name__ == "__main__": +if __name__ == '__main__': absltest.main()