-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linear and quadratic part get mixed up in genot #514
Comments
Also, I am wondering if the current implementation (aside from the above) is compatible with the linear match function: def match_linear(
x: jnp.ndarray,
y: Optional[jnp.ndarray],
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
scale_cost: ScaleCost_t = 1.0,
**kwargs: Any
) -> jnp.ndarray: which takes only Maybe it would be change |
Hi @soerenab , thanks for catching this! Yes, agree, would be best if |
I already changed it locally, I can prepare a pull request :) |
That would be great, thanks! |
Describe the bug
I am trying to run GENOT for the fused Gromov case and get a shape mismatch error that is raised by this line
tmat = self.data_match_fn(*matching_data)
.matching_data
is returned byprepare_data()
and consists (in the fused case) of four arrays, which arearrs = src_lin, tgt_lin, src_quad, tgt_quad
, see https://github.com/ott-jax/ott/blob/main/src/ott/neural/methods/flows/genot.py#L169.However, the arguments of the quadratic match function are the other way around:
, i.e., first it takes the quadratic terms, then the linear ones. Hence there is a mismatch, and this (presumably) results in the shape error I am getting.
To Reproduce
I am using
ott-jax==0.4.6.dev18+ge1bbd34
which corresponds topip install -e.
of the latest ott code on github (cloned & installed today).The text was updated successfully, but these errors were encountered: