Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom GPU ops #623

Closed
proteneer opened this issue Apr 17, 2019 · 24 comments
Closed

Custom GPU ops #623

proteneer opened this issue Apr 17, 2019 · 24 comments
Assignees
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@proteneer
Copy link
Contributor

I'm looking to implement custom GPU ops similar to how tensorflow allows for defining custom jvps. Is there a similar tutorial/guide on how feasible this will be with jax?

@mattjj
Copy link
Collaborator

mattjj commented Apr 17, 2019

Do you want to be able to call your own hand-written CUDA kernels, or instead do you want to be able to control how some of your functions act under transformations (like forward-mode or reverse-mode differentiation), even if they're just implemented in Python in terms of jax.numpy? Could you give an example?

@proteneer
Copy link
Contributor Author

proteneer commented Apr 17, 2019

Call my own hand-written CUDA kernels.

I have a function F(X(p), p) and I'm interested in using autodiff to compute the total derivative:

dF/dp = dF/dX * dX/dp + dF/dp, where dX/dp is the input grad.

Technically, dF/dx is an NxN symmetric hessian of some energy function E, and dF/dp is a second order mixed partial, the * operator calls into cuBLAS L3 symmetric GEMM to make everything super fast.

In reality, F=B(X(p),p) + N(X(p),p) is the result of multiple different types of forces summed together (bonded, non bonded, etc.). So I've written custom CUDA kernels for each:

dF/dp = dF/dX * dX/dp + dF/dp

can be expanded into

dF/dp = d(B+N)/dX * dX/dp + d(B+N)/dp
dF/dp = (dB/dX+dN/dX) * dX/dp + dB/dp + dN/dp <-- much faster than:
dF/dp = (dB/dX * dX/dp + dB/dp) + (dN/dX * dX/dp + dN/dp)

I have custom ops written for

dB/dX
dN/dX
dB/dp
dN/dp

I'd like to be able to re-use them in JAX as a custom op in conjunction with the GEMM.

So in JVP notation, the J is Jacobian of a sum of different energy functions (B, N, etc.), and the V is supplied dX/dp term.

For an example of what dN/dX and dN/dp look like

https://github.com/proteneer/timemachine/blob/master/timemachine/cpu_functionals/electrostatics.cuh

Note that these kernels are designed to compute both terms simultaneously for speed reasons. I'd be okay with separating them out if needed.

(There's a lot peculiarities in terms of optimizing the calculations to be fully-warp asynchronous by abusing the __shfl instrinsic to death).

@proteneer
Copy link
Contributor Author

To add: I'm currently using AD systems as a "reference" platform to check my highly optimized production code against. If it were possible to directly incorporate the kernels as custom ops then its one less set of code that I'd have to maintain.

@nottombrown
Copy link

nottombrown commented Apr 20, 2019

I would also like to be able to call hand-written CUDA kernels.

Say for example that I wanted to use some of Scott Gray's efficient blocksparse kernels (https://github.com/openai/blocksparse). Is there a story for how I would get that working in JAX?

@mattjj
Copy link
Collaborator

mattjj commented Apr 20, 2019

Not yet, but it's possible. We may need an XLA:GPU feature to be further developed though.

There's an XLA HLO named CustomCall that in principle allows jumping into custom code from an XLA computation. @hawkinsp used it with the XLA:CPU backend to set up special CPU-specific translation rules for linear algebra calls so that they jump into LAPACK code on CPU (and currently fall back to HLO-level implementations on GPU and TPU). That same technique works on CPU for jumping into custom cython or whatever. But my very limited understanding is that CustomCall is not yet sufficiently developed on XLA:GPU to support this kind of thing.

I believe we already have a feature request in with XLA:GPU to improve CustomCall so that we can jump into cuSOLVER / MAGMA routines for optimized linear algebra on GPU. I'm fuzzy on the details but I think that would also let JAX expose a way to stitch your custom GPU kernels into an XLA:GPU-compiled program. In the JAX layer you could also attach your own rules for the other transformations, like differentiation, so it'd compose nicely.

@jlebar and @hawkinsp understand this much better and so may be able to shed more light, but I suspect the bottom line is that CustomCall on XLA:GPU needs to be fleshed out, and that the hard-working XLA:GPU team is aware but is also balancing tons of other important work.

@mattjj
Copy link
Collaborator

mattjj commented Apr 20, 2019

By the way, all this development we're talking about is on open-source code, so if you know any crack GPU developers, contributions welcome! :)

@jlebar
Copy link
Contributor

jlebar commented Apr 22, 2019

Implementing proper custom ops in XLA:GPU would be pretty simple in theory (famous last words), but as @mattjj says, we have other higher-importance things on our plate at the moment. I would be happy to advise and review patches if anyone wanted to take this on.

I believe we already have a feature request in with XLA:GPU to improve CustomCall so that we can jump into cuSOLVER / MAGMA routines for optimized linear algebra on GPU.

I believe the plan of record for these specifically is to implement them in XLA itself, i.e. not using CustomCall. This way everyone who wants to use cuSOLVER doesn't have to reimplement these same custom ops.

@proteneer
Copy link
Contributor Author

In theory if you guys already support XLA:CPU, and we write our own CPU code that then calls into the various CUDA kernels - would that cause issues?

@jlebar
Copy link
Contributor

jlebar commented Apr 23, 2019 via email

@proteneer
Copy link
Contributor Author

A little more about our use case. We have an extremely optimized set of GPU kernels for doing physics simulations of molecules that took us many years to write:

https://github.com/pandegroup/openmm/tree/master/platforms/cuda/src/kernels
https://github.com/proteneer/timemachine/tree/master/timemachine/cpu_functionals

Some of these kernels themselves are actually also JITed (on the cuda source level, not PTX level), though we do it purely symbolically on a very restricted set of functional forms.

We'd love to be able to bring all that code into our JAX workflows somehow and it'd definitely help convince more of us old-school physics-types to get more involved with the project. Currently JAX/XLA is about 100x slower than hand-written CUDA code (for first and second order derivatives) in forward mode.

@mattjj
Copy link
Collaborator

mattjj commented May 10, 2019

The amazing @jlebar just landed this exciting commit in XLA:GPU:

tensorflow/tensorflow@acb84a0

@mattjj
Copy link
Collaborator

mattjj commented May 10, 2019

The docs for CustomCall are really nice.

@shoyer
Copy link
Collaborator

shoyer commented May 23, 2019

I opened #766 for the related issue of JIT for custom CPU ops.

@skye skye changed the title Custom ops Custom GPU ops May 23, 2019
@proteneer
Copy link
Contributor Author

Just pinging this again to see how hard it would be to be able to

  1. define a custom op and 2) define a custom defjvp (similar to how tensorflow can do this)

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 19, 2019

You can get custom jvps right now:
https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms

All the pieces for custom GPU ops are now there in principle since XLA now supports CustomCall on GPU, but we need to plug them together and expose them to users.

@proteneer
Copy link
Contributor Author

To clarify:

If I were to import custom python-wrapped C++ CPU/GPU code in a function decorated with custom_transforms, is the expectation that while they won't be JITTable, I should still be able to run them normally in op-by-op mode?

@proteneer
Copy link
Contributor Author

Is there a way to define jvps in a way that avoids calling the original f(x) code? I ask because I have a very expensive function where I compute the primal and tangent in one pass and would like to directly use it as opposed to going through a separate pass

@jax.custom_transforms
def f(x):
    print("calling f(x)")
    return np.sin(x ** 2)

def jvp_f(g, ans, x):
    print("calling jvp_f")
    # I have an expensive function here that also computes the primals in addition to the tangents
    return 8. * g + ans

jax.defjvp(f, jvp_f)

out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))

@proteneer
Copy link
Contributor Author

Nevermind - I keep posting comments where I find the right solution immediately afterwards (Sorry guys). Looks like I just needed to use the defjvp_all arg instead

@jax.custom_transforms
def f(x):
    print("calling f(x)")
    return np.sin(x ** 2)

def jvp_f(ps, ts):
    print("calling jvp_f")
    return np.sin(ps[0] ** 2), 8. * ts[0]

jax.defjvp_all(f, jvp_f)

# jax.defjvp(f, jvp_f)

out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))

print(out_primal)
# 0.4121185
print(out_tangent)
# 16.412119

@jekbradbury
Copy link
Contributor

To clarify:

If I were to import custom python-wrapped C++ CPU/GPU code in a function decorated with custom_transforms, is the expectation that while they won't be JITTable, I should still be able to run them normally in op-by-op mode?

If you do this using your own Python wrappers and FFI, that would be the case. But if you plug into the XLA CustomCall infrastructure, there should be a way to make it work even under @jit (though it will likely require a some new code in JAX itself).

@proteneer
Copy link
Contributor Author

Thank you for the clarification. I realized that I actually ended up implementing not jvps in my custom op, but rather vmap/batched jvps over the parameters (for computational efficiency purposes), so it will take a little more work for me to get this to work within the jax ecosystem.

@hawkinsp hawkinsp mentioned this issue Aug 1, 2019
@hawkinsp
Copy link
Collaborator

hawkinsp commented Aug 2, 2019

There's no documentation for this yet, but:
https://github.com/google/jax/blob/master/jaxlib/cusolver.py
https://github.com/google/jax/blob/master/jaxlib/cusolver.cc
and
https://www.tensorflow.org/xla/custom_call#custom-call_on_gpu
should be most of the information you need for defining a custom GPU op because the Cusolver ops there are implemented as custom-call ops.

i.e., all the pieces are now there; all that is missing is now documentation.

@proteneer
Copy link
Contributor Author

Awesome thanks - I'll take a look

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 10, 2022
@hawkinsp hawkinsp added the enhancement New feature or request label Aug 12, 2022
@sharadmv sharadmv self-assigned this Sep 29, 2022
@sharadmv
Copy link
Collaborator

sharadmv commented Oct 3, 2022

#12632 proposes a new JAX API for custom ops. Anyone interested should give us feedback!

@hawkinsp hawkinsp assigned dfm and unassigned sharadmv Jun 3, 2024
@dfm
Copy link
Collaborator

dfm commented Aug 2, 2024

Now that #21925 has been merged we have a recommended/documented workflow for this: https://jax.readthedocs.io/en/latest/ffi.html

@dfm dfm closed this as completed Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

10 participants