Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Apr 12, 2024
1 parent 42b05a9 commit be8390e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
34 changes: 21 additions & 13 deletions src/ott/neural/methods/flows/genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -29,12 +29,11 @@

__all__ = ["GENOT"]

# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt))
# all are optional because the problem can be linear/quadratic/fused
DataMatchFn_t = Callable[[
Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]
], jnp.ndarray]
LinTerm = Tuple[jnp.ndarray, jnp.ndarray]
QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
Optional[jnp.ndarray]]
DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm],
jnp.ndarray]]


class GENOT:
Expand All @@ -49,8 +48,14 @@ class GENOT:
vf: Vector field parameterized by a neural network.
flow: Flow between the latent and the target distributions.
data_match_fn: Function to match samples from the source and the target
distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching``
signature.
distributions. Depending on the data passed in :meth:`__call__`, it has
the following signature:
- ``(src_lin, tgt_lin) -> matching`` - linear matching.
- ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` -
quadratic (fused) GW matching. In the pure GW setting, btoh ``src_lin``
and ``tgt_lin`` will be set to :obj:`None`.
source_dim: Dimensionality of the source distribution.
target_dim: Dimensionality of the target distribution.
condition_dim: Dimension of the conditions. If :obj:`None`, the underlying
Expand All @@ -73,7 +78,7 @@ def __init__(
self,
vf: velocity_field.VelocityField,
flow: dynamics.BaseFlow,
data_match_fn: DataMatchFn_t,
data_match_fn: DataMatchFn,
*,
source_dim: int,
target_dim: int,
Expand Down Expand Up @@ -167,16 +172,19 @@ def prepare_data(
Optional[jnp.ndarray]]]:
src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad")
tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad")
arrs = src_lin, tgt_lin, src_quad, tgt_quad

if src_quad is None and tgt_quad is None: # lin
src, tgt = src_lin, tgt_lin
arrs = src_lin, tgt_lin # get rid of src_quad, tgt_quad args
arrs = src_lin, tgt_lin
elif src_lin is None and tgt_lin is None: # quad
src, tgt = src_quad, tgt_quad
elif all(arr is not None for arr in arrs): # fused quad
arrs = src_quad, tgt_quad
elif all(
arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad)
): # fused quad
src = jnp.concatenate([src_lin, src_quad], axis=1)
tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1)
arrs = src_quad, tgt_quad, src_lin, tgt_lin
else:
raise RuntimeError("Cannot infer OT problem type from data.")

Expand Down
4 changes: 2 additions & 2 deletions src/ott/solvers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def match_linear(


def match_quadratic(
x: Optional[jnp.ndarray],
y: Optional[jnp.ndarray],
xx: jnp.ndarray,
yy: jnp.ndarray,
x: Optional[jnp.ndarray] = None,
y: Optional[jnp.ndarray] = None,
scale_cost: ScaleCost_t = 1.0,
cost_fn: Optional[costs.CostFn] = None,
**kwargs: Any
Expand Down
7 changes: 4 additions & 3 deletions tests/neural/methods/genot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def get_match_fn(typ: Literal["lin", "quad", "fused"]):
return solver_utils.match_linear
if typ == "quad":
return solver_utils.match_quadratic
# typ == "fused"
return solver_utils.match_quadratic
if typ == "fused":
return solver_utils.match_quadratic
raise NotImplementedError(typ)


class TestGENOT:
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_genot(self, rng: jax.Array, dl: str, request):
model = genot.GENOT(
vf,
flow=dynamics.ConstantNoiseFlow(0.0),
data_match_fn=get_match_fn(typ=problem_type),
data_match_fn=get_match_fn(problem_type),
source_dim=src_dim,
target_dim=tgt_dim,
condition_dim=cond_dim,
Expand Down

0 comments on commit be8390e

Please sign in to comment.