Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear and quadratic part get mixed up in genot #514

Closed
soerenab opened this issue Apr 8, 2024 · 4 comments · Fixed by #517
Closed

linear and quadratic part get mixed up in genot #514

soerenab opened this issue Apr 8, 2024 · 4 comments · Fixed by #517

Comments

@soerenab
Copy link
Contributor

soerenab commented Apr 8, 2024

Describe the bug
I am trying to run GENOT for the fused Gromov case and get a shape mismatch error that is raised by this line
tmat = self.data_match_fn(*matching_data).

matching_data is returned by prepare_data() and consists (in the fused case) of four arrays, which are
arrs = src_lin, tgt_lin, src_quad, tgt_quad, see https://github.com/ott-jax/ott/blob/main/src/ott/neural/methods/flows/genot.py#L169.

However, the arguments of the quadratic match function are the other way around:

def match_quadratic(
    xx: jnp.ndarray,
    yy: jnp.ndarray,
    x: Optional[jnp.ndarray] = None,
    y: Optional[jnp.ndarray] = None,
    scale_cost: ScaleCost_t = 1.0,
    cost_fn: Optional[costs.CostFn] = None,
    **kwargs: Any
)

, i.e., first it takes the quadratic terms, then the linear ones. Hence there is a mismatch, and this (presumably) results in the shape error I am getting.

To Reproduce
I am using ott-jax==0.4.6.dev18+ge1bbd34 which corresponds to pip install -e. of the latest ott code on github (cloned & installed today).

@soerenab
Copy link
Contributor Author

soerenab commented Apr 8, 2024

Also, I am wondering if the current implementation (aside from the above) is compatible with the linear match function:

def match_linear(
    x: jnp.ndarray,
    y: Optional[jnp.ndarray],
    cost_fn: Optional[costs.CostFn] = None,
    epsilon: Optional[float] = None,
    scale_cost: ScaleCost_t = 1.0,
    **kwargs: Any
) -> jnp.ndarray:

which takes only x and y whereas arrs will, according to prepare_data(), always consist of 4 arrays.

Maybe it would be change prepare_data() to return a dict with only the non-None terms and to use **matching_data to pass the arguments to the matching function.

@michalk8
Copy link
Collaborator

michalk8 commented Apr 8, 2024

Hi @soerenab , thanks for catching this! Yes, agree, would be best if prepare_data were to return a dict!

@soerenab
Copy link
Contributor Author

soerenab commented Apr 8, 2024

I already changed it locally, I can prepare a pull request :)

@michalk8
Copy link
Collaborator

michalk8 commented Apr 8, 2024

That would be great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants