Can Pallas Auto Diff? #19184
Unanswered
karan-dalal
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The Pallas Docs state that transformations should work inside a Pallas kernel, using the following example with jax.grad:
I'm wondering if it would be possible to actually differentiate an entire model inside a kernel, since this would reduce to just matrix multiplications / other operations that could be executed on Mosiac / Triton.
For example, could I forward and backprop through a 2 Layer MLP with a non-linearity inside a Pallas kernel?
Beta Was this translation helpful? Give feedback.
All reactions