Replies: 1 comment
-
Thanks for the question! This doesn't seem like a JAX discussion to me since it builds so much on lineax and equinox. I'd recommend trying to isolate the issue related to one of those packages and open an issue or discussion on one of those repositories. I expect you'll get more useful feedback there, because I don't know enough about the details of those libraries to provide useful suggestions. If you do find that there's a JAX issue once you get to the bottom of it, please report back with the JAX specific example code. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi 👋 ,
I'm not sure if this is the best place for my question but I've been working on this project where we are using the [equinox] (https://github.com/patrick-kidger/equinox) internal API (
eqxi
) to define a trace primitive for our traceax package, which utilizes the linear operators of lineax.The issue I'm running into is that our trace estimation implementation is producing results that are 2 * the expected value compared to using
jax.numpy.trace
.I suspect that the issue is with the
_make_identity
function. But I am not sure why.The issue gets fixed if I replace (in the jvp implementation below) this
with this
jvp implementation
transpose implementation
_make_identity implementation
minimal test case that demonstrates the issue:
The resulting gradient should be an identity matrix, but our code is producing a matrix with all elements equal to 2 on the diagonal.
I am not quite sure what I am missing here.
Your input would be appreciated.
Thank you so much!
Beta Was this translation helpful? Give feedback.
All reactions