-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom_jvp to jax.numpy.ldexp. #23923
Conversation
Thanks for looking into this! I'm not sure this is actually the correct gradient behavior for this function. As I mentioned in #11467 (comment), Maybe the best solution here would be to define |
@jakevdp Thanks for your feedback. It seems to me that, if a JAX function extensionally computes a differentiable mathematical function Indeed, my understanding is that one of the primary use cases of Since Thoughts? |
So my question is, if we're just computing def ldexp(x, y):
return x * 2 ** y Then no custom JVP is necessary at all. |
I think (the bit-twiddling implementation of) ldexp is supposed to be a faster way to do that, at least on some platforms. The CUDA Math API has dedicated ldexp functions for single and double precision. The C standard library also has dedicated ldexp functions. Not sure what the performance advantage is for different platforms. |
Sure, but JAX does not dispatch to any of those fast kernels, and I imagine the current bit-twiddling implementation is far slower than just writing |
I guess stepping back, here are the options:
Until now, we've approached this as (1). Which do you think is the right approach? |
You raise an interesting point. If there's indeed no performance advantage to the bit-twiddling implementation of At least for the time being, perhaps it's worth adding a note to the documentation for ldexp stating that there's no performance advantage to its current bit-twiddling implementation over I'd also welcome any additional opinions from people who are more familiar with the hardware side of things. |
7ff34a5
to
acb47ab
Compare
acb47ab
to
4c6dfb3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
One other thing: we should update the function docs with info about the implementation (this would involve removing the @implements
decorator and writing a full docstring).
If you'd like to do this as part of the PR then go ahead, but I'm happy to update docs in a followup.
I'll let you handle that so you can choose the best wording. |
4c6dfb3
to
65a58d6
Compare
Thanks for putting this together! |
Addresses #11467 (comment).