Skip to content

Commit

Permalink
prepare_data() in GENOT now returns a tuple instead of a dict; change…
Browse files Browse the repository at this point in the history
… order of args in utils.match_quadratic()
  • Loading branch information
soerenab committed Apr 12, 2024
1 parent af1f092 commit 42b05a9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
13 changes: 6 additions & 7 deletions src/ott/neural/methods/flows/genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,26 +162,25 @@ def __call__(

def prepare_data(
batch: Dict[str, jnp.ndarray]
) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Dict[
str, jnp.ndarray]]:
) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
Optional[jnp.ndarray]]]:
src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad")
tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad")
arrs = src_lin, tgt_lin, src_quad, tgt_quad

if src_quad is None and tgt_quad is None: # lin
src, tgt = src_lin, tgt_lin
arrs_dict = {"x": src_lin, "y": tgt_lin}
arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args
elif src_lin is None and tgt_lin is None: # quad
src, tgt = src_quad, tgt_quad
arrs_dict = {"xx": src_quad, "yy": tgt_quad}
elif all(arr is not None for arr in arrs): # fused quad
src = jnp.concatenate([src_lin, src_quad], axis=1)
tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1)
arrs_dict = {"x": src_lin, "y": tgt_lin, "xx": src_quad, "yy": tgt_quad}
else:
raise RuntimeError("Cannot infer OT problem type from data.")

return (src, batch.get("src_condition"), tgt), arrs_dict
return (src, batch.get("src_condition"), tgt), arrs

rng = utils.default_prng_key(rng)
training_logs = {"loss": []}
Expand All @@ -196,7 +195,7 @@ def prepare_data(
time = self.time_sampler(rng_time, n * self.n_samples_per_src)
latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src))

tmat = self.data_match_fn(**matching_data) # (n, m)
tmat = self.data_match_fn(*matching_data) # (n, m)
src_ixs, tgt_ixs = solver_utils.sample_conditional( # (n, k), (m, k)
rng_resample,
tmat,
Expand Down
4 changes: 2 additions & 2 deletions src/ott/solvers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def match_linear(


def match_quadratic(
x: Optional[jnp.ndarray],
y: Optional[jnp.ndarray],
xx: jnp.ndarray,
yy: jnp.ndarray,
x: Optional[jnp.ndarray] = None,
y: Optional[jnp.ndarray] = None,
scale_cost: ScaleCost_t = 1.0,
cost_fn: Optional[costs.CostFn] = None,
**kwargs: Any
Expand Down
2 changes: 1 addition & 1 deletion tests/neural/methods/genot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]):
return solver_utils.match_linear
if typ == "quad":
return solver_utils.match_quadratic
# typ == "fused":
# typ == "fused"
return solver_utils.match_quadratic


Expand Down

0 comments on commit 42b05a9

Please sign in to comment.