diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index 7bc989fd7..d0d88fd04 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -29,12 +29,11 @@ __all__ = ["GENOT"] -# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt)) -# all are optional because the problem can be linear/quadratic/fused -DataMatchFn_t = Callable[[ - Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray] -], jnp.ndarray] +LinTerm = Tuple[jnp.ndarray, jnp.ndarray] +QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]] +DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm], + jnp.ndarray]] class GENOT: @@ -49,8 +48,14 @@ class GENOT: vf: Vector field parameterized by a neural network. flow: Flow between the latent and the target distributions. data_match_fn: Function to match samples from the source and the target - distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching`` - signature. + distributions. Depending on the data passed in :meth:`__call__`, it has + the following signature: + + - ``(src_lin, tgt_lin) -> matching`` - linear matching. + - ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` - + quadratic (fused) GW matching. In the pure GW setting, btoh ``src_lin`` + and ``tgt_lin`` will be set to :obj:`None`. + source_dim: Dimensionality of the source distribution. target_dim: Dimensionality of the target distribution. condition_dim: Dimension of the conditions. If :obj:`None`, the underlying @@ -73,7 +78,7 @@ def __init__( self, vf: velocity_field.VelocityField, flow: dynamics.BaseFlow, - data_match_fn: DataMatchFn_t, + data_match_fn: DataMatchFn, *, source_dim: int, target_dim: int, @@ -167,16 +172,19 @@ def prepare_data( Optional[jnp.ndarray]]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") - arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin - arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args + arrs = src_lin, tgt_lin elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad - elif all(arr is not None for arr in arrs): # fused quad + arrs = src_quad, tgt_quad + elif all( + arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad) + ): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) + arrs = src_quad, tgt_quad, src_lin, tgt_lin else: raise RuntimeError("Cannot infer OT problem type from data.") diff --git a/src/ott/solvers/utils.py b/src/ott/solvers/utils.py index 667d74bf5..f7bdae63a 100644 --- a/src/ott/solvers/utils.py +++ b/src/ott/solvers/utils.py @@ -59,10 +59,10 @@ def match_linear( def match_quadratic( - x: Optional[jnp.ndarray], - y: Optional[jnp.ndarray], xx: jnp.ndarray, yy: jnp.ndarray, + x: Optional[jnp.ndarray] = None, + y: Optional[jnp.ndarray] = None, scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 78f0aa8cb..d4d8a1399 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -32,8 +32,9 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]): return solver_utils.match_linear if typ == "quad": return solver_utils.match_quadratic - # typ == "fused" - return solver_utils.match_quadratic + if typ == "fused": + return solver_utils.match_quadratic + raise NotImplementedError(typ) class TestGENOT: @@ -61,7 +62,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): model = genot.GENOT( vf, flow=dynamics.ConstantNoiseFlow(0.0), - data_match_fn=get_match_fn(typ=problem_type), + data_match_fn=get_match_fn(problem_type), source_dim=src_dim, target_dim=tgt_dim, condition_dim=cond_dim,