-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom GPU ops #623
Comments
Do you want to be able to call your own hand-written CUDA kernels, or instead do you want to be able to control how some of your functions act under transformations (like forward-mode or reverse-mode differentiation), even if they're just implemented in Python in terms of jax.numpy? Could you give an example? |
Call my own hand-written CUDA kernels. I have a function
Technically, In reality,
can be expanded into
I have custom ops written for
I'd like to be able to re-use them in JAX as a custom op in conjunction with the GEMM. So in JVP notation, the J is Jacobian of a sum of different energy functions (B, N, etc.), and the V is supplied dX/dp term. For an example of what Note that these kernels are designed to compute both terms simultaneously for speed reasons. I'd be okay with separating them out if needed. (There's a lot peculiarities in terms of optimizing the calculations to be fully-warp asynchronous by abusing the __shfl instrinsic to death). |
To add: I'm currently using AD systems as a "reference" platform to check my highly optimized production code against. If it were possible to directly incorporate the kernels as custom ops then its one less set of code that I'd have to maintain. |
I would also like to be able to call hand-written CUDA kernels. Say for example that I wanted to use some of Scott Gray's efficient blocksparse kernels (https://github.com/openai/blocksparse). Is there a story for how I would get that working in JAX? |
Not yet, but it's possible. We may need an XLA:GPU feature to be further developed though. There's an XLA HLO named CustomCall that in principle allows jumping into custom code from an XLA computation. @hawkinsp used it with the XLA:CPU backend to set up special CPU-specific translation rules for linear algebra calls so that they jump into LAPACK code on CPU (and currently fall back to HLO-level implementations on GPU and TPU). That same technique works on CPU for jumping into custom cython or whatever. But my very limited understanding is that CustomCall is not yet sufficiently developed on XLA:GPU to support this kind of thing. I believe we already have a feature request in with XLA:GPU to improve CustomCall so that we can jump into cuSOLVER / MAGMA routines for optimized linear algebra on GPU. I'm fuzzy on the details but I think that would also let JAX expose a way to stitch your custom GPU kernels into an XLA:GPU-compiled program. In the JAX layer you could also attach your own rules for the other transformations, like differentiation, so it'd compose nicely. @jlebar and @hawkinsp understand this much better and so may be able to shed more light, but I suspect the bottom line is that CustomCall on XLA:GPU needs to be fleshed out, and that the hard-working XLA:GPU team is aware but is also balancing tons of other important work. |
By the way, all this development we're talking about is on open-source code, so if you know any crack GPU developers, contributions welcome! :) |
Implementing proper custom ops in XLA:GPU would be pretty simple in theory (famous last words), but as @mattjj says, we have other higher-importance things on our plate at the moment. I would be happy to advise and review patches if anyone wanted to take this on.
I believe the plan of record for these specifically is to implement them in XLA itself, i.e. not using CustomCall. This way everyone who wants to use cuSOLVER doesn't have to reimplement these same custom ops. |
In theory if you guys already support XLA:CPU, and we write our own CPU code that then calls into the various CUDA kernels - would that cause issues? |
would that cause issues?
Yes, because a single XLA computation runs on one type of device -- either
CPU or a particular GPU model -- so you wouldn't be able to mix them. Also
all of the inputs to a CPU custom op would need to be in CPU memory.
…On Tue, Apr 23, 2019 at 7:09 AM Yutong Zhao ***@***.***> wrote:
In theory if you guys already support XLA:CPU, and we write our own CPU
code that then calls into the various CUDA kernels - would that cause
issues?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#623 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AABEZBYLLFI5ONQCE6ARIYLPR4KCRANCNFSM4HGYAUDQ>
.
|
A little more about our use case. We have an extremely optimized set of GPU kernels for doing physics simulations of molecules that took us many years to write: https://github.com/pandegroup/openmm/tree/master/platforms/cuda/src/kernels Some of these kernels themselves are actually also JITed (on the cuda source level, not PTX level), though we do it purely symbolically on a very restricted set of functional forms. We'd love to be able to bring all that code into our JAX workflows somehow and it'd definitely help convince more of us old-school physics-types to get more involved with the project. Currently JAX/XLA is about 100x slower than hand-written CUDA code (for first and second order derivatives) in forward mode. |
The amazing @jlebar just landed this exciting commit in XLA:GPU: |
The docs for CustomCall are really nice. |
I opened #766 for the related issue of JIT for custom CPU ops. |
Just pinging this again to see how hard it would be to be able to
|
You can get custom All the pieces for custom GPU ops are now there in principle since XLA now supports |
To clarify: If I were to import custom python-wrapped C++ CPU/GPU code in a function decorated with custom_transforms, is the expectation that while they won't be JITTable, I should still be able to run them normally in op-by-op mode? |
Is there a way to define jvps in a way that avoids calling the original f(x) code? I ask because I have a very expensive function where I compute the primal and tangent in one pass and would like to directly use it as opposed to going through a separate pass @jax.custom_transforms
def f(x):
print("calling f(x)")
return np.sin(x ** 2)
def jvp_f(g, ans, x):
print("calling jvp_f")
# I have an expensive function here that also computes the primals in addition to the tangents
return 8. * g + ans
jax.defjvp(f, jvp_f)
out_primal, out_tangent = jax.jvp(f, (3.,), (2.,)) |
Nevermind - I keep posting comments where I find the right solution immediately afterwards (Sorry guys). Looks like I just needed to use the defjvp_all arg instead @jax.custom_transforms
def f(x):
print("calling f(x)")
return np.sin(x ** 2)
def jvp_f(ps, ts):
print("calling jvp_f")
return np.sin(ps[0] ** 2), 8. * ts[0]
jax.defjvp_all(f, jvp_f)
# jax.defjvp(f, jvp_f)
out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
print(out_primal)
# 0.4121185
print(out_tangent)
# 16.412119 |
If you do this using your own Python wrappers and FFI, that would be the case. But if you plug into the XLA CustomCall infrastructure, there should be a way to make it work even under |
Thank you for the clarification. I realized that I actually ended up implementing not jvps in my custom op, but rather vmap/batched jvps over the parameters (for computational efficiency purposes), so it will take a little more work for me to get this to work within the jax ecosystem. |
There's no documentation for this yet, but: i.e., all the pieces are now there; all that is missing is now documentation. |
Awesome thanks - I'll take a look |
#12632 proposes a new JAX API for custom ops. Anyone interested should give us feedback! |
Now that #21925 has been merged we have a recommended/documented workflow for this: https://jax.readthedocs.io/en/latest/ffi.html |
I'm looking to implement custom GPU ops similar to how tensorflow allows for defining custom jvps. Is there a similar tutorial/guide on how feasible this will be with jax?
The text was updated successfully, but these errors were encountered: