diff --git a/xla/service/cpu/runtime/kernel_thunk.cc b/xla/service/cpu/runtime/kernel_thunk.cc index 88847041c3b51f..23a9537ee231c6 100644 --- a/xla/service/cpu/runtime/kernel_thunk.cc +++ b/xla/service/cpu/runtime/kernel_thunk.cc @@ -15,19 +15,19 @@ limitations under the License. #include "xla/service/cpu/runtime/kernel_thunk.h" -#include - #define EIGEN_USE_THREADS #include +#include #include #include #include #include +#include #include +#include "absl/base/attributes.h" #include "absl/base/optimization.h" -#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" @@ -51,50 +51,109 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { +namespace internal { -absl::StatusOr> KernelThunk::Create( - Info info, absl::Span arguments_buffers, +// Checks that all buffers are aligned to the minimum alignment. We codegen +// with the assumption that all buffers are aligned, and if they are not, we +// will crash with a segmentation fault, or worse, produce incorrect results. +static absl::Status CheckBufferAlignment( + const Thunk::Info& info, uint64_t min_alignment, + absl::Span kernel_args) { + if (min_alignment == 0) return absl::OkStatus(); + + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); + if (ABSL_PREDICT_FALSE((ptr & (min_alignment - 1)) != 0)) { + return Internal( + "Host kernel %s buffer argument #%d (%p) is not aligned to a " + "required minimum alignment of %d bytes", + info.op_name, i, kernel_args[i].data, min_alignment); + } + } + + return absl::OkStatus(); +} + +// VLOGs kernel arguments resolved from the buffer allocations. +static void VlogKernelArgs( + absl::Span arguments_buffers, absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) { - if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { - return Internal("Host kernel %s minimum alignment %d is not a power of 2", - info.op_name, *min_alignment); + absl::Span kernel_args) { + for (int64_t i = 0; i < arguments_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, + arguments_buffers[i].ToString(), + kernel_args[i].data); } + for (int64_t i = 0; i < results_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " res #%d: %s (%p)", i, results_buffers[i].ToString(), + kernel_args[arguments_buffers.size() + i].data); + } +} - return absl::WrapUnique( - new KernelThunk(std::move(info), arguments_buffers, results_buffers, - std::move(kernel_name), thread_dim, min_alignment)); +// Returns kernel buffer uses for a given arguments and results buffers. +static Thunk::BufferUses KernelBufferUses( + absl::Span arguments_buffers, + absl::Span results_buffers) { + Thunk::BufferUses buffer_uses; + for (const BufferAllocation::Slice& buffer : arguments_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kRead); + } + for (const BufferAllocation::Slice& buffer : results_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kWrite); + } + return buffer_uses; } -KernelThunk::KernelThunk( +template +KernelThunk::KernelThunk( Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment) : Thunk(Kind::kKernel, std::move(info)), - arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()), - results_buffers_(results_buffers.begin(), results_buffers.end()), num_kernel_args_(arguments_buffers.size() + results_buffers.size()), kernel_name_(std::move(kernel_name)), thread_dim_(thread_dim), min_alignment_(min_alignment), call_once_(thread_dim_ == se::ThreadDim()), kernel_ptr_(nullptr) { + // Resize storage for arguments and results buffers if it is dynamic. + if constexpr (IsDynamic(num_arguments)) { + arguments_buffers_.resize(arguments_buffers.size()); + } + if constexpr (IsDynamic(num_results)) { + results_buffers_.resize(results_buffers.size()); + } + + // Copy buffers from the arguments and results. + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + arguments_buffers_[i] = arguments_buffers[i]; + } + for (size_t i = 0; i < results_buffers.size(); ++i) { + results_buffers_[i] = results_buffers[i]; + } + + // Resize storage for kernel arguments if it is dynamic. + if constexpr (IsDynamic(num_arguments) || IsDynamic(num_results)) { + kernel_args_.resize(num_kernel_args_); + } + // Initialize kernel arguments with null pointers and known buffer sizes. // We'll use them as a template to resolve buffer addresses at run time. - kernel_args_.reserve(num_kernel_args_); - for (const BufferAllocation::Slice& buffer : arguments_buffers_) { - kernel_args_.emplace_back( - SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + kernel_args_[i] = SE_HOST_KernelArg{ + nullptr, static_cast(arguments_buffers_[i].size())}; } - for (const BufferAllocation::Slice& buffer : results_buffers_) { - kernel_args_.emplace_back( - SE_HOST_KernelArg{nullptr, static_cast(buffer.size())}); + for (size_t i = 0; i < results_buffers.size(); ++i) { + kernel_args_[arguments_buffers_.size() + i] = SE_HOST_KernelArg{ + nullptr, static_cast(results_buffers_[i].size())}; } } -tsl::AsyncValueRef KernelThunk::Execute( +template +ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef +KernelThunk::ExecuteInternal( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); @@ -104,7 +163,7 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args = kernel_args_; + KernelArgs kernel_args = kernel_args_; SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); const BufferAllocations* allocations = params.buffer_allocations; @@ -130,12 +189,13 @@ tsl::AsyncValueRef KernelThunk::Execute( } if (ABSL_PREDICT_FALSE(VLOG_IS_ON(3))) { - VlogKernelArgs(kernel_args); + VlogKernelArgs(arguments_buffers_, results_buffers_, kernel_args); } // Сheck that all resolved buffers are properly aligned. if constexpr (ShouldCheckBufferSlices()) { - TF_RETURN_IF_ERROR(CheckBufferAlignment(kernel_args)); + TF_RETURN_IF_ERROR( + CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args)); } // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk @@ -173,45 +233,68 @@ tsl::AsyncValueRef KernelThunk::Execute( return OkExecuteEvent(); } -absl::Status KernelThunk::CheckBufferAlignment( - absl::Span kernel_args) { - if (min_alignment_.has_value()) { - for (int64_t i = 0; i < num_kernel_args_; ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); - if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { - return Internal( - "Host kernel %s buffer argument #%d (%p) is not aligned to a " - "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); - } - } - } - return absl::OkStatus(); +template +KernelThunk::BufferUses +KernelThunk::buffer_uses() const { + return KernelBufferUses(arguments_buffers_, results_buffers_); } -void KernelThunk::VlogKernelArgs( - absl::Span kernel_args) { - for (int64_t i = 0; i < arguments_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, - arguments_buffers_[i].ToString(), - kernel_args[i].data); - } - for (int64_t i = 0; i < results_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat( - " res #%d: %s (%p)", i, results_buffers_[i].ToString(), - kernel_args[arguments_buffers_.size() + i].data); - } +} // namespace internal + +tsl::AsyncValueRef KernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); } -KernelThunk::BufferUses KernelThunk::buffer_uses() const { - BufferUses buffer_uses; - for (const BufferAllocation::Slice& buffer : arguments_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kRead); - } - for (const BufferAllocation::Slice& buffer : results_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kWrite); +template +tsl::AsyncValueRef +SmallKernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); +} + +absl::StatusOr> KernelThunk::Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment) { + if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { + return Internal("Host kernel %s minimum alignment %d is not a power of 2", + info.op_name, *min_alignment); } - return buffer_uses; + + auto make_small_kernel_thunk = [&](auto num_arguments, auto num_results) { + return absl::WrapUnique( + new SmallKernelThunk( + std::move(info), arguments_buffers, results_buffers, + std::move(kernel_name), thread_dim, min_alignment)); + }; + + static constexpr auto _0 = std::integral_constant{}; + static constexpr auto _1 = std::integral_constant{}; + static constexpr auto _2 = std::integral_constant{}; + static constexpr auto _3 = std::integral_constant{}; + static constexpr auto _4 = std::integral_constant{}; + static constexpr auto _5 = std::integral_constant{}; + static constexpr auto _6 = std::integral_constant{}; + + std::pair params(arguments_buffers.size(), + results_buffers.size()); + + // Return SmallKernelThunk specializations for the most common cases. + if (params == std::make_pair(_0, _1)) return make_small_kernel_thunk(_0, _1); + if (params == std::make_pair(_1, _1)) return make_small_kernel_thunk(_1, _1); + if (params == std::make_pair(_2, _1)) return make_small_kernel_thunk(_2, _1); + if (params == std::make_pair(_3, _1)) return make_small_kernel_thunk(_3, _1); + if (params == std::make_pair(_4, _1)) return make_small_kernel_thunk(_4, _1); + if (params == std::make_pair(_5, _1)) return make_small_kernel_thunk(_5, _1); + if (params == std::make_pair(_6, _1)) return make_small_kernel_thunk(_6, _1); + + // Return a generic KernelThunk for dynamic numbers of arguments and results. + return absl::WrapUnique( + new KernelThunk(std::move(info), arguments_buffers, results_buffers, + std::move(kernel_name), thread_dim, min_alignment)); } } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/kernel_thunk.h b/xla/service/cpu/runtime/kernel_thunk.h index 871176ba73ec5b..134602f99537b5 100644 --- a/xla/service/cpu/runtime/kernel_thunk.h +++ b/xla/service/cpu/runtime/kernel_thunk.h @@ -16,17 +16,19 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ #define XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ +#include +#include #include #include #include #include #include #include +#include #include #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -39,36 +41,64 @@ limitations under the License. namespace xla::cpu { -// Launches compiled host kernel on the caller thread. -class KernelThunk final : public Thunk { - public: - static absl::StatusOr> Create( - Info info, absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment = std::nullopt); +// Forward declare thunk defined below. +class KernelThunk; - tsl::AsyncValueRef Execute(const ExecuteParams& params) final; +namespace internal { +// If the number of kernel parameters (arguments and results) is unknown at +// compile time, we use this value to indicate that the parameter is dynamic. +inline constexpr int64_t kDynamicKernelParameter = -1; + +// A base template for a KernelThunk that can be specialized for a statically +// known number of arguments and results. We go extra mile here to optimize +// host kernel dispatching on the hot execution path to minimize the XLA runtime +// overheads for the smallest HLO modules. +template +class KernelThunk : public Thunk { + public: BufferUses buffer_uses() const final; + protected: + tsl::AsyncValueRef ExecuteInternal(const ExecuteParams& params); + private: + friend class ::xla::cpu::KernelThunk; + + static constexpr bool IsDynamic(size_t n) { + return n == kDynamicKernelParameter; + } + + static constexpr size_t Size(int64_t size) { + return std::max(size, 0); + } + + // If we know the number of arguments and results at compile time, we use + // std::array with a fixed size, which allows compiler to automatically unroll + // all the loops on a hot path. + + using ArgumentsBuffers = std::conditional_t< + IsDynamic(num_arguments), std::vector, + std::array>; + + using ResultsBuffers = std::conditional_t< + IsDynamic(num_results), std::vector, + std::array>; + + using KernelArgs = std::conditional_t< + IsDynamic(num_arguments) || IsDynamic(num_results), + absl::InlinedVector, + std::array>; + KernelThunk(Info info, absl::Span arguments_buffers, absl::Span results_buffers, std::string kernel_name, se::ThreadDim thread_dim, std::optional min_alignment); - // Checks that all buffers are aligned to the minimum alignment. We codegen - // with the assumption that all buffers are aligned, and if they are not, we - // will crash with a segmentation fault, or worse, produce incorrect results. - absl::Status CheckBufferAlignment( - absl::Span kernel_args); - - void VlogKernelArgs(absl::Span kernel_args); - - std::vector arguments_buffers_; - std::vector results_buffers_; + ArgumentsBuffers arguments_buffers_; + ResultsBuffers results_buffers_; size_t num_kernel_args_; @@ -88,7 +118,41 @@ class KernelThunk final : public Thunk { // Pre-initialized kernel arguments that are updated with memory addresses // before the kernel launch. - absl::InlinedVector kernel_args_; + KernelArgs kernel_args_; +}; + +} // namespace internal + +// Kernel thunk specialization for a small kernel with a statically known number +// of arguments and results. +template +class SmallKernelThunk final + : public internal::KernelThunk { + using Base = internal::KernelThunk; + + public: + using Base::Base; + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; +}; + +// Kernel thunk specialization for dynamic number of arguments and results. +class KernelThunk final : public internal::KernelThunk<> { + using Base = internal::KernelThunk<>; + + public: + using Base::Base; + + static absl::StatusOr> Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment = std::nullopt); + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; }; } // namespace xla::cpu diff --git a/xla/service/cpu/runtime/thunk.cc b/xla/service/cpu/runtime/thunk.cc index 5438e60b33d844..9588b02a61a4df 100644 --- a/xla/service/cpu/runtime/thunk.cc +++ b/xla/service/cpu/runtime/thunk.cc @@ -85,6 +85,10 @@ std::string_view Thunk::KindToString(Kind kind) { return "while"; } } +Thunk::Thunk(Kind kind, Info info) + : kind_(kind), + info_(std::move(info)), + ok_event_(OkExecuteEventSingleton()) {} absl::StatusOr Thunk::CollectiveExecuteParams::Create( @@ -150,13 +154,13 @@ Thunk::CustomCallExecuteParams::CustomCallExecuteParams( allocator(allocator), ffi_execution_context(ffi_execution_context) {} -const tsl::AsyncValueOwningRef* Thunk::OkEvent() { - static tsl::AsyncValueOwningRef* owner = [] { +tsl::AsyncValueRef Thunk::OkExecuteEventSingleton() { + static tsl::AsyncValueOwningRef* singleton = [] { auto* storage = new tsl::internal::AsyncValueStorage(); return new tsl::AsyncValueOwningRef( tsl::MakeAvailableAsyncValueRef(*storage)); }(); - return owner; + return singleton->AsRef(); } Thunk::ExecuteState::ExecuteState(int64_t num_tasks) diff --git a/xla/service/cpu/runtime/thunk.h b/xla/service/cpu/runtime/thunk.h index 5bf8cfb8baf01d..0e645f247776c5 100644 --- a/xla/service/cpu/runtime/thunk.h +++ b/xla/service/cpu/runtime/thunk.h @@ -110,7 +110,7 @@ class Thunk { using Task = std::function; using TaskRunner = absl::AnyInvocable; - Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)) {} + Thunk(Kind kind, Info info); Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; @@ -286,18 +286,20 @@ class Thunk { // An execute event that becomes ready when all tasks are completed. using ExecuteEvent = tsl::Chain; - // Returns non-reference-counted async value ref for thunks executed in the - // caller thread to avoid reference counting overhead. - static tsl::AsyncValueRef OkExecuteEvent() { - return OkEvent()->AsRef(); - } + // Returns non-reference-counted async value ref in constructed state. + // Returned async value is a per-process singleton stored in a storage with a + // static duration, and can be safely compared using pointer equality. + static tsl::AsyncValueRef OkExecuteEventSingleton(); + + // Returns `OkExecuteEventSingleton()` cached by this thunk instance. + tsl::AsyncValueRef OkExecuteEvent() const { return ok_event_; } - static bool IsOkExecuteEvent(tsl::AsyncValuePtr event) { - return event == OkEvent()->AsPtr(); + bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) const { + return event == ok_event_; } - static bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) { - return IsOkExecuteEvent(event.AsPtr()); + bool IsOkExecuteEvent(tsl::AsyncValuePtr event) const { + return event == ok_event_.AsPtr(); } // Thunk execution must be asynchronous and never block the caller thread, @@ -339,10 +341,10 @@ class Thunk { } private: - static const tsl::AsyncValueOwningRef* OkEvent(); - Kind kind_; Info info_; + + tsl::AsyncValueRef ok_event_; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); diff --git a/xla/service/cpu/runtime/thunk_executor.cc b/xla/service/cpu/runtime/thunk_executor.cc index 26c084e7e8c5e4..f25fd6119a284d 100644 --- a/xla/service/cpu/runtime/thunk_executor.cc +++ b/xla/service/cpu/runtime/thunk_executor.cc @@ -45,6 +45,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, const ThunkExecutor::Options& options) : thunk_sequence_(std::move(thunk_sequence)), options_(options), + num_thunks_(thunk_sequence_.size()), nodes_defs_(std::move(nodes_defs)), is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { @@ -143,10 +144,10 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, tsl::AsyncValueRef ThunkExecutor::Execute( const Thunk::ExecuteParams& params) { // Short-circuit execution of trivial thunk sequences. - if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) { - return Thunk::OkExecuteEvent(); + if (ABSL_PREDICT_FALSE(num_thunks_ == 0)) { + return Thunk::OkExecuteEventSingleton(); } - if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) { + if (ABSL_PREDICT_FALSE(num_thunks_ == 1)) { return thunk_sequence_[0]->Execute(params); } @@ -161,6 +162,12 @@ tsl::AsyncValueRef ThunkExecutor::Execute( Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()), /*lock=*/params.session.Join()); + // If execution already completed (all kernels executed in the caller thread), + // immediately return the result to avoid wasteful reference counting below. + if (ABSL_PREDICT_TRUE(state->execute_event.IsAvailable())) { + return std::move(state->execute_event); + } + // Move execute state to the execute event callback to ensure that it is kept // alive while thunk executor has pending tasks. auto execute_event = state->execute_event; @@ -176,12 +183,12 @@ tsl::AsyncValueRef ThunkExecutor::Execute( tsl::AsyncValueRef ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { - for (int64_t i = 0; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (auto it = thunk_sequence_.begin(); it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. - if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { continue; } @@ -189,11 +196,11 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { auto event = tsl::MakeConstructedAsyncValueRef(); - execute_event.AndThen([this, ¶ms, i, event](absl::Status status) { + execute_event.AndThen([this, ¶ms, it, event](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return event; @@ -207,18 +214,18 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { // If we got to the end of the sequence it means that all thunks have // succeeded. - return Thunk::OkExecuteEvent(); + return Thunk::OkExecuteEventSingleton(); } void ThunkExecutor::ResumeExecuteSequential( - int64_t index, const Thunk::ExecuteParams& params, + ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event) { - for (int64_t i = index; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (; it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); // Fast path for thunks executed inline and returned OkExecuteEvent. - if (ABSL_PREDICT_TRUE(Thunk::IsOkExecuteEvent(execute_event))) { + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { continue; } @@ -226,11 +233,11 @@ void ThunkExecutor::ResumeExecuteSequential( // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { execute_event.AndThen( - [this, ¶ms, i, event = std::move(event)](absl::Status status) { + [this, ¶ms, it, event = std::move(event)](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return; @@ -281,7 +288,7 @@ void ThunkExecutor::Execute(ExecuteState* state, Thunk& thunk = *state->executor->thunk_sequence_[id]; tsl::AsyncValueRef execute_event = ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed)) - ? Thunk::OkExecuteEvent() + ? Thunk::OkExecuteEventSingleton() : thunk.Execute(params); if (ABSL_PREDICT_TRUE(execute_event.IsAvailable())) { @@ -471,11 +478,11 @@ int64_t ThunkExecutor::TransitiveReduction() { std::string ThunkExecutor::ToString() const { std::string str = absl::StrFormat( - "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", - thunk_sequence_.size(), source_.size(), sink_.size()); + "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", num_thunks_, + source_.size(), sink_.size()); // Collect names of `in_edges`. - std::vector> in_edges(thunk_sequence_.size()); + std::vector> in_edges(num_thunks_); for (const auto& node_def : nodes_defs_) { for (NodeId in_edge : node_def.in_edges) { in_edges[node_def.id].push_back(thunk_sequence_[in_edge]->info().op_name); @@ -483,7 +490,7 @@ std::string ThunkExecutor::ToString() const { } // Print thunks with a list of their dependencies; - for (NodeId i = 0; i < thunk_sequence_.size(); ++i) { + for (NodeId i = 0; i < num_thunks_; ++i) { const Thunk& thunk = *thunk_sequence_[i]; bool is_source = absl::c_find(source_, i) != source_.end(); bool is_sink = absl::c_find(sink_, i) != sink_.end(); diff --git a/xla/service/cpu/runtime/thunk_executor.h b/xla/service/cpu/runtime/thunk_executor.h index 8965a7a51652a4..67a66c422bf5c6 100644 --- a/xla/service/cpu/runtime/thunk_executor.h +++ b/xla/service/cpu/runtime/thunk_executor.h @@ -144,7 +144,8 @@ class ThunkExecutor { const Thunk::ExecuteParams& params); // Resumes sequential thunk execution starting from the given index. - void ResumeExecuteSequential(int64_t index, + using ThunkIterator = typename ThunkSequence::iterator; + void ResumeExecuteSequential(ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event); @@ -173,6 +174,8 @@ class ThunkExecutor { ThunkSequence thunk_sequence_; Options options_; + int64_t num_thunks_; + std::vector nodes_defs_; std::vector source_; diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD index cb3d691eb839a0..043dbf0cc801d7 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -83,7 +83,6 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index 98fcd5172728f6..7d2a929822b1fb 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -2500,6 +2501,42 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( return Internal("Couldn't convert layout."); } +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder) { + llvm::SmallVector element_attrs; + alias.ForEachAlias([&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + std::string kindToString; + switch (alias.kind) { + case HloInputOutputAliasConfig::AliasKind::kMayAlias: + kindToString = "may_alias"; + break; + case HloInputOutputAliasConfig::AliasKind::kMustAlias: + kindToString = "must_alias"; + break; + default: + kindToString = "undefined_alias"; + } + mlir::NamedAttribute alias_named_attributes[3] = { + builder->getNamedAttr( + "parameter_index", + builder->getDenseI64ArrayAttr(ArrayRef( + alias.parameter_index.begin(), alias.parameter_index.end()))), + builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr( + alias.parameter_number)), + builder->getNamedAttr("kind", builder->getStringAttr(kindToString))}; + + mlir::NamedAttribute named_attributes[2] = { + builder->getNamedAttr("output_index", + builder->getDenseI64ArrayAttr(ArrayRef( + output_index.begin(), output_index.end()))), + builder->getNamedAttr( + "alias", builder->getDictionaryAttr(alias_named_attributes))}; + element_attrs.push_back(builder->getDictionaryAttr(named_attributes)); + }); + return builder->getArrayAttr(element_attrs); +} + mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/xla/translate/hlo_to_mhlo/hlo_function_importer.h index cb3953990f4030..5c5a4e309bfbf6 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/ValueRange.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -297,6 +298,12 @@ class HloFunctionImporter { bool flatten_computation_args_result_; }; +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 1f2ea997c81e8a..76037442d52099 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -122,6 +122,10 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ConvertSharding(hlo_module.spmd_output_sharding(), &builder_)); } + module->setAttr("mhlo.input_output_alias", + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder_)); + if (hlo_module.has_spmd_parameters_shardings()) { llvm::SmallVector parameter_shardings; parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size()); diff --git a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo index 74eaaea5a0e8fe..d3433dce372cbf 100644 --- a/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo +++ b/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo @@ -5,6 +5,18 @@ # FLATTEN-CHECK-LABEL: module @main attributes { hlo_module { name: "main" + input_output_alias { + entries { + output_shape_index: 0 + parameter_number: 0 + kind: MAY_ALIAS + } + entries { + output_shape_index: 1 + parameter_number: 1 + kind: MAY_ALIAS + } + } entry_computation_name: "main.5" computations { name: "main.5" @@ -217,6 +229,7 @@ hlo_module { value: "attr_value" } } +# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array, parameter_number = 0 : i64}, output_index = array}, {alias = {kind = "may_alias", parameter_index = array, parameter_number = 1 : i64}, output_index = array}] # CHECK-SAME: mhlo.is_dynamic = true is_dynamic: true # CHECK-SAME: mhlo.use_auto_spmd_partitioning = true diff --git a/xla/translate/mhlo_to_hlo/BUILD b/xla/translate/mhlo_to_hlo/BUILD index 92b7265298f6e7..3de8007804af4b 100644 --- a/xla/translate/mhlo_to_hlo/BUILD +++ b/xla/translate/mhlo_to_hlo/BUILD @@ -23,6 +23,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/xla/translate/mhlo_to_hlo/attribute_exporter.cc index a492861b28d831..73a5c8b994e57e 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" @@ -185,4 +187,99 @@ std::optional ConvertSharding(llvm::StringRef sharding) { return std::nullopt; } +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing) { + if (aliasing.empty()) return std::nullopt; + + xla::HloInputOutputAliasProto input_output_alias_proto; + for (auto attr : aliasing) { + auto entry_attr = mlir::cast(attr); + auto alias_attr = mlir::cast(entry_attr.get("alias")); + mlir::ArrayRef output_index = + mlir::cast(entry_attr.get("output_index")) + .asArrayRef(); + mlir::ArrayRef parameter_index = + mlir::cast(alias_attr.get("parameter_index")) + .asArrayRef(); + HloInputOutputAliasProto::AliasEntryProto entry; + entry.mutable_output_shape_index()->Add(output_index.begin(), + output_index.end()); + entry.set_parameter_number( + mlir::cast(alias_attr.get("parameter_number")) + .getInt()); + entry.mutable_parameter_shape_index()->Add(parameter_index.begin(), + parameter_index.end()); + mlir::StringRef kind = + mlir::cast(alias_attr.get("kind")).getValue(); + if (kind == "may_alias") + entry.set_kind(xla::Kind::MAY_ALIAS); + else if (kind == "must_alias") + entry.set_kind(xla::Kind::MUST_ALIAS); + else + entry.set_kind(xla::Kind::UNDEFINED_ALIAS); + input_output_alias_proto.add_entries()->Swap(&entry); + } + return input_output_alias_proto; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + mlir::mhlo::DotDimensionNumbersAttr input) { + DotDimensionNumbers output; + + for (auto v : input.getLhsBatchingDimensions()) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : input.getRhsBatchingDimensions()) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : input.getLhsContractingDimensions()) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : input.getRhsContractingDimensions()) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + absl::Span lhs_batch, absl::Span lhs_contract, + absl::Span rhs_batch, + absl::Span rhs_contract) { + DotDimensionNumbers output; + for (auto v : lhs_batch) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : rhs_batch) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : lhs_contract) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : rhs_contract) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( + const mlir::ArrayAttr& array) { + int rank = array.size(); + std::vector converted_array(rank); + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = mlir::dyn_cast(array[i]); + if (!attr) { + return Internal("Type Error: Expected layout integer attribute"); + } + converted_array[i] = attr.getInt(); + } + return converted_array; +} } // namespace xla diff --git a/xla/translate/mhlo_to_hlo/attribute_exporter.h b/xla/translate/mhlo_to_hlo/attribute_exporter.h index e0e0dc9821d21e..49daefe6935650 100644 --- a/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -59,5 +60,8 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Will fail if both attempts at parsing failed. std::optional ConvertSharding(mlir::StringRef sharding); +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 623080e11fd60d..90eb1a902127bc 100644 --- a/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3736,6 +3736,13 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, *hlo_module.mutable_spmd_output_sharding() = *xla::ConvertSharding(spmd_output_sharding.getValue()); } + if (auto input_output_alias = + module->getAttrOfType("mhlo.input_output_alias")) { + if (std::optional input_output_alias_proto = + xla::ConvertInputOutputAlias(input_output_alias.getValue())) { + *hlo_module.mutable_input_output_alias() = *input_output_alias_proto; + } + } if (auto spmd_parameters_sharding = module->getAttrOfType( "mhlo.spmd_parameters_shardings")) { for (const auto& sharding : spmd_parameters_sharding.getValue()) { diff --git a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir index 049456bb09e6f7..6ad08374e5d2e6 100644 --- a/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir +++ b/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir @@ -100,3 +100,45 @@ module @ModuleWithFrontendAttributes attributes { func.return %arg0 : tensor<1xf32> } } + + + +// ----- + +module attributes { +// CHECK: input_output_alias { +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 0 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 1 +// CHECK-NEXT: parameter_number: 1 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: } + mhlo.input_output_alias = [ + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 0 : i64 + }, + output_index = array + }, + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 1 : i64 + }, + output_index = array + } +] +} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) { + func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32> + } +} \ No newline at end of file