From 42b05a9e97eca956b7cb93f3196d32eccefb17d3 Mon Sep 17 00:00:00 2001 From: soerenab Date: Fri, 12 Apr 2024 14:13:38 +0200 Subject: [PATCH] prepare_data() in GENOT now returns a tuple instead of a dict; change order of args in utils.match_quadratic() --- src/ott/neural/methods/flows/genot.py | 13 ++++++------- src/ott/solvers/utils.py | 4 ++-- tests/neural/methods/genot_test.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/ott/neural/methods/flows/genot.py b/src/ott/neural/methods/flows/genot.py index c70e5375c..7bc989fd7 100644 --- a/src/ott/neural/methods/flows/genot.py +++ b/src/ott/neural/methods/flows/genot.py @@ -162,26 +162,25 @@ def __call__( def prepare_data( batch: Dict[str, jnp.ndarray] - ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Dict[ - str, jnp.ndarray]]: + ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], + Optional[jnp.ndarray]]]: src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad") tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad") arrs = src_lin, tgt_lin, src_quad, tgt_quad if src_quad is None and tgt_quad is None: # lin src, tgt = src_lin, tgt_lin - arrs_dict = {"x": src_lin, "y": tgt_lin} + arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args elif src_lin is None and tgt_lin is None: # quad src, tgt = src_quad, tgt_quad - arrs_dict = {"xx": src_quad, "yy": tgt_quad} elif all(arr is not None for arr in arrs): # fused quad src = jnp.concatenate([src_lin, src_quad], axis=1) tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1) - arrs_dict = {"x": src_lin, "y": tgt_lin, "xx": src_quad, "yy": tgt_quad} else: raise RuntimeError("Cannot infer OT problem type from data.") - return (src, batch.get("src_condition"), tgt), arrs_dict + return (src, batch.get("src_condition"), tgt), arrs rng = utils.default_prng_key(rng) training_logs = {"loss": []} @@ -196,7 +195,7 @@ def prepare_data( time = self.time_sampler(rng_time, n * self.n_samples_per_src) latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src)) - tmat = self.data_match_fn(**matching_data) # (n, m) + tmat = self.data_match_fn(*matching_data) # (n, m) src_ixs, tgt_ixs = solver_utils.sample_conditional( # (n, k), (m, k) rng_resample, tmat, diff --git a/src/ott/solvers/utils.py b/src/ott/solvers/utils.py index f7bdae63a..667d74bf5 100644 --- a/src/ott/solvers/utils.py +++ b/src/ott/solvers/utils.py @@ -59,10 +59,10 @@ def match_linear( def match_quadratic( + x: Optional[jnp.ndarray], + y: Optional[jnp.ndarray], xx: jnp.ndarray, yy: jnp.ndarray, - x: Optional[jnp.ndarray] = None, - y: Optional[jnp.ndarray] = None, scale_cost: ScaleCost_t = 1.0, cost_fn: Optional[costs.CostFn] = None, **kwargs: Any diff --git a/tests/neural/methods/genot_test.py b/tests/neural/methods/genot_test.py index b3eaae383..78f0aa8cb 100644 --- a/tests/neural/methods/genot_test.py +++ b/tests/neural/methods/genot_test.py @@ -32,7 +32,7 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]): return solver_utils.match_linear if typ == "quad": return solver_utils.match_quadratic - # typ == "fused": + # typ == "fused" return solver_utils.match_quadratic