Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix: avoid mixing up linear and quadratic in genot #517

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/ott/neural/methods/flows/genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -29,12 +29,11 @@

__all__ = ["GENOT"]

# input: (src_lin, tgt_lin, src_quad, tgt_quad), output: (len(src), len(tgt))
# all are optional because the problem can be linear/quadratic/fused
DataMatchFn_t = Callable[[
Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray]
], jnp.ndarray]
LinTerm = Tuple[jnp.ndarray, jnp.ndarray]
QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
Optional[jnp.ndarray]]
DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm],
jnp.ndarray]]


class GENOT:
Expand All @@ -49,8 +48,14 @@ class GENOT:
vf: Vector field parameterized by a neural network.
flow: Flow between the latent and the target distributions.
data_match_fn: Function to match samples from the source and the target
distributions with a ``(src_lin, tgt_lin, src_quad, tgt_quad) -> matching``
signature.
distributions. Depending on the data passed in :meth:`__call__`, it has
the following signature:

- ``(src_lin, tgt_lin) -> matching`` - linear matching.
- ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` -
quadratic (fused) GW matching. In the pure GW setting, both ``src_lin``
and ``tgt_lin`` will be set to :obj:`None`.

source_dim: Dimensionality of the source distribution.
target_dim: Dimensionality of the target distribution.
condition_dim: Dimension of the conditions. If :obj:`None`, the underlying
Expand All @@ -73,7 +78,7 @@ def __init__(
self,
vf: velocity_field.VelocityField,
flow: dynamics.BaseFlow,
data_match_fn: DataMatchFn_t,
data_match_fn: DataMatchFn,
*,
source_dim: int,
target_dim: int,
Expand Down Expand Up @@ -162,19 +167,24 @@ def __call__(

def prepare_data(
batch: Dict[str, jnp.ndarray]
) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
Optional[jnp.ndarray]]]:
src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad")
tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad")
arrs = src_lin, tgt_lin, src_quad, tgt_quad

if src_quad is None and tgt_quad is None: # lin
src, tgt = src_lin, tgt_lin
arrs = src_lin, tgt_lin
elif src_lin is None and tgt_lin is None: # quad
src, tgt = src_quad, tgt_quad
elif all(arr is not None for arr in arrs): # fused quad
arrs = src_quad, tgt_quad
elif all(
arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad)
): # fused quad
src = jnp.concatenate([src_lin, src_quad], axis=1)
tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1)
arrs = src_quad, tgt_quad, src_lin, tgt_lin
else:
raise RuntimeError("Cannot infer OT problem type from data.")

Expand Down
21 changes: 7 additions & 14 deletions tests/neural/methods/genot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Literal, Optional
from typing import Literal

import pytest

Expand All @@ -28,20 +27,14 @@
from ott.solvers import utils as solver_utils


def data_match_fn(
src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray],
src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *,
typ: Literal["lin", "quad", "fused"]
) -> jnp.ndarray:
def get_match_fn(typ: Literal["lin", "quad", "fused"]):
if typ == "lin":
return solver_utils.match_linear(x=src_lin, y=tgt_lin)
return solver_utils.match_linear
if typ == "quad":
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
return solver_utils.match_quadratic
if typ == "fused":
return solver_utils.match_quadratic(
xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin
)
raise NotImplementedError(f"Unknown type: {typ}.")
return solver_utils.match_quadratic
raise NotImplementedError(typ)


class TestGENOT:
Expand Down Expand Up @@ -69,7 +62,7 @@ def test_genot(self, rng: jax.Array, dl: str, request):
model = genot.GENOT(
vf,
flow=dynamics.ConstantNoiseFlow(0.0),
data_match_fn=functools.partial(data_match_fn, typ=problem_type),
data_match_fn=get_match_fn(problem_type),
source_dim=src_dim,
target_dim=tgt_dim,
condition_dim=cond_dim,
Expand Down
Loading