Skip to content

Commit

Permalink
Added a thunk for stream synchronization
Browse files Browse the repository at this point in the history
added emitter logic fot SyncOnStreamsThunk
  • Loading branch information
Tixxx committed Feb 1, 2024
1 parent 33ad1d8 commit e9c44f7
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 20 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ cc_library(
"//xla/service/gpu/runtime3:replica_id_thunk",
"//xla/service/gpu/runtime3:send_recv_thunk",
"//xla/service/gpu/runtime3:sequential_thunk",
"//xla/service/gpu/runtime3:wait_for_streams_thunk",
"//xla/service/gpu/runtime3:while_thunk",
"//xla/service/llvm_ir:buffer_assignment_util",
"//xla/service/llvm_ir:dynamic_update_slice_util",
Expand Down
48 changes: 45 additions & 3 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,26 @@ absl::Status MaybeRendezvousAfterInitialization(
const ServiceExecutableRunOptions* run_options,
std::atomic<int64_t>* thunks_initialized);

absl::flat_hash_set<ExecutionStreamId>
ExtractAdditionalComputeStreamIds(
const HloModule& module) {
absl::flat_hash_set<ExecutionStreamId> stream_ids;
for (const HloComputation* comp : module.computations()) {
for (const HloInstruction* hlo : comp->instructions()) {
if (hlo->has_backend_config() &&
hlo->backend_config<GpuBackendConfig>().ok()) {
int64_t op_queue_id = hlo->backend_config<GpuBackendConfig>()
.value()
.operation_queue_id();
if (op_queue_id > 0) {
stream_ids.insert(ExecutionStreamId(op_queue_id));
}
}
}
}
return stream_ids;
}

absl::Status ExecuteThunks(const std::string& module_name,
ModuleIdentifier module_id,
const ThunkSequence& thunk_sequence,
Expand All @@ -314,7 +334,9 @@ absl::Status ExecuteThunks(const std::string& module_name,
const BufferAllocations& buffer_allocations,
bool block_host_until_done,
bool use_highest_priority_for_async_stream,
std::atomic<int64_t>* thunks_initialized) {
std::atomic<int64_t>* thunks_initialized,
absl::flat_hash_set<ExecutionStreamId>
additional_compute_stream_ids) {
se::Stream* main_stream = run_options->stream();
se::StreamExecutor* executor = main_stream->parent();
stream_executor::StreamPriority stream_priority =
Expand Down Expand Up @@ -343,6 +365,23 @@ absl::Status ExecuteThunks(const std::string& module_name,
command_buffer_trace_stream = borrowed_command_buffer_trace_stream->get();
}

// Borrow stream for additional compute streams
absl::flat_hash_map<ExecutionStreamId, se::Stream*> additional_compute_streams;
if (additional_compute_stream_ids.size() > 0) {
int num_streams = additional_compute_stream_ids.size();
absl::StatusOr<std::vector<StreamPool::Ptr>> additional_streams =
run_options->BorrowStreams(executor->device_ordinal(), num_streams);
if (streams.ok()) {
int64_t i = 0;
for (auto& stream : additional_compute_stream_ids) {
additional_compute_streams[stream] =
additional_streams->at(i).get();
i++;
}
VLOG(2) << "Using " << num_streams << " additional compute streams.";
}
}

tsl::profiler::TraceMe hlo_module_activity(
[&] { return absl::StrCat(module_name, ":XLA GPU module"); },
tsl::profiler::TraceMeLevel::kInfo);
Expand Down Expand Up @@ -404,7 +443,7 @@ absl::Status ExecuteThunks(const std::string& module_name,
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
*run_options, buffer_allocations, main_stream,
command_buffer_trace_stream, async_comms_streams, &collective_params,
&collective_cliques);
&collective_cliques, additional_compute_streams);

for (const std::unique_ptr<Thunk>& thunk : thunk_sequence) {
// Annotate execution of this op if tracing was enabled when we started
Expand Down Expand Up @@ -987,8 +1026,11 @@ absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime(

// There isn't always an HLO module.
ModuleIdentifier unique_id = -1;
absl::flat_hash_set<ExecutionStreamId> additional_compute_stream_ids;
if (has_module()) {
unique_id = module().unique_id();
additional_compute_stream_ids =
ExtractAdditionalComputeStreamIds(module());
}

ScopedAnnotationAlways annotation([&]() -> ModuleAnnotation {
Expand All @@ -1011,7 +1053,7 @@ absl::Status GpuExecutable::ExecuteThunksOrXlaRuntime(
.debug_options()
.xla_gpu_enable_highest_priority_async_stream()
: false,
&thunks_initialized_);
&thunks_initialized_, additional_compute_stream_ids);
}

// Match IrEmitter's temp buffer allocation for kernel launches. See
Expand Down
67 changes: 65 additions & 2 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ limitations under the License.
#include "xla/service/gpu/runtime3/replica_id_thunk.h"
#include "xla/service/gpu/runtime3/send_recv_thunk.h"
#include "xla/service/gpu/runtime3/sequential_thunk.h"
#include "xla/service/gpu/runtime3/wait_for_streams_thunk.h"
#include "xla/service/gpu/runtime3/while_thunk.h"
#include "xla/service/gpu/thunk.h"
#include "xla/service/llvm_ir/buffer_assignment_util.h"
Expand Down Expand Up @@ -3641,6 +3642,37 @@ absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind,
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitWaitForStreamsThunk(
const HloInstruction* inst, GpuBackendConfig& gpu_config,
bool is_async_done) {
const HloInstruction* wrapped = inst->async_wrapped_instruction();
std::vector<ExecutionStreamId> wait_on_streams;
ExecutionStreamId source_stream_id = Thunk::GetMainComputeStreamId();
// If it's for an async done, then we need to sychronize on the execution
// stream of the instruction from main compute stream
if (is_async_done) {
wait_on_streams.push_back(
ExecutionStreamId(gpu_config.operation_queue_id()));
} else if (gpu_config.wait_on_operation_queues().size() == 0) {
// If wait on queue is empty, we just synchronize on the main compute
// stream from the execution stream.
wait_on_streams.push_back(Thunk::GetMainComputeStreamId());
source_stream_id = gpu_config.operation_queue_id();
} else {
// Else, we synchronize on all specified
// streams from the execution stream.
for (int64_t stream_id : gpu_config.wait_on_operation_queues()) {
wait_on_streams.push_back(ExecutionStreamId(stream_id));
}
source_stream_id = gpu_config.operation_queue_id();
}

AddThunkToThunkSequence(std::make_unique<WaitForStreamsThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(inst), source_stream_id,
wait_on_streams));
return absl::OkStatus();
}

absl::StatusOr<std::vector<ShapedSlice>> IrEmitterUnnested::GetShapedSlices(
mlir::Operation::operand_range operands) {
std::vector<ShapedSlice> shaped_slices;
Expand Down Expand Up @@ -4395,9 +4427,23 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
return EmitNcclAsyncDone(Thunk::kNcclReduceScatterDone, instr);
case HloOpcode::kAllToAll:
return EmitNcclAsyncDone(Thunk::kNcclAllToAllDone, instr);
default:
default: {
if (wrapped->has_backend_config()) {
TF_ASSIGN_OR_RETURN(
xla::gpu::GpuBackendConfig gpu_config,
wrapped->backend_config<xla::gpu::GpuBackendConfig>());
if (gpu_config.operation_queue_id() != 0) {
// If there an async-done instruction that wraps an instruction
// that runs on a non-default stream, then we will
// just emit syncOnStreamThunk().
return EmitWaitForStreamsThunk(instr, gpu_config,
/*is_async_done=*/true);
}
}

return Internal("Unsupported async done wrapped instruction: %s",
HloOpcodeString(wrapped->opcode()));
}
}
}
case HloOpcode::kAsyncStart: {
Expand All @@ -4415,9 +4461,26 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
return EmitNcclThunk<NcclAllToAllStartThunk, HloAllToAllInstruction>(
Thunk::kNcclAllToAll, instr, all_to_all, std::nullopt);
}
default:
default: {
if (wrapped->has_backend_config()) {
TF_ASSIGN_OR_RETURN(
xla::gpu::GpuBackendConfig gpu_config,
wrapped->backend_config<xla::gpu::GpuBackendConfig>());
if (gpu_config.operation_queue_id() != 0) {
// If there an async instruction that wraps an instruction
// that runs on a non-default stream, then we will
// emit syncOnStreamThunk(source=execution_stream,
// wait_on=main_compute_stream)
// then the thunk of wrapped instruction.
TF_RETURN_IF_ERROR(
EmitWaitForStreamsThunk(instr, gpu_config,
/*is_async_done=*/false));
return EmitHloInstruction(wrapped);
}
}
return Internal("Unsupported async start wrapped instruction: %s",
HloOpcodeString(wrapped->opcode()));
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ class IrEmitterUnnested : public IrEmitter {

absl::Status EmitNcclAsyncDone(Thunk::Kind kind, const HloInstruction* instr);

absl::Status EmitWaitForStreamsThunk(const HloInstruction* inst,
GpuBackendConfig& gpu_config,
bool is_async_done);
template <typename ThunkType, typename OpT>
absl::Status EmitReplicaOrPartitionId(mlir::Operation* op);
template <typename ThunkType>
Expand Down
13 changes: 13 additions & 0 deletions xla/service/gpu/runtime3/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ cc_library(
":memset_thunk",
":replica_id_thunk",
":sequential_thunk",
":wait_for_streams_thunk",
":while_thunk",
"//xla:status",
"//xla:statusor",
Expand Down Expand Up @@ -587,3 +588,15 @@ cc_library(
"@tsl//tsl/platform:logging",
],
)

cc_library(
name = "wait_for_streams_thunk",
srcs = ["wait_for_streams_thunk.cc"],
hdrs = ["wait_for_streams_thunk.h"],
deps = [
"//xla/service:global_device_id",
"//xla/service/gpu:thunk",
"@com_google_absl//absl/status",
"@tsl//tsl/platform:statusor",
],
)
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime3/command_buffer_cmd_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/service/gpu/runtime3/memset_thunk.h"
#include "xla/service/gpu/runtime3/replica_id_thunk.h"
#include "xla/service/gpu/runtime3/sequential_thunk.h"
#include "xla/service/gpu/runtime3/wait_for_streams_thunk.h"
#include "xla/service/gpu/runtime3/while_thunk.h"
#include "xla/service/gpu/thunk.h"
#include "xla/util.h"
Expand Down Expand Up @@ -226,6 +227,7 @@ static absl::Status AppendCommands(CommandBufferCmdSequence& cmd_sequence,
case Thunk::Kind::kNcclAllGatherDone:
case Thunk::Kind::kNcclAllReduceDone:
case Thunk::Kind::kNcclReduceScatterDone:
case Thunk::Kind::kWaitForStreams:
return absl::OkStatus();

default:
Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/runtime3/gemm_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ absl::Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) {
if (workspace_.has_value()) {
workspace = allocs.GetDeviceAddress(workspace_.value());
}
TF_ASSIGN_OR_RETURN(se::Stream* stream, GetStreamForExecution(
Thunk::execution_stream_id(),
params));

return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_),
allocs.GetDeviceAddress(rhs_buffer_),
allocs.GetDeviceAddress(output_buffer_), workspace,
deterministic_, params.stream);
deterministic_, stream);
}

absl::Status GemmThunk::Initialize(const InitializeParams& params) {
Expand Down
46 changes: 46 additions & 0 deletions xla/service/gpu/runtime3/wait_for_streams_thunk.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/runtime3/wait_for_streams_thunk.h"

#include <string>
#include <utility>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "xla/service/gpu/thunk.h"
#include "tsl/platform/errors.h"

namespace xla::gpu {

absl::Status WaitForStreamsThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(se::Stream* source_stream,
Thunk::GetStreamForExecution(source_stream_id_, params));

VLOG(5) << "Waiting for stream ids: " <<
absl::StrJoin(wait_on_stream_ids_, ", ",
[&](std::string* s, const ExecutionStreamId& stream_id) {
absl::StrAppend(s, stream_id.value());
});
for (const auto& stream_id : wait_on_stream_ids_) {
TF_ASSIGN_OR_RETURN(se::Stream* wait_on_stream,
Thunk::GetStreamForExecution(stream_id, params));

source_stream->ThenWaitFor(wait_on_stream);
}
return absl::OkStatus();
}

} // namespace xla::gpu
55 changes: 55 additions & 0 deletions xla/service/gpu/runtime3/wait_for_streams_thunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_RUNTIME3_WAIT_FOR_STREAMS_THUNK_H_
#define XLA_SERVICE_GPU_RUNTIME3_WAIT_FOR_STREAMS_THUNK_H_

#include <string>

#include "absl/status/status.h"
#include "xla/service/gpu/thunk.h"

namespace xla::gpu {

// This thunk
class WaitForStreamsThunk : public Thunk {
public:
WaitForStreamsThunk(ThunkInfo thunk_info, ExecutionStreamId source_stream_id,
std::vector<ExecutionStreamId> wait_on_stream_ids)
: Thunk(Kind::kWaitForStreams, thunk_info),
source_stream_id_(source_stream_id),
wait_on_stream_ids_(wait_on_stream_ids){};

WaitForStreamsThunk(const WaitForStreamsThunk&) = delete;
WaitForStreamsThunk& operator=(const WaitForStreamsThunk&) = delete;

const ExecutionStreamId& source_stream_id() const {
return source_stream_id_;
}

const std::vector<ExecutionStreamId>& wait_on_stream_ids() const {
return wait_on_stream_ids_;
}

absl::Status ExecuteOnStream(const ExecuteParams& params) override;

private:
ExecutionStreamId source_stream_id_;
std::vector<ExecutionStreamId> wait_on_stream_ids_;
};

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_RUNTIME3_WAIT_FOR_STREAMS_THUNK_H_
Loading

0 comments on commit e9c44f7

Please sign in to comment.