Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU SPMD] Add runtime support to run windowed einsum in multiple streams #8707

Closed
wants to merge 2 commits into from

Conversation

Tixxx
Copy link
Contributor

@Tixxx Tixxx commented Jan 23, 2024

This PR contains the runtime changes to be able to run windowed einsum in multiple cuda streams.
This pr follows(#7854) which adds stream attributes to the HLO graph.
We take the stream attributes and dispatch corresponding kernels to separate cuda streams.

We do this by wrapping the kernel with an asyncStartDone pair and non-default stream id.
The emitter will emit a SyncOnStreamsThunk and then the kernel's thunk for AsyncStart. For asyncStartDone, it will just emit SyncOnStreamsThunk.

Detailed discussion here.

@Tixxx Tixxx requested a review from jurahul January 23, 2024 00:49
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 23, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 23, 2024
@cheshire
Copy link
Contributor

Could you add a bit more context/benchmarks to the commit description on what workflows does this make faster?

@ezhulenev
Copy link
Member

With CUDA graphs this should happen automatically without any need to add user annotations to operations, xla_gpu_graph_enable_concurrent_region flag controls it, today it's disabled by default.

I'd suggest to try it first to see if you get any perf improvements from that with graphs.

In general it's ok to add execution stream to thunk (but without AsyncExecutors) and instead add new thunk:

// Define strongly typed alias (see GlobalDeviceId)
ExecutionThreadId id;

// dst threads waits for completion of all work items launched on src threads
class JoinExecutionThreadsThunk(absl::Span<const ExecutionThreadId> dsts, absl::Span<const ExecutionThreadId> srcs);

for (auto dst : dsts) {
  for (auto src : srcs) {
    GetStream(dst).WaitFor(GetStream(src));
  }
}


And in ir_emitter_unnested detect when execution threads should be joined, e.g. in you example extra compute streams shpuld be joined with a main one before the untolled loop, and then joined back after the loop.

@Tixxx
Copy link
Contributor Author

Tixxx commented Jan 23, 2024

xla_gpu_graph_enable_concurrent_region

Yes, I have experimented with this before. But it doesnt give as much perf benefit as specifically running kernels in separate streams.
For gemms of these lhs and rhs sizes:
bf16[1024,24576]{1,0} @ bf16[24576,24576]{1,0}
Using the multi-stream mechanism in this pr, total runtime of back-to-back gemms with overlapping is 8.2 ms
with concurrent region cuda graph, total runtime is 9.56 ms

I'm trying to understand the new thunk mechanism, are you proposing to use JoinExecutionThreadsThunk to wrap the execution of thunk sequence within the while thunk or is that just a synchronization scheme in the graph?

Maybe a pseudo hlo graph could also be helpful.

@ezhulenev
Copy link
Member

Basically the proposal is:

Thunk::ExecutionThreadId id;

ThunkInfo {
  Thunk::ExecutionThreadId id;
}

// Thunk implementation should be execution-thread-aware
Thunk::ExecuteOnStream(ExecuteParams params) {
   GetStream(execution_thread_id)->Launch(...);
}

In ir_emitter_unnested

%0 = custom-call(gemm), exec_thread_id=1
%1 = custom-call(gemm), exec_thread_id=2

emits

{0} - "main" thread

JoinExecutionThreadsThunk(1, {0});
GemmThunk(exec_id=1);
JoinExecutionThreadsThunk(2, {0})
GemmThunk(exec_id=2);
JoinExecutionThreadsThunk(0, {1, 2});

The tricky part is how to tell buffer assignemtn not to reuse buffers:

%0 = (result0, s8[1024]) custom-call(gemm), exec_id=1
%1 = (result1, s8[1024]) custom-call(gemm), exec_id=2

BufferAssignment will assign the same buffer for scratch allocation and it is unsafe to run these gemms in parallel.

VLOG level of 100 for gpu_command_buffer will print what CUDA graphs were constructed, maybe you don't get concurrency there for some reasons? Because of scratch allocation?

@Tixxx
Copy link
Contributor Author

Tixxx commented Jan 24, 2024

Basically the proposal is:

Thunk::ExecutionThreadId id;

ThunkInfo {
  Thunk::ExecutionThreadId id;
}

// Thunk implementation should be execution-thread-aware
Thunk::ExecuteOnStream(ExecuteParams params) {
   GetStream(execution_thread_id)->Launch(...);
}

In ir_emitter_unnested

%0 = custom-call(gemm), exec_thread_id=1
%1 = custom-call(gemm), exec_thread_id=2

emits

{0} - "main" thread

JoinExecutionThreadsThunk(1, {0});
GemmThunk(exec_id=1);
JoinExecutionThreadsThunk(2, {0})
GemmThunk(exec_id=2);
JoinExecutionThreadsThunk(0, {1, 2});

The tricky part is how to tell buffer assignemtn not to reuse buffers:

%0 = (result0, s8[1024]) custom-call(gemm), exec_id=1
%1 = (result1, s8[1024]) custom-call(gemm), exec_id=2

BufferAssignment will assign the same buffer for scratch allocation and it is unsafe to run these gemms in parallel.

VLOG level of 100 for gpu_command_buffer will print what CUDA graphs were constructed, maybe you don't get concurrency there for some reasons? Because of scratch allocation?

Ah ok, so we basically make the async executor into a thunk to be more explicit. I think i can give this a try. I think for the buffer assignment, we can tell it to re-use buffer only for kernels with the same exe_id?

As for the concurrent region, I dumped the graph, it looks like each of the gemm has a graph instance. looks like the concurrent region wont work with scratch allocations

@ezhulenev
Copy link
Member

Proper XLA solution at HLO level would be to put async-start and async-done around gemm custom call (dot instruction). At Thunk level ExecutionThreadId looks like a reasonable abstraction. With async wrappers buffer assignment will work out of the box.

@ezhulenev
Copy link
Member

The correct XLA-way of doing this would be something like this:

%gemm0 {
  ROOT gemm0 = custom-call(...)
}

%gemm1 {
  ROOT gemm1 = custom-call(...)
}

ENTRY main {

 %gemm0-start = async-start(), called=%gemm0
 %gemm1-start = async-start(), called=%gemm1

 ...
 %gemm0 = async-done(%gemm0-start)
 %gemm1 = async-done(%gemm1-start)
}

This IR has syntactic sugar in XLA, and will be printed like:

ENTRY main {

 %gemm0-start = custom-call-start(...)
 %gemm1-start = custom-call-start(...)

 ...
 %gemm0 = async-done(%gemm0-start)
 %gemm1 = async-done(%gemm1-start)
}

With this representation you can assign execution threads to individual gemms, and at IR emitter stage we'll have lowering for AsyncStart of CustomCall operation like:

thunks.add(JoinExecutionThread(... join gemm thread with "main"))
thunks.add(GemmThunk("fooo"))

and AsyncDone will be lowered to

thunks.add(JoinExecutionThread(... join "main" with gemm execution thread)

This is the way XLA is going with async instructions and this will have out of the box support by buffer assignment. See more details here: https://github.com/openxla/xla/blob/main/docs/async_ops.md

Stream ids attributes in backend configs are hacks that will require even more hacks to make buffer assignment work propertly. With correct async representation it will be lowered to correct thunk sequence and also will have correct lowering to commands (CUDA graphs) with proper DAG.

@Tixxx
Copy link
Contributor Author

Tixxx commented Jan 26, 2024

Could you add a bit more context/benchmarks to the commit description on what workflows does this make faster?

I have created a separate discussion on openxla forum here to describe the motivation and design.
As for benchmarking, the preliminary number achieved on a allgather+gemm pattern unit test is ~10% speedup. The gemm has the same size used in 175B gpt3 model.

@Tixxx
Copy link
Contributor Author

Tixxx commented Jan 26, 2024

The correct XLA-way of doing this would be something like this:

%gemm0 {
  ROOT gemm0 = custom-call(...)
}

%gemm1 {
  ROOT gemm1 = custom-call(...)
}

ENTRY main {

 %gemm0-start = async-start(), called=%gemm0
 %gemm1-start = async-start(), called=%gemm1

 ...
 %gemm0 = async-done(%gemm0-start)
 %gemm1 = async-done(%gemm1-start)
}

This IR has syntactic sugar in XLA, and will be printed like:

ENTRY main {

 %gemm0-start = custom-call-start(...)
 %gemm1-start = custom-call-start(...)

 ...
 %gemm0 = async-done(%gemm0-start)
 %gemm1 = async-done(%gemm1-start)
}

With this representation you can assign execution threads to individual gemms, and at IR emitter stage we'll have lowering for AsyncStart of CustomCall operation like:

thunks.add(JoinExecutionThread(... join gemm thread with "main"))
thunks.add(GemmThunk("fooo"))

and AsyncDone will be lowered to

thunks.add(JoinExecutionThread(... join "main" with gemm execution thread)

This is the way XLA is going with async instructions and this will have out of the box support by buffer assignment. See more details here: https://github.com/openxla/xla/blob/main/docs/async_ops.md

Stream ids attributes in backend configs are hacks that will require even more hacks to make buffer assignment work propertly. With correct async representation it will be lowered to correct thunk sequence and also will have correct lowering to commands (CUDA graphs) with proper DAG.

I have included your suggestion in the discussion here. I think the AsyncExecutor is deprecated so I removed that alternative. I have the basic functionality working for the HLO graph part, trying to get the async thunks working to run it.

@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 391ff8a to 11136d5 Compare January 31, 2024 07:37
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 31, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 31, 2024
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 11136d5 to bf1341f Compare January 31, 2024 07:46
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 31, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 31, 2024
@Tixxx Tixxx requested a review from ezhulenev January 31, 2024 22:00
Copy link
Member

@ezhulenev ezhulenev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add notes for all namespaces, but with c++17 we prefer namespace xla::gpu for new code

xla/service/gpu/thunk.h Outdated Show resolved Hide resolved
xla/service/gpu/thunk.h Outdated Show resolved Hide resolved
xla/service/gpu/thunk.h Outdated Show resolved Hide resolved
return params.stream;
}
auto iter = params.additional_compute_streams.find(stream_id);
CHECK(iter != params.additional_compute_streams.end());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new checks are strongly discouraged in XLA, absl::StatusOr<se::Stream*> is preferred result type here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i changed it to return an invalidArg error instead of a CHECK.

xla/service/gpu/runtime3/sync_on_streams_thunk.h Outdated Show resolved Hide resolved
xla/service/gpu/runtime3/sync_on_streams_thunk.cc Outdated Show resolved Hide resolved
xla/service/gpu/runtime3/sync_on_streams_thunk.cc Outdated Show resolved Hide resolved
xla/service/gpu/runtime3/sync_on_streams_thunk.h Outdated Show resolved Hide resolved
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from bf1341f to e9c44f7 Compare February 1, 2024 06:28
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 1, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 1, 2024
@Tixxx Tixxx requested a review from ezhulenev February 1, 2024 19:02
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 81b450f to 2a87790 Compare February 2, 2024 05:30
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@Tixxx Tixxx requested a review from ezhulenev February 2, 2024 05:30
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 2a87790 to 81119b3 Compare February 2, 2024 20:03
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 81119b3 to 1fdad5c Compare February 2, 2024 22:30
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 2, 2024
@ezhulenev
Copy link
Member

There are few more build time errorrs:

In file included from [third_party/tensorflow/compiler/xla/service/gpu/thunk.cc:16](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/thunk.cc?l=16&ws=tap-presubmit-server/154060694&snapshot=2):
[./third_party/tensorflow/compiler/xla/service/gpu/thunk.h:46](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/thunk.h?l=46&ws=tap-presubmit-server/154060694&snapshot=2):10: error: module //third_party/tensorflow/compiler/xla/service/gpu:thunk does not depend on a module exporting 'third_party/tensorflow/tsl/lib/gtl/int_type.h'
see [http://go/cpp-features#layering_check](https://www.google.com/url?q=http://go/cpp-features%23layering_check&sa=D); to fix run:
	build_cleaner //third_party/tensorflow/compiler/xla/service/gpu:thunk
   46 | #include "third_party/tensorflow/tsl/lib/gtl/int_type.h"
      |          ^
[third_party/tensorflow/compiler/xla/service/gpu/thunk.cc:37](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/thunk.cc?l=37&ws=tap-presubmit-server/154060694&snapshot=2):10: error: module //third_party/tensorflow/compiler/xla/service/gpu:thunk does not directly depend on a module exporting 'third_party/tensorflow/compiler/xla/service/gpu/backend_configs.proto.h', which is part of indirectly-used module blaze-out/k8-opt-cuda12/bin/third_party/tensorflow/compiler/xla/service/gpu/backend_configs.proto.h
see [http://go/cpp-features#layering_check](https://www.google.com/url?q=http://go/cpp-features%23layering_check&sa=D); to fix run:
	build_cleaner //third_party/tensorflow/compiler/xla/service/gpu:thunk
   37 | #include "third_party/tensorflow/compiler/xla/service/gpu/backend_configs.proto.h"
      |          ^
[third_party/tensorflow/compiler/xla/service/gpu/thunk.cc:189](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/thunk.cc?l=189&ws=tap-presubmit-server/154060694&snapshot=2):7: error: field 'recv_device_memory_function' will be initialized after field 'additional_compute_streams' [-Werror,-Wreorder-ctor]
[third_party/tensorflow/compiler/xla/service/gpu/runtime3/wait_for_streams_thunk.cc:22](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/runtime3/wait_for_streams_thunk.cc?l=22&ws=tap-presubmit-server/154060694&snapshot=2):10: error: module //third_party/tensorflow/compiler/xla/service/gpu/runtime3:wait_for_streams_thunk does not directly depend on a module exporting 'third_party/absl/strings/str_cat.h', which is part of indirectly-used module third_party/absl/strings/str_cat.h
see [http://go/cpp-features#layering_check](https://www.google.com/url?q=http://go/cpp-features%23layering_check&sa=D); to fix run:
	build_cleaner //third_party/tensorflow/compiler/xla/service/gpu/runtime3:wait_for_streams_thunk
   22 | #include "third_party/absl/strings/str_cat.h"
      |          ^
[third_party/tensorflow/compiler/xla/service/gpu/runtime3/wait_for_streams_thunk.cc:24](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/runtime3/wait_for_streams_thunk.cc?l=24&ws=tap-presubmit-server/154060694&snapshot=2):10: error: module //third_party/tensorflow/compiler/xla/service/gpu/runtime3:wait_for_streams_thunk does not depend on a module exporting 'third_party/tensorflow/tsl/platform/errors.h'
see [http://go/cpp-features#layering_check](https://www.google.com/url?q=http://go/cpp-features%23layering_check&sa=D); to fix run:
	build_cleaner //third_party/tensorflow/compiler/xla/service/gpu/runtime3:wait_for_streams_thunk
   24 | #include "third_party/tensorflow/tsl/platform/errors.h"
[third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc:3800](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc?l=3800&ws=tap-presubmit-server/154060694&snapshot=2):25: error: unused variable 'wrapped' [-Werror,-Wunused-variable]
 3800 |   const HloInstruction* wrapped = inst->async_wrapped_instruction();

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 3, 2024
@Tixxx
Copy link
Contributor Author

Tixxx commented Feb 3, 2024

tion* wrapped = inst->async_wrapped_instruction

hopefully the latest commit will fix them.

@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 3, 2024
@Tixxx Tixxx force-pushed the tixxx/collective_matmul_runtime branch from 725ab0f to 70068f9 Compare February 4, 2024 05:42
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Feb 4, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Feb 4, 2024
@copybara-service copybara-service bot closed this in 490c2f0 Feb 5, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Feb 5, 2024
…m in multiple streams

Imported from GitHub PR openxla/xla#8707

This PR contains the runtime changes to be able to run windowed einsum in multiple cuda streams.
This pr follows(openxla/xla#7854) which adds stream attributes to the HLO graph.
We take the stream attributes and dispatch corresponding kernels to separate cuda streams.

We do this by wrapping the kernel with an asyncStartDone pair and non-default stream id.
The emitter will emit a SyncOnStreamsThunk and then the kernel's thunk for AsyncStart. For asyncStartDone, it will just emit SyncOnStreamsThunk.

Detailed discussion [here](openxla/xla#8865).
Copybara import of the project:

--
22320215797434b0f507aeb93872c174cd4499c5 by TJ <[email protected]>:

Added a thunk for stream synchronization
added emitter logic fot SyncOnStreamsThunk

--
70068f9fc81343a2e0c2b48d05ad98df1c9d70ab by TJ <[email protected]>:

attempting to fix build errors

Merging this change closes #8707

PiperOrigin-RevId: 604215385
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants