From f2b7f9f7d5d0824f52deb197400f6cbfb231ea94 Mon Sep 17 00:00:00 2001 From: soerenab Date: Tue, 9 Apr 2024 09:51:44 +0200 Subject: [PATCH 1/5] bug fix: avoid mixing up linear and quadratic part by returning Dict in genot prepare_data() --- src/ott/neural/methods/flows/genot.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index ce200d376..c70e5375c 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -162,23 +162,26 @@ def __call__( def prepare_data( batch: Dict[str, jnp.ndarray] - ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Dict[ + str, jnp.ndarray]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin + arrs_dict = {"x": src_lin, "y": tgt_lin} elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad + arrs_dict = {"xx": src_quad, "yy": tgt_quad} elif all(arr is not None for arr in arrs): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) + arrs_dict = {"x": src_lin, "y": tgt_lin, "xx": src_quad, "yy": tgt_quad} else: raise RuntimeError("Cannot infer OT problem type from data.") - return (src, batch.get("src_condition"), tgt), arrs + return (src, batch.get("src_condition"), tgt), arrs_dict rng = utils.default_prng_key(rng) training_logs = {"loss": []} @@ -193,7 +196,7 @@ def prepare_data( time = self.time_sampler(rng_time, n * self.n_samples_per_src) latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src)) - tmat = self.data_match_fn(*matching_data) # (n, m) + tmat = self.data_match_fn(**matching_data) # (n, m) src_ixs, tgt_ixs = solver_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, From af1f09236a8e98de21a2b7c3b644afee3aebbe69 Mon Sep 17 00:00:00 2001 From: soerenab Date: Fri, 12 Apr 2024 13:49:45 +0200 Subject: [PATCH 2/5] fix data_match_fn() setup in genot tests --- tests/neural/methods/genot_test.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 2c746596c..b3eaae383 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Literal, Optional +from typing import Literal import pytest @@ -28,20 +27,13 @@ from ott.solvers import utils as solver_utils -def data_match_fn( - src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray], - src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *, - typ: Literal["lin", "quad", "fused"] -) -> jnp.ndarray: +def get_match_fn(typ: Literal["lin", "quad", "fused"]): if typ == "lin": - return solver_utils.match_linear(x=src_lin, y=tgt_lin) + return solver_utils.match_linear if typ == "quad": - return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad) - if typ == "fused": - return solver_utils.match_quadratic( - xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin - ) - raise NotImplementedError(f"Unknown type: {typ}.") + return solver_utils.match_quadratic + # typ == "fused": + return solver_utils.match_quadratic class TestGENOT: @@ -69,7 +61,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): model = genot.GENOT( vf, flow=dynamics.ConstantNoiseFlow(0.0), - data_match_fn=functools.partial(data_match_fn, typ=problem_type), + data_match_fn=get_match_fn(typ=problem_type), source_dim=src_dim, target_dim=tgt_dim, condition_dim=cond_dim, From 42b05a9e97eca956b7cb93f3196d32eccefb17d3 Mon Sep 17 00:00:00 2001 From: soerenab Date: Fri, 12 Apr 2024 14:13:38 +0200 Subject: [PATCH 3/5] prepare_data() in GENOT now returns a tuple instead of a dict; change order of args in utils.match_quadratic() --- src/ott/neural/methods/flows/genot.py | 13 ++++++------- src/ott/solvers/utils.py | 4 ++-- tests/neural/methods/genot_test.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index c70e5375c..7bc989fd7 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -162,26 +162,25 @@ def __call__( def prepare_data( batch: Dict[str, jnp.ndarray] - ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Dict[ - str, jnp.ndarray]]: + ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin - arrs_dict = {"x": src_lin, "y": tgt_lin} + arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad - arrs_dict = {"xx": src_quad, "yy": tgt_quad} elif all(arr is not None for arr in arrs): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) - arrs_dict = {"x": src_lin, "y": tgt_lin, "xx": src_quad, "yy": tgt_quad} else: raise RuntimeError("Cannot infer OT problem type from data.") - return (src, batch.get("src_condition"), tgt), arrs_dict + return (src, batch.get("src_condition"), tgt), arrs rng = utils.default_prng_key(rng) training_logs = {"loss": []} @@ -196,7 +195,7 @@ def prepare_data( time = self.time_sampler(rng_time, n * self.n_samples_per_src) latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src)) - tmat = self.data_match_fn(**matching_data) # (n, m) + tmat = self.data_match_fn(*matching_data) # (n, m) src_ixs, tgt_ixs = solver_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, diff --git a/src/ott/solvers/utils.py b/src/ott/solvers/utils.py index f7bdae63a..667d74bf5 100644 --- a/src/ott/solvers/utils.py +++ b/src/ott/solvers/utils.py @@ -59,10 +59,10 @@ def match_linear( def match_quadratic( + x: Optional[jnp.ndarray], + y: Optional[jnp.ndarray], xx: jnp.ndarray, yy: jnp.ndarray, - x: Optional[jnp.ndarray] = None, - y: Optional[jnp.ndarray] = None, scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index b3eaae383..78f0aa8cb 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -32,7 +32,7 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]): return solver_utils.match_linear if typ == "quad": return solver_utils.match_quadratic - # typ == "fused": + # typ == "fused" return solver_utils.match_quadratic From be8390e46cd5b4f8ed28c7dec7b83a0847dbdab5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 12 Apr 2024 17:21:38 +0200 Subject: [PATCH 4/5] Update docs --- src/ott/neural/methods/flows/genot.py | 34 +++++++++++++++++---------- src/ott/solvers/utils.py | 4 ++-- tests/neural/methods/genot_test.py | 7 +++--- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index 7bc989fd7..d0d88fd04 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -29,12 +29,11 @@ __all__ = ["GENOT"] -# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt)) -# all are optional because the problem can be linear/quadratic/fused -DataMatchFn_t = Callable[[ - Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray] -], jnp.ndarray] +LinTerm = Tuple[jnp.ndarray, jnp.ndarray] +QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]] +DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm], + jnp.ndarray]] class GENOT: @@ -49,8 +48,14 @@ class GENOT: vf: Vector field parameterized by a neural network. flow: Flow between the latent and the target distributions. data_match_fn: Function to match samples from the source and the target - distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching`` - signature. + distributions. Depending on the data passed in :meth:`__call__`, it has + the following signature: + + - ``(src_lin, tgt_lin) -> matching`` - linear matching. + - ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` - + quadratic (fused) GW matching. In the pure GW setting, btoh ``src_lin`` + and ``tgt_lin`` will be set to :obj:`None`. + source_dim: Dimensionality of the source distribution. target_dim: Dimensionality of the target distribution. condition_dim: Dimension of the conditions. If :obj:`None`, the underlying @@ -73,7 +78,7 @@ def __init__( self, vf: velocity_field.VelocityField, flow: dynamics.BaseFlow, - data_match_fn: DataMatchFn_t, + data_match_fn: DataMatchFn, *, source_dim: int, target_dim: int, @@ -167,16 +172,19 @@ def prepare_data( Optional[jnp.ndarray]]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") - arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin - arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args + arrs = src_lin, tgt_lin elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad - elif all(arr is not None for arr in arrs): # fused quad + arrs = src_quad, tgt_quad + elif all( + arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad) + ): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) + arrs = src_quad, tgt_quad, src_lin, tgt_lin else: raise RuntimeError("Cannot infer OT problem type from data.") diff --git a/src/ott/solvers/utils.py b/src/ott/solvers/utils.py index 667d74bf5..f7bdae63a 100644 --- a/src/ott/solvers/utils.py +++ b/src/ott/solvers/utils.py @@ -59,10 +59,10 @@ def match_linear( def match_quadratic( - x: Optional[jnp.ndarray], - y: Optional[jnp.ndarray], xx: jnp.ndarray, yy: jnp.ndarray, + x: Optional[jnp.ndarray] = None, + y: Optional[jnp.ndarray] = None, scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index 78f0aa8cb..d4d8a1399 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -32,8 +32,9 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]): return solver_utils.match_linear if typ == "quad": return solver_utils.match_quadratic - # typ == "fused" - return solver_utils.match_quadratic + if typ == "fused": + return solver_utils.match_quadratic + raise NotImplementedError(typ) class TestGENOT: @@ -61,7 +62,7 @@ def test_genot(self, rng: jax.Array, dl: str, request): model = genot.GENOT( vf, flow=dynamics.ConstantNoiseFlow(0.0), - data_match_fn=get_match_fn(typ=problem_type), + data_match_fn=get_match_fn(problem_type), source_dim=src_dim, target_dim=tgt_dim, condition_dim=cond_dim, From 2bbf90444f630829a420d8ffdf66d3e16bb9d5b1 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 12 Apr 2024 17:27:56 +0200 Subject: [PATCH 5/5] Fix typo --- src/ott/neural/methods/flows/genot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index d0d88fd04..e7ca5c1bc 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -53,7 +53,7 @@ class GENOT: - ``(src_lin, tgt_lin) -> matching`` - linear matching. - ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` - - quadratic (fused) GW matching. In the pure GW setting, btoh ``src_lin`` + quadratic (fused) GW matching. In the pure GW setting, both ``src_lin`` and ``tgt_lin`` will be set to :obj:`None`. source_dim: Dimensionality of the source distribution.