You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Instead of launching kernels one-by-one, cuda graphs allow users to build a graph of device operations (kernel launches, stream synchronization, etc…), and submit it for execution with just one call to cuda API, avoiding all overheads of device launches.
Cuda graphs can be constructed in two modes:
Implicit stream capture: instead of eagerly executing all commands submitted to a stream, they a recorded into a graph that can be later instantiated and executed (see cudaStreamBeginCapture)
Explicit graph building: it’s possible to build a graph explicitly by adding graph nodes and edges with a low level API.
The pros of implicit stream capture is that it can capture library calls (like convolution), that do not expose device kernels through the public API.
With explicit graph construction you have better visibility into what is actually launched on a stream. Proposed XLA + Cuda Graphs integration can support both modes. Currently only the implicit capture is implemented (because it was a trivial change).
XLA Runtime Intro
XLA runtime compiles Xla programs to executables that are essentially native functions executed on the host. These executables can export multiple functions that could be called “externally”. Currently XLA modules always have a single entry point, but with Xla runtime this restriction is relaxed, in theory it should be possible to compile XLA “program” like this one:
This program will be compiled to an executable that exports two functions to the user. We will rely on this feature to implement cuda graphs support.
##Exporting Graph Capture Functions
Xla when compiled with cuda graphs support in addition to exporting “main” function (with ordinal 0), will export special “graph capture” functions, that instead of submitting operations to a stream, will add them to an implicit graph that is being constructed when these functions are invoked.
A special “graph launch” custom call, will take operands of all graph nodes as an argument, and a reference to a “graph capture” function. It will construct a hash code from the arguments (shapes, pointers and launch dimensions) to keep a cache of constructed graphs, and on a hash match it will launch the graph on a compute stream.
module {
rt.export@mainordinal0rt.export@graph.captureordinal1// This function does not execute "add". Once it's lowered to runtime HAL it// adds an "add" command to the "command buffer" (cuda graph).func@graph.capture(%arg0:tensor<?xf32>, %arg1:tensor<?xf32>)
-> tensor<?xf32> {
%0 = mhlo.add%arg0, %arg1 : tensor<?xf32>
return%0 : tensor<?f32>
}
func@main(%arg0:tensor<?xf32>, %arg1:tensor<?xf32>) -> tensor<?xf32> {
// Xla runtime is responsible for capturing graphs, caching them, and// submitting to the compute stream. %0 = call@xla.graph.launch(%arg0, %arg1) { capture = @graph.capture }
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return%0 : tensor<?f32>
}
}
Currently @xla.graph.launch relies on implicit graph capture, it wraps the capture function invocation in begin-capture/end-capture calls. It should be possible to export the “graph construction” function, that takes cuda graph as an argument, and builds graphs explicitly, by connecting graph nodes with edges.
Implementation Status
This proposal is implemented as a minimal proof of concept, with a trivial “cuda graph outlining” pass, and minimum runtime support without any caching.
XLA_GPU_RUNTIME_USE_CUDA_GRAPHS=1 bazel run --config=cuda --test_env=XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=true third_party/tensorflow/compiler/xla/service/gpu/tests:mnist -- --alsologtostderr --suppress_failure_output --vmodule=graph_launch=1
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Cuda Graphs Intro
Link: https://developer.nvidia.com/blog/cuda-graphs/
Instead of launching kernels one-by-one, cuda graphs allow users to build a graph of device operations (kernel launches, stream synchronization, etc…), and submit it for execution with just one call to cuda API, avoiding all overheads of device launches.
Cuda graphs can be constructed in two modes:
The pros of implicit stream capture is that it can capture library calls (like convolution), that do not expose device kernels through the public API.
With explicit graph construction you have better visibility into what is actually launched on a stream. Proposed XLA + Cuda Graphs integration can support both modes. Currently only the implicit capture is implemented (because it was a trivial change).
XLA Runtime Intro
XLA runtime compiles Xla programs to executables that are essentially native functions executed on the host. These executables can export multiple functions that could be called “externally”. Currently XLA modules always have a single entry point, but with Xla runtime this restriction is relaxed, in theory it should be possible to compile XLA “program” like this one:
This program will be compiled to an executable that exports two functions to the user. We will rely on this feature to implement cuda graphs support.
##Exporting Graph Capture Functions
Xla when compiled with cuda graphs support in addition to exporting “main” function (with ordinal 0), will export special “graph capture” functions, that instead of submitting operations to a stream, will add them to an implicit graph that is being constructed when these functions are invoked.
A special “graph launch” custom call, will take operands of all graph nodes as an argument, and a reference to a “graph capture” function. It will construct a hash code from the arguments (shapes, pointers and launch dimensions) to keep a cache of constructed graphs, and on a hash match it will launch the graph on a compute stream.
Currently
@xla.graph.launch
relies on implicit graph capture, it wraps the capture function invocation in begin-capture/end-capture calls. It should be possible to export the “graph construction” function, that takes cuda graph as an argument, and builds graphs explicitly, by connecting graph nodes with edges.Implementation Status
This proposal is implemented as a minimal proof of concept, with a trivial “cuda graph outlining” pass, and minimum runtime support without any caching.
Commit: tensorflow/tensorflow@563c2ab
See prior discussion in: openxla/community#23 (comment)
Beta Was this translation helpful? Give feedback.
All reactions