From f221555504d6277212e966f545913c1147ddd855 Mon Sep 17 00:00:00 2001 From: Toli Yevtushenko Date: Tue, 7 Jan 2025 11:28:05 -0800 Subject: [PATCH 01/45] Make FindInstruction methods public. PiperOrigin-RevId: 712983498 --- .../hlo_hardware_independent_test_base.cc | 7 ++-- .../hlo_hardware_independent_test_base.h | 33 ++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index bbe1ecea736a3e..d5af349ef6dece 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" @@ -119,7 +120,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); - return std::move(module); + return module; } /* static */ @@ -258,9 +259,11 @@ HloHardwareIndependentTestBase::RunAndCheckHloRewrite( VLOG(7) << "Input HLO: " << hlo_string; TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); + VLOG(7) << "Input HLO parsed. Running the pass: + " << hlo_pass.name(); TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); VLOG(7) << "Output HLO: " - << module->ToString(HloPrintOptions::ShortParsable()); + << module->ToString(HloPrintOptions::ShortParsable() + .set_print_control_dependencies(true)); EXPECT_EQ(changed, expect_change); return module; } diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 2a7f1f488b54e8..e41bcea3e4d828 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -55,6 +55,23 @@ class HloHardwareIndependentTestBase : public ::testing::Test { public: static PrecisionConfig DefaultPrecisionConfig(int operands); + // Gets the computation/instruction from the given module with the given name. + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + static HloComputation* FindComputation(HloModule* module, + absl::string_view name); + static HloInstruction* FindInstruction(HloModule* module, + absl::string_view name); + // Gets the instruction from the given module with the given opcode. + static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); + // Gets all the instructions from the given module with the given opcode. + static std::vector FindInstructions(HloModule* module, + HloOpcode opcode); + protected: explicit HloHardwareIndependentTestBase( bool verifier_layout_sensitive = false, @@ -199,22 +216,6 @@ class HloHardwareIndependentTestBase : public ::testing::Test { ->Clear(); } - // Gets the computation/instruction from the given module with the given name. - // Note that it is encouraged to use these functions directly via the - // hlo_query.h header instead since they are independent from any test-time - // variables or contexts. - - // This is useful for tests which create HLOs from a string and then want to - // inspect a particular computation or instruction. - static HloComputation* FindComputation(HloModule* module, - absl::string_view name); - static HloInstruction* FindInstruction(HloModule* module, - absl::string_view name); - // Gets the instruction from the given module with the given opcode. - static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); - // Gets all the instructions from the given module with the given opcode. - static std::vector FindInstructions(HloModule* module, - HloOpcode opcode); bool verifier_layout_sensitive() const { return verifier_layout_sensitive_; } void set_verifier_layout_sensitive(bool verifier_layout_sensitive) { From 94950f9ab1d6da1475ea190c8ba10bf03ac19557 Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 7 Jan 2025 11:57:51 -0800 Subject: [PATCH 02/45] Reverts 1112df1cd1be6e7ca1c496166eface070e9558aa PiperOrigin-RevId: 712993078 --- xla/service/spmd/spmd_partitioner.cc | 60 +++++++++------------ xla/service/spmd/spmd_partitioner_test.cc | 64 +++-------------------- 2 files changed, 30 insertions(+), 94 deletions(-) diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 1abdf7359f71b7..9d0912d4b4c5a4 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -3355,48 +3355,36 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { if (hlo->sharding().IsTileMaximal()) { return DefaultAction(hlo); } - - // Replicate along the slice dims to get temp_sharding. - std::vector slice_dims; for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (hlo->dynamic_slice_sizes()[i] != - hlo->operand(0)->shape().dimensions(i)) { - slice_dims.push_back(i); + if (hlo->sharding().tile_assignment().dim(i) != 1 && + hlo->dynamic_slice_sizes()[i] != + hlo->operand(0)->shape().dimensions(i)) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); } } - const HloSharding temp_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), - slice_dims); - - // Reshard the input to temp_sharding. - HloInstruction* input_with_temp_sharding = - GetPartitionedHlo(hlo->operand(0)).Reshard(temp_sharding).hlo(); - - std::vector new_indices; - new_indices.reserve(hlo->shape().rank()); - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (hlo->dynamic_slice_sizes()[i] != + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64_t i = 0; i < new_indices.size(); ++i) { + if (hlo->dynamic_slice_sizes()[i] == hlo->operand(0)->shape().dimensions(i)) { - new_indices.push_back( - GetPartitionedHlo(hlo->operand(i + 1)).Replicate().hlo()); - } else { - // Index must be clamped to be 0. - new_indices.push_back(CreateZero(hlo->operand(i + 1)->shape(), &b_)); + // Trivial slice dim: index must be clampped to 0. + new_indices[i] = CreateZero(hlo->operand(i + 1)->shape(), &b_); + continue; } + // Replicate the indices.; + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); } - - // Apply dynamic slice with temp_sharding. - Shape temp_sharded_shape = MakePartitionedShape(hlo->shape(), temp_sharding); - HloInstruction* ds_with_temp_sharding = - b_.AddInstruction(HloInstruction::CreateDynamicSlice( - temp_sharded_shape, input_with_temp_sharding, new_indices, - temp_sharded_shape.dimensions())); - ds_with_temp_sharding->set_sharding(temp_sharding); - - // Reshard the output to the final sharding. - SetPartitionedHlo(hlo, PartitionedHlo(ds_with_temp_sharding, hlo->shape(), - MakePartitioningState()) - .Reshard(hlo->sharding())); + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); return absl::OkStatus(); } diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 727448674a5e1b..8e9823d413ac41 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -7531,7 +7531,7 @@ ENTRY entry { EXPECT_THAT(root, op::PartitionId()); } -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedBatchDimension) { +TEST_P(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { absl::string_view hlo_string = R"( HloModule module @@ -7539,71 +7539,19 @@ ENTRY entry { %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1} %index = s32[] parameter(1) %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[2,1]0,1} + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input, %trivial_index, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); + const auto root = module->entry_computation()->root_instruction(); auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]")); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(root, AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), - op::Shape("s32[64,16]"))); -} - -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedSliceDimension) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %input = s32[128,64] parameter(0), sharding={devices=[1,2]0,1} - %index = s32[] parameter(1) - %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[1,2]0,1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/2)); - - auto input = AllOf(op::Parameter(0), op::Shape("s32[128,32]")); - auto input_replicated = - AllOf(op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), input, _, _)), - op::Shape("s32[128,64]")); - auto ds_replicated = AllOf( - op::DynamicSlice(input_replicated, op::Constant(), op::Parameter(1)), - op::Shape("s32[128,16]")); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - AllOf(op::DynamicSlice(ds_replicated, _, _), op::Shape("s32[128,8]"))); -} - -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedBothDimensions) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %input = s32[128,64] parameter(0), sharding={devices=[2,2]<=[4]} - %index = s32[] parameter(1) - %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[2,2]<=[4]} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/4)); - - auto input = AllOf(op::Parameter(0), op::Shape("s32[64,32]")); - auto input_reshard = - AllOf(op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), input, _, _)), - op::Shape("s32[64,64]")); - auto ds = - AllOf(op::DynamicSlice(input_reshard, op::Constant(), op::Parameter(1)), - op::Shape("s32[64,16]")); - EXPECT_THAT(module->entry_computation()->root_instruction(), - AllOf(op::DynamicSlice(ds, _, _), op::Shape("s32[64,8]"))); + op::Shape("s32[64,2]"))); } TEST_P(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { From ade659b1b005db2c1491ae18e53642f98566f476 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Tue, 7 Jan 2025 12:08:30 -0800 Subject: [PATCH 03/45] Reverts 313d56fc66638fc32abdba49f2614b54df51f900 PiperOrigin-RevId: 712996803 --- xla/pjrt/c/CHANGELOG.md | 6 +++ xla/pjrt/c/pjrt_c_api.h | 40 ++++++++++++++++++- xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 6 +-- xla/pjrt/c/pjrt_c_api_helpers.cc | 38 ++++++++++++++++++ xla/pjrt/c/pjrt_c_api_helpers.h | 17 +++++--- xla/pjrt/c/pjrt_c_api_helpers_test.cc | 8 ++++ xla/pjrt/c/pjrt_c_api_test_base.cc | 4 +- xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 36 +++++++++++++++-- xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 1 + xla/pjrt/distributed/client.cc | 12 ++++++ xla/pjrt/distributed/client.h | 4 ++ xla/pjrt/distributed/client_server_test.cc | 14 +++++++ .../distributed/in_memory_key_value_store.cc | 12 ++++++ .../distributed/in_memory_key_value_store.h | 4 ++ .../distributed/key_value_store_interface.h | 7 ++++ xla/pjrt/pjrt_c_api_client.cc | 2 + xla/python/xla.cc | 15 +++++++ xla/python/xla_extension/__init__.pyi | 2 + 18 files changed, 214 insertions(+), 14 deletions(-) diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 5852c9a54dcc01..d56741eb3500b0 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,10 @@ # PJRT C API changelog + +## 0.61 +* Added ``PJRT_KeyValueTryGet`` to the KV store interface, + which is non-blocking and immediately returns an error if the + key is not found. + ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 36d82b0787ba41..f2fc3b1c507a3c 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 60 +#define PJRT_API_MINOR 61 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -351,6 +351,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); +// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is +// not found. +typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); + +struct PJRT_KeyValueTryGetCallback_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const char* key; + size_t key_size; + PJRT_CallbackError* callback_error; + void* user_arg; + char* value; // out + size_t value_size; // out + // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to + // delete the value returned by PJRT_KeyValueTryGetCallback. The + // implementation is responsible for copying `value` and then calling + // value_deleter_callback. + PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, + value_deleter_callback); + +// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different nodes in one plugin). +typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( + PJRT_KeyValueTryGetCallback_Args* args); + struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -389,8 +418,15 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out + + // Key-value try-get callback provided by the caller of PJRT_Client_Create. + // Same as key-value get callback, but returns `NotFoundError` immediately if + // the key is not found. + PJRT_KeyValueTryGetCallback kv_try_get_callback; + // Will be passed to `kv_try_get_callback` as `user_arg` argument. + void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 4f53c640a6a3dc..68d36fdb7f5c86 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = - pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, - args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = pjrt::ToCppKeyValueStore( + args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, + args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 2060a73a634a48..c5d4b92c1a541e 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -797,6 +797,25 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } +static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( + xla::KeyValueStoreInterface* kv_store) { + return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + absl::StatusOr output = + kv_store->TryGet(absl::string_view(args->key, args->key_size)); + if (!output.ok()) { + absl::string_view message = output.status().message(); + return (*args->callback_error)( + StatusCodeToPjrtErrorCode(output.status().code()), message.data(), + message.size()); + } + args->value = new char[output->size()]; + std::copy(output->begin(), output->end(), args->value); + args->value_size = output->size(); + args->value_deleter_callback = &PjRtValueDeleterCallback; + return nullptr; + }; +} + static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -828,6 +847,22 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } +static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { + return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = + reinterpret_cast(args->user_arg); + if (kv_try_get_c_func == nullptr) { + absl::Status status = xla::InvalidArgument( + "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); + return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), + status.message().data(), + status.message().size()); + } + return (*kv_try_get_c_func)(args); + }; +} + static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -848,9 +883,12 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); + kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); + kv_callback_data->c_kv_try_get = + ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index 709558fba465af..d7a4286571b730 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -218,6 +218,9 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; +using PJRT_KeyValueTryGetCFunc = + std::function; + using PJRT_KeyValuePutCFunc = std::function; @@ -228,17 +231,21 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. + // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to + // kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and - // kv_put_c_func. + // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, + // kv_try_get_c_func and kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; + pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; + PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func and &kv_put_c_func must be set as -// PJRT_Client_Create_Args.kv_get_user_arg and +// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be +// set as PJRT_Client_Create_Args.kv_get_user_arg, +// PJRT_Client_Create_Args.kv_try_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 4b8a59287589ed..6dfce81a1e4514 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,14 +108,22 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, + kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); + auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); + EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); + auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); + + auto v_2 = converted_kv_store->TryGet("key"); + TF_EXPECT_OK(v.status()); + EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index 9602813c573c52..f867846ebcbd54 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,9 +47,11 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; + create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_get_user_arg = nullptr; + create_args.kv_try_get_callback = nullptr; + create_args.kv_try_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 64aa20bac3c0e2..f832fad0c997c3 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -235,9 +235,13 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, + void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), + c_try_get_callback_(c_try_get_callback), + try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} @@ -264,6 +268,27 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } + absl::StatusOr TryGet(absl::string_view key) override { + PJRT_CallbackError callback_error = [](PJRT_Error_Code code, + const char* message, + size_t message_size) { + return new PJRT_Error{absl::Status(static_cast(code), + std::string(message, message_size))}; + }; + PJRT_KeyValueTryGetCallback_Args args; + args.key = key.data(); + args.key_size = key.size(); + args.callback_error = &callback_error; + args.user_arg = try_get_user_arg_; + std::unique_ptr error(c_try_get_callback_(&args)); + if (error != nullptr) { + return error->status; + } + auto result = std::string(args.value, args.value_size); + args.value_deleter_callback(args.value); + return result; + } + absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -288,18 +313,23 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; + PJRT_KeyValueTryGetCallback c_try_get_callback_; + void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_try_get_callback == nullptr || + c_put_callback == nullptr) { return nullptr; } - return std::make_shared(c_get_callback, get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared( + c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 04463410ee7e08..27b1cac051dbd0 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -464,6 +464,7 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 280c60873e9d07..305afe7ae4c6d4 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -53,6 +54,7 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) override; + absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> KeyValueDirGet(absl::string_view key) override; absl::Status KeyValueSet(absl::string_view key, @@ -144,6 +146,12 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( return coord_agent_->GetKeyValue(key, timeout); } +absl::StatusOr +DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( + absl::string_view key) { + return coord_agent_->TryGetKeyValue(key); +} + absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( absl::string_view key) { @@ -216,6 +224,10 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } + absl::StatusOr TryGet(absl::string_view key) override { + return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); + } + absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index e597ff158cc674..58f4fe367681d2 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -116,6 +117,9 @@ class DistributedRuntimeClient { virtual absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) = 0; + // Returns `NotFoundError` immediately if the key is not found. + virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; + // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index f5b7e656fe69a2..baec103eced933 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1029,6 +1029,20 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } +TEST_F(ClientServerTest, KeyValueTryGet) { + StartService(/*num_nodes=*/1); + auto client = GetClient(/*node_id=*/0); + TF_ASSERT_OK(client->Connect()); + + ASSERT_THAT(client->KeyValueTryGet("test_key").status(), + StatusIs(absl::StatusCode::kNotFound)); + + TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); + auto result = client->KeyValueTryGet("test_key"); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(result.value(), "value"); +} + TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc index 70cc5360ecf7b3..49fc73ec87f163 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -40,6 +41,17 @@ absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, return kv_store_.find(key)->second; } +absl::StatusOr InMemoryKeyValueStore::TryGet( + absl::string_view key) { + absl::MutexLock lock(&mu_); + auto it = kv_store_.find(key); + if (it == kv_store_.end()) { + return absl::NotFoundError( + absl::StrCat(key, " is not found in the kv store.")); + } + return it->second; +} + absl::Status InMemoryKeyValueStore::Set(absl::string_view key, absl::string_view value) { absl::MutexLock lock(&mu_); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h index 1530633a98b754..13f50c722bd125 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -21,7 +21,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { @@ -31,6 +33,8 @@ class InMemoryKeyValueStore : public KeyValueStoreInterface { absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; + absl::StatusOr TryGet(absl::string_view key) override; + absl::Status Set(absl::string_view key, absl::string_view value) override; private: diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h index 29580fb86847b1..312ebb8abb6463 100644 --- a/xla/pjrt/distributed/key_value_store_interface.h +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -38,11 +38,18 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). + // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; + // Returns `NotFoundError` immediately if the key is not found. + // Useful for checking key existence. + // There are no concurrency guarantees. To avoid a race / impose an ordering + // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). + virtual absl::StatusOr TryGet(absl::string_view key) = 0; + virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index b7dea23fe13c36..00e242434f4376 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2599,6 +2599,8 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; + init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; + init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 219d6704b4f791..647fc37f089df7 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -672,6 +672,21 @@ NB_MODULE(xla_extension, m) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + nb::gil_scoped_release gil_release; + std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 2e3862285898f2..5fa885f9f92255 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -830,6 +830,8 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str, From 2c9e4b2f87b86cd71356d56bb03c8ee3e43e76c5 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 7 Jan 2025 13:06:00 -0800 Subject: [PATCH 04/45] [xla:cpu] FFI: Add support for token arguments and results Fix for https://github.com/jax-ml/jax/issues/25756 PiperOrigin-RevId: 713015117 --- xla/backends/cpu/runtime/custom_call_thunk.cc | 12 ++++++++ xla/tests/custom_call_test.cc | 30 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/xla/backends/cpu/runtime/custom_call_thunk.cc b/xla/backends/cpu/runtime/custom_call_thunk.cc index 8f693a1e3c5378..974a77522ac77d 100644 --- a/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -132,6 +132,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.arguments_buffers.size(); ++i) { auto& shape = op_buffers.arguments_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenArg(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); @@ -144,6 +150,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.results_buffers.size(); ++i) { auto& shape = op_buffers.results_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenRet(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); diff --git a/xla/tests/custom_call_test.cc b/xla/tests/custom_call_test.cc index 3f264f1996fc63..ff88a0de868cf8 100644 --- a/xla/tests/custom_call_test.cc +++ b/xla/tests/custom_call_test.cc @@ -409,6 +409,18 @@ XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", PLATFORM, kAlwaysFail); +static absl::Status Tokens(ffi::Token, ffi::Result, + ffi::Result) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kTokens, Tokens, + ffi::Ffi::Bind().Arg().Ret().Ret()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens", PLATFORM, + kTokens); + static absl::Status FfiR0F32Add2(R0F32Buffer in, R0F32ResultBuffer out) { auto in_data = in.typed_data(); auto out_data = out->typed_data(); @@ -843,6 +855,24 @@ XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { EXPECT_EQ(status, absl::OkStatus()); } +XLA_TEST_F(FfiCustomCallTest, Tokens) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + + std::vector ret = {ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape()}; + + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape(ret), {token}, "__xla_test$$tokens", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + module->AddEntryComputation(builder.Build()); + + auto status = Execute(std::move(module), {}).status(); + EXPECT_EQ(status, absl::OkStatus()); +} + XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) { auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); From c65992483df310fa4192df6aaa5cd7cb1614f213 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Tue, 7 Jan 2025 13:06:45 -0800 Subject: [PATCH 05/45] Add new class xla::ifrt::PjRtMemoryDescription. (This only adds the class, in preparation of plumbing memory descriptions through IFRT. No functional changes yet.) PiperOrigin-RevId: 713015406 --- xla/python/pjrt_ifrt/pjrt_memory.cc | 25 +++++++++++++++++++++++++ xla/python/pjrt_ifrt/pjrt_memory.h | 25 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/xla/python/pjrt_ifrt/pjrt_memory.cc b/xla/python/pjrt_ifrt/pjrt_memory.cc index 8edb3bfa29fe2c..5217eb72b1fbdc 100644 --- a/xla/python/pjrt_ifrt/pjrt_memory.cc +++ b/xla/python/pjrt_ifrt/pjrt_memory.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" @@ -29,6 +30,7 @@ namespace ifrt { char PjRtCompatibleMemory::ID = 0; char PjRtMemory::ID = 0; +char PjRtMemoryDescription::ID = 0; PjRtMemory::PjRtMemory(PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory) : client_(client), pjrt_memory_(pjrt_memory), kind_(pjrt_memory->kind()) { @@ -51,6 +53,29 @@ absl::string_view PjRtMemory::DebugString() const { absl::Span PjRtMemory::Devices() const { return devices_; } +PjRtMemoryDescription::PjRtMemoryDescription( + PjRtClient* client, absl::Span devices, + const xla::PjRtMemorySpaceDescription* desc) + : desc_(desc), kind_(desc->kind()) { + for (auto device : devices) { + devices_.push_back(device); + } +} + +MemoryId PjRtMemoryDescription::Id() const { + return MemoryId(desc_->kind_id()); +} + +const MemoryKind& PjRtMemoryDescription::Kind() const { return kind_; } + +absl::string_view PjRtMemoryDescription::ToString() const { + return desc_->kind(); +} + +absl::string_view PjRtMemoryDescription::DebugString() const { + return desc_->kind(); +} + MemoryKind CanonicalizeMemoryKindWithPjRtDevice(MemoryKind memory_kind, xla::PjRtDevice* device) { if (memory_kind.memory_kind().has_value()) { diff --git a/xla/python/pjrt_ifrt/pjrt_memory.h b/xla/python/pjrt_ifrt/pjrt_memory.h index 3964ac56b184d5..f6517f9e191d9e 100644 --- a/xla/python/pjrt_ifrt/pjrt_memory.h +++ b/xla/python/pjrt_ifrt/pjrt_memory.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/memory.h" namespace xla { @@ -60,6 +61,30 @@ class PjRtMemory final std::vector devices_; }; +class PjRtMemoryDescription final + : public llvm::RTTIExtends { + public: + PjRtMemoryDescription(PjRtClient* client, absl::Span devices, + const xla::PjRtMemorySpaceDescription* desc); + + PjRtClient* client() const { return client_; } + xla::PjRtMemorySpace* pjrt_memory() override { return nullptr; } + + MemoryId Id() const override; + const MemoryKind& Kind() const override; + absl::string_view ToString() const override; + absl::string_view DebugString() const override; + absl::Span Devices() const override { return devices_; } + + static char ID; // NOLINT + + private: + PjRtClient* client_; + const xla::PjRtMemorySpaceDescription* desc_; + MemoryKind kind_; + std::vector devices_; +}; + // Canonicalizes `MemoryKind`. If `MemoryKind` has no memory kind chosen, // returns a default `MemoryKind` chosen for the PjRt device. If there is no // default indicated by the device, simply returns `MemoryKind` with no memory From 63eacb06e3a704919d651d95ba087c1e2566b3b0 Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 7 Jan 2025 13:31:45 -0800 Subject: [PATCH 06/45] Implement infeed and outfeed support for `HloRunnerPjRt`. PiperOrigin-RevId: 713023324 --- xla/service/BUILD | 2 + xla/service/hlo_runner_pjrt.cc | 157 +++++++++++++++++++++++---------- xla/service/hlo_runner_pjrt.h | 24 ++--- 3 files changed, 123 insertions(+), 60 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 8cd9cac1da809d..b041d7e59211da 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4636,10 +4636,12 @@ cc_library( "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ], diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index b4b9e1cd889c39..9a2d0c72955516 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -23,11 +23,13 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/die_if_null.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -135,6 +137,27 @@ absl::StatusOr GetStaticDeviceAssignmentOrComputeDefault( module.config().num_partitions()); } +std::vector BufferVecToPointerVec( + const absl::Span> buffer) { + std::vector argument_ptrs; + argument_ptrs.resize(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs[i] = buffer[i].get(); + } + + return argument_ptrs; +} + +std::vector> BufferMatToPointerMat( + const absl::Span>> buffer) { + std::vector> argument_ptrs; + argument_ptrs.reserve(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); + } + return argument_ptrs; +} + } // namespace // TODO(b/245550554): Remove the use of PjRtWrappedExecutable. @@ -314,27 +337,6 @@ absl::StatusOr HloRunnerPjRt::Execute( return ExecuteWithExecutable(executable.get(), arguments, {}); } -std::vector HloRunnerPjRt::BufferVecToPointerVec( - const std::vector>& buffer) { - std::vector argument_ptrs; - argument_ptrs.resize(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs[i] = buffer[i].get(); - } - - return argument_ptrs; -} - -std::vector> HloRunnerPjRt::BufferMatToPointerMat( - std::vector>>& buffer) { - std::vector> argument_ptrs; - argument_ptrs.reserve(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); - } - return argument_ptrs; -} - absl::StatusOr> HloRunnerPjRt::CreateExecutable(HloModule* module, CompileOptions compile_options) { @@ -442,7 +444,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile) { return ExecuteReplicatedImpl( - [&](absl::Span>& argument_buffer_slices) + [&](absl::Span> argument_buffer_slices) -> absl::StatusOr>> { PjRtWrappedExecutable* wrapped_executable = static_cast(executable); @@ -476,7 +478,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( TF_RET_CHECK(device_assignment->computation_count() == 1) << "Only single-computation execution is supported."; return ExecuteReplicatedImpl( - [&](absl::Span>& argument_buffer_slices) + [&](absl::Span> argument_buffer_slices) -> absl::StatusOr>> { TF_RET_CHECK(options.use_threads); @@ -538,26 +540,29 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { + TF_RET_CHECK(options.infeed_values.empty() || + options.infeed_values.size() == options.num_replicas); + + std::vector replica_devices(options.num_replicas, nullptr); std::vector>> argument_buffer_slices; argument_buffer_slices.reserve(options.num_replicas); - for (int64_t i = 0; i < options.num_replicas; ++i) { - TF_ASSIGN_OR_RETURN(PjRtDevice * device_ptr, + // Amortize device lookup. + TF_ASSIGN_OR_RETURN(PjRtDevice* const device_ptr, pjrt_client_->LookupDevice( DeviceIdForInvocation(*device_assignment, i))); + replica_devices[i] = device_ptr; // Transfer literals to device. const int64_t argument_count = argument_count_provider(i); - std::vector> replica_buffers; replica_buffers.reserve(argument_count); - for (int64_t arg_index = 0; arg_index < argument_count; arg_index++) { const Literal* const argument = argument_provider(i, arg_index); TF_RET_CHECK(argument != nullptr); @@ -570,37 +575,93 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( : pjrt_client_->BufferFromHostLiteral(*argument, device_ptr)); replica_buffers.push_back(std::move(assignment)); } - argument_buffer_slices.push_back(std::move(replica_buffers)); } - TF_RET_CHECK(options.infeed_values.empty() || - options.infeed_values.size() == options.num_replicas); - - if (!options.infeed_values.empty()) { - // TODO(b/245550554): Infeed/Outfeed + // Handle infeed and outfeed. + const bool has_infeed = !options.infeed_values.empty(); + const bool has_outfeed = ShapeUtil::IsInitialized(options.outfeed_shape); + std::unique_ptr pool = nullptr; + absl::Mutex infeed_outfeed_status_mu; + absl::Status infeed_outfeed_status = absl::OkStatus(); + if (has_infeed || has_outfeed) { + // One infeed per infeed value and one outfeed per replica. + const int64_t num_threads = + options.infeed_values.size() + (has_outfeed ? options.num_replicas : 0); + pool = std::make_unique( + tsl::Env::Default(), "infeed_outfeed", num_threads); } - - if (ShapeUtil::IsInitialized(options.outfeed_shape)) { - // TODO(b/245550554): Infeed/Outfeed + if (has_infeed) { + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule( + [device = replica_devices[i], + &infeed_literal = *ABSL_DIE_IF_NULL(options.infeed_values[i]), + infeed_steps = options.infeed_steps, &infeed_outfeed_status_mu, + &infeed_outfeed_status]() { + VLOG(1) << "Starting infeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + per_feed_status.Update(device->TransferToInfeed(infeed_literal)); + if (step % 100 == 0) { + VLOG(1) << "Infeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } + } + if (has_outfeed) { + if (options.outfeed_values != nullptr) { + options.outfeed_values->resize(options.num_replicas); + } + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule([i, device = replica_devices[i], + outfeed_values = options.outfeed_values, + outfeed_shape = options.outfeed_shape, + infeed_steps = options.infeed_steps, + &infeed_outfeed_status_mu, &infeed_outfeed_status]() { + VLOG(1) << "Starting outfeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + Literal literal(outfeed_shape); + per_feed_status.Update(device->TransferFromOutfeed(&literal)); + if (outfeed_values != nullptr) { + outfeed_values->at(i) = std::move(literal); + } + if (step % 100 == 0) { + VLOG(1) << "Outfeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } } - auto mat = BufferMatToPointerMat(argument_buffer_slices); - - auto span = absl::Span>(mat); - - TF_ASSIGN_OR_RETURN(auto results, execution_helper(span)); - std::vector exec_results; - exec_results.reserve(options.num_replicas); + VLOG(1) << "Replicated execution started"; + TF_ASSIGN_OR_RETURN( + const std::vector> result_buffers, + execution_helper(BufferMatToPointerMat(argument_buffer_slices))); + VLOG(1) << "Replicated execution terminated"; + // Get the result from execution. + std::vector result_literals; + result_literals.reserve(options.num_replicas); for (int64_t i = 0; i < options.num_replicas; ++i) { TF_ASSIGN_OR_RETURN(Literal literal, - TransferLiteralFromDevice(*results[i])); - - exec_results.push_back(std::move(literal)); + TransferLiteralFromDevice(*result_buffers[i])); + result_literals.push_back(std::move(literal)); } - return std::move(exec_results); + // Join infeed and outfeed threads, if they exist. The thread pool's threads + // are joined on destruction. No-op otherwise. + pool = nullptr; + TF_RETURN_IF_ERROR(infeed_outfeed_status); + + return std::move(result_literals); } absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; } diff --git a/xla/service/hlo_runner_pjrt.h b/xla/service/hlo_runner_pjrt.h index dc4ec3921b4a6e..db0f258895866e 100644 --- a/xla/service/hlo_runner_pjrt.h +++ b/xla/service/hlo_runner_pjrt.h @@ -25,7 +25,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner_interface.h" #include "xla/xla_data.pb.h" @@ -118,28 +124,22 @@ class HloRunnerPjRt : public HloRunnerInterface { } private: - std::unique_ptr pjrt_client_; - DeviceShapeRepresentationFn device_shape_representation_fn_; - DeviceShapeSizeFn device_shape_size_fn_; - bool use_parameter_layout_on_device_ = false; - - std::vector BufferVecToPointerVec( - const std::vector>& buffer); - - std::vector> BufferMatToPointerMat( - std::vector>>& buffer); - absl::StatusOr GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes); absl::StatusOr> ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment); + + std::unique_ptr pjrt_client_; + DeviceShapeRepresentationFn device_shape_representation_fn_; + DeviceShapeSizeFn device_shape_size_fn_; + bool use_parameter_layout_on_device_ = false; }; } // namespace xla From 18b3c8df81ad560b63f71938750aa1414da2cf11 Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Tue, 7 Jan 2025 13:33:08 -0800 Subject: [PATCH 07/45] PR #20861: [XLA:GPU] add cudnn flash attention sequence packing support Imported from GitHub PR https://github.com/openxla/xla/pull/20861 cudnn flash attention has support for sequence packing, which means multiple batches(segments) could be packed into one batch. It could help save memories and speed up both training and inference workloads. This PR makes following changes: * added 2 extra tensors to cudnn custom call, **q_offsets** and **kv_offsets** which specify the starting position of each segment in one batch and one extra element for ending of last segment. For example, 3 segments of size 80 is packed into one batch with maximum sequence 256, the q_offsets will be [0, 80, 160, 256]. **q_offsets** and **kv_offsets** will be used to indicate the layout of Q, K, V, O, dO, dQ, dK, dV. * added one **max_segment_per_batch** option in backend config which specify the maximum number of segments each batch has, since XLA has static memory allocation and the number of segments can change at runtime, we use this option to compile one cudnn graph and allocate static size for **softmax_stat** tensors. * added one test case. This sequence packing feature essentially has the same effect as using a segment mask. Comparing this feature against passing segment mask as bias to cudnn. Copybara import of the project: -- ae2c14a7c2391f1b343c3721d739a1588360841f by cjkkkk : add cudnn sequence packing support Merging this change closes #20861 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20861 from Cjkkkk:segment_id ae2c14a7c2391f1b343c3721d739a1588360841f PiperOrigin-RevId: 713023783 --- xla/service/gpu/backend_configs.proto | 5 + xla/service/gpu/tests/gpu_fused_mha_test.cc | 137 +++++++++++++ .../transforms/cudnn_custom_call_compiler.cc | 16 +- xla/stream_executor/cuda/cuda_dnn.cc | 180 ++++++++++++++---- xla/stream_executor/cuda/cuda_dnn.h | 5 +- 5 files changed, 296 insertions(+), 47 deletions(-) diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index 84f008d3717b3b..906baaa33d512c 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -270,6 +270,11 @@ message CudnnfMHABackendConfig { // Sliding window length // ignored if the value <= 0 int32 sliding_window_length = 24; + + // The maximum number of segments in each batch + // Only used with packed layout + // ignored if the valued <= 1 + int32 max_seg_per_batch = 25; } // Backend config for a general custom call instruction, e.g. XLA FFI. diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 33214758e230fd..abdb9f471d1ce1 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1263,6 +1263,136 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; +class FlashAttentionBMMScaleSegmentMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_impl, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + ENTRY main.22 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + constant.5 = s32[] constant(256) + broadcast.6 = s32[4]{0} broadcast(constant.5), dimensions={} + constant.7 = s32[5]{0} constant({0, 32768, 65536, 98304, 131072}) + custom-call.8 = (bf16[2,2,512,64]{3,1,2,0}, f32[4,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, broadcast.6, broadcast.6, /*index=5*/constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.11 = u8[0]{0} get-tuple-element(custom-call.8), index=2 + get-tuple-element.10 = f32[4,2,512]{2,1,0} get-tuple-element(custom-call.8), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.9 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.8), index=0 + transpose.12 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.9), dimensions={0,2,1,3} + custom-call.13 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.10, Arg_3.4, /*index=5*/transpose.12, broadcast.6, broadcast.6, constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[4,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.17 = u8[0]{0} get-tuple-element(custom-call.13), index=3 + get-tuple-element.14 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=0 + transpose.18 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.14), dimensions={0,2,1,3} + get-tuple-element.15 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=1 + transpose.19 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.15), dimensions={0,2,1,3} + get-tuple-element.16 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=2 + transpose.20 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.16), dimensions={0,2,1,3} + ROOT tuple.21 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.12, transpose.18, transpose.19, transpose.20) + } // main.22 + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_ref, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + _where.9 { + Arg_0.10 = pred[512]{0} parameter(0) + Arg_1.11 = s32[512]{0} parameter(1) + Arg_2.12 = s32[512]{0} parameter(2) + ROOT select.13 = s32[512]{0} select(Arg_0.10, Arg_1.11, Arg_2.12) + } + + floor_divide.14 { + Arg_0.15 = s32[512]{0} parameter(0) + sign.23 = s32[512]{0} sign(Arg_0.15) + Arg_1.16 = s32[] parameter(1) + sign.24 = s32[] sign(Arg_1.16) + broadcast.25 = s32[512]{0} broadcast(sign.24), dimensions={} + compare.26 = pred[512]{0} compare(sign.23, broadcast.25), direction=NE + broadcast.27 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + remainder.28 = s32[512]{0} remainder(Arg_0.15, broadcast.27) + constant.19 = s32[] constant(0) + broadcast.20 = s32[512]{0} broadcast(constant.19), dimensions={} + compare.29 = pred[512]{0} compare(remainder.28, broadcast.20), direction=NE + and.30 = pred[512]{0} and(compare.26, compare.29) + broadcast.21 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + divide.22 = s32[512]{0} divide(Arg_0.15, broadcast.21) + constant.17 = s32[] constant(1) + broadcast.18 = s32[512]{0} broadcast(constant.17), dimensions={} + subtract.31 = s32[512]{0} subtract(divide.22, broadcast.18) + ROOT call.32 = s32[512]{0} call(and.30, subtract.31, divide.22), to_apply=_where.9 + } // floor_divide.14 + + ENTRY main.61 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + iota.8 = s32[512]{0} iota(), iota_dimension=0 + constant.7 = s32[] constant(256) + call.33 = s32[512]{0} call(iota.8, constant.7), to_apply=floor_divide.14 + broadcast.34 = s32[2,512]{1,0} broadcast(call.33), dimensions={1} + reshape.35 = s32[2,512,1]{2,1,0} reshape(broadcast.34) + broadcast.37 = s32[2,512,1]{2,1,0} broadcast(reshape.35), dimensions={0,1,2} + reshape.38 = s32[2,512]{1,0} reshape(broadcast.37) + broadcast.39 = s32[2,512,512]{2,1,0} broadcast(reshape.38), dimensions={0,1} + reshape.36 = s32[2,1,512]{2,1,0} reshape(broadcast.34) + broadcast.40 = s32[2,1,512]{2,1,0} broadcast(reshape.36), dimensions={0,1,2} + reshape.41 = s32[2,512]{1,0} reshape(broadcast.40) + broadcast.42 = s32[2,512,512]{2,1,0} broadcast(reshape.41), dimensions={0,2} + compare.43 = pred[2,512,512]{2,1,0} compare(broadcast.39, broadcast.42), direction=NE + convert.44 = bf16[2,512,512]{2,1,0} convert(compare.43) + reshape.45 = bf16[2,1,512,512]{3,2,1,0} reshape(convert.44) + constant.5 = bf16[] constant(-2.199e+12) + broadcast.6 = bf16[2,1,512,512]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.46 = bf16[2,1,512,512]{3,2,1,0} multiply(reshape.45, broadcast.6) + custom-call.47 = (bf16[2,2,512,64]{3,1,2,0}, f32[2,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, multiply.46), custom_call_target="__cudnn$fmhaScaleBiasSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.50 = u8[0]{0} get-tuple-element(custom-call.47), index=2 + get-tuple-element.49 = f32[2,2,512]{2,1,0} get-tuple-element(custom-call.47), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.48 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.47), index=0 + transpose.51 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.48), dimensions={0,2,1,3} + custom-call.52 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.49, Arg_3.4, /*index=5*/multiply.46, transpose.51), custom_call_target="__cudnn$fmhaScaleBiasSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[2,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.56 = u8[0]{0} get-tuple-element(custom-call.52), index=3 + get-tuple-element.53 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=0 + transpose.57 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.53), dimensions={0,2,1,3} + get-tuple-element.54 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=1 + transpose.58 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.54), dimensions={0,2,1,3} + get-tuple-element.55 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=2 + transpose.59 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.55), dimensions={0,2,1,3} + ROOT tuple.60 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.51, transpose.57, transpose.58, transpose.59) + } // main.61 + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 6, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.6.0."; + } + XlaBuilder builder(TestName()); + // Cudnn sequence packing packs multiple batches(segments) into one batch + // using offsets and seqlen tensors to indicate where each segment begins + std::string hlo_string = + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16(); // NOLINT + // Reference implementation is regular attention with segment mask + std::string hlo_string_ref = + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16(); // NOLINT + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{1e-3, 1e-3})); + } +}; + class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM @@ -1378,6 +1508,13 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +// BMM1 - Scale - SegmentMask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleSegmentMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2< + bfloat16>(); // NOLINT +} + absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { static constexpr absl::string_view hlo_text = R"( diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 0dc92c47d2cb55..67f33164fa2638 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -149,12 +149,14 @@ absl::StatusOr HloCustomCallToCuDnnGraph( GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); const int sliding_window_length = config.sliding_window_length(); + const int max_seg_per_batch = config.max_seg_per_batch(); TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionOperationGraph( dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation, static_cast(config.fmha_scale()), dropout_rate > 0.0, - dropout_rate, dnn_mask_type, sliding_window_length)); + dropout_rate, dnn_mask_type, sliding_window_length, + max_seg_per_batch)); return graph; } else if (IsFwdCustomCallTofMHAF8(*custom_call)) { TF_ASSIGN_OR_RETURN( @@ -230,12 +232,19 @@ absl::StatusOr HloCustomCallToCuDnnGraph( // Unused fwd_output_shape ++input_index; + const int max_seg_per_batch = config.max_seg_per_batch(); if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL || + max_seg_per_batch > 1) { // skip q_seqlen and kv_seqlen input_index += 2; } + + if (max_seg_per_batch > 1) { + // skip q_offsets and kv_offsets + input_index += 2; + } TF_RET_CHECK(input_index == custom_call->operand_count()); int output_index = 0; @@ -312,7 +321,8 @@ absl::StatusOr HloCustomCallToCuDnnGraph( bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, d_bmm1_rhs, d_bmm2_rhs, bias, dropout_rate, config.seed(), config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, - dnn_mask_type, force_deterministic, sliding_window_length)); + dnn_mask_type, force_deterministic, sliding_window_length, + max_seg_per_batch)); return graph; } else { TF_ASSIGN_OR_RETURN( diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 57448f9c01319c..cc1494e5096f65 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -4965,6 +4965,10 @@ static absl::StatusOr RebuildExecutionPlan( } // namespace +void FixDimsForRaggedOffset(std::vector& dims, int max_reg_per_batch) { + dims[0] *= max_reg_per_batch; +} + absl::StatusOr GetCudnnFlashAttentionOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_descriptor, @@ -4974,7 +4978,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length) { + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch) { using cudnn_frontend::graph::Tensor_attributes; #if CUDNN_VERSION >= 90000 @@ -5007,23 +5012,34 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::vector q_dims = q_descriptor.GetCudnnCompatibleDimensions(true); + std::vector k_dims = k_descriptor.GetCudnnCompatibleDimensions(true); + std::vector v_dims = + v_descriptor.GetCudnnCompatibleDimensions(false); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + } + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(q_dims) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr k_tensor = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(k_dims) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") - .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) + .set_dim(v_dims) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) .set_uid(next_uid())); @@ -5049,9 +5065,9 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_descriptor.GetCudnnCompatibleDimensions(true); - auto b = q_dim[0]; + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5070,6 +5086,30 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + auto offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_kv") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q_tensor->set_ragged_offset(offset_q); + k_tensor->set_ragged_offset(offset_kv); + v_tensor->set_ragged_offset(offset_kv); + } + // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5100,10 +5140,16 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto [o_tensor, stats_tensor] = graph.sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + auto o_dims = o_descriptor.dimensions(); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(o_dims, max_seg_per_batch); + o_tensor->set_ragged_offset(offset_q); + } // Set output attributes. o_tensor->set_name("O") .set_output(true) - .set_dim(o_descriptor.dimensions()) + .set_dim(o_dims) .set_stride(o_descriptor.GetLogicalStrides()) .set_uid(next_uid()); if (stats_descriptor.has_value()) { @@ -5488,7 +5534,8 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const std::optional bias_descriptor, std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type, - bool force_deterministic, const int sliding_window_length) { + bool force_deterministic, const int sliding_window_length, + const int max_seg_per_batch) { #if CUDNN_VERSION >= 90000 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() @@ -5514,19 +5561,38 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - + // Get dims and strides + std::vector q_dims = q_desc.GetCudnnCompatibleDimensions(false); + std::vector k_dims = k_desc.GetCudnnCompatibleDimensions(false); + std::vector v_dims = v_desc.GetCudnnCompatibleDimensions(true); + std::vector p_dims = p_desc.GetCudnnCompatibleDimensions(false); + std::vector p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector do_dims = do_desc.GetCudnnCompatibleDimensions(false); + std::vector dq_dims = dq_desc.dimensions(); + std::vector dk_dims = dk_desc.dimensions(); + std::vector dv_dims = dv_desc.dimensions(); + std::vector stats_dims(p_dims.begin(), p_dims.end() - 1); + stats_dims.push_back(1); // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); + std::vector stats_strides; + stats_strides.reserve(p_strides.size()); int64_t p_reduced_dim_len = p_dims.back(); for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); + stats_strides.push_back(stride / p_reduced_dim_len); + } + stats_strides[3] = 1; + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + FixDimsForRaggedOffset(p_dims, max_seg_per_batch); + FixDimsForRaggedOffset(do_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dq_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dk_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dv_dims, max_seg_per_batch); + FixDimsForRaggedOffset(stats_dims, max_seg_per_batch); } - p_reduction_strides[3] = 1; bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; auto sdpa_backward_options = @@ -5541,52 +5607,51 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(q_dims) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(k_dims) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") - .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_dim(v_dims) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr stats = graph.tensor(Tensor_attributes() .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) + .set_dim(stats_dims) + .set_stride(stats_strides) .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); - auto bias_dim = bias_descriptor->dimensions(); - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = bias_dim[0]; - auto n = bias_dim[1]; - auto q_n = q_dim[1]; - auto bias_tensor = - graph.tensor(Tensor_attributes() - .set_name("bias") - .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(next_uid())); + auto bias_dims = bias_descriptor->dimensions(); + auto bias_strides = bias_descriptor->GetLogicalStrides(); + auto b = bias_dims[0]; + auto n = bias_dims[1]; + auto q_n = q_dims[1]; + auto bias_tensor = graph.tensor(Tensor_attributes() + .set_name("bias") + .set_dim(bias_dims) + .set_stride(bias_strides) + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for @@ -5604,7 +5669,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr o = graph.tensor(Tensor_attributes() .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); @@ -5612,9 +5677,10 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = q_dim[0]; + + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5633,6 +5699,31 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q, offset_kv; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q->set_ragged_offset(offset_q); + k->set_ragged_offset(offset_kv); + v->set_ragged_offset(offset_kv); + o->set_ragged_offset(offset_q); + dO->set_ragged_offset(offset_q); + } // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5668,20 +5759,25 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( auto [dQ, dK, dV] = graph.sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); + if (max_seg_per_batch > 1) { + dQ->set_ragged_offset(offset_q); + dK->set_ragged_offset(offset_kv); + dV->set_ragged_offset(offset_kv); + } dQ->set_output(true) - .set_dim(dq_desc.dimensions()) + .set_dim(dq_dims) .set_stride(dq_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dQ") .set_data_type(ioDataType); dK->set_output(true) - .set_dim(dk_desc.dimensions()) + .set_dim(dk_dims) .set_stride(dk_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dK") .set_data_type(ioDataType); dV->set_output(true) - .set_dim(dv_desc.dimensions()) + .set_dim(dv_dims) .set_stride(dv_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dV") diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 78a43f654b7641..9d46794e2329b8 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -707,7 +707,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length); + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( dnn::DnnSupport& dnn_support, @@ -730,7 +731,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, const dnn::FMHAMaskKind mask_type, bool force_deterministic, - const int sliding_window_length); + const int sliding_window_length, const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, From 684ce4ce45d56cf31a7f06bafaccec7b1caa3dc8 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Tue, 7 Jan 2025 14:08:30 -0800 Subject: [PATCH 08/45] PR #20604: hlo_instruction_utils had no tests. Adding them. Imported from GitHub PR https://github.com/openxla/xla/pull/20604 See title. Copybara import of the project: -- 7bc8052999822b879173448ddc79c949cca10339 by Shraiysh Vaishay : hlo_instruction_utils had no tests. Adding them. -- 318444c8b9cc20301b5584c3b9a926d012a8878e by Shraiysh Vaishay : Addressed comments Merging this change closes #20604 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20604 from shraiysh:add_tests_for_instruction_utils 318444c8b9cc20301b5584c3b9a926d012a8878e PiperOrigin-RevId: 713036733 --- xla/hlo/ir/BUILD | 13 ++++ xla/hlo/ir/hlo_instruction_utils_test.cc | 89 ++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 xla/hlo/ir/hlo_instruction_utils_test.cc diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index d1b6e0a409ee4b..560f46517e07ed 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -232,6 +232,19 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_instruction_utils_test", + srcs = ["hlo_instruction_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_instruction_utils", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_query", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "hlo_reachability", hdrs = ["hlo_reachability.h"], diff --git a/xla/hlo/ir/hlo_instruction_utils_test.cc b/xla/hlo/ir/hlo_instruction_utils_test.cc new file mode 100644 index 00000000000000..fe8c488b154e88 --- /dev/null +++ b/xla/hlo/ir/hlo_instruction_utils_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_instruction_utils.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_query.h" + +namespace xla { + +namespace hlo_instruction_utils { + +namespace { + +class HloInstructionUtilsTest : public HloHardwareIndependentTestBase {}; + +TEST_F(HloInstructionUtilsTest, TestIsUnstridedSlice) { + const char* hlo_text = R"( + HloModule test + ENTRY main { + param = f32[2,8] parameter(0) + strided_slice = f32[2,2] slice(param), slice={[0:2:1], [4:8:2]} + unstrided_slice = f32[2,4] slice(param), slice={[0:2:1], [4:8:1]} + ROOT tuple = (f32[2,2], f32[2,4]) tuple(strided_slice, unstrided_slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + HloInstruction* unstrided_slice = + hlo_query::FindInstruction(m->entry_computation(), "unstrided_slice"); + HloInstruction* strided_slice = + hlo_query::FindInstruction(m->entry_computation(), "strided_slice"); + EXPECT_NE(unstrided_slice, nullptr); + EXPECT_NE(strided_slice, nullptr); + EXPECT_TRUE(IsUnstridedSlice(unstrided_slice)); + EXPECT_FALSE(IsUnstridedSlice(strided_slice)); +} + +TEST_F(HloInstructionUtilsTest, TestAddOrUpdateVectorOfPairsAsAttribute) { + const char* hlo = R"( + HloModule test + ENTRY main { + ROOT param = s32[] parameter(0), frontend_attributes={foo="bar", baz="qux"} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + HloInstruction* param = m->entry_computation()->root_instruction(); + EXPECT_EQ(param->frontend_attributes().map().size(), 2); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + + std::string new_key = "quux"; + std::vector> value = {{1, 2}, {3, 4}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{1,2},{3,4}}"); + + std::vector> new_value = {{5, 6}, {7, 8}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, new_value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{5,6},{7,8}}"); +} + +} // namespace + +} // namespace hlo_instruction_utils + +} // namespace xla From 3cfe81f6cbbbc08e452d23c83691d69f0676f74f Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 7 Jan 2025 15:11:41 -0800 Subject: [PATCH 09/45] Forward `use_spmd_partitioning` in HloRunnerPjRt. This patch also removes an unused and redundant invocation of `GenerateDefaultCompileOptions`. PiperOrigin-RevId: 713056450 --- xla/service/hlo_runner_pjrt.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index 9a2d0c72955516..dce3bc9e1ca5be 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -237,6 +237,9 @@ absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( compile_options.executable_build_options.set_result_layout( module->entry_computation_layout().result_shape()); + compile_options.executable_build_options.set_use_spmd_partitioning( + module->config().use_spmd_partitioning()); + return compile_options; } @@ -328,9 +331,6 @@ absl::StatusOr HloRunnerPjRt::Execute( ExecutionProfile* profile) { // TODO (b/245550554) : Remove UpdateEntryComputationLayout from runner. UpdateEntryComputationLayout(module.get()); - TF_ASSIGN_OR_RETURN(auto compile_options, GenerateDefaultCompileOptions( - module.get(), run_hlo_passes)); - TF_ASSIGN_OR_RETURN(auto executable, CreateExecutable(std::move(module), run_hlo_passes)); From 913c11f28f35c22fe6d8346f3fcab6493db4819c Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 7 Jan 2025 15:22:32 -0800 Subject: [PATCH 10/45] Fix resource number calculation in the latency hiding scheduler. PiperOrigin-RevId: 713059583 --- xla/service/latency_hiding_scheduler.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 6532e9c9934079..d199e1f046daa0 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -384,19 +384,17 @@ AsyncTracker::RecursivelyComputeResourceMap( int64_t AsyncTracker::GetNumResourcesPerInstruction( int64_t resource_type, const HloInstruction& instr) const { - // For instructions not calling a computation then return 1 if the instruction - // has opcode equal to 'async_done' + // For instructions not calling a computation, or async start/done + // instructions, we directly check the resources from the instruction. if (instr.called_computations().empty() || instr.opcode() == HloOpcode::kAsyncStart || instr.opcode() == HloOpcode::kAsyncDone) { - return absl::c_any_of(GetResourcesFromInstruction(instr), - [resource_type](const ResourcePair& resource) { - return resource.second == - ResourceUsageType::kResourceOccupy && - (resource_type == resource.first); - }) - ? 1 - : 0; + return absl::c_count_if(GetResourcesFromInstruction(instr), + [resource_type](const ResourcePair& resource) { + return resource.second == + ResourceUsageType::kResourceOccupy && + (resource_type == resource.first); + }); } int64_t num_resources = 0; for (const HloComputation* computation : instr.called_computations()) { From e6723327914e67d2a291e93f46a189920c885cb1 Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 7 Jan 2025 15:58:57 -0800 Subject: [PATCH 11/45] Improve speed and collision/aliasing resistance of Absl::HashOf() on HloModule/HloComputation: * Rather that hashing only opcodes + output/operand shapes (in hlo_instruction.h), build the hash progressively (in hlo_computation.h) walking the instructions in post-order, hashing opcode, shape and other constants (e.g. parameter value, literal value) once per instruction * Add wrapper to support Absh::Hash on Literals * Add tests covering parameter/literal values, instruction reordering etc. PiperOrigin-RevId: 713070440 --- xla/hlo/ir/hlo_computation.h | 18 ++- xla/hlo/ir/hlo_instruction.h | 22 ++-- xla/hlo/ir/hlo_instructions.h | 22 ++++ xla/hlo/ir/hlo_module_test.cc | 229 +++++++++++++++++++++++++++++++++- xla/literal.h | 17 ++- 5 files changed, 286 insertions(+), 22 deletions(-) diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index 757505980a079e..4411e3102b5a26 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_COMPUTATION_H_ #define XLA_HLO_IR_HLO_COMPUTATION_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -420,11 +422,23 @@ class HloComputation { // with respect to HloComputation::Equal() method. template friend H AbslHashValue(H h, const HloComputation& computation) { + // Walk the computation in post-order, computing (and caching) the + // Absl::Hash after each instruction to use to as an operand for + // subsequent instructions. auto instructions = computation.MakeInstructionPostOrder(); + absl::flat_hash_map instruction_hash_cache; + instruction_hash_cache.reserve(instructions.size()); for (auto* instruction : instructions) { - h = H::combine(std::move(h), *instruction); + absl::InlinedVector operand_hashes; + for (auto* operand : instruction->operands()) { + operand_hashes.push_back(instruction_hash_cache[operand]); + } + instruction_hash_cache.emplace( + instruction, absl::HashOf(*instruction, operand_hashes)); } - return H::combine(std::move(h), instructions.size()); + return H::combine(std::move(h), + instruction_hash_cache[computation.root_instruction()], + instructions.size()); } using InstructionSequence = tsl::gtl::iterator_range< diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index cd8d5368cc8320..db3d994215963b 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -41,6 +41,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -1736,27 +1737,20 @@ class HloInstruction { /*ignore_commutative_operand_order=*/true); } + // Allow subclasses to contribute additional attributes to the hash. + virtual void HashAdditionalAttributes(absl::HashState h) const {}; + // Generates a hash value of an HLO instruction. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO instructions, - // with respect to HloInstruction::Identical() method. - // TODO(majnemer): Make the comment here more crisp & accurate. + // information on opcode, shape, number of operands, and other relevant + // additional attributes (e.g. literal values, parameters, etc.). template friend H AbslHashValue(H h, const HloInstruction& hlo) { h = H::combine(std::move(h), hlo.opcode(), hlo.shape()); - if (!hlo.IsCrossModuleAllReduce()) { - for (size_t i = 0; i < hlo.operands().size(); ++i) { - h = H::combine(std::move(h), hlo.operand(i)->shape()); - } h = H::combine(std::move(h), hlo.operand_count()); } - - if (hlo.opcode() == HloOpcode::kFusion) { - h = H::combine(std::move(h), *hlo.fused_expression_root(), - hlo.fusion_kind(), hlo.fused_instruction_count(), - hlo.fused_parameters().size()); - } + // Allow subclasses to mix additional data into h before returning + hlo.HashAdditionalAttributes(absl::HashState::Create(&h)); return h; } diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index 1ca2bfddd55592..c21dddeee907b5 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_INSTRUCTIONS_H_ #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -1356,6 +1358,14 @@ class HloConstantInstruction : public HloInstruction { return false; } + // Add literal to the hash state. + void HashAdditionalAttributes(absl::HashState h) const override { + if (HasLiteral()) { + absl::HashState::combine(std::move(h), + Literal::AbslHashable(literal())); + } + } + private: bool IsElementwiseImpl( const std::optional& operand_idx) const override; @@ -1595,6 +1605,13 @@ class HloFusionInstruction : public HloCallableInstruction { return hlo->opcode() == HloOpcode::kFusion; } + // Add various fusion parameters to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), *fused_expression_root(), + fusion_kind(), fused_instruction_count(), + fused_parameters().size()); + } + protected: std::string default_called_computation_name() const override { return "fused_computation"; @@ -1714,6 +1731,11 @@ class HloParameterInstruction : public HloInstruction { return hlo->opcode() == HloOpcode::kParameter; } + // Add parameter number to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), parameter_number()); + } + private: void PrintExtraAttributesImpl(AttributePrinter& printer, const HloPrintOptions& options) const override; diff --git a/xla/hlo/ir/hlo_module_test.cc b/xla/hlo/ir/hlo_module_test.cc index 226bf5c892a210..01756318c93ec6 100644 --- a/xla/hlo/ir/hlo_module_test.cc +++ b/xla/hlo/ir/hlo_module_test.cc @@ -32,9 +32,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -204,5 +201,231 @@ TEST(HloModuleTest, CloneWithNewConfig) { m1.config().device_memory_size()); } +TEST(HloModuleTest, AbslHashInstructionOrdering) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add.0 and add.1 are swapped. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.1 = f32[32,32] add(b, c) // Swapped with below + add.0 = f32[32,32] add(a, b) // Swapped with above + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionOpcodes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add changed to sub + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] subtract(b, c) // Changed from add to subtract + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionShapes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add has different shape. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + // Shapes changed from [32,32] to [16,16] + a = f32[16,16] parameter(0) + b = f32[16,16] parameter(1) + c = f32[16,16] parameter(2) + add.0 = f32[16,16] add(a, b) + add.1 = f32[16,16] add(b, c) + ROOT result = f32[16,16] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionNaming) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add x to all names + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + // All names changed to x + ax = f32[32,32] parameter(0) + bx = f32[32,32] parameter(1) + cx = f32[32,32] parameter(2) + add.0x = f32[32,32] add(ax, bx) + add.1x = f32[32,32] add(bx, cx) + ROOT resultx = f32[32,32] add(add.0x, add.1x) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashGraphChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Changed from (a+b)+(b+c) to ((a+b)+c)+a + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(add.0, c) // Changed from add(b, c) + ROOT result = f32[32,32] add(add.1, a) // Changed from add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashParameterChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Change parameter numbers + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(1) // Changed from parameter(0) + b = f32[32,32] parameter(0) // Changed from parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashConstantValues) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + // Changed from 42 to 43 + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(43) // Changed from constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + } // namespace } // namespace xla diff --git a/xla/literal.h b/xla/literal.h index 0c028bd1aa60ea..1b76f2effe6a94 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -367,9 +367,9 @@ class LiteralBase { static_assert(sizeof(H) == 0, "Do not use Literal directly as a hash key, because it has " "multiple definitions of equality - layout sensitive or " - "insensitive. Instead, provide an external hash function " - "that uses Literal::Hash which allows you to specify layout " - "sensitivity."); + "insensitive. Instead, use AbslHashable<...>() to create a " + "wrapper with layout sensitivity specified suitable for " + "passing to Absl::Hash"); } // Always use this together with the Equal method and not operator== in order @@ -419,6 +419,17 @@ class LiteralBase { return std::move(state); } + // Templated wrapper struct to control layout sensitivity during Absl::Hash. + template + struct AbslHashable { + const LiteralBase& literal; + explicit AbslHashable(const LiteralBase& l) : literal(l) {} + template + friend H AbslHashValue(H h, const AbslHashable& w) { + return LiteralBase::Hash(std::move(h), w.literal); + } + }; + // Converts this literal to the given shape. Returns an error is the // conversion is not possible. absl::StatusOr ConvertToShape(const Shape& dest_shape) const; From d1f63e2f60ee4ccb73a5e06484f4783eae79420a Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Tue, 7 Jan 2025 16:27:41 -0800 Subject: [PATCH 12/45] Fix `HloRunnerAgnosticTestBase` includes. Many of the tests that extend `HloTestBase` rely on symbols included transitively. The main ones are: - `PlatformUtil` - `LiteralUtil` - `LiteralTestUtil` This patch adds includes for these explicitly. PiperOrigin-RevId: 713079045 --- xla/service/BUILD | 41 ++++++++++++++----- xla/service/cpu/BUILD | 2 + xla/service/cpu/conv_canonicalization_test.cc | 1 + .../cpu/cpu_instruction_fusion_test.cc | 1 + xla/service/gpu/BUILD | 2 + xla/service/gpu/gpu_aot_compilation_test.cc | 2 + xla/service/gpu/tests/BUILD | 4 ++ xla/service/gpu/tests/nop_custom_call_test.cc | 5 +++ xla/service/hlo_creation_utils_test.cc | 12 +++++- xla/service/hlo_module_test.cc | 18 ++++++-- xla/service/hlo_schedule_test.cc | 7 ++-- xla/service/triangular_solve_expander_test.cc | 11 +++-- xla/tests/BUILD | 33 +++++++-------- xla/tests/dot_operation_test.cc | 6 +-- xla/tests/hlo_runner_agnostic_test_base.cc | 15 +++---- xla/tests/hlo_runner_agnostic_test_base.h | 37 +++++------------ xla/tests/replicated_io_feed_test.cc | 1 + 17 files changed, 123 insertions(+), 75 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index b041d7e59211da..bf03c3f7929923 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1837,21 +1837,21 @@ xla_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ + ":buffer_value", + "//xla:literal_util", "//xla:shape_util", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", ], ) @@ -2024,14 +2024,22 @@ xla_cc_test( ":hlo_creation_utils", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array2d", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ], ) @@ -2230,13 +2238,16 @@ xla_cc_test( shard_count = 12, deps = [ ":triangular_solve_expander", + "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", - "//xla:test", - "//xla:types", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -3493,25 +3504,35 @@ xla_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], deps = [ + ":buffer_value", ":computation_placer_hdr", + ":hlo_module_config", ":test_compilation_environment_proto_cc", - "//xla:literal", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:protobuf", ], ) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 5dbad693e78471..430def53c196bb 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1386,6 +1386,7 @@ xla_cc_test( tags = ["not_run:arm"], deps = [ ":cpu_instruction_fusion", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1539,6 +1540,7 @@ xla_cc_test( deps = [ ":conv_canonicalization", ":target_machine_features_stub", + "//xla:literal_util", "//xla:test", "//xla:test_helpers", "//xla:util", diff --git a/xla/service/cpu/conv_canonicalization_test.cc b/xla/service/cpu/conv_canonicalization_test.cc index 00c9ee256452c9..6f6ebd96fb64c2 100644 --- a/xla/service/cpu/conv_canonicalization_test.cc +++ b/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_util.h" #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/test.h" #include "xla/test_helpers.h" diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index 6b4de145d8e809..787c4d138b3448 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 3bf55230507d43..9d8d7f90622176 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1960,6 +1960,7 @@ xla_cc_test( ":amdgpu_compiler_impl", ]) + [ ":gpu_transfer_manager", + "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/service:compiler", @@ -1971,6 +1972,7 @@ xla_cc_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/xla/service/gpu/gpu_aot_compilation_test.cc b/xla/service/gpu/gpu_aot_compilation_test.cc index 945f63a1f87c0d..76efde170bca39 100644 --- a/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/xla/service/gpu/gpu_aot_compilation_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/literal_util.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/fusions/triton/triton_support.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 1c524b8c35c70f..3140bdd3bbe03d 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -889,7 +889,11 @@ xla_test( srcs = ["nop_custom_call_test.cc"], backends = ["gpu"], deps = [ + "//xla:literal", + "//xla:literal_util", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", + "//xla/tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/gpu/tests/nop_custom_call_test.cc b/xla/service/gpu/tests/nop_custom_call_test.cc index d979d18aa8ac9d..06df6792eb3e9a 100644 --- a/xla/service/gpu/tests/nop_custom_call_test.cc +++ b/xla/service/gpu/tests/nop_custom_call_test.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { diff --git a/xla/service/hlo_creation_utils_test.cc b/xla/service/hlo_creation_utils_test.cc index 252345fbbbc5ff..debabe09c3c51e 100644 --- a/xla/service/hlo_creation_utils_test.cc +++ b/xla/service/hlo_creation_utils_test.cc @@ -15,19 +15,29 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" +#include #include +#include +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/array2d.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/service/hlo_module_test.cc b/xla/service/hlo_module_test.cc index 339feeb8fd2d4e..960f107c9117b9 100644 --- a/xla/service/hlo_module_test.cc +++ b/xla/service/hlo_module_test.cc @@ -24,25 +24,37 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/test_compilation_environment.pb.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/service/hlo_schedule_test.cc b/xla/service/hlo_schedule_test.cc index d18c8527893c81..fd89bcc5b23fc5 100644 --- a/xla/service/hlo_schedule_test.cc +++ b/xla/service/hlo_schedule_test.cc @@ -22,19 +22,20 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" -#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/triangular_solve_expander_test.cc b/xla/service/triangular_solve_expander_test.cc index fa382b24d0d9db..1a2ba8c71ece6e 100644 --- a/xla/service/triangular_solve_expander_test.cc +++ b/xla/service/triangular_solve_expander_test.cc @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/triangular_solve_expander.h" +#include #include +#include #include +#include "xla/array2d.h" +#include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/tests/BUILD b/xla/tests/BUILD index e8ab69dffb4dce..dbe2e5f2eebe11 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -214,35 +214,27 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", - "//xla:debug_options_flags", "//xla:error_spec", "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", "//xla:shape_util", "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend - "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -252,10 +244,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", + "@tsl//tsl/platform:protobuf", ], ) @@ -979,6 +968,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1024,7 +1015,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -1070,7 +1064,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -1158,7 +1155,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -2560,6 +2560,7 @@ xla_test( backends = ["gpu"], deps = [ ":hlo_test_base", + ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", diff --git a/xla/tests/dot_operation_test.cc b/xla/tests/dot_operation_test.cc index 674ada04d96c30..2acc860804d0d6 100644 --- a/xla/tests/dot_operation_test.cc +++ b/xla/tests/dot_operation_test.cc @@ -22,21 +22,21 @@ limitations under the License. #include "xla/array3d.h" #include "xla/client/local_client.h" #include "xla/error_spec.h" -#include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index 402159a1858530..b781a0eebd37d0 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -30,19 +30,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" @@ -53,11 +47,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index e43ddec3e28926..9b8ae26f615f45 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,31 +34,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/layout.h" #include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/backend.h" -#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_verifier.h" -#include "xla/service/platform_util.h" -#include "xla/shape_layout.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/test_helpers.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { @@ -189,7 +174,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, @@ -197,14 +182,14 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and compares the results. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, @@ -212,26 +197,26 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and checks that the execution is // successful. - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( std::unique_ptr module, bool run_hlo_passes, const std::function& test_preprocessor = nullptr); // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, std::optional args_max_bits_of_precision = std::nullopt); - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, const tsl::protobuf::Message* backend_config = nullptr, @@ -299,19 +284,19 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. - [[nodiscard]] ::testing::AssertionResult RunReplicated( + ::testing::AssertionResult RunReplicated( absl::string_view hlo_string, bool run_hlo_passes = true, int64_t num_replicas = 1, const tsl::protobuf::Message* backend_config = nullptr); // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. - [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( + ::testing::AssertionResult RunMultipleTimes( absl::string_view hlo_string, bool run_hlo_passes, std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); diff --git a/xla/tests/replicated_io_feed_test.cc b/xla/tests/replicated_io_feed_test.cc index 415faa01ff89e7..194697936e13af 100644 --- a/xla/tests/replicated_io_feed_test.cc +++ b/xla/tests/replicated_io_feed_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" From 96f079065d65bcf423107e0530d9a6c0b11cb683 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 7 Jan 2025 16:31:26 -0800 Subject: [PATCH 13/45] [XLA:Python] Fix three concurrency problems. These problems can be reproduced even with the GIL enabled, they are not no-GIL bugs. In pmap_lib.cc, defend against a use after free in the following scenario: * thread A misses in the compilation cache and calls `cache_miss()` to populate the cache, relying on the new entry in executables_ remaining alive. * thread B calls `cache_clear()`, which erases the contents of `executables_` Use a std::shared_ptr to keep the entry alive. In pjit.cc, refactor PjitFunctionStore to use a doubly-linked list of PjitFunctionObject entries. When consuming the list of functions in the store, take strong references to them. This prevents a use-after-free if the cache is cleared concurrently multiple times. In pjit.cc, do not add functions to the PjitFunctionStore until executables_ is populated. This avoids a null pointer dereference from a concurrent call to `cache_clear`. Problems found with some upcoming test infrastructure that runs JAX test cases in parallel. PiperOrigin-RevId: 713080199 --- xla/python/BUILD | 1 - xla/python/pjit.cc | 110 ++++++++++++++++++++++++++--------------- xla/python/pmap_lib.cc | 21 ++++---- 3 files changed, 80 insertions(+), 52 deletions(-) diff --git a/xla/python/BUILD b/xla/python/BUILD index 4b024b2c8c3df0..dd7b53028fa6e5 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -735,7 +735,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index d492311a81ba45..88c3d7c9bd5fb0 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -34,7 +34,6 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -325,6 +324,11 @@ class PjitFunction { executables_->Clear(); } + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + nb::object PythonSignature() { if (!fun_.has_value()) { throw nb::value_error( @@ -362,41 +366,6 @@ class PjitFunction { std::shared_ptr executables_; }; -// Thread-safe. -class PjitFunctionStore { - public: - void Insert(PjitFunction* function) { - nb::ft_lock_guard lock(mu_); - compiled_functions_.insert(function); - } - - void Erase(PjitFunction* function) { - nb::ft_lock_guard lock(mu_); - compiled_functions_.erase(function); - } - - void ClearFunctionCache() { - absl::flat_hash_set functions; - { - nb::ft_lock_guard lock(mu_); - std::swap(functions, compiled_functions_); - } - for (auto* function : functions) { - function->ClearCache(); - } - } - - private: - // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. - nb::ft_mutex mu_; - absl::flat_hash_set compiled_functions_; -}; - -PjitFunctionStore& GetGlobalPjitFunctionStore() { - static auto* const store = new PjitFunctionStore(); - return *store; -} - PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, @@ -418,8 +387,6 @@ PjitFunction::PjitFunction( PyUnicode_InternInPlace(&s); static_argnames_.push_back(nb::steal(s)); } - - GetGlobalPjitFunctionStore().Insert(this); } void PjitFunction::InitExecutables() { @@ -432,7 +399,7 @@ void PjitFunction::InitExecutables() { } } -PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } +PjitFunction::~PjitFunction() = default; void CallShardArgFallback( nb::handle arg, nb::handle sharding, nb::handle layout, @@ -969,8 +936,64 @@ struct PjitFunctionObject { #endif // PY_VERSION_HEX < 0x030C0000 vectorcallfunc vectorcall; PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; }; +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + PyObject* PjitFunction_Type = nullptr; bool PjitFunction::IsPjitFunction(nb::handle handle) { @@ -1036,6 +1059,7 @@ void PjitFunction_tp_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); PyTypeObject* tp = Py_TYPE(self); PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); PyObject_ClearWeakRefs(self); #if PY_VERSION_HEX < 0x030C0000 Py_CLEAR(o->dict); @@ -1125,6 +1149,7 @@ void InitializePjitFunction( xla::nb_class_ptr pytree_registry, nb::callable shard_arg_fallback, xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; if (nb::isinstance(global_cache_key)) { global_cache_key = nb::tuple(global_cache_key); } @@ -1136,6 +1161,10 @@ void InitializePjitFunction( // Handled separately because it is not exception safe to call this // in the constructor because it leaves the object improperly constructed. fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); } nb::object MakePjitFunction( @@ -1201,8 +1230,7 @@ void BuildPjitSubmodule(nb::module_& m) { cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); - cache.def_static("clear_all", - []() { GetGlobalPjitFunctionStore().ClearFunctionCache(); }); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); cache.def( "__getstate__", // Pickles as an empty cache; the client can repopulate as needed. diff --git a/xla/python/pmap_lib.cc b/xla/python/pmap_lib.cc index 3999b7b7473a63..609cee2deb46ff 100644 --- a/xla/python/pmap_lib.cc +++ b/xla/python/pmap_lib.cc @@ -432,8 +432,10 @@ class PmapFunction { // passed to the underlying PyLoadedExecutable. In sorted order. std::vector static_argnums_; xla::nb_class_ptr pytree_registry_; - // We need a `unique_ptr` here to ensure value pointer stability. - absl::flat_hash_map> + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> executables_; // The fallback function to use with `ShardArgs`. @@ -581,15 +583,14 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, } // Retrieve/Maybe add the executable to the cache. - absl::flat_hash_map>::iterator - it; - bool inserted; - std::tie(it, inserted) = executables_.try_emplace( - call_signature, std::unique_ptr()); - if (inserted) { - it->second = std::make_unique(pytree_registry_.get()); + bool inserted = false; + std::shared_ptr& cache_entry_ptr = + executables_[call_signature]; + if (cache_entry_ptr == nullptr) { + inserted = true; + cache_entry_ptr = std::make_shared(pytree_registry_.get()); } - PmapCacheEntry& cache_entry = *(it->second); + PmapCacheEntry& cache_entry = *cache_entry_ptr; if (!cache_entry.compilation_complete.HasBeenNotified()) { // In case of several threads attempting to compile the executable, only From 60b1fc79dd0df5f6bb08b578a2616207342884bd Mon Sep 17 00:00:00 2001 From: xla authors Date: Tue, 7 Jan 2025 17:04:29 -0800 Subject: [PATCH 14/45] Add a using to make referencing environment option overrides as a parameter later easier. PiperOrigin-RevId: 713088932 --- xla/pjrt/pjrt_executable.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/pjrt_executable.h b/xla/pjrt/pjrt_executable.h index fc4f76ef4776a8..1244039ede0cd1 100644 --- a/xla/pjrt/pjrt_executable.h +++ b/xla/pjrt/pjrt_executable.h @@ -101,7 +101,9 @@ struct CompileOptions { // Key-value string pairs, parsed in order to set miscellaneous options, // overriding if appropriate. using OptionOverride = std::variant; - std::vector> env_option_overrides; + using EnvironmentOptionOverrides = + std::vector>; + EnvironmentOptionOverrides env_option_overrides; std::optional target_config; From 5eaa9c892a9febd1031ffd714ef794c5ae39be0c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 7 Jan 2025 17:14:15 -0800 Subject: [PATCH 15/45] [xla:collectives] Replace xla::cpu::CollectivesCommunicator with xla::Communicator PiperOrigin-RevId: 713091963 --- xla/backends/cpu/runtime/all_gather_thunk.cc | 2 +- xla/backends/cpu/runtime/all_reduce_thunk.cc | 2 +- xla/backends/cpu/runtime/all_to_all_thunk.cc | 2 +- .../cpu/runtime/collective_permute_thunk.cc | 2 +- xla/backends/cpu/runtime/collective_thunk.cc | 2 +- xla/backends/cpu/runtime/collective_thunk.h | 4 +- .../cpu/runtime/reduce_scatter_thunk.cc | 2 +- xla/core/collectives/BUILD | 1 + xla/core/collectives/communicator.h | 26 +++++---- xla/pjrt/cpu/BUILD | 2 + xla/pjrt/cpu/gloo_collectives.cc | 11 ++-- xla/pjrt/cpu/gloo_collectives.h | 37 +++++++++++-- xla/pjrt/cpu/gloo_collectives_test.cc | 2 +- xla/pjrt/cpu/mpi_collectives.cc | 5 +- xla/pjrt/cpu/mpi_collectives.h | 29 +++++++++- xla/service/cpu/BUILD | 1 + xla/service/cpu/collectives_interface.h | 53 +------------------ xla/service/cpu/in_process_collectives.cc | 6 +-- xla/service/cpu/in_process_collectives.h | 34 ++++++++++-- 19 files changed, 133 insertions(+), 90 deletions(-) diff --git a/xla/backends/cpu/runtime/all_gather_thunk.cc b/xla/backends/cpu/runtime/all_gather_thunk.cc index c56fdf94903b44..9a3c2fff062deb 100644 --- a/xla/backends/cpu/runtime/all_gather_thunk.cc +++ b/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -77,7 +77,7 @@ tsl::AsyncValueRef AllGatherThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/backends/cpu/runtime/all_reduce_thunk.cc b/xla/backends/cpu/runtime/all_reduce_thunk.cc index d9be82226ec347..9dca34f90ceaec 100644 --- a/xla/backends/cpu/runtime/all_reduce_thunk.cc +++ b/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -102,7 +102,7 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); diff --git a/xla/backends/cpu/runtime/all_to_all_thunk.cc b/xla/backends/cpu/runtime/all_to_all_thunk.cc index ee18d893c07bdc..37235935754bce 100644 --- a/xla/backends/cpu/runtime/all_to_all_thunk.cc +++ b/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -76,7 +76,7 @@ tsl::AsyncValueRef AllToAllThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); const Shape& shape = destination_shape(0); diff --git a/xla/backends/cpu/runtime/collective_permute_thunk.cc b/xla/backends/cpu/runtime/collective_permute_thunk.cc index 5ee3a8ea2cb456..6387eb31f35be3 100644 --- a/xla/backends/cpu/runtime/collective_permute_thunk.cc +++ b/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -131,7 +131,7 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/backends/cpu/runtime/collective_thunk.cc b/xla/backends/cpu/runtime/collective_thunk.cc index 4bebdd09cd31c1..f838fb0e49acd1 100644 --- a/xla/backends/cpu/runtime/collective_thunk.cc +++ b/xla/backends/cpu/runtime/collective_thunk.cc @@ -183,7 +183,7 @@ CollectiveThunk::ExecuteWithCommunicator( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, + TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, collectives->GetCommunicator(key.global_devices, rank)); TF_RETURN_IF_ERROR(callback(key, *communicator)); diff --git a/xla/backends/cpu/runtime/collective_thunk.h b/xla/backends/cpu/runtime/collective_thunk.h index 8efc767838806d..60c98ce37547c4 100644 --- a/xla/backends/cpu/runtime/collective_thunk.h +++ b/xla/backends/cpu/runtime/collective_thunk.h @@ -86,8 +86,8 @@ class CollectiveThunk : public Thunk { protected: // Callback for collective thunk implementations. - using Callback = absl::AnyInvocable; + using Callback = absl::AnyInvocable; static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype); diff --git a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index badeb6a860c3ee..20311adf01b7c7 100644 --- a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc +++ b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -90,7 +90,7 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/core/collectives/BUILD b/xla/core/collectives/BUILD index 0ab61569ecc1ea..2e9ace8f7aa25f 100644 --- a/xla/core/collectives/BUILD +++ b/xla/core/collectives/BUILD @@ -68,6 +68,7 @@ cc_library( hdrs = ["communicator.h"], deps = [ ":rank_id", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", diff --git a/xla/core/collectives/communicator.h b/xla/core/collectives/communicator.h index b6139dec3684b9..af95f7063fc803 100644 --- a/xla/core/collectives/communicator.h +++ b/xla/core/collectives/communicator.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -53,23 +54,24 @@ class Communicator { virtual absl::Status Unregister() = 0; }; + // Register `buffer` for efficient collective operations (i.e. on NCCL backend + // it registers the buffer for zero-copy collective operations). + virtual absl::StatusOr> + RegisterBuffer(stream_executor::DeviceMemoryBase buffer) { + return Unimplemented("User-managed buffer registration is not supported"); + } + // Abort any uncompleted operations and destroys the underlying communicator // object. It is undefined behavior to use the communicator after calling // this method. - virtual absl::Status Abort() = 0; + virtual absl::Status Abort() { + return Unimplemented("Aborting communicator is not implemented"); + } // Checks the health of the communicator. It might return an error from the // previously launched asynchronous collective operations, and it does not // have to wait for the completion of scheduled operations. - virtual absl::Status HealthCheck() const = 0; - - // Returns the number of ranks in the communicator. - virtual absl::StatusOr NumRanks() const = 0; - - // Register `buffer` for efficient collective operations (i.e. on NCCL backend - // it registers the buffer for zero-copy collective operations). - virtual absl::StatusOr> - RegisterBuffer(stream_executor::DeviceMemoryBase buffer) = 0; + virtual absl::Status HealthCheck() const { return absl::OkStatus(); } // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. @@ -129,6 +131,10 @@ class Communicator { PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; + // Returns the number of ranks in the communicator. + virtual absl::StatusOr NumRanks() const = 0; + + // Returns a human-readable description of the communicator. virtual std::string ToString() const = 0; }; diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index ba0265eaed3c28..29a58a216d7f8a 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -298,6 +298,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", "//xla/core/collectives:rank_id", @@ -382,6 +383,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", diff --git a/xla/pjrt/cpu/gloo_collectives.cc b/xla/pjrt/cpu/gloo_collectives.cc index 02e5602dd28f2a..0d479d7bfe2fd1 100644 --- a/xla/pjrt/cpu/gloo_collectives.cc +++ b/xla/pjrt/cpu/gloo_collectives.cc @@ -64,8 +64,8 @@ limitations under the License. namespace xla::cpu { GlooCollectivesCommunicator::GlooCollectivesCommunicator( - std::shared_ptr context) - : context_(std::move(context)) {} + std::shared_ptr context, size_t rank, size_t num_ranks) + : context_(std::move(context)), rank_(rank), num_ranks_(num_ranks) {} GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default; template @@ -453,8 +453,7 @@ GlooCollectives::GlooCollectives( GlooCollectives::~GlooCollectives() = default; -absl::StatusOr> -GlooCollectives::GetCommunicator( +absl::StatusOr> GlooCollectives::GetCommunicator( absl::Span global_devices, int rank) { Context* context; { @@ -487,8 +486,8 @@ GlooCollectives::GetCommunicator( return absl::UnknownError( absl::StrCat("Gloo context initialization failed: ", e.what())); } - context->communicator = - std::make_shared(std::move(gloo_context)); + context->communicator = std::make_shared( + std::move(gloo_context), rank, global_devices.size()); return context->communicator; } diff --git a/xla/pjrt/cpu/gloo_collectives.h b/xla/pjrt/cpu/gloo_collectives.h index 401ad0c54f7285..7bac8b7d662721 100644 --- a/xla/pjrt/cpu/gloo_collectives.h +++ b/xla/pjrt/cpu/gloo_collectives.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -26,22 +27,27 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "gloo/context.h" #include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class GlooCollectivesCommunicator : public CollectivesCommunicator { +class GlooCollectivesCommunicator : public Communicator { public: - explicit GlooCollectivesCommunicator(std::shared_ptr context); + explicit GlooCollectivesCommunicator(std::shared_ptr context, + size_t rank, size_t num_ranks); ~GlooCollectivesCommunicator() override; absl::Status AllReduce(se::DeviceMemoryBase send_buffer, @@ -67,8 +73,33 @@ class GlooCollectivesCommunicator : public CollectivesCommunicator { ReductionKind reduction_kind, const Executor& executor) override; + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("GlooCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + private: std::shared_ptr context_; + size_t rank_; + size_t num_ranks_; }; class GlooCollectives : public CollectivesInterface { @@ -78,7 +109,7 @@ class GlooCollectives : public CollectivesInterface { ~GlooCollectives() override; // Thread-safe. - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span devices, int rank) override; private: diff --git a/xla/pjrt/cpu/gloo_collectives_test.cc b/xla/pjrt/cpu/gloo_collectives_test.cc index 4537b1073fb564..e4c79982beeaa6 100644 --- a/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/xla/pjrt/cpu/gloo_collectives_test.cc @@ -59,7 +59,7 @@ constexpr int kNumParticipants = 2; constexpr size_t kBufferSize = 256; constexpr absl::Duration kTimeout = absl::Seconds(5); -absl::StatusOr> GetCommunicator( +absl::StatusOr> GetCommunicator( size_t kNumParticipants, absl::Span global_devices, const std::shared_ptr& kv_store, int rank) { auto collectives = std::make_shared( diff --git a/xla/pjrt/cpu/mpi_collectives.cc b/xla/pjrt/cpu/mpi_collectives.cc index aaf1ebe6bb5815..002f278c79bb63 100644 --- a/xla/pjrt/cpu/mpi_collectives.cc +++ b/xla/pjrt/cpu/mpi_collectives.cc @@ -261,9 +261,8 @@ void MpiCollectives::Finalize() { MPI_Finalize(); } -absl::StatusOr> -MpiCollectives::GetCommunicator(absl::Span global_devices, - int rank) { +absl::StatusOr> MpiCollectives::GetCommunicator( + absl::Span global_devices, int rank) { int flag; MPI_Is_thread_main(&flag); if (!flag) { diff --git a/xla/pjrt/cpu/mpi_collectives.h b/xla/pjrt/cpu/mpi_collectives.h index 8058c5f38077e7..f24537b52d4c51 100644 --- a/xla/pjrt/cpu/mpi_collectives.h +++ b/xla/pjrt/cpu/mpi_collectives.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -32,11 +33,12 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class MpiCollectivesCommunicator : public CollectivesCommunicator { +class MpiCollectivesCommunicator : public Communicator { public: explicit MpiCollectivesCommunicator(int color, int key); ~MpiCollectivesCommunicator() override; @@ -64,6 +66,29 @@ class MpiCollectivesCommunicator : public CollectivesCommunicator { ReductionKind reduction_kind, const Executor& executor) override; + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return mpi_size_; } + + std::string ToString() const override { + return absl::StrCat("MpiCommunicator [rank: ", mpi_rank_, + " num_ranks: ", mpi_size_, "]"); + } + private: MPI_Comm comm_; int mpi_rank_; @@ -84,7 +109,7 @@ class MpiCollectives : public CollectivesInterface { void Init(); void Finalize(); - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span global_devices, int rank) override; private: diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 430def53c196bb..40dbd27b160ab3 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1985,6 +1985,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index faba50bc2280af..cfa3b11f36513a 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -32,55 +32,6 @@ limitations under the License. namespace xla::cpu { -// TODO(b/380457503): We are in the middle of migrating this API to the new XLA -// collectives API defined under `xla/core/collectives`. -class CollectivesCommunicator { - public: - using Executor = Communicator::Executor; - - virtual ~CollectivesCommunicator() = default; - - // Performs an all-reduce. - virtual absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; - - // Performs a collective permute. - // Arguments: - // source_rank: the rank from which this rank should receive its data. - // Optional; if absent, then the output is filled with zeros. - // target_rank: the ranks to which this rank should send its data. - virtual absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) = 0; - - // Performs an all-to-all. - // The all-to-all chunks are passed separately and do not have to be - // contiguous in memory. - virtual absl::Status AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) = 0; - - // Performs an all-gather. - virtual absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - const Executor& executor) = 0; - - // Performs a reduce-scatter - virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; -}; - class CollectivesInterface { public: virtual ~CollectivesInterface() = default; @@ -89,8 +40,8 @@ class CollectivesInterface { // Args: // devices: the devices participating in this collective. // rank: the rank of this process. - virtual absl::StatusOr> - GetCommunicator(absl::Span devices, int rank) = 0; + virtual absl::StatusOr> GetCommunicator( + absl::Span devices, int rank) = 0; }; } // namespace xla::cpu diff --git a/xla/service/cpu/in_process_collectives.cc b/xla/service/cpu/in_process_collectives.cc index 46e5d47993d15e..b75b557c8525b6 100644 --- a/xla/service/cpu/in_process_collectives.cc +++ b/xla/service/cpu/in_process_collectives.cc @@ -435,8 +435,8 @@ struct InProcessCollectivesState { }; InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( - InProcessCollectivesState* state, int rank, int size) - : state_(state), rank_(rank) {} + InProcessCollectivesState* state, int rank, int num_ranks) + : state_(state), rank_(rank), num_ranks_(num_ranks) {} InProcessCollectivesCommunicator::~InProcessCollectivesCommunicator() = default; absl::Status InProcessCollectivesCommunicator::AllReduce( @@ -576,7 +576,7 @@ InProcessCollectives::InProcessCollectives() : state_(std::make_unique()) {} InProcessCollectives::~InProcessCollectives() = default; -absl::StatusOr> +absl::StatusOr> InProcessCollectives::GetCommunicator(absl::Span devices, int rank) { // We don't care about devices here: we share rendezvous state globally. diff --git a/xla/service/cpu/in_process_collectives.h b/xla/service/cpu/in_process_collectives.h index 9f04e9890eda06..ffabb0cd526aa7 100644 --- a/xla/service/cpu/in_process_collectives.h +++ b/xla/service/cpu/in_process_collectives.h @@ -19,25 +19,29 @@ limitations under the License. #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu::runtime { struct InProcessCollectivesState; -class InProcessCollectivesCommunicator : public CollectivesCommunicator { +class InProcessCollectivesCommunicator : public Communicator { public: InProcessCollectivesCommunicator(InProcessCollectivesState* state, int rank, - int size); + int num_ranks); ~InProcessCollectivesCommunicator() override; absl::Status AllReduce(se::DeviceMemoryBase send_buffer, @@ -67,9 +71,33 @@ class InProcessCollectivesCommunicator : public CollectivesCommunicator { ReductionKind reduction_kind, const Executor& executor) override; + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("InProcessCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + private: InProcessCollectivesState* state_; int rank_; + int num_ranks_; }; class InProcessCollectives : public CollectivesInterface { @@ -78,7 +106,7 @@ class InProcessCollectives : public CollectivesInterface { ~InProcessCollectives() override; // Thread-safe. - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span devices, int rank) override; private: From 0f122397e69b29aec92688335c6ed1bcddaaf4ec Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Tue, 7 Jan 2025 17:17:48 -0800 Subject: [PATCH 16/45] [XLA] Handle empty leaf nodes in an original value Add a warning when parsing an original value with leaf nodes without values. Issue an error for such cases in HloVerifier. PiperOrigin-RevId: 713093109 --- xla/hlo/ir/hlo_original_value.cc | 5 ++--- xla/hlo/parser/hlo_parser.cc | 23 +++++++++++++++-------- xla/hlo/parser/hlo_parser_test.cc | 30 ++++++++++++++---------------- xla/service/hlo_verifier.cc | 22 ++++++++++++++++++++++ xla/service/hlo_verifier_test.cc | 14 ++++++++++++++ 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/xla/hlo/ir/hlo_original_value.cc b/xla/hlo/ir/hlo_original_value.cc index c1617888510a4d..e76cd15d989ce0 100644 --- a/xla/hlo/ir/hlo_original_value.cc +++ b/xla/hlo/ir/hlo_original_value.cc @@ -53,15 +53,14 @@ std::string OriginalValueToStringHelper(const OriginalValue& original_value, return result; } - // The original_value may refer to an empty array, such as origin {}, so let's - // check whether that's the case before accessing them. Generally speaking the - // index _should_ be good, but let's double check. const auto& leaf = original_value.element(shape_index); if (leaf.has_value()) { absl::StrAppend( &result, "{", "\"", leaf->instruction_name, "\"", (leaf->shape_index.empty() ? "" : " " + leaf->shape_index.ToString()), "}"); + } else { + absl::StrAppend(&result, "{}"); } return result; } diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc index 01335cb5ff28dc..3436fd408890f1 100644 --- a/xla/hlo/parser/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -6488,18 +6488,25 @@ bool HloParserImpl::ParseOriginalValue( ++leaf_shape_index.back(); } else if (lexer_.GetKind() == TokKind::kLbrace) { lexer_.Lex(); - std::string instruction_name; - ShapeIndex shape_index; - if (!ParseString(&instruction_name)) { - return false; - } if (lexer_.GetKind() != TokKind::kRbrace) { - if (!ParseShapeIndex(&shape_index)) { + std::string instruction_name; + ShapeIndex shape_index; + if (!ParseString(&instruction_name)) { return false; } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (!ParseShapeIndex(&shape_index)) { + return false; + } + } + *(**original_value)->mutable_element(leaf_shape_index) = { + instruction_name, shape_index}; + } else { + // The original_value is not expected to have any leaf without values. + // However we should not fail the execution here. This should + // be done in HloVerifier instead. + LOG(WARNING) << "Found an empty leaf node in an original value"; } - *(**original_value)->mutable_element(leaf_shape_index) = { - instruction_name, shape_index}; if (!ParseToken(TokKind::kRbrace, "Expects '} at end of each OriginalArray'")) { return false; diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index f1ce17e4a57b76..61de9ca31adcd8 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -5726,6 +5726,20 @@ ENTRY %test { HasSubstr("expects instruction shape"))); } +TEST_F(HloParserTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"(HloModule test + +ENTRY %test { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + ExpectHasSubstr(module->ToString(HloPrintOptions::ShortParsable()), + "origin={(({}, {\"v2\"}), {\"v3\"})}"); +} + TEST_F(HloParserTest, TranscendentalAccuracyMode) { constexpr absl::string_view hlo_string = R"( HloModule exponential_hw @@ -5842,21 +5856,5 @@ ENTRY main { "error: unexpected attribute \"result_accuracy\""); } -TEST_F(HloParserTest, EmptyOriginalValueIsPrintedCorrectly) { - const std::string hlo_string = R"(HloModule test - -ENTRY %test { - ROOT op = f32[] parameter(0), origin={} -} - - -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - ExpectHasSubstr(module->ToString(HloPrintOptions::Fingerprint()), - "origin={}"); -} - } // namespace } // namespace xla diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 88823f1dd9e5c1..9e84f287beb874 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2483,6 +2483,27 @@ absl::Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { return absl::OkStatus(); } +// Verifies that leaf nodes in an original value contain values. +absl::Status VerifyOriginalValue(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (auto original_value = instruction->original_value()) { + // An original value is expected to have intermediate nodes that are + // always nullopt and leaves with actual values. + for (const auto& leaf : original_value->leaves()) { + if (!leaf.second.has_value()) { + return Internal( + "Leaf nodes in an original value is expected to contain values." + " Instruction: %s.", + instruction->ToString()); + } + } + } + } + } + return absl::OkStatus(); +} + // Checks various invariants of channel instructions (send/recv and // collectives). absl::Status VerifyChannels(const HloModule& module, @@ -3117,6 +3138,7 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module)); TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); + TF_RETURN_IF_ERROR(VerifyOriginalValue(*module)); return false; }(); if (status_or_changed.ok()) { diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 419156664e7f46..6e2207726caeb2 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -3635,5 +3635,19 @@ TEST_F(HloVerifierTest, UnaryOpWithResultAccuracy) { EXPECT_TRUE(status.ok()) << status; } +TEST_F(HloVerifierTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"( +HloModule module +ENTRY %entry_computation { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_FALSE(status.ok()); +} + } // namespace } // namespace xla From 5a86fb191388e14f7ec5a08ea5261d1a3eb19385 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 7 Jan 2025 18:04:00 -0800 Subject: [PATCH 17/45] [xla:cpu] Move InProcessCommunicator to backends/cpu/collectives PiperOrigin-RevId: 713105459 --- xla/backends/cpu/collectives/BUILD | 28 + .../collectives/in_process_communicator.cc | 576 ++++++++++++++++++ .../cpu/collectives/in_process_communicator.h | 109 ++++ xla/service/cpu/BUILD | 3 + xla/service/cpu/in_process_collectives.cc | 573 +---------------- xla/service/cpu/in_process_collectives.h | 84 +-- 6 files changed, 739 insertions(+), 634 deletions(-) create mode 100644 xla/backends/cpu/collectives/in_process_communicator.cc create mode 100644 xla/backends/cpu/collectives/in_process_communicator.h diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index 947173cc8a7f99..0f03fd72acc90a 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -34,3 +34,31 @@ cc_library( "@tsl//tsl/platform:casts", ], ) + +cc_library( + name = "in_process_communicator", + srcs = ["in_process_communicator.cc"], + hdrs = ["in_process_communicator.h"], + deps = [ + ":cpu_collectives", + "//xla:refcounting_hash_map", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + ], +) diff --git a/xla/backends/cpu/collectives/in_process_communicator.cc b/xla/backends/cpu/collectives/in_process_communicator.cc new file mode 100644 index 00000000000000..a293c1e72672c3 --- /dev/null +++ b/xla/backends/cpu/collectives/in_process_communicator.cc @@ -0,0 +1,576 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/in_process_communicator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/refcounting_hash_map.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { +namespace { + +void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { + absl::StrAppend(out, device.value()); +} + +struct AllReduceParticipantData : ParticipantData { + explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + int64_t element_count; + const void* source_data; + void* destination_data; + PrimitiveType primitive_type; + + ReductionKind reduction_kind; + + std::string ToString() const override { + return absl::StrFormat( + "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " + "rendezvous_key=%s}", + local_rank, element_count, PrimitiveType_Name(primitive_type), + rendezvous_key.ToString()); + } +}; + +template +T GetInitialValue(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return static_cast(0); + case ReductionKind::PRODUCT: + return static_cast(1); + case ReductionKind::MIN: + return std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + case ReductionKind::MAX: + return std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + } +} + +// We cannot use static_assert(false), because the C++ standard (prior to +// CWG2518) does not allow the statement discarded by a constexpr if to +// be ill-formed for every possible specialization. +// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if +template +constexpr bool always_false_v = false; + +template +void ReduceHelper(absl::Span acc, absl::Span inputs) { + // TODO(penporn): make sure this gets vectorized. + if constexpr (reduction_kind == ReductionKind::SUM) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] += inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] *= inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::MIN) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::min(acc[i], inputs[j][i]); + } + } + } else if constexpr (reduction_kind == ReductionKind::MAX) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::max(acc[i], inputs[j][i]); + } + } + } else { + static_assert(always_false_v, "Unsupported reduction kind"); + } +} + +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = primitive_util::NativeTypeOf; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + +class CpuAllReduceRendezvous + : public Rendezvous { + public: + explicit CpuAllReduceRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllReduceParticipantData& me) override { + VLOG(3) << me.ToString(); + int64_t world_size = participants_.size(); + // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th + // chunk of the output. + int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); + + int64_t start_elem = me.local_rank * chunk_elems; + int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); + chunk_elems = std::max(int64_t{0}, end_elem - start_elem); + if (chunk_elems == 0) { + return nullptr; + } + + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + + if (primitive_util::IsArrayType(me.primitive_type)) { + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto constant_type) { + return ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems); + }, + me.primitive_type)); + } else { + return absl::UnimplementedError(absl::StrCat( + "Unexpected datatype: ", + primitive_util::LowercasePrimitiveTypeName(me.primitive_type))); + } + + // All-gather the reduced chunks. + for (const auto& p : participants_) { + if (p->local_rank != me.local_rank) { + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); + } + } + return nullptr; + } +}; + +struct CollectivePermuteParticipantData : ParticipantData { + CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + const void* source_buffer; + void* destination_buffer; + size_t num_bytes; + + // From which rank is this participant receiving its data? Optional; if + // absent fill with zeros. + std::optional source_rank; + + std::string ToString() const override { + return absl::StrFormat( + "CollectivePermuteParticipantData{rank=%d, " + "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " + "source_replica_id=%d, " + "devices=[%s]}", + local_rank, source_buffer, destination_buffer, num_bytes, + source_rank.value_or(-1), + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); + } +}; + +class CpuCollectivePermuteRendezvous + : public Rendezvous { + public: + explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const CollectivePermuteParticipantData& p) override { + VLOG(3) << p.ToString(); + if (p.source_rank) { + std::memcpy(p.destination_buffer, + participants_[*p.source_rank]->source_buffer, p.num_bytes); + } else { + std::memset(p.destination_buffer, 0, p.num_bytes); + } + return nullptr; + } +}; + +struct AllToAllParticipantData : ParticipantData { + AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + std::vector source_buffers; + std::vector destination_buffers; + size_t chunk_size; + + std::string ToString() const override { + auto addr_formatter = [](std::string* out, const void* mem) { + absl::StrAppend(out, absl::StrFormat("%p", mem)); + }; + return absl::StrFormat( + "AllToAllParticipantData{rank=%d, " + "devices=[%s], source_buffers=[%s], " + "destination_buffers=[%s], chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + absl::StrJoin(source_buffers, ", ", addr_formatter), + absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); + } +}; + +class CpuAllToAllRendezvous + : public Rendezvous { + public: + explicit CpuAllToAllRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllToAllParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + for (int i = 0; i < world_size; ++i) { + std::memcpy(participants_[i]->destination_buffers[p.local_rank], + p.source_buffers[i], p.chunk_size); + } + return nullptr; + } +}; + +struct AllGatherParticipantData : ParticipantData { + AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + const void* source_buffer; + void* destination_buffer; + size_t chunk_size; + + std::string ToString() const override { + return absl::StrFormat( + "AllGatherParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_size); + } +}; + +class CpuAllGatherRendezvous + : public Rendezvous { + public: + explicit CpuAllGatherRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllGatherParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + char* out = static_cast(p.destination_buffer); + for (int i = 0; i < world_size; ++i, out += p.chunk_size) { + std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); + } + return nullptr; + } +}; + +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + if (primitive_util::IsArrayType(me.element_type)) { + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto constant_type) { + return ReduceScatter(me.reduction_kind, inputs, + me.destination_buffer, + me.chunk_elems); + }, + me.element_type)); + } else { + return absl::UnimplementedError(absl::StrCat( + "Unexpected datatype: ", + primitive_util::LowercasePrimitiveTypeName(me.element_type))); + } + return nullptr; + } +}; + +} // namespace + +struct InProcessCommunicator::State { + RefcountingHashMap + all_reduce_rendezvous_map; + RefcountingHashMap + collective_permute_rendezvous_map; + RefcountingHashMap + all_to_all_rendezvous_map; + RefcountingHashMap + all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; +}; + +InProcessCommunicator::InProcessCommunicator(std::shared_ptr state, + size_t rank, size_t num_ranks) + : state_(std::move(state)), rank_(rank), num_ranks_(num_ranks) {} + +InProcessCommunicator::~InProcessCommunicator() = default; + +std::shared_ptr +InProcessCommunicator::CreateState() { + return std::make_shared(); +} + +absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllReduceParticipantData participant(key, rank_); + participant.element_count = count; + participant.primitive_type = dtype; + participant.source_data = send_buffer.opaque(); + participant.destination_data = recv_buffer.opaque(); + participant.reduction_kind = reduction_kind; + + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + + return CpuAllReduceRendezvous::SubmitParticipant( + [&] { + return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + CollectivePermuteParticipantData participant(key, rank_); + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + participant.num_bytes = count * primitive_util::ByteWidth(dtype); + participant.source_rank = std::nullopt; + if (source_rank) { + participant.source_rank = source_rank->value(); + } + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuCollectivePermuteRendezvous::SubmitParticipant( + [&] { + return state_->collective_permute_rendezvous_map + .GetOrCreateIfAbsent(key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllToAllParticipantData participant(key, rank_); + TF_RET_CHECK(send_buffers.size() == recv_buffers.size()); + + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + participant.chunk_size = chunk_bytes; + participant.source_buffers.reserve(send_buffers.size()); + participant.destination_buffers.reserve(recv_buffers.size()); + for (se::DeviceMemoryBase send_buffer : send_buffers) { + participant.source_buffers.push_back(send_buffer.opaque()); + } + for (se::DeviceMemoryBase recv_buffer : recv_buffers) { + participant.destination_buffers.push_back(recv_buffer.opaque()); + } + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllToAllRendezvous::SubmitParticipant( + [&] { + return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllGatherParticipantData participant(key, rank_); + participant.chunk_size = count * primitive_util::ByteWidth(dtype); + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllGatherRendezvous::SubmitParticipant( + [&] { + return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = dtype; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = count; + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/in_process_communicator.h b/xla/backends/cpu/collectives/in_process_communicator.h new file mode 100644 index 00000000000000..abc82c7aba211c --- /dev/null +++ b/xla/backends/cpu/collectives/in_process_communicator.h @@ -0,0 +1,109 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator that implements collective operations using shared memory +// and works only within a single process. +class InProcessCommunicator : public Communicator { + public: + // A state shared by all InProcessCommunicators in the clique. + struct State; + + // Creates a new State for constructing InProcessCommunicators. + static std::shared_ptr CreateState(); + + InProcessCommunicator(std::shared_ptr state, size_t rank, + size_t num_ranks); + ~InProcessCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("InProcessCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + std::shared_ptr state_; + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 40dbd27b160ab3..88692c8eb2c6d7 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1985,17 +1985,20 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:in_process_communicator", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", diff --git a/xla/service/cpu/in_process_collectives.cc b/xla/service/cpu/in_process_collectives.cc index b75b557c8525b6..a7d759348fefdb 100644 --- a/xla/service/cpu/in_process_collectives.cc +++ b/xla/service/cpu/in_process_collectives.cc @@ -15,575 +15,34 @@ limitations under the License. #include "xla/service/cpu/in_process_collectives.h" -#include -#include -#include -#include -#include #include -#include -#include -#include +#include #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/time/time.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/backends/cpu/collectives/cpu_collectives.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/primitive_util.h" -#include "xla/refcounting_hash_map.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" -namespace xla { -namespace cpu { -namespace runtime { -namespace { - -void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { - absl::StrAppend(out, device.value()); -} - -struct AllReduceParticipantData : ParticipantData { - explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - int64_t element_count; - const void* source_data; - void* destination_data; - PrimitiveType primitive_type; - - ReductionKind reduction_kind; - - std::string ToString() const override { - return absl::StrFormat( - "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " - "rendezvous_key=%s}", - local_rank, element_count, PrimitiveType_Name(primitive_type), - rendezvous_key.ToString()); - } -}; - -template -T GetInitialValue(ReductionKind reduction_kind) { - switch (reduction_kind) { - case ReductionKind::SUM: - return static_cast(0); - case ReductionKind::PRODUCT: - return static_cast(1); - case ReductionKind::MIN: - return std::numeric_limits::has_infinity - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - case ReductionKind::MAX: - return std::numeric_limits::has_infinity - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - } -} - -// We cannot use static_assert(false), because the C++ standard (prior to -// CWG2518) does not allow the statement discarded by a constexpr if to -// be ill-formed for every possible specialization. -// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if -template -constexpr bool always_false_v = false; - -template -void ReduceHelper(absl::Span acc, absl::Span inputs) { - // TODO(penporn): make sure this gets vectorized. - if constexpr (reduction_kind == ReductionKind::SUM) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] += inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] *= inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::MIN) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::min(acc[i], inputs[j][i]); - } - } - } else if constexpr (reduction_kind == ReductionKind::MAX) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::max(acc[i], inputs[j][i]); - } - } - } else { - static_assert(always_false_v, "Unsupported reduction kind"); - } -} - -template -absl::Status ReduceScatter(ReductionKind reduction_kind, - absl::Span inputs, void* output, - int64_t num_elems) { - using T = primitive_util::NativeTypeOf; - T initial_value = GetInitialValue(reduction_kind); - - absl::Span out_chunk = - absl::MakeSpan(reinterpret_cast(output), num_elems); - for (int64_t i = 0; i < num_elems; ++i) { - out_chunk[i] = initial_value; - } - - absl::Span input_chunks( - reinterpret_cast(inputs.data()), inputs.size()); - switch (reduction_kind) { - case ReductionKind::SUM: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::PRODUCT: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); -} - -class CpuAllReduceRendezvous - : public Rendezvous { - public: - explicit CpuAllReduceRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - absl::StatusOr RunCollectiveOp( - const AllReduceParticipantData& me) override { - VLOG(3) << me.ToString(); - int64_t world_size = participants_.size(); - // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th - // chunk of the output. - int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); - - int64_t start_elem = me.local_rank * chunk_elems; - int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); - chunk_elems = std::max(int64_t{0}, end_elem - start_elem); - if (chunk_elems == 0) { - return nullptr; - } - - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; - void* reduce_output = - reinterpret_cast(me.destination_data) + chunk_offset; - - std::vector inputs; - inputs.reserve(world_size); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_data) + - chunk_offset); - } - - if (primitive_util::IsArrayType(me.primitive_type)) { - TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( - [&](const auto constant_type) { - return ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems); - }, - me.primitive_type)); - } else { - return absl::UnimplementedError(absl::StrCat( - "Unexpected datatype: ", - primitive_util::LowercasePrimitiveTypeName(me.primitive_type))); - } - - // All-gather the reduced chunks. - for (const auto& p : participants_) { - if (p->local_rank != me.local_rank) { - std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, - reduce_output, chunk_bytes); - } - } - return nullptr; - } -}; - -struct CollectivePermuteParticipantData : ParticipantData { - CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - const void* source_buffer; - void* destination_buffer; - size_t num_bytes; - - // From which rank is this participant receiving its data? Optional; if - // absent fill with zeros. - std::optional source_rank; - - std::string ToString() const override { - return absl::StrFormat( - "CollectivePermuteParticipantData{rank=%d, " - "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " - "source_replica_id=%d, " - "devices=[%s]}", - local_rank, source_buffer, destination_buffer, num_bytes, - source_rank.value_or(-1), - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); - } -}; - -class CpuCollectivePermuteRendezvous - : public Rendezvous { - public: - explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - - absl::StatusOr RunCollectiveOp( - const CollectivePermuteParticipantData& p) override { - VLOG(3) << p.ToString(); - if (p.source_rank) { - std::memcpy(p.destination_buffer, - participants_[*p.source_rank]->source_buffer, p.num_bytes); - } else { - std::memset(p.destination_buffer, 0, p.num_bytes); - } - return nullptr; - } -}; - -struct AllToAllParticipantData : ParticipantData { - AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - std::vector source_buffers; - std::vector destination_buffers; - size_t chunk_size; - - std::string ToString() const override { - auto addr_formatter = [](std::string* out, const void* mem) { - absl::StrAppend(out, absl::StrFormat("%p", mem)); - }; - return absl::StrFormat( - "AllToAllParticipantData{rank=%d, " - "devices=[%s], source_buffers=[%s], " - "destination_buffers=[%s], chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - absl::StrJoin(source_buffers, ", ", addr_formatter), - absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); - } -}; - -class CpuAllToAllRendezvous - : public Rendezvous { - public: - explicit CpuAllToAllRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllToAllParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - for (int i = 0; i < world_size; ++i) { - std::memcpy(participants_[i]->destination_buffers[p.local_rank], - p.source_buffers[i], p.chunk_size); - } - return nullptr; - } -}; - -struct AllGatherParticipantData : ParticipantData { - AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - const void* source_buffer; - void* destination_buffer; - size_t chunk_size; - - std::string ToString() const override { - return absl::StrFormat( - "AllGatherParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_size); - } -}; - -class CpuAllGatherRendezvous - : public Rendezvous { - public: - explicit CpuAllGatherRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllGatherParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - char* out = static_cast(p.destination_buffer); - for (int i = 0; i < world_size; ++i, out += p.chunk_size) { - std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); - } - return nullptr; - } -}; - -struct ReduceScatterParticipantData : ParticipantData { - ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - ReductionKind reduction_kind; - PrimitiveType element_type; - const void* source_buffer; - void* destination_buffer; - size_t chunk_elems; - - std::string ToString() const override { - return absl::StrFormat( - "ReduceScatterParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_elems=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_elems); - } -}; - -class CpuReduceScatterRendezvous - : public Rendezvous { - public: - explicit CpuReduceScatterRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const ReduceScatterParticipantData& me) override { - auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); - int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; - - std::vector inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_buffer) + - chunk_offset); - } - - if (primitive_util::IsArrayType(me.element_type)) { - TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( - [&](const auto constant_type) { - return ReduceScatter(me.reduction_kind, inputs, - me.destination_buffer, - me.chunk_elems); - }, - me.element_type)); - } else { - return absl::UnimplementedError(absl::StrCat( - "Unexpected datatype: ", - primitive_util::LowercasePrimitiveTypeName(me.element_type))); - } - return nullptr; - } -}; - -} // namespace - -struct InProcessCollectivesState { - RefcountingHashMap - all_reduce_rendezvous_map; - RefcountingHashMap - collective_permute_rendezvous_map; - RefcountingHashMap - all_to_all_rendezvous_map; - RefcountingHashMap - all_gather_rendezvous_map; - RefcountingHashMap - reduce_scatter_rendezvous_map; -}; - -InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( - InProcessCollectivesState* state, int rank, int num_ranks) - : state_(state), rank_(rank), num_ranks_(num_ranks) {} -InProcessCollectivesCommunicator::~InProcessCollectivesCommunicator() = default; - -absl::Status InProcessCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllReduceParticipantData participant(key, rank_); - participant.element_count = count; - participant.primitive_type = dtype; - participant.source_data = send_buffer.opaque(); - participant.destination_data = recv_buffer.opaque(); - participant.reduction_kind = reduction_kind; - - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - - return CpuAllReduceRendezvous::SubmitParticipant( - [&] { - return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - CollectivePermuteParticipantData participant(key, rank_); - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - participant.num_bytes = count * primitive_util::ByteWidth(dtype); - participant.source_rank = std::nullopt; - if (source_rank) { - participant.source_rank = source_rank->value(); - } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuCollectivePermuteRendezvous::SubmitParticipant( - [&] { - return state_->collective_permute_rendezvous_map - .GetOrCreateIfAbsent(key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllToAllParticipantData participant(key, rank_); - TF_RET_CHECK(send_buffers.size() == recv_buffers.size()); - - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - participant.chunk_size = chunk_bytes; - participant.source_buffers.reserve(send_buffers.size()); - participant.destination_buffers.reserve(recv_buffers.size()); - for (se::DeviceMemoryBase send_buffer : send_buffers) { - participant.source_buffers.push_back(send_buffer.opaque()); - } - for (se::DeviceMemoryBase recv_buffer : recv_buffers) { - participant.destination_buffers.push_back(recv_buffer.opaque()); - } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllToAllRendezvous::SubmitParticipant( - [&] { - return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllGatherParticipantData participant(key, rank_); - participant.chunk_size = count * primitive_util::ByteWidth(dtype); - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllGatherRendezvous::SubmitParticipant( - [&] { - return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - ReduceScatterParticipantData participant(key, rank_); - participant.element_type = dtype; - participant.reduction_kind = reduction_kind; - participant.chunk_elems = count; - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuReduceScatterRendezvous::SubmitParticipant( - [&] { - return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} -InProcessCollectives::InProcessCollectives() - : state_(std::make_unique()) {} -InProcessCollectives::~InProcessCollectives() = default; +namespace xla::cpu::runtime { absl::StatusOr> InProcessCollectives::GetCommunicator(absl::Span devices, int rank) { + absl::MutexLock lock(&mu_); + + std::shared_ptr state = state_.lock(); + if (state == nullptr) { + state = InProcessCommunicator::CreateState(); + state_ = state; + } + // We don't care about devices here: we share rendezvous state globally. - return std::make_shared(state_.get(), rank, - devices.size()); + return std::make_shared(std::move(state), rank, + devices.size()); } -} // namespace runtime -} // namespace cpu -} // namespace xla +} // namespace xla::cpu::runtime diff --git a/xla/service/cpu/in_process_collectives.h b/xla/service/cpu/in_process_collectives.h index ffabb0cd526aa7..976470ac07b8a0 100644 --- a/xla/service/cpu/in_process_collectives.h +++ b/xla/service/cpu/in_process_collectives.h @@ -16,101 +16,31 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ #define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ -#include #include -#include -#include -#include "absl/status/status.h" +#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" #include "xla/core/collectives/communicator.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu::runtime { -struct InProcessCollectivesState; - -class InProcessCollectivesCommunicator : public Communicator { - public: - InProcessCollectivesCommunicator(InProcessCollectivesState* state, int rank, - int num_ranks); - ~InProcessCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Broadcast is not implemented"); - } - - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Send is not implemented"); - } - - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Recv is not implemented"); - } - - absl::StatusOr NumRanks() const override { return num_ranks_; } - - std::string ToString() const override { - return absl::StrCat("InProcessCommunicator [rank: ", rank_, - " num_ranks: ", num_ranks_, "]"); - } - - private: - InProcessCollectivesState* state_; - int rank_; - int num_ranks_; -}; - class InProcessCollectives : public CollectivesInterface { public: - InProcessCollectives(); - ~InProcessCollectives() override; - // Thread-safe. absl::StatusOr> GetCommunicator( absl::Span devices, int rank) override; private: - std::unique_ptr state_; + absl::Mutex mu_; + + // State shared by all constructed communicators. + std::weak_ptr state_ ABSL_GUARDED_BY(mu_); }; } // namespace xla::cpu::runtime From f93bc747159837c28eb2850243113f4d6273b64c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 7 Jan 2025 18:56:38 -0800 Subject: [PATCH 18/45] [XLA:Python] Add an optional argument to the CPU client factory method that specifies the number of CPU devices. This is more ergonomic than overriding the CPU device count via XLA_FLAGS. PiperOrigin-RevId: 713116916 --- xla/python/xla.cc | 8 +++++--- xla/python/xla_client.py | 6 ++++-- xla/python/xla_client.pyi | 1 + xla/python/xla_client_test.py | 4 +++- xla/python/xla_extension/__init__.pyi | 1 + 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 647fc37f089df7..46ecfb4a6dd4fe 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -337,8 +337,8 @@ NB_MODULE(xla_extension, m) { [](bool asynchronous, std::shared_ptr distributed_client, int node_id, int num_nodes, - std::shared_ptr collectives) - -> nb_class_ptr { + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { std::unique_ptr ifrt_client; { nb::gil_scoped_release gil_release; @@ -347,6 +347,7 @@ NB_MODULE(xla_extension, m) { options.asynchronous = asynchronous; options.collectives = std::move(collectives); options.process_id = node_id; + options.cpu_device_count = num_devices; std::unique_ptr client = xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); ifrt::PjRtClient::CreateOptions ifrt_options; @@ -367,7 +368,8 @@ NB_MODULE(xla_extension, m) { nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, nb::arg("collectives").none() = - std::shared_ptr()); + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 040c781cd087d6..46dd4a72edd1e7 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 302 +_version = 303 # Version number for MLIR:Python components. mlir_api_version = 57 @@ -70,7 +70,8 @@ def make_cpu_client( distributed_client=None, node_id=0, num_nodes=1, - collectives=None + collectives=None, + num_devices=None, ) -> ...: register_custom_call_handler('cpu', _xla.register_custom_call_target) register_custom_type_id_handler('cpu', _xla.register_custom_type_id) @@ -80,6 +81,7 @@ def make_cpu_client( node_id=node_id, num_nodes=num_nodes, collectives=collectives, + num_devices=num_devices, ) diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index cac63a98c1b2de..efc3d2573b2224 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -89,6 +89,7 @@ def make_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., ) -> Client: ... diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 35b4a1ee77964f..f0cecc9903295e 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -2757,6 +2757,8 @@ def testDevices(self): def testLocalDevices(self): self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) def testGetAllDevices(self): # TODO(hyeontaek): Remove this method once we have a unified API for @@ -3692,7 +3694,7 @@ def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): backends = { - "cpu": xla_client.make_cpu_client, + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), "gpu": xla_client.make_gpu_client, } diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 5fa885f9f92255..67eadd44c14a48 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -607,6 +607,7 @@ def get_tfrt_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., ) -> Client: ... def get_gpu_client( asynchronous: bool = ..., From 9d1dbbb54af3850a3fecfbe64ec7b86fed5d076a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 7 Jan 2025 20:01:48 -0800 Subject: [PATCH 19/45] [xla:cpu] Move GlooCommunicator to backends/cpu/collectives PiperOrigin-RevId: 713129065 --- xla/backends/cpu/collectives/BUILD | 41 ++ .../cpu/collectives/gloo_communicator.cc | 443 ++++++++++++++++++ .../cpu/collectives/gloo_communicator.h | 103 ++++ xla/pjrt/cpu/BUILD | 2 + xla/pjrt/cpu/gloo_collectives.cc | 411 +--------------- xla/pjrt/cpu/gloo_collectives.h | 80 +--- 6 files changed, 600 insertions(+), 480 deletions(-) create mode 100644 xla/backends/cpu/collectives/gloo_communicator.cc create mode 100644 xla/backends/cpu/collectives/gloo_communicator.h diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index 0f03fd72acc90a..7be08f8866dc0c 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -35,6 +35,47 @@ cc_library( ], ) +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "gloo_communicator", + srcs = ["gloo_communicator.cc"], + hdrs = ["gloo_communicator.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cpu_collectives", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@gloo", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + ], +) + +# TODO(b/380457503): Restrict visibility to private. cc_library( name = "in_process_communicator", srcs = ["in_process_communicator.cc"], diff --git a/xla/backends/cpu/collectives/gloo_communicator.cc b/xla/backends/cpu/collectives/gloo_communicator.cc new file mode 100644 index 00000000000000..e5e19aa3a1cfed --- /dev/null +++ b/xla/backends/cpu/collectives/gloo_communicator.cc @@ -0,0 +1,443 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/gloo_communicator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "gloo/algorithm.h" +#include "gloo/allgather.h" +#include "gloo/allreduce.h" +#include "gloo/context.h" +#include "gloo/math.h" +#include "gloo/reduce_scatter.h" +#include "gloo/transport/device.h" +#include "gloo/transport/unbound_buffer.h" +#include "gloo/types.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +GlooCommunicator::GlooCommunicator(std::shared_ptr context, + size_t rank, size_t num_ranks) + : context_(std::move(context)), rank_(rank), num_ranks_(num_ranks) {} + +GlooCommunicator::~GlooCommunicator() = default; + +template +static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, + se::DeviceMemoryBase input_buffer, + se::DeviceMemoryBase output_buffer, + size_t num_elements, + gloo::AllreduceOptions& options) { + options.setInput( + reinterpret_cast(const_cast(input_buffer.opaque())), + num_elements); + options.setOutput( + reinterpret_cast(const_cast(output_buffer.opaque())), + num_elements); + + using ReductionFn = void (*)(void*, const void*, const void*, size_t); + + switch (reduction_kind) { + case ReductionKind::SUM: + options.setReduceFunction(static_cast(&gloo::sum)); + break; + case ReductionKind::PRODUCT: + options.setReduceFunction(static_cast(&gloo::product)); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::min)); + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::max)); + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + break; + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + + gloo::AllreduceOptions options(context_); + // TODO(phawkins): how to do tags? + // options.setTag(tag); + switch (dtype) { + case S8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case BF16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case C64: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case C128: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in allreduce"); + } + options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); + + try { + gloo::allreduce(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-reduce failed: ", e.what())); + } + return absl::OkStatus(); +} + +static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; + +absl::Status GlooCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + uint32_t tag = 0; // TODO(phawkins): come up with better tags. + const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + + try { + std::unique_ptr in; + std::unique_ptr out; + for (RankId target : target_ranks) { + if (target != context_->rank) { + VLOG(1) << "send from " << context_->rank << " to " << target.value(); + if (!in) { + in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes); + } + in->send(target.value(), slot); + } + } + if (source_rank) { + if (*source_rank == context_->rank) { + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); + } else { + VLOG(1) << "recv at " << context_->rank << " from " + << source_rank->value(); + out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes); + out->recv(source_rank->value(), slot); + } + } else { + std::memset(recv_buffer.opaque(), 0, num_bytes); + } + VLOG(1) << "wait for send at " << context_->rank; + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); + if (in) { + in->waitSend(deadline); + } + VLOG(1) << "wait for recv at " << context_->rank; + if (out) { + out->waitRecv(deadline); + } + VLOG(1) << "done waiting at " << context_->rank; + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo collective permute failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + // We can't use Gloo's all-to-all implementation directly because it assumes + // that the inputs and outputs are contiguous. No big deal; it's just built + // on top of send/recv and we can do the same as it. + uint32_t tag = 0; // TODO(phawkins): use better tags. + int my_rank = context_->rank; + int world_size = context_->size; + + TF_RET_CHECK(world_size == send_buffers.size()); + TF_RET_CHECK(world_size == recv_buffers.size()); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + try { + const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); + std::vector> ins( + context_->size); + std::vector> outs( + context_->size); + for (size_t i = 0; i < world_size; ++i) { + if (i != my_rank) { + ins[i] = context_->createUnboundBuffer( + const_cast(send_buffers[i].opaque()), chunk_bytes); + outs[i] = context_->createUnboundBuffer( + const_cast(recv_buffers[i].opaque()), chunk_bytes); + } + } + + for (int i = 1; i < world_size; i++) { + int send_rank = (my_rank + i) % world_size; + int recv_rank = (my_rank + world_size - i) % world_size; + ins[send_rank]->send(send_rank, slot); + outs[recv_rank]->recv(recv_rank, slot); + } + + std::memcpy(const_cast(recv_buffers[my_rank].opaque()), + send_buffers[my_rank].opaque(), chunk_bytes); + + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); + for (int i = 0; i < world_size; i++) { + if (i != my_rank) { + ins[i]->waitSend(deadline); + outs[i]->waitRecv(deadline); + } + } + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-to-all failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + uint32_t tag = 0; // TODO(phawkins): use better tags. + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + gloo::AllgatherOptions options(context_); + options.setTag(tag); + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); + options.setInput(reinterpret_cast(send_buffer.opaque()), chunk_bytes); + options.setOutput(reinterpret_cast(recv_buffer.opaque()), + chunk_bytes * context_->size); + + try { + gloo::allgather(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo AllGather failed: ", e.what())); + } + return absl::OkStatus(); +} + +template +absl::Status ReduceScatterHelper(std::shared_ptr context, + ReductionKind reduction_kind, void* buffer, + size_t chunk_elems) { + const gloo::ReductionFunction* reduction_function = nullptr; + if constexpr (is_complex_v) { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } else { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + case ReductionKind::MAX: + reduction_function = gloo::ReductionFunction::max; + break; + case ReductionKind::MIN: + reduction_function = gloo::ReductionFunction::min; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } + try { + std::vector recv_elems(context->size, chunk_elems); + gloo::ReduceScatterHalvingDoubling algorithm( + context, std::vector{reinterpret_cast(buffer)}, + chunk_elems * context->size, recv_elems, reduction_function); + algorithm.run(); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo ReduceScatter failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + std::unique_ptr temp(new char[chunk_bytes * context_->size]); + std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); + switch (dtype) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case BF16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), count)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), count)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), count)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in reducescatter"); + } + std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); + return absl::OkStatus(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/gloo_communicator.h b/xla/backends/cpu/collectives/gloo_communicator.h new file mode 100644 index 00000000000000..234716da759340 --- /dev/null +++ b/xla/backends/cpu/collectives/gloo_communicator.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "gloo/context.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator implemented using Gloo communication library. +class GlooCommunicator : public Communicator { + public: + GlooCommunicator(std::shared_ptr context, size_t rank, + size_t num_ranks); + ~GlooCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("GlooCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + std::shared_ptr context_; + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index 29a58a216d7f8a..a9d3f9300d42ff 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -301,6 +301,8 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:gloo_communicator", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", diff --git a/xla/pjrt/cpu/gloo_collectives.cc b/xla/pjrt/cpu/gloo_collectives.cc index 0d479d7bfe2fd1..09451f220b97d4 100644 --- a/xla/pjrt/cpu/gloo_collectives.cc +++ b/xla/pjrt/cpu/gloo_collectives.cc @@ -15,13 +15,8 @@ limitations under the License. #include "xla/pjrt/cpu/gloo_collectives.h" -#include -#include -#include -#include #include #include -#include #include #include #include @@ -33,419 +28,19 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "gloo/algorithm.h" -#include "gloo/allgather.h" -#include "gloo/allreduce.h" #include "gloo/context.h" -#include "gloo/math.h" -#include "gloo/reduce_scatter.h" #include "gloo/rendezvous/context.h" #include "gloo/rendezvous/prefix_store.h" #include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" -#include "gloo/transport/unbound_buffer.h" -#include "gloo/types.h" -#include "xla/backends/cpu/collectives/cpu_collectives.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/primitive_util.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -GlooCollectivesCommunicator::GlooCollectivesCommunicator( - std::shared_ptr context, size_t rank, size_t num_ranks) - : context_(std::move(context)), rank_(rank), num_ranks_(num_ranks) {} -GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default; - -template -static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, - se::DeviceMemoryBase input_buffer, - se::DeviceMemoryBase output_buffer, - size_t num_elements, - gloo::AllreduceOptions& options) { - options.setInput( - reinterpret_cast(const_cast(input_buffer.opaque())), - num_elements); - options.setOutput( - reinterpret_cast(const_cast(output_buffer.opaque())), - num_elements); - - using ReductionFn = void (*)(void*, const void*, const void*, size_t); - - switch (reduction_kind) { - case ReductionKind::SUM: - options.setReduceFunction(static_cast(&gloo::sum)); - break; - case ReductionKind::PRODUCT: - options.setReduceFunction(static_cast(&gloo::product)); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - options.setReduceFunction(static_cast(&gloo::min)); - } else { - return absl::InvalidArgumentError( - "MIN reduction not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - options.setReduceFunction(static_cast(&gloo::max)); - } else { - return absl::InvalidArgumentError( - "MAX reduction not supported for complex types"); - } - break; - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - - gloo::AllreduceOptions options(context_); - // TODO(phawkins): how to do tags? - // options.setTag(tag); - switch (dtype) { - case S8: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case BF16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case C64: - TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case C128: - TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - default: - return absl::InvalidArgumentError("Unknown datatype in allreduce"); - } - options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); - options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); - - try { - gloo::allreduce(options); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo all-reduce failed: ", e.what())); - } - return absl::OkStatus(); -} - -static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; - -absl::Status GlooCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - uint32_t tag = 0; // TODO(phawkins): come up with better tags. - const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t num_bytes = count * primitive_util::ByteWidth(dtype); - - try { - std::unique_ptr in; - std::unique_ptr out; - for (RankId target : target_ranks) { - if (target != context_->rank) { - VLOG(1) << "send from " << context_->rank << " to " << target.value(); - if (!in) { - in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes); - } - in->send(target.value(), slot); - } - } - if (source_rank) { - if (*source_rank == context_->rank) { - std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); - } else { - VLOG(1) << "recv at " << context_->rank << " from " - << source_rank->value(); - out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes); - out->recv(source_rank->value(), slot); - } - } else { - std::memset(recv_buffer.opaque(), 0, num_bytes); - } - VLOG(1) << "wait for send at " << context_->rank; - auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); - if (in) { - in->waitSend(deadline); - } - VLOG(1) << "wait for recv at " << context_->rank; - if (out) { - out->waitRecv(deadline); - } - VLOG(1) << "done waiting at " << context_->rank; - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo collective permute failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - // We can't use Gloo's all-to-all implementation directly because it assumes - // that the inputs and outputs are contiguous. No big deal; it's just built - // on top of send/recv and we can do the same as it. - uint32_t tag = 0; // TODO(phawkins): use better tags. - int my_rank = context_->rank; - int world_size = context_->size; - - TF_RET_CHECK(world_size == send_buffers.size()); - TF_RET_CHECK(world_size == recv_buffers.size()); - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - try { - const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); - std::vector> ins( - context_->size); - std::vector> outs( - context_->size); - for (size_t i = 0; i < world_size; ++i) { - if (i != my_rank) { - ins[i] = context_->createUnboundBuffer( - const_cast(send_buffers[i].opaque()), chunk_bytes); - outs[i] = context_->createUnboundBuffer( - const_cast(recv_buffers[i].opaque()), chunk_bytes); - } - } - - for (int i = 1; i < world_size; i++) { - int send_rank = (my_rank + i) % world_size; - int recv_rank = (my_rank + world_size - i) % world_size; - ins[send_rank]->send(send_rank, slot); - outs[recv_rank]->recv(recv_rank, slot); - } - - std::memcpy(const_cast(recv_buffers[my_rank].opaque()), - send_buffers[my_rank].opaque(), chunk_bytes); - - auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); - for (int i = 0; i < world_size; i++) { - if (i != my_rank) { - ins[i]->waitSend(deadline); - outs[i]->waitRecv(deadline); - } - } - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo all-to-all failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - uint32_t tag = 0; // TODO(phawkins): use better tags. - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - gloo::AllgatherOptions options(context_); - options.setTag(tag); - options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); - options.setInput(reinterpret_cast(send_buffer.opaque()), chunk_bytes); - options.setOutput(reinterpret_cast(recv_buffer.opaque()), - chunk_bytes * context_->size); - - try { - gloo::allgather(options); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo AllGather failed: ", e.what())); - } - return absl::OkStatus(); -} - -template -absl::Status ReduceScatterHelper(std::shared_ptr context, - ReductionKind reduction_kind, void* buffer, - size_t chunk_elems) { - const gloo::ReductionFunction* reduction_function = nullptr; - if constexpr (is_complex_v) { - switch (reduction_kind) { - case ReductionKind::SUM: - reduction_function = gloo::ReductionFunction::sum; - break; - case ReductionKind::PRODUCT: - reduction_function = gloo::ReductionFunction::product; - break; - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported reduction kind: ", static_cast(reduction_kind))); - } - } else { - switch (reduction_kind) { - case ReductionKind::SUM: - reduction_function = gloo::ReductionFunction::sum; - break; - case ReductionKind::PRODUCT: - reduction_function = gloo::ReductionFunction::product; - break; - case ReductionKind::MAX: - reduction_function = gloo::ReductionFunction::max; - break; - case ReductionKind::MIN: - reduction_function = gloo::ReductionFunction::min; - break; - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported reduction kind: ", static_cast(reduction_kind))); - } - } - try { - std::vector recv_elems(context->size, chunk_elems); - gloo::ReduceScatterHalvingDoubling algorithm( - context, std::vector{reinterpret_cast(buffer)}, - chunk_elems * context->size, recv_elems, reduction_function); - algorithm.run(); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo ReduceScatter failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - std::unique_ptr temp(new char[chunk_bytes * context_->size]); - std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); - switch (dtype) { - case S8: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case BF16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case F16: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), count)); - break; - case F32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case F64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case C64: - TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), count)); - break; - case C128: - TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), count)); - break; - default: - return absl::InvalidArgumentError("Unknown datatype in reducescatter"); - } - std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); - return absl::OkStatus(); -} - GlooCollectives::GlooCollectives( std::unique_ptr store, std::shared_ptr device) @@ -486,7 +81,7 @@ absl::StatusOr> GlooCollectives::GetCommunicator( return absl::UnknownError( absl::StrCat("Gloo context initialization failed: ", e.what())); } - context->communicator = std::make_shared( + context->communicator = std::make_shared( std::move(gloo_context), rank, global_devices.size()); return context->communicator; } diff --git a/xla/pjrt/cpu/gloo_collectives.h b/xla/pjrt/cpu/gloo_collectives.h index 7bac8b7d662721..174cdb48accebf 100644 --- a/xla/pjrt/cpu/gloo_collectives.h +++ b/xla/pjrt/cpu/gloo_collectives.h @@ -16,92 +16,26 @@ limitations under the License. #ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ #define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ -#include #include -#include -#include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "absl/types/span.h" #include "gloo/context.h" #include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/service/collective_ops_utils.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class GlooCollectivesCommunicator : public Communicator { - public: - explicit GlooCollectivesCommunicator(std::shared_ptr context, - size_t rank, size_t num_ranks); - ~GlooCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Broadcast is not implemented"); - } - - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Send is not implemented"); - } - - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Recv is not implemented"); - } - - absl::StatusOr NumRanks() const override { return num_ranks_; } - - std::string ToString() const override { - return absl::StrCat("GlooCommunicator [rank: ", rank_, - " num_ranks: ", num_ranks_, "]"); - } - - private: - std::shared_ptr context_; - size_t rank_; - size_t num_ranks_; -}; - class GlooCollectives : public CollectivesInterface { public: GlooCollectives(std::unique_ptr store, @@ -113,13 +47,15 @@ class GlooCollectives : public CollectivesInterface { absl::Span devices, int rank) override; private: - std::unique_ptr store_; - std::shared_ptr device_; - absl::Mutex mu_; struct Context { absl::Mutex mu; - std::shared_ptr communicator; + std::shared_ptr communicator; }; + + std::unique_ptr store_; + std::shared_ptr device_; + + absl::Mutex mu_; absl::flat_hash_map, int>, std::unique_ptr> contexts_ ABSL_GUARDED_BY(mu_); From cd4f277d68de64e048f3b760eda4a2d3b47f1c93 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 7 Jan 2025 21:14:13 -0800 Subject: [PATCH 20/45] [xla:cpu] Move MpiCommunicator to backends/cpu/collectives PiperOrigin-RevId: 713144393 --- xla/backends/cpu/collectives/BUILD | 44 ++++ .../cpu/collectives/mpi_communicator.cc | 242 ++++++++++++++++++ .../cpu/collectives/mpi_communicator.h | 98 +++++++ xla/pjrt/cpu/BUILD | 34 ++- xla/pjrt/cpu/mpi_collectives.cc | 235 +---------------- xla/pjrt/cpu/mpi_collectives.h | 69 +---- xla/python/BUILD | 6 +- 7 files changed, 420 insertions(+), 308 deletions(-) create mode 100644 xla/backends/cpu/collectives/mpi_communicator.cc create mode 100644 xla/backends/cpu/collectives/mpi_communicator.h diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index 7be08f8866dc0c..be10b5cafa1250 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -103,3 +103,47 @@ cc_library( "@tsl//tsl/platform:errors", ], ) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "mpi_communicator", + srcs = ["mpi_communicator.cc"], + hdrs = ["mpi_communicator.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@mpitrampoline", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/backends/cpu/collectives/mpi_communicator.cc b/xla/backends/cpu/collectives/mpi_communicator.cc new file mode 100644 index 00000000000000..0062593da75407 --- /dev/null +++ b/xla/backends/cpu/collectives/mpi_communicator.cc @@ -0,0 +1,242 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/mpi_communicator.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +absl::StatusOr PrimitiveTypeToMpiType( + PrimitiveType element_type) { + switch (element_type) { + case S8: + return MPI_INT8_T; + case U8: + case PRED: + return MPI_UINT8_T; + case S16: + return MPI_INT16_T; + case U16: + return MPI_UINT16_T; + case S32: + return MPI_INT32_T; + case U32: + return MPI_UINT32_T; + case S64: + return MPI_INT64_T; + case U64: + return MPI_UINT64_T; + case F32: + return MPI_FLOAT; + case F64: + return MPI_DOUBLE; + case C64: + return MPI_C_COMPLEX; + case C128: + return MPI_C_DOUBLE_COMPLEX; + default: + // For implementing the reduction of unsupported types + // see e.g. https://stackoverflow.com/a/29643391 + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported primitive type for reduction: ", + primitive_util::LowercasePrimitiveTypeName(element_type))); + } +} + +bool MpiTypeIsComplex(MPI_Datatype type) { + return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; +} + +absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, + MPI_Datatype type) { + switch (reduction_kind) { + case ReductionKind::SUM: + return MPI_SUM; + case ReductionKind::PRODUCT: + return MPI_PROD; + case ReductionKind::MIN: + if (!MpiTypeIsComplex(type)) { + return MPI_MIN; + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + case ReductionKind::MAX: + if (!MpiTypeIsComplex(type)) { + return MPI_MAX; + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown reduction kind: ", reduction_kind)); + } +} + +static absl::Status MpiErrorToAbslStatus(int error) { + if (error != MPI_SUCCESS) { + char error_str[MPI_MAX_ERROR_STRING]; + int len; + MPI_Error_string(error, error_str, &len); + return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); + } + return absl::OkStatus(); +} + +MpiCommunicator::MpiCommunicator(int color, int key) { + MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); + MPI_Comm_rank(comm_, &mpi_rank_); + MPI_Comm_size(comm_, &mpi_size_); +} + +MpiCommunicator::~MpiCommunicator() { MPI_Comm_free(&comm_); }; + +absl::Status MpiCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Allreduce( + send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); +} + +absl::Status MpiCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + int tag = 0; // TODO come up with better tags. + + const int rank = mpi_rank_; + + std::vector requests; + + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + + if (source_rank) { + if (source_rank->value() == rank) { + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); + } else { + VLOG(1) << "recv at " << rank << " from " << source_rank->value(); + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE, + source_rank->value(), tag, comm_, &requests.back()))); + } + } else { + std::memset(recv_buffer.opaque(), 0, num_bytes); + } + + for (RankId target : target_ranks) { + if (target != rank) { + VLOG(1) << "send from " << rank << " to " << target.value(); + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(), + tag, comm_, &requests.back()))); + } + } + + for (auto& request : requests) { + TF_RETURN_IF_ERROR( + MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + // We can't use MPI_Alltoall directly because it assumes that the inputs and + // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. + + int tag = 0; // TODO use better tags. + const int rank = mpi_rank_; + const int size = mpi_size_; + TF_RET_CHECK(size == send_buffers.size()); + TF_RET_CHECK(size == recv_buffers.size()); + + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + std::vector input_buffers; + std::vector output_buffers; + + for (int i = 0; i < size; i++) { + input_buffers.push_back(const_cast(send_buffers[i].opaque())); + output_buffers.push_back(const_cast(recv_buffers[i].opaque())); + } + + std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); + + for (int i = 1; i < size; i++) { + int send_rank = (rank + i) % size; + int recv_rank = (rank + size - i) % size; + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, + tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, + recv_rank, tag, comm_, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, + recv_buffer.opaque(), count, type, + comm_)); +} + +absl::Status MpiCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + const int size = mpi_size_; + std::vector recvcounts(size, count); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus( + MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), + recvcounts.data(), type, op, comm_)); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/mpi_communicator.h b/xla/backends/cpu/collectives/mpi_communicator.h new file mode 100644 index 00000000000000..cfed534b66bd51 --- /dev/null +++ b/xla/backends/cpu/collectives/mpi_communicator.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCommunicator : public Communicator { + public: + explicit MpiCommunicator(int color, int key); + ~MpiCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return mpi_size_; } + + std::string ToString() const override { + return absl::StrCat("MpiCommunicator [rank: ", mpi_rank_, + " num_ranks: ", mpi_size_, "]"); + } + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index a9d3f9300d42ff..1ce663e34dc1d3 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -1,6 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/pjrt/cpu:package_groups.bzl", "xla_cpu_internal_packages") -load("//xla/tsl:tsl.bzl", "if_oss", "internal_visibility") +load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -364,34 +364,42 @@ xla_cc_test( cc_library( name = "mpi_collectives", - srcs = if_oss(["mpi_collectives.cc"]), - hdrs = if_oss(["mpi_collectives.h"]), + srcs = ["mpi_collectives.cc"], + hdrs = ["mpi_collectives.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end ], features = ["-use_header_modules"], visibility = [ "//xla/pjrt/cpu:legacy_cpu_internal_users", ], - deps = if_oss([ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", + deps = [ "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:mpi_communicator", + "//xla/core/collectives:communicator", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service/cpu:collectives_interface", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@mpitrampoline", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@mpitrampoline", - ]), + ], ) diff --git a/xla/pjrt/cpu/mpi_collectives.cc b/xla/pjrt/cpu/mpi_collectives.cc index 002f278c79bb63..88dc69a31917d6 100644 --- a/xla/pjrt/cpu/mpi_collectives.cc +++ b/xla/pjrt/cpu/mpi_collectives.cc @@ -15,242 +15,25 @@ limitations under the License. #include "xla/pjrt/cpu/mpi_collectives.h" -#include -#include -#include -#include -#include #include -#include -#include #include -#include #include -#include "mpi.h" // NOLINT +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/primitive_util.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "mpi.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" namespace xla::cpu { -absl::StatusOr PrimitiveTypeToMpiType( - PrimitiveType element_type) { - switch (element_type) { - case S8: - return MPI_INT8_T; - case U8: - case PRED: - return MPI_UINT8_T; - case S16: - return MPI_INT16_T; - case U16: - return MPI_UINT16_T; - case S32: - return MPI_INT32_T; - case U32: - return MPI_UINT32_T; - case S64: - return MPI_INT64_T; - case U64: - return MPI_UINT64_T; - case F32: - return MPI_FLOAT; - case F64: - return MPI_DOUBLE; - case C64: - return MPI_C_COMPLEX; - case C128: - return MPI_C_DOUBLE_COMPLEX; - default: - // For implementing the reduction of unsupported types - // see e.g. https://stackoverflow.com/a/29643391 - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported primitive type for reduction: ", - primitive_util::LowercasePrimitiveTypeName(element_type))); - } -} - -bool MpiTypeIsComplex(MPI_Datatype type) { - return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; -} - -absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, - MPI_Datatype type) { - switch (reduction_kind) { - case ReductionKind::SUM: - return MPI_SUM; - case ReductionKind::PRODUCT: - return MPI_PROD; - case ReductionKind::MIN: - if (!MpiTypeIsComplex(type)) { - return MPI_MIN; - } else { - return absl::InvalidArgumentError( - "MIN reduction not supported for complex types"); - } - case ReductionKind::MAX: - if (!MpiTypeIsComplex(type)) { - return MPI_MAX; - } else { - return absl::InvalidArgumentError( - "MAX reduction not supported for complex types"); - } - default: - return absl::InvalidArgumentError( - absl::StrCat("Unknown reduction kind: ", reduction_kind)); - } -} - -static absl::Status MpiErrorToAbslStatus(int error) { - if (error != MPI_SUCCESS) { - char error_str[MPI_MAX_ERROR_STRING]; - int len; - MPI_Error_string(error, error_str, &len); - return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); - } - return absl::OkStatus(); -} - -MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { - MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); - MPI_Comm_rank(comm_, &mpi_rank_); - MPI_Comm_size(comm_, &mpi_size_); -} - -MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { - MPI_Comm_free(&comm_); -}; - -absl::Status MpiCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus(MPI_Allreduce( - send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); -} - -absl::Status MpiCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - int tag = 0; // TODO come up with better tags. - - const int rank = mpi_rank_; - - std::vector requests; - - size_t num_bytes = count * primitive_util::ByteWidth(dtype); - - if (source_rank) { - if (source_rank->value() == rank) { - std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); - } else { - VLOG(1) << "recv at " << rank << " from " << source_rank->value(); - requests.emplace_back(); - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE, - source_rank->value(), tag, comm_, &requests.back()))); - } - } else { - std::memset(recv_buffer.opaque(), 0, num_bytes); - } - - for (RankId target : target_ranks) { - if (target != rank) { - VLOG(1) << "send from " << rank << " to " << target.value(); - requests.emplace_back(); - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(), - tag, comm_, &requests.back()))); - } - } - - for (auto& request : requests) { - TF_RETURN_IF_ERROR( - MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); - } - - return absl::OkStatus(); -} - -absl::Status MpiCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - // We can't use MPI_Alltoall directly because it assumes that the inputs and - // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. - - int tag = 0; // TODO use better tags. - const int rank = mpi_rank_; - const int size = mpi_size_; - TF_RET_CHECK(size == send_buffers.size()); - TF_RET_CHECK(size == recv_buffers.size()); - - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - std::vector input_buffers; - std::vector output_buffers; - - for (int i = 0; i < size; i++) { - input_buffers.push_back(const_cast(send_buffers[i].opaque())); - output_buffers.push_back(const_cast(recv_buffers[i].opaque())); - } - - std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); - - for (int i = 1; i < size; i++) { - int send_rank = (rank + i) % size; - int recv_rank = (rank + size - i) % size; - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, - tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, - recv_rank, tag, comm_, MPI_STATUS_IGNORE))); - } - - return absl::OkStatus(); -} - -absl::Status MpiCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, - recv_buffer.opaque(), count, type, - comm_)); -} - -absl::Status MpiCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - const int size = mpi_size_; - std::vector recvcounts(size, count); - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus( - MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), - recvcounts.data(), type, op, comm_)); -} - void MpiCollectives::Init() { int provided; - MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_FUNNELED, &provided); MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; @@ -267,9 +50,9 @@ absl::StatusOr> MpiCollectives::GetCommunicator( MPI_Is_thread_main(&flag); if (!flag) { return absl::UnknownError( - absl::StrCat("MPI: Communicator requested from a thread that is not " - "the one MPI was initialized from. Multiple " - "threads/devices per process are not yet supported.")); + "MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported."); } auto& context = contexts_[std::make_tuple( @@ -287,7 +70,7 @@ absl::StatusOr> MpiCollectives::GetCommunicator( } else { color = MPI_UNDEFINED; } - context = std::make_shared(color, key); + context = std::make_shared(color, key); return context; } diff --git a/xla/pjrt/cpu/mpi_collectives.h b/xla/pjrt/cpu/mpi_collectives.h index f24537b52d4c51..5db5f13f410bdf 100644 --- a/xla/pjrt/cpu/mpi_collectives.h +++ b/xla/pjrt/cpu/mpi_collectives.h @@ -16,85 +16,22 @@ limitations under the License. #ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ #define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ -#include #include -#include -#include #include #include -#include "mpi.h" // NOLINT -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/service/collective_ops_utils.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class MpiCollectivesCommunicator : public Communicator { - public: - explicit MpiCollectivesCommunicator(int color, int key); - ~MpiCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, - PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Broadcast is not implemented"); - } - - absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Send is not implemented"); - } - - absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, - const Executor&) override { - return Unimplemented("Recv is not implemented"); - } - - absl::StatusOr NumRanks() const override { return mpi_size_; } - - std::string ToString() const override { - return absl::StrCat("MpiCommunicator [rank: ", mpi_rank_, - " num_ranks: ", mpi_size_, "]"); - } - - private: - MPI_Comm comm_; - int mpi_rank_; - int mpi_size_; -}; - class MpiCollectives : public CollectivesInterface { public: /* @@ -119,7 +56,7 @@ class MpiCollectives : public CollectivesInterface { int mpi_world_rank_; int mpi_world_size_; absl::flat_hash_map, int>, - std::shared_ptr> + std::shared_ptr> contexts_; }; diff --git a/xla/python/BUILD b/xla/python/BUILD index dd7b53028fa6e5..2b322670477190 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -12,6 +12,7 @@ load( "//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google", + "if_oss", "internal_visibility", ) load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension") @@ -1357,9 +1358,8 @@ tsl_pybind_extension( }) + select({ # mpitrampoline does not build on windows "//xla/tsl:windows": [], - "//conditions:default": [ - "//xla/pjrt/cpu:mpi_collectives", - ], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["//xla/pjrt/cpu:mpi_collectives"]), }), ) From a1991c983a2b37c35fcf9d29bc0231dd8a59daac Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 8 Jan 2025 00:00:55 -0800 Subject: [PATCH 21/45] Remove unused alias rules The last internal users have been migrated. PiperOrigin-RevId: 713178119 --- xla/hlo/transforms/BUILD | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 774be3834cc413..2c8f45317a59ec 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -1107,13 +1107,3 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) - -alias( - name = "hlo_dce", - actual = "//xla/hlo/transforms/simplifiers:hlo_dce", -) - -alias( - name = "dynamic_dimension_simplifier", - actual = "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", -) From 4bde0bab2a277f84fe0f6a9612de1f65f7323135 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 8 Jan 2025 00:10:03 -0800 Subject: [PATCH 22/45] Load all available dialects in `xla::ifrt::support::RegisterMlirDialects` This avoids lazily loading dialects in a potentially multi-threaded context, which results in the following crash: `LLVM ERROR: Loading a dialect (chlo) while in a multi-threaded execution context (maybe the PassManager): this can indicate a missing `dependentDialects` in a pass for example.`. PiperOrigin-RevId: 713180730 --- xla/python/ifrt/support/module_parsing.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/python/ifrt/support/module_parsing.cc b/xla/python/ifrt/support/module_parsing.cc index b1740cd5cf0ca9..8d6efaf1a4a560 100644 --- a/xla/python/ifrt/support/module_parsing.cc +++ b/xla/python/ifrt/support/module_parsing.cc @@ -52,6 +52,7 @@ void RegisterMlirDialects(mlir::MLIRContext& context) { mlir::DialectRegistry registry; InitializeMlirDialectRegistry(registry); context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); } absl::StatusOr> ParseMlirModuleString( From 74e358f912971aaa2a80e7f4c677c2a2cc2ba88f Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 8 Jan 2025 00:42:02 -0800 Subject: [PATCH 23/45] NFC: Improve comments for IndexingMap members. Also change GetDimVars to GetDimVar for naming consistency. PiperOrigin-RevId: 713188483 --- xla/backends/gpu/codegen/ir/xla_gpu_ops.cc | 2 +- xla/codegen/ir/xla_ops.cc | 4 ++-- xla/hlo/analysis/indexing_map.h | 14 +++++++------- xla/service/gpu/model/coalescing_analysis.cc | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc b/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc index 79efa4e752e9fe..846925a925ce12 100644 --- a/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc +++ b/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc @@ -114,7 +114,7 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "must have thread_id dimension in both indexing maps"; } - if (map_in.GetDimVars(0).bounds != map_out.GetDimVars(0).bounds) { + if (map_in.GetDimVar(0).bounds != map_out.GetDimVar(0).bounds) { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } diff --git a/xla/codegen/ir/xla_ops.cc b/xla/codegen/ir/xla_ops.cc index 1f48f5bdd5c9c2..1d72b0264b66f9 100644 --- a/xla/codegen/ir/xla_ops.cc +++ b/xla/codegen/ir/xla_ops.cc @@ -323,7 +323,7 @@ absl::StatusOr GetNewIndexingMapAfterFoldingSequence( replacement_expr = getAffineDimExpr(num_dims + added_dim_args.size(), ctx); added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); + new_dim_vars.push_back(producer_map.GetDimVar(dim_num)); } producer_dim_replacements.push_back(replacement_expr); } @@ -529,7 +529,7 @@ struct FoldApplyIndexingOperands } else { new_operands.push_back(operand.get()); dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); - new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); + new_dim_vars.push_back(indexing_map.GetDimVar(operand_id)); } } rewriter.replaceOpWithNewOp( diff --git a/xla/hlo/analysis/indexing_map.h b/xla/hlo/analysis/indexing_map.h index 17038aa05f73e0..77ea7ec24f3be4 100644 --- a/xla/hlo/analysis/indexing_map.h +++ b/xla/hlo/analysis/indexing_map.h @@ -286,7 +286,7 @@ class IndexingMap { RangeEvaluator GetRangeEvaluator() const; // Getters for dimension vars. - const Variable& GetDimVars(int64_t id) const { return dim_vars_[id]; } + const Variable& GetDimVar(int64_t id) const { return dim_vars_[id]; } const std::vector& GetDimVars() const { return dim_vars_; } int64_t GetDimVarsCount() const { return dim_vars_.size(); } @@ -407,18 +407,18 @@ class IndexingMap { mlir::AffineMap affine_map_; - // Dimension variable represents a dimension of a tensor or a GPU grid. - // Dimensions correspond to the dimension parameter of `affine_map_`. + // A dimension variable represents a dimension of a tensor or a GPU grid. + // Dimension variables correspond to the dimensions of the `affine_map_`. std::vector dim_vars_; - // RangeSymbol variable represents a range of values, e.g. to compute a single + // A range variable represents a range of values, e.g. to compute a single // element of the reduction's result we need a range of values from the input - // tensor. RangeSymbol variables correspond to the front portion of the + // tensor. Range variables correspond to the front portion of the // symbols in `affine_map_`. std::vector range_vars_; - // RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in - // HLO dynamic-update-slice op. RTSymbol variables correspond to the back + // A runtime variable represents a runtime symbol, e.g. a dynamic offset in of + // a HLO dynamic-update-slice op. Runtime variables correspond to the back // portion of the symbols in `affine_map_`. std::vector rt_vars_; diff --git a/xla/service/gpu/model/coalescing_analysis.cc b/xla/service/gpu/model/coalescing_analysis.cc index a2ceba1f01a29d..a583c692c2d8b5 100644 --- a/xla/service/gpu/model/coalescing_analysis.cc +++ b/xla/service/gpu/model/coalescing_analysis.cc @@ -548,7 +548,7 @@ std::vector FindContiguousIntervals( } // Case 2: f(thread_x) != thread_x * multiplier. auto intervals = FindIntervals(partitioned_expr.func_of_d0, - {indexing_map.GetDimVars(0).bounds}); + {indexing_map.GetDimVar(0).bounds}); // Case 2.1: g(s) != s. if (partitioned_expr.func_of_s0 != range) { return intervals; From 5a95b696dbefdb919377fad0a122b2da95c30eda Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 8 Jan 2025 01:28:55 -0800 Subject: [PATCH 24/45] [XLA:GPU] Inline a call to `ScheduleGpuModuleWithMemoryScheduler`. PiperOrigin-RevId: 713199894 --- xla/service/gpu/BUILD | 10 +--- xla/service/gpu/all_gather_combiner.cc | 4 +- xla/service/gpu/all_reduce_combiner.cc | 4 +- .../gpu/gpu_collective_combiner_utils.cc | 11 ++-- .../gpu/gpu_collective_combiner_utils.h | 5 -- .../gpu/gpu_collective_combiner_utils_test.cc | 52 +------------------ xla/service/gpu/reduce_scatter_combiner.cc | 4 +- 7 files changed, 10 insertions(+), 80 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 9d8d7f90622176..e0c2a1472cf52c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3071,13 +3071,13 @@ cc_library( hdrs = ["gpu_collective_combiner_utils.h"], deps = [ ":backend_configs_cc", + ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:collective_utils", "//xla/stream_executor:device_description", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:statusor", ], ) @@ -3088,7 +3088,6 @@ xla_cc_test( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -3096,16 +3095,12 @@ xla_cc_test( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/service:collective_pipeliner", - "//xla/service:collective_utils", "//xla/service:hlo_module_config", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) @@ -3116,7 +3111,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_gather_combiner", @@ -3155,7 +3149,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_domain_map", @@ -3192,7 +3185,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_reduce_combiner", diff --git a/xla/service/gpu/all_gather_combiner.cc b/xla/service/gpu/all_gather_combiner.cc index 996d3a1fe83bed..96f10d43113b5c 100644 --- a/xla/service/gpu/all_gather_combiner.cc +++ b/xla/service/gpu/all_gather_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -78,8 +77,7 @@ absl::StatusOr GpuAllGatherCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllGather, pointer_size_); + *module, device_info_, HloOpcode::kAllGather, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/xla/service/gpu/all_reduce_combiner.cc b/xla/service/gpu/all_reduce_combiner.cc index 108d10cee3e5d3..5fb3d960bb2371 100644 --- a/xla/service/gpu/all_reduce_combiner.cc +++ b/xla/service/gpu/all_reduce_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuAllReduceCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size_); + *module, device_info_, HloOpcode::kAllReduce, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/xla/service/gpu/gpu_collective_combiner_utils.cc b/xla/service/gpu/gpu_collective_combiner_utils.cc index d789b652df6d4a..43a99ea4fe612b 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils.cc @@ -25,14 +25,11 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - -using MemoryAwareScheduler = std::function( - const HloModule*, int64_t, int64_t*)>; - namespace { int64_t GetDefaultValue(HloOpcode opcode) { @@ -52,13 +49,13 @@ int64_t GetDefaultValue(HloOpcode opcode) { int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - MemoryAwareScheduler scheduler, HloOpcode collective_opcode, - int64_t pointer_size) { + HloOpcode collective_opcode, int64_t pointer_size) { int64_t base_limit = module.config().device_memory_size() != 0 ? module.config().device_memory_size() : device_info.device_memory_size(); int64_t peak_memory_bytes = -1; - auto mem_schedule = scheduler(&module, pointer_size, &peak_memory_bytes); + auto mem_schedule = ScheduleGpuModuleWithMemoryScheduler( + &module, pointer_size, &peak_memory_bytes); if (!mem_schedule.ok() || peak_memory_bytes == -1) { VLOG(1) << "Cannot schedule module: " << mem_schedule.status().message(); diff --git a/xla/service/gpu/gpu_collective_combiner_utils.h b/xla/service/gpu/gpu_collective_combiner_utils.h index 38a7890decb59b..d78abf552eeb33 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.h +++ b/xla/service/gpu/gpu_collective_combiner_utils.h @@ -17,10 +17,8 @@ limitations under the License. #define XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ #include -#include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -36,9 +34,6 @@ namespace xla::gpu { // `collective_opcode`. int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - std::function(const HloModule*, int64_t, - int64_t*)> - scheduler, HloOpcode collective_opcode, int64_t pointer_size); // Adds information that `instr` has been pipelined to the diff --git a/xla/service/gpu/gpu_collective_combiner_utils_test.cc b/xla/service/gpu/gpu_collective_combiner_utils_test.cc index f0b213f343e587..9d7a9596641618 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils_test.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils_test.cc @@ -19,27 +19,20 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla::gpu { namespace { @@ -65,8 +58,7 @@ TEST_F(CollectiveCombinerUtilsTest, device_info.set_device_memory_size(20000); int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -96,8 +88,7 @@ TEST_F(CollectiveCombinerUtilsTest, stream_executor::DeviceDescription device_info; int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -106,45 +97,6 @@ TEST_F(CollectiveCombinerUtilsTest, EXPECT_EQ(suggested_threshold, 6712); } -TEST_F( - CollectiveCombinerUtilsTest, - ComputeSuggestedCombinerThresholdReturnsDefaultValueUponSchedulingFailure) { - absl::string_view kHloText = R"( - HloModule m - - ENTRY ar { - p0 = f32[32,32] parameter(0) - p1 = f32[32,32] parameter(1) - - ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), - custom_call_target="__cublas$gemm" - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - int pointer_size = 4; - stream_executor::DeviceDescription device_info; - device_info.set_device_memory_size(20000); - - auto sched_fun = [](const HloModule* m, int64_t p_sz, - int64_t* p) -> absl::StatusOr { - return absl::UnimplementedError("Fail."); - }; - - int64_t suggested_threshold_all_reduce = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllReduce, pointer_size); - int64_t suggested_threshold_all_gather = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllGather, pointer_size); - int64_t suggested_threshold_reduce_scatter = - ComputeSuggestedCombinerThreshold(*module, device_info, sched_fun, - HloOpcode::kReduceScatter, - pointer_size); - - EXPECT_EQ(suggested_threshold_all_reduce, kDefaultAllReduceCombineThreshold); - EXPECT_EQ(suggested_threshold_all_gather, kDefaultAllGatherCombineThreshold); - EXPECT_EQ(suggested_threshold_reduce_scatter, - kDefaultReduceScatterCombineThreshold); -} - TEST_F(CollectiveCombinerUtilsTest, AppendPipelinedInstructionAppendsPipelinedInstructionInfoForward) { // This is just a canonical IR which makes it easy to pipeline a collective diff --git a/xla/service/gpu/reduce_scatter_combiner.cc b/xla/service/gpu/reduce_scatter_combiner.cc index 6b07a79cd4ecd8..2d9813dda1e6a0 100644 --- a/xla/service/gpu/reduce_scatter_combiner.cc +++ b/xla/service/gpu/reduce_scatter_combiner.cc @@ -26,7 +26,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "xla/service/reduce_scatter_combiner.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuReduceScatterCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kReduceScatter, pointer_size_); + *module, device_info_, HloOpcode::kReduceScatter, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); From 6e16da5b7ceddef739ff924d69780507637f75cd Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 8 Jan 2025 01:38:09 -0800 Subject: [PATCH 25/45] Update to match upstream API change (NFC). This method was renamed but staging function kept, switch to renamed variant. PiperOrigin-RevId: 713202168 --- xla/mlir_hlo/transforms/detensorize_scf_ops.cc | 2 +- xla/mlir_hlo/transforms/generic_host_to_llvm.cc | 2 +- xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc | 2 +- xla/mlir_hlo/transforms/lower_index_cast_pass.cc | 3 +-- xla/mlir_hlo/transforms/naive_copy_removal.cc | 2 +- xla/mlir_hlo/transforms/tile_loops_pass.cc | 2 +- xla/mlir_hlo/transforms/vectorize_copy.cc | 2 +- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc index 2a8be4e6b09ae0..12d8b3814646e7 100644 --- a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc +++ b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc @@ -120,7 +120,7 @@ struct DetensorizeScfOpsPass patterns.add, RegionOpPattern, RegionOpPattern>(&getContext()); - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + if (failed(applyPatternsGreedily(f, std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index 9df69afbaf55aa..8cd4bf99f5133d 100644 --- a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -86,7 +86,7 @@ class GenericHostToLLVMPass // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } LLVMConversionTarget target(*ctx); diff --git a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc index 3e22aa55888327..d490588de4508b 100644 --- a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc +++ b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc @@ -96,7 +96,7 @@ void GpuKernelToNVVMPass::runOnOperation() { { RewritePatternSet patterns(&getContext()); populateAllCommonVectorProgressiveLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } RewritePatternSet patterns(&getContext()); diff --git a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc index 489d8fb4cb811e..b773792e67b5c4 100644 --- a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc +++ b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc @@ -64,8 +64,7 @@ struct LowerIndexCastPass patterns.add, IndexCastConverter>( patterns.getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/transforms/naive_copy_removal.cc b/xla/mlir_hlo/transforms/naive_copy_removal.cc index 55ab2fbb2e0ee5..a13f0396a85e63 100644 --- a/xla/mlir_hlo/transforms/naive_copy_removal.cc +++ b/xla/mlir_hlo/transforms/naive_copy_removal.cc @@ -80,7 +80,7 @@ struct NaiveCopyRemovalPass RewritePatternSet patterns(ctx); patterns.add(removeCopy); memref::AllocOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/transforms/tile_loops_pass.cc b/xla/mlir_hlo/transforms/tile_loops_pass.cc index ee3b935cff2771..d6efd72d2437c0 100644 --- a/xla/mlir_hlo/transforms/tile_loops_pass.cc +++ b/xla/mlir_hlo/transforms/tile_loops_pass.cc @@ -127,7 +127,7 @@ void TileLoopsPass::runOnOperation() { getContext() .getOrLoadDialect() ->getCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/xla/mlir_hlo/transforms/vectorize_copy.cc b/xla/mlir_hlo/transforms/vectorize_copy.cc index 1b68cd8b28b74e..5650e83be0c2d4 100644 --- a/xla/mlir_hlo/transforms/vectorize_copy.cc +++ b/xla/mlir_hlo/transforms/vectorize_copy.cc @@ -215,7 +215,7 @@ struct VectorizeCopyPass RewritePatternSet patterns(ctx); patterns.add( ctx, /*numElementsThreshold = */ 8); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { return signalPassFailure(); } } From f3ad216b2bae484e4a8d62fa23f22b6faf3b1a8a Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Wed, 8 Jan 2025 02:50:43 -0800 Subject: [PATCH 26/45] PR #20557: [ds-fusion] Add HandleReducePrecision to algebraic simplifier Imported from GitHub PR https://github.com/openxla/xla/pull/20557 When the mantissa and exponent of the reduce-precision instruction are the same as the mantissa and exponent of the primitive type of the operand, then the reduce-precision operation is a no-op. Copybara import of the project: -- 8b9852bb24ea6dbbc2a6d6dd6cf68c41efde8b30 by Shraiysh Vaishay : Add HandleReducePrecision to algebraic simplifier When the mantissa and exponent of the reduce-precision instruction are the same as the mantissa and exponent of the primitive type of the operant, then the reduce-precision operation is a no-op. -- f54f2d35f2d85913e3d5febdbb12c38468d4e1ea by Shraiysh Vaishay : Addressed comments -- 39c4be640db7a3b8a60483cea7f8f47154c1e691 by Shraiysh Vaishay : Move the pass after the last pass that causes precision changes The last pass to cause precision changes is SimplifyFPConversions. Moved the handling of reduce-precision after that. -- f82bc5c034922ba39c301ee0e173f86917d08da4 by Shraiysh Vaishay : addressed comments -- 34ee3317c45d48fc3904d11db2ad296e90b6f51a by Shraiysh Vaishay : Handle clang-format failure. Merging this change closes #20557 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20557 from shraiysh:handle_reduce_precision 34ee3317c45d48fc3904d11db2ad296e90b6f51a PiperOrigin-RevId: 713220616 --- .../simplifiers/algebraic_simplifier.cc | 23 ++++++++++++-- .../simplifiers/algebraic_simplifier.h | 13 ++++++++ .../simplifiers/algebraic_simplifier_test.cc | 31 +++++++++++++++++++ xla/service/gpu/gpu_compiler.cc | 16 ++++++++++ xla/service/gpu/gpu_compiler_test.cc | 13 ++++++++ 5 files changed, 94 insertions(+), 2 deletions(-) diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 269284b021d5de..4b96bf2a81d502 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -5939,8 +5939,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( new_operands.push_back(operand); } } - VLOG(4) << "Sinking broadcast after user:" << "\n old broadcast: " - << broadcast->ToString() << "\n old user: " << user->ToString(); + VLOG(4) << "Sinking broadcast after user:" + << "\n old broadcast: " << broadcast->ToString() + << "\n old user: " << user->ToString(); changed_shape = ShapeUtil::ChangeElementType(operand->shape(), user->shape().element_type()); simplifier_->UpdateLayout(&changed_shape); @@ -8233,6 +8234,24 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status AlgebraicSimplifierVisitor::HandleReducePrecision( + HloInstruction* hlo) { + HloReducePrecisionInstruction* reduce_precision = + Cast(hlo); + PrimitiveType element_type = + reduce_precision->operand(0)->shape().element_type(); + if (options_.enable_remove_no_op_reduce_precision() && + reduce_precision->exponent_bits() == + primitive_util::ExponentWidth(element_type) && + reduce_precision->mantissa_bits() + 1 == + primitive_util::SignificandWidth(element_type)) { + return ReplaceInstruction( + /*old_instruction=*/hlo, + /*new_instruction=*/reduce_precision->mutable_operand(0)); + } + return absl::OkStatus(); +} + absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* hlo) { auto* reduce_window = Cast(hlo); diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/xla/hlo/transforms/simplifiers/algebraic_simplifier.h index 96c50ba251a949..f3ded542605dbf 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier.h +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -322,6 +322,16 @@ class AlgebraicSimplifierOptions { return enable_broadcast_degenerate_dimension_; } + void set_enable_remove_no_op_reduce_precision( + bool enable_remove_no_op_reduce_precision) { + enable_remove_no_op_reduce_precision_ = + enable_remove_no_op_reduce_precision; + } + + bool enable_remove_no_op_reduce_precision() const { + return enable_remove_no_op_reduce_precision_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -364,6 +374,7 @@ class AlgebraicSimplifierOptions { bool disable_dynamic_slice_to_slice_conversion_{false}; bool enable_fast_math_{false}; bool enable_broadcast_degenerate_dimension_{true}; + bool enable_remove_no_op_reduce_precision_{false}; Metadata metadata_; }; @@ -484,6 +495,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { absl::Status HandleReduce(HloInstruction* hlo) override; + absl::Status HandleReducePrecision(HloInstruction* hlo) override; + absl::Status HandleReduceWindow(HloInstruction* hlo) override; absl::Status HandleReverse(HloInstruction* reverse) override; diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 5b0519107ad653..e30822e37f578d 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -12688,5 +12688,36 @@ TEST_F(AlgebraicSimplifierTest, TestNew123) { EXPECT_FALSE(simplifier.Run(module.get()).value()); } +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithSamePrecisionAsOperandIsRemovedIfRemoveNoOpIsSet) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=8, mantissa_bits=7 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter())); +} + +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithDifferentPrecisionFromOperandIsNotModifiedByDefault) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=7, mantissa_bits=8 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + } // namespace } // namespace xla diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 705c0eb327e1be..5c6a5ab6ca172e 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1676,6 +1676,22 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); } + { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. We are creating a sub-pipeline here because that + // allows us to test this order in a unit test. + HloPassPipeline& remove_no_op_reduce_precision_pipeline = + pipeline.AddPass( + "remove-no-op-reduce-precision-algebraic-simplifier"); + AlgebraicSimplifierOptions simplifier_options_{simplifier_options}; + simplifier_options_.set_enable_remove_no_op_reduce_precision(true); + remove_no_op_reduce_precision_pipeline + .AddPass>(simplifier_options_, + gpu_version); + } + pipeline.AddPass(/*is_layout_sensitive=*/true); pipeline.AddPass( diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index d60a7f5daedcb8..26e8899aa65609 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1554,6 +1554,19 @@ TEST_F(PassOrderTest, GemmRewriterRunsAfterDotNormalizer) { VerifyNotRunInBetween(pass_range, /*pass_regex=*/"algsimp"); } +TEST_F(PassOrderTest, + ReducePrecisionIsRemovedAfterAllCallsToSimplifyFPConversions) { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. No-op reduce-precisions are removed within + // algebraic simplifier, if the option to remove them is set. In the compiler + // pipeline, this is done as a subpipeline, which should be after the last + // invocation of SimplifyFPConversions. + VerifyPassOrder("simplify-fp-conversions", + "remove-no-op-reduce-precision-algebraic-simplifier"); +} + } // namespace } // namespace gpu } // namespace xla From 3c1e1f6702545dcb5e04b4d0daa4811ba4d66ed7 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Wed, 8 Jan 2025 02:53:21 -0800 Subject: [PATCH 27/45] [NFC] Polish ScalarOrTensor a little. - Remove std::variant, MLIR's run-time type information already provides the same. - Change `ScalarOrTensor::UnwrapTensor` to return `TypedValue`. - Use `getType()` instead of `Type()` to align the naming. PiperOrigin-RevId: 713221237 --- .../gpu/fusions/triton/emitter_helpers.cc | 13 +++----- .../gpu/fusions/triton/emitter_helpers.h | 33 +++++++------------ .../fusions/triton/triton_fusion_emitter.cc | 8 ++--- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/xla/service/gpu/fusions/triton/emitter_helpers.cc b/xla/service/gpu/fusions/triton/emitter_helpers.cc index 60f4132b9e7f1b..7f3b990219c231 100644 --- a/xla/service/gpu/fusions/triton/emitter_helpers.cc +++ b/xla/service/gpu/fusions/triton/emitter_helpers.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -64,13 +65,9 @@ namespace mh = ::mlir::mhlo; namespace mm = ::mlir::math; namespace mt = ::mlir::triton; -ScalarOrTensor::ScalarOrTensor(mlir::Value value) { - if (auto tt = mlir::dyn_cast(value.getType())) { - CHECK_GT(tt.getRank(), 0); - value_ = TensorValue{value}; - } else { - value_ = ScalarValue{value}; - } +ScalarOrTensor::ScalarOrTensor(mlir::Value value) : value_(value) { + CHECK(IsScalar() || UnwrapTensor().getType().getRank() > 0) + << "0D tensors are not supported by Triton"; } SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { @@ -313,7 +310,7 @@ Value Minimum(EmitterLocOpBuilder& b, const se::DeviceDescription& device_info, ScalarOrTensor Splat(EmitterLocOpBuilder& b, ScalarOrTensor value, ArrayRef shape) { CHECK(!shape.empty()); - auto type = mlir::RankedTensorType::get(shape, value.Type()); + auto type = mlir::RankedTensorType::get(shape, value.getType()); return ScalarOrTensor(b.create(type, value.UnwrapUnsafe())); } diff --git a/xla/service/gpu/fusions/triton/emitter_helpers.h b/xla/service/gpu/fusions/triton/emitter_helpers.h index fe283bada6f5ed..7e20b6b3f6157f 100644 --- a/xla/service/gpu/fusions/triton/emitter_helpers.h +++ b/xla/service/gpu/fusions/triton/emitter_helpers.h @@ -48,6 +48,8 @@ namespace xla::gpu::triton { // non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail // because 0D tensors are not supported by Triton. class ScalarOrTensor { + using TensorValue = mlir::TypedValue; + public: ScalarOrTensor() = default; @@ -55,17 +57,17 @@ class ScalarOrTensor { // value is a 0D tensor, because Triton does not support 0D tensors. explicit ScalarOrTensor(mlir::Value value); - bool IsScalar() const { return std::holds_alternative(value_); } - bool IsTensor() const { return std::holds_alternative(value_); } + bool IsScalar() const { return !IsTensor(); } + bool IsTensor() const { return mlir::isa(value_); } - mlir::Value UnwrapScalar() { + mlir::Value UnwrapScalar() const { CHECK(IsScalar()); - return std::get(value_).scalar_value; + return value_; } - mlir::Value UnwrapTensor() { + TensorValue UnwrapTensor() const { CHECK(IsTensor()); - return std::get(value_).tensor_value; + return mlir::cast(value_); } // Returns the underlying value regardless of whether it is a scalar or a @@ -73,25 +75,12 @@ class ScalarOrTensor { // both needs to use an `mlir::Value` and functions identically for scalars // and tensors. In other cases, prefer to use the `UnwrapScalar` or // `UnwrapTensor` methods. - mlir::Value UnwrapUnsafe() { - if (auto* scalar = std::get_if(&value_)) { - return scalar->scalar_value; - } - return std::get(value_).tensor_value; - } + mlir::Value UnwrapUnsafe() const { return value_; } - mlir::Type Type() { return UnwrapUnsafe().getType(); } + mlir::Type getType() const { return value_.getType(); } private: - struct ScalarValue { - mlir::Value scalar_value; - }; - - struct TensorValue { - mlir::Value tensor_value; - }; - - std::variant value_; + mlir::Value value_; }; // Triton requires that all block dimensions are a power of 2. diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 46655c5be86229..d0afa63f721773 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -218,7 +218,7 @@ absl::StatusOr EmitReduce( *::xla::Cast(tiled_hlo_reduce.hlo()); ScalarOrTensor input = values[tiled_hlo_reduce.operand(0)]; llvm::ArrayRef input_shape = - mlir::cast(input.Type()).getShape(); + mlir::cast(input.getType()).getShape(); absl::Span source_tensor_shape = hlo_reduce.operand(0)->shape().dimensions(); @@ -511,7 +511,7 @@ absl::StatusOr EmitTiledReshape(EmitterLocOpBuilder& b, // At this point we know that the input is a non-0D tensor. - auto input_shaped_type = mlir::cast(input.Type()); + auto input_shaped_type = mlir::cast(input.getType()); // Handle the case of reshaping [1,1,1...] to a scalar. if (tile_sizes.empty()) { @@ -621,7 +621,7 @@ absl::StatusOr EmitTiledHloInstruction( // as i8. It's important to type checking that we perform a conversion after // loading if the type of the loaded parameter does not match what is // expected. - Type loaded_element_type = getElementTypeOrSelf(parameter.Type()); + Type loaded_element_type = getElementTypeOrSelf(parameter.getType()); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, hlo->shape().element_type())); @@ -976,7 +976,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, // as i8. It's important to type checking that we perform a conversion before // storing if the type of the result does not match the type of the output // pointer. - Type result_element_type = getElementTypeOrSelf(result.Type()); + Type result_element_type = getElementTypeOrSelf(result.getType()); Type result_storage_type = StorageType(b, result_element_type); if (result_element_type != result_storage_type) { From dcd6016c1b2b7cf63cb584680d6023479baaa17c Mon Sep 17 00:00:00 2001 From: Will Froom Date: Wed, 8 Jan 2025 03:43:39 -0800 Subject: [PATCH 28/45] [XLA:CPU] Remove old IrEmitter::EmitElementalHostKernel PiperOrigin-RevId: 713232255 --- xla/service/cpu/BUILD | 1 + xla/service/cpu/cpu_compiler.cc | 13 +++- xla/service/cpu/ir_emitter2.cc | 114 +++++++++---------------------- xla/service/cpu/ir_emitter2.h | 13 +--- xla/service/cpu/thunk_emitter.cc | 22 ++++-- 5 files changed, 63 insertions(+), 100 deletions(-) diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 88692c8eb2c6d7..4ea228a0c63300 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -669,6 +669,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 41b3847b50613e..5c28de6021def4 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -1503,7 +1504,17 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::string ir_module_string; if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpToString(llvm_module.get()); + std::string emitter2_ir = llvm_ir::DumpToString(llvm_module.get()); + + auto thunk_kernel_fmt = [](std::string* out, + const ThunkEmitter::EmittedKernel& kernel) { + absl::StrAppend( + out, llvm_ir::DumpToString(kernel.module.getModuleUnlocked())); + }; + std::string thunks_ir = + absl::StrJoin(thunk_emitter.kernels(), "\n", thunk_kernel_fmt); + + ir_module_string = absl::StrCat(emitter2_ir, "\n", thunks_ir); } TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index ca6f1d26101167..1890d5377bfb49 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -114,46 +115,6 @@ IrEmitter2::KernelInfo::KernelInfo(KernelPrototype prototype, thread_dims(thread_dims), invariant_arguments(std::move(prototype.invariant_arguments)) {} -absl::StatusOr IrEmitter2::EmitElementalHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit elemental host kernel: " << instr->name(); - - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - - IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); - - CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (int64_t i = 0; i < instr->operand_count(); ++i) { - const HloInstruction* operand = instr->operand(i); - operand_to_generator[operand] = [&, i](const llvm_ir::IrArray::Index& idx) { - return kernel_prototype.arguments[i].EmitReadArrayElement(idx, &b); - }; - } - - if (instr->has_to_apply()) { - HloComputation* nested_computation = instr->to_apply(); - bool is_reducer = instr->opcode() == HloOpcode::kReduce || - instr->opcode() == HloOpcode::kReduceWindow; - TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitNestedComputation( - *nested_computation, llvm_ir::IrName(instr), is_reducer)); - } - - CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(&b); - llvm_ir::ElementGenerator element_generator = - elemental_emitter.MakeElementGenerator(instr, operand_to_generator); - - TF_ASSIGN_OR_RETURN( - se::ThreadDim thread_dims, - EmitElementalLoops(b, instr, kernel_prototype, element_generator)); - - return kernels_.emplace_back( - KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); -} - absl::StatusOr IrEmitter2::EmitPadHostKernel( const HloInstruction* pad) { VLOG(2) << "Emit Pad host kernel."; @@ -247,14 +208,6 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } -absl::StatusOr IrEmitter2::EmitReductionHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit reduction host kernel: " << instr->name(); - - // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. - return EmitElementalHostKernel(instr); -} - // Dot (fusion) host kernel only supports strategies that emit LLVM IR. static bool IsDotCodegenStrategy(DotImplementationStrategy strategy) { static std::array kDotCodegenStrategies = { @@ -303,25 +256,20 @@ absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( const HloInstruction* instr) { VLOG(2) << "Emit concatenate host kernel: " << instr->name(); - auto fast_impl_reason = CanDoFastConcatenate(instr); - if (fast_impl_reason.ok()) { - VLOG(1) << "Emitting fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - llvm::IRBuilder<> ir_builder(module_->getContext()); - ir_builder.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); - - llvm_ir::IrArray output_array = kernel_prototype.results[0]; - TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( - instr, kernel_prototype.arguments, output_array, module_, ir_builder)); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } - VLOG(1) << "Could not emit fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - return EmitElementalHostKernel(instr); + DCHECK_OK(CanDoFastConcatenate(instr)); + + VLOG(1) << "Emitting fast concatenate for " << instr->ToString(); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); + llvm::IRBuilder<> ir_builder(module_->getContext()); + ir_builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( + instr, kernel_prototype.arguments, output_array, module_, ir_builder)); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( @@ -401,26 +349,22 @@ absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( absl::StatusOr IrEmitter2::EmitDynamicUpdateSliceHostKernel(const HloInstruction* instr) { - if (llvm_ir::CanUpdateDynamicSliceInPlace(const_cast(instr), - nested_ir_emitter_->assignment())) { - VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); + DCHECK(CanUpdateDynamicSliceInPlace(instr)); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); + VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); - TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( - kernel_prototype.arguments, kernel_prototype.results.front(), - llvm_ir::IrName(instr, "in_place"), &b)); + llvm::IRBuilder<> b(module_->getContext()); + b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } + TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( + kernel_prototype.arguments, kernel_prototype.results.front(), + llvm_ir::IrName(instr, "in_place"), &b)); - return EmitElementalHostKernel(instr); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitSortComparator( @@ -499,6 +443,12 @@ absl::Status IrEmitter2::CanDoFastConcatenate( return absl::OkStatus(); }; +bool IrEmitter2::CanUpdateDynamicSliceInPlace( + const HloInstruction* update) const { + return llvm_ir::CanUpdateDynamicSliceInPlace( + const_cast(update), nested_ir_emitter_->assignment()); +} + IrEmitter2::ParallelPartitionBounds IrEmitter2::EmitParallelPartitionBounds( llvm::IRBuilderBase& b, const KernelPrototype& kernel_prototype, const ParallelConfig& parallel_config, const Shape& shape, diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index 2bcb7c1c9316fc..77ea6647d4ec97 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -98,10 +98,6 @@ class IrEmitter2 { absl::Span comparators() const { return comparators_; } - // Emits an elemental host kernel for the given HLO instruction. - absl::StatusOr EmitElementalHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the pad instruction. absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); @@ -109,10 +105,6 @@ class IrEmitter2 { absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); - // Emits a host kernel for the given reduction instruction. - absl::StatusOr EmitReductionHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the given dot instruction. Small dot operations // are emitted as LLVM IR directly, while larger ones are emitted as a dot // thunk that calls into libraries. @@ -137,6 +129,9 @@ class IrEmitter2 { // Emits a comparator function for the given sort instruction. absl::StatusOr EmitSortComparator(HloComputation* comparator); + absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; + bool CanUpdateDynamicSliceInPlace(const HloInstruction* update) const; + private: class ElementalIrEmitter; @@ -160,8 +155,6 @@ class IrEmitter2 { // the instruction has to be compiled to a single threaded loop. std::optional GetParallelConfig(const HloInstruction* instr); - absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; - // Emits LLVM IR that computes parallel partition bounds from the call frame's // block and thread dimensions and parallel execution config. ParallelPartitionBounds EmitParallelPartitionBounds( diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 5a3b848c3db3c8..a5d0aeade482f0 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -526,6 +526,13 @@ absl::StatusOr ThunkEmitter::EmitCallThunk( absl::StatusOr ThunkEmitter::EmitConcatenateKernelThunk( const HloInstruction* instruction) { + if (absl::Status status = ir_emitter_.CanDoFastConcatenate(instruction); + !status.ok()) { + VLOG(1) << "Could not emit fast concatenate for " << instruction->ToString() + << ": " << status.message(); + return EmitElementalKernelThunk(instruction); + } + auto* concatenate = Cast(instruction); TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitConcatenateHostKernel(concatenate)); @@ -661,13 +668,8 @@ absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( const HloInstruction* instruction) { - TF_ASSIGN_OR_RETURN(auto kernel, - ir_emitter_.EmitReductionHostKernel(instruction)); - TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - - return MakeKernelThunkSequence( - instruction, buffers, kernel, - /*min_alignment=*/cpu_function_runtime::MinAlign()); + // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. + return EmitElementalKernelThunk(instruction); } absl::StatusOr ThunkEmitter::EmitRngThunk( @@ -1041,6 +1043,12 @@ absl::StatusOr ThunkEmitter::EmitSliceThunk( absl::StatusOr ThunkEmitter::EmitDynamicUpdateSliceThunk( const HloInstruction* instruction) { + if (!ir_emitter_.CanUpdateDynamicSliceInPlace(instruction)) { + VLOG(2) << "Could not emit in-place dynamic-update-slice kernel: " + << instruction->name(); + return EmitElementalKernelThunk(instruction); + } + TF_ASSIGN_OR_RETURN( auto kernel, ir_emitter_.EmitDynamicUpdateSliceHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); From 87b8d79836b17e8dbf240f5e9ad6af59dbad0b1b Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 8 Jan 2025 03:44:13 -0800 Subject: [PATCH 29/45] Automated Code Change PiperOrigin-RevId: 713232397 --- xla/service/gpu/BUILD | 1 + xla/service/gpu/auto_sharding_gpu_compiler_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e0c2a1472cf52c..76133d52b65b9c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1728,6 +1728,7 @@ xla_test( "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", "@com_google_googletest//:gtest", "@tsl//tsl/platform:logging", ], diff --git a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc index 89be2dac856e06..ad5a80d836ea2b 100644 --- a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc +++ b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" From 7a801635665be5c7208b7271a001d4d6596e56ff Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 8 Jan 2025 04:05:04 -0800 Subject: [PATCH 30/45] Remove unused constructor parameter (NFC). This was forgotten to be removed during an earlier refactoring. PiperOrigin-RevId: 713237188 --- xla/backends/profiler/gpu/device_tracer_cuda.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xla/backends/profiler/gpu/device_tracer_cuda.cc b/xla/backends/profiler/gpu/device_tracer_cuda.cc index 578d4ab6d3021d..2d675afba107d4 100644 --- a/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -46,8 +46,7 @@ using tsl::ReadBoolFromEnvVar; // GpuTracer for GPU. class GpuTracer : public tsl::profiler::ProfilerInterface { public: - GpuTracer(CuptiTracer* cupti_tracer, CuptiInterface* cupti_interface) - : cupti_tracer_(cupti_tracer) { + explicit GpuTracer(CuptiTracer* cupti_tracer) : cupti_tracer_(cupti_tracer) { VLOG(1) << "GpuTracer created."; } ~GpuTracer() override {} @@ -227,8 +226,7 @@ std::unique_ptr CreateGpuTracer( if (!cupti_tracer->IsAvailable()) { return nullptr; } - profiler::CuptiInterface* cupti_interface = profiler::GetCuptiInterface(); - return std::make_unique(cupti_tracer, cupti_interface); + return std::make_unique(cupti_tracer); } auto register_gpu_tracer_factory = [] { From 4eb42a71034c455f9dc64d778f0a3b91d1c4fa51 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 8 Jan 2025 04:25:38 -0800 Subject: [PATCH 31/45] Update to match upstream API change (NFC). This method was renamed but staging function kept, switch to renamed variant. PiperOrigin-RevId: 713242393 --- xla/mlir/framework/transforms/outline_with_xla_framework.cc | 2 +- xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc | 4 ++-- .../broadcast_propagation/broadcast_propagation.cc | 4 ++-- .../collapse_elementwise_map/collapse_elementwise_map.cc | 3 +-- .../legalize_dot_to_dot_general.cc | 3 +-- .../legalize_einsum_to_dot_general.cc | 3 +-- .../legalize_torch_index_select_to_gather.cc | 3 +-- .../legalize_trigonometric_to_approximation.cc | 3 +-- .../mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc | 4 ++-- .../mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc | 3 +-- .../transforms/shape_simplification/shape_simplification.cc | 2 +- .../symbolic_shape_optimization.cc | 4 ++-- .../test_infer_shaped_type/test_infer_shaped_type_pass.cc | 3 +-- .../transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc | 3 +-- .../transforms/stablehlo_canonicalize_dynamism.cpp | 3 +-- xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc | 5 ++--- 16 files changed, 21 insertions(+), 31 deletions(-) diff --git a/xla/mlir/framework/transforms/outline_with_xla_framework.cc b/xla/mlir/framework/transforms/outline_with_xla_framework.cc index b960958a7d6344..7d9b8fc700767a 100644 --- a/xla/mlir/framework/transforms/outline_with_xla_framework.cc +++ b/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -164,7 +164,7 @@ class OutlineWithXLAFrameworkPass patterns.add(ctx); // Set target. - if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + if (failed(applyPatternsGreedily(m, std::move(patterns)))) { signalPassFailure(); } m->walk([](func::FuncOp f) { diff --git a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 064978aec3982b..3a09b6e3b33814 100644 --- a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -536,7 +536,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { eliminateCopies(block, /*root=*/block); do { // Eliminate dead code. - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); // Only coalesce dealloc/alloc pairs that are immediate neighbors, to // make sure we don't accidentally extend the live range of a buffer. result = reuseBuffers(block, BufferReuseMode::CONSERVATIVE); @@ -547,7 +547,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { // Now we can also coalesce distant dealloc/alloc pairs. reuseBuffers(block, BufferReuseMode::AGGRESSIVE); promoteBuffers(block); - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index da27173913f81e..c8268e4335dca2 100644 --- a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -439,8 +439,8 @@ struct BroadcastPropagationPass GreedyRewriteConfig config; config.useTopDownTraversal = false; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index 60fcd198853911..cbe532ba959f76 100644 --- a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc +++ b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -92,8 +92,7 @@ struct CollapseElementwiseMapPass MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index e986bdc5ad694c..79e55a4c9f3d53 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -68,8 +68,7 @@ struct LegalizeDotToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateDotToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index c35ce560146dcb..e861dec331848c 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -179,8 +179,7 @@ struct LegalizeEinsumToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateEinsumToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc index 8cc65ea23f04c2..865c07fc316d89 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -139,8 +139,7 @@ struct LegalizeTorchIndexSelectToGatherPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTorchIndexSelectToGatherPatterns(&getContext(), &patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc index 2e7018e2fd17c3..ccf2ed1151ccc7 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc @@ -172,8 +172,7 @@ struct LegalizeTrigonometricToApproximationPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTrigonometricToApproximationPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc index 185b2c9d7caa18..d6c4b4767297d6 100644 --- a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc +++ b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc @@ -434,8 +434,8 @@ struct MergeAssumingOpsPass mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns); GreedyRewriteConfig config; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc index deccadf230d5a3..b86038624c4c24 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc @@ -132,8 +132,7 @@ class FlattenTuplePass : public impl::FlattenTuplePassBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc index b96370f71cf23c..1747bd93b492ef 100644 --- a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc +++ b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc @@ -242,7 +242,7 @@ struct ShapeSimplification ExtractFromBroadcastedTensorCanonicalizationPattern>(context); auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index 20808e4d12d9e7..961e512d239686 100644 --- a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -793,8 +793,8 @@ class SymbolicShapeOptimizationPass final shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc index 8bd3bbc1409610..d585ea0b9d1592 100644 --- a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc @@ -95,8 +95,7 @@ struct TestInferShapedTypeMethodsPass RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc index 7409def78d770f..285f056008da72 100644 --- a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc @@ -43,8 +43,7 @@ struct TestUnfuseBatchNormPass RewritePatternSet patterns(&getContext()); populateUnfuseBatchNormInferencePattern(&getContext(), &patterns); populateUnfuseBatchNormTrainingPattern(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp index 0ad3029f96ccf6..9cd3e90e6f5dfb 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp @@ -200,8 +200,7 @@ struct StablehloCanonicalizeDynamismPass patterns.add(&getContext()); auto funcOp = getOperation(); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) { funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") << config.maxIterations << " iterations"; return signalPassFailure(); diff --git a/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc b/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc index d4c84259f2dbd9..08d4bc8894a2ef 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc +++ b/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc @@ -360,7 +360,7 @@ struct SparseBlockedToMMAPass auto pattern = std::make_unique(context, compute_capability); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { return signalPassFailure(); } } @@ -975,8 +975,7 @@ struct SparseWGMMAOpToLLVMPass MLIRContext *context = &getContext(); auto pattern = std::make_unique(context); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } From b322cfca6de4e8cbb36690b6446ea95da9053cfb Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 8 Jan 2025 04:52:23 -0800 Subject: [PATCH 32/45] PR #19649: [ROCm] Implement hermetic rocm dependency Imported from GitHub PR https://github.com/openxla/xla/pull/19649 This change has as a goal to introduce an external dependency to the rocm library and tools. Building xla with the hermetic rocm is done by using these env variables: --repo_env=OS=ubuntu_20.04 --repo_env=ROCM_VERSION=6.2.0 To use only hermetic libs define this flag: --@local_config_rocm//rocm:use_rocm_hermetic_rpath=True This flag will make rpaths and configs to look inside the sandbox If flag is not set then default installation paths are used e.g /opt/rocm One has to provie OS version and ROCm version to initialize a proper rocm repository. If these flags are not set then default ROCm installation will be used to build XLA. depends-on: https://github.com/openxla/xla/pull/19691 Copybara import of the project: -- cf744eca78f697144e122c6a9d1aa8fc52722b20 by Alexandros Theodoridis : Implement hermetic rocm dependency -- 4f4ad859ec3143fdb04f7792541c61b98c708397 by Alexandros Theodoridis : Add missing dependency -- 8e164f765b45b5e5d118b02695fd6d6e2b0b232d by Alexandros Theodoridis : Add missing dependency and remove so files from data -- 35538f4922b5b28b9debd0ce17bb15b83b5921fc by Alexandros Theodoridis : Rename setting to use_rocm_hermetic_rpath -- 58d140220e9e58572c9a7ae3de2ec1ea189566d3 by Alexandros Theodoridis : Fix build for cuda and cpu Merging this change closes #19649 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19649 from ROCm:ci_implement_hermetic_rocm_dependency_upstream 58d140220e9e58572c9a7ae3de2ec1ea189566d3 PiperOrigin-RevId: 713248195 --- .../third_party/gpus/crosstool/BUILD.rocm.tpl | 6 +- .../bin/crosstool_wrapper_driver_rocm.tpl | 1 + .../tsl/third_party/gpus/rocm/BUILD.tpl | 425 +++++++++++++++--- .../third_party/gpus/rocm/build_defs.bzl.tpl | 2 + .../tsl/third_party/gpus/rocm/rocm_redist.bzl | 18 + .../gpus/rocm/rocm_redist_ubuntu_20_04.bzl | 183 ++++++++ .../gpus/rocm/rocm_redist_ubuntu_22_04.bzl | 183 ++++++++ .../gpus/rocm/rocm_redist_ubuntu_24_04.bzl | 187 ++++++++ .../tsl/third_party/gpus/rocm_configure.bzl | 208 ++++----- .../tsl/third_party/remote_config/common.bzl | 11 +- xla/service/gpu/BUILD | 9 +- xla/stream_executor/rocm/BUILD | 11 +- xla/tsl/platform/default/BUILD | 15 +- 13 files changed, 1054 insertions(+), 205 deletions(-) create mode 100644 third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl create mode 100644 third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl create mode 100644 third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl create mode 100644 third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl diff --git a/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl index 03a9dde83cfddc..ac3082fbcb3055 100644 --- a/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -111,7 +111,7 @@ filegroup( ) filegroup( - name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = [":clang/bin/crosstool_wrapper_driver_is_not_gcc"], + data = ["@local_config_rocm//rocm:all_files"], ) - diff --git a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 3c59884c6f729e..389ffea421035a 100755 --- a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -186,6 +186,7 @@ def InvokeHipcc(argv, log=False): hipccopts += defines hipccopts += std_options hipccopts += m_options + hipccopts += ' --rocm-path="%{rocm_path}" ' if depfiles: # Generate the dependency file diff --git a/third_party/tsl/third_party/gpus/rocm/BUILD.tpl b/third_party/tsl/third_party/gpus/rocm/BUILD.tpl index aa3688e335df37..7ebf2773eb48b1 100644 --- a/third_party/tsl/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/tsl/third_party/gpus/rocm/BUILD.tpl @@ -1,8 +1,22 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_version_number", "select_threshold") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like -package(default_visibility = ["//visibility:public"]) +package(default_visibility = ["//visibility:private"]) + +bool_flag( + name = "use_rocm_hermetic_rpath", + build_setting_default = False, +) + +config_setting( + name = "build_hermetic", + flag_values = { + ":use_rocm_hermetic_rpath": "True", + }, +) config_setting( name = "using_hipcc", @@ -12,171 +26,434 @@ config_setting( ) cc_library( - name = "rocm_headers", + name = "config", hdrs = [ - "rocm/rocm_config.h", - %{rocm_headers} + "rocm_config/rocm_config.h", ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config", +) + +cc_library( + name = "config_hermetic", + hdrs = [ + "rocm_config_hermetic/rocm_config.h", + ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config_hermetic", +) + +cc_library( + name = "rocm_config", + visibility = ["//visibility:public"], + deps = select({ + ":build_hermetic": [ + ":config_hermetic", + ], + "//conditions:default": [ + "config", + ], + }), +) + +cc_library( + name = "rocm_headers", + hdrs = glob([ + "%{rocm_root}/include/**", + "%{rocm_root}/lib/llvm/lib/**/*.h", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", - "rocm/include/roctracer", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", + "%{rocm_root}/include/roctracer", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [ + ":rocm_rpath", + ], ) cc_library( - name = "hip", - srcs = ["rocm/lib/%{hip_lib}"], - data = ["rocm/lib/%{hip_lib}"], + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + ":hip", + ":hipblas", + ":hipblaslt", + ":hiprand", + ":hipsolver", + ":hipsparse", + ":hsa_rocr", + ":miopen", + ":rocblas", + ":rocm_config", + ":rocprofiler_register", + ":rocsolver", + ":roctracer", + ":rocsparse", + ] + select_threshold( + above_or_eq = [":hipfft"], + below = [":rocfft"], + threshold = 40100, + value = rocm_version_number(), + ), +) + +cc_library( + name = "hsa_rocr", + srcs = glob(["%{rocm_root}/lib/libhsa-runtime*.so*"]), + hdrs = glob(["%{rocm_root}/include/hsa/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_rpath", + linkopts = select({ + ":build_hermetic": [ + "-Wl,-rpath=%{rocm_toolkit_path}/lib", + ], + "//conditions:default": [ + "-Wl,-rpath=/opt/rocm/lib", + ], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hip", visibility = ["//visibility:public"], + deps = [ + ":rocm_hip", + ":rocm_rpath", + ], +) + +cc_library( + name = "rocm_hip", + srcs = glob(["%{rocm_root}/lib/libamdhip*.so*"]), + hdrs = glob(["%{rocm_root}/include/hip/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [ + ":amd_comgr", + ":hsa_rocr", + ":rocm_config", + ":rocm_smi", + ":rocprofiler_register", + ":system_libs", + ], ) cc_library( name = "rocblas", - srcs = ["rocm/lib/%{rocblas_lib}"], - data = ["rocm/lib/%{rocblas_lib}"], + hdrs = glob(["%{rocm_root}/include/rocblas/**"]), + data = glob([ + "%{rocm_root}/lib/librocblas*.so*", + "%{rocm_root}/lib/rocblas/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring tensile files to the same fs layout as expected in the lib + # rocblas assumes that tensile files are located in ../roblas/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "%{hipfft_or_rocfft}", - srcs = ["rocm/lib/%{hipfft_or_rocfft_lib}"], - data = ["rocm/lib/%{hipfft_or_rocfft_lib}"], + name = "rocfft", + srcs = glob(["%{rocm_root}/lib/librocfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "hiprand", - srcs = ["rocm/lib/%{hiprand_lib}"], - data = ["rocm/lib/%{hiprand_lib}"], + name = "hipfft", + srcs = glob(["%{rocm_root}/lib/libhipfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", + "%{rocm_root}/include", ], linkstatic = 1, - visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "miopen", - srcs = ["rocm/lib/%{miopen_lib}"], - data = ["rocm/lib/%{miopen_lib}"], + name = "hiprand", + srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), + hdrs = glob(["%{rocm_root}/include/hiprand/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rccl", - srcs = ["rocm/lib/%{rccl_lib}"], - data = ["rocm/lib/%{rccl_lib}"], + name = "miopen", + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + data = glob([ + "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/share/miopen/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring miopen db files to the same fs layout as expected in the lib + # rocblas assumes that miopen db files are located in ../share/miopen/db directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - ":rocm_headers", - ":hip", - ":rocblas", - ":hipblas", - ":%{hipfft_or_rocfft}", - ":hiprand", - ":miopen", - ":hipsparse", - ":roctracer", - ":rocsolver", - ":hipsolver", + name = "rccl", + srcs = glob(["%{rocm_root}/lib/librccl*.so*"]), + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", ], + linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], + visibility = ["//visibility:public"], ) cc_library( name = "rocprim", srcs = [ - "rocm/include/hipcub/hipcub_version.hpp", - "rocm/include/rocprim/rocprim_version.hpp", + "%{rocm_root}/include/hipcub/hipcub_version.hpp", + "%{rocm_root}/include/rocprim/rocprim_version.hpp", ], hdrs = glob([ - "rocm/include/hipcub/**", - "rocm/include/rocprim/**", + "%{rocm_root}/include/hipcub/**", + "%{rocm_root}/include/rocprim/**", ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include/hipcub", - "rocm/include/rocprim", + "%{rocm_root}/include/hipcub", + "%{rocm_root}/include/rocprim", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ - "@local_config_rocm//rocm:rocm_headers", + ":rocm_config", + ":rocm_headers", ], ) cc_library( name = "hipsparse", - srcs = ["rocm/lib/%{hipsparse_lib}"], - data = ["rocm/lib/%{hipsparse_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), + data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "roctracer", - data = ["rocm/lib/%{roctracer_lib}"], + hdrs = glob(["%{rocm_root}/include/roctracer/**"]), + data = glob(["%{rocm_root}/lib/libroctracer*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "rocsolver", - srcs = ["rocm/lib/%{rocsolver_lib}"], - data = ["rocm/lib/%{rocsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocsolver/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocsparse", + srcs = glob(["%{rocm_root}/lib/librocsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipsolver", - srcs = ["rocm/lib/%{hipsolver_lib}"], - data = ["rocm/lib/%{hipsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), + data = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipblas", - srcs = ["rocm/lib/%{hipblas_lib}"], - data = ["rocm/lib/%{hipblas_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipblas.so*"]), + hdrs = glob(["%{rocm_root}/include/hipblas/**"]), + data = glob(["%{rocm_root}/lib/libhipblas.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "hipblaslt", + hdrs = glob(["%{rocm_root}/include/hipblaslt/**"]), + data = glob([ + "%{rocm_root}/lib/hipblaslt/**", + "%{rocm_root}/lib/libhipblaslt.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + # workaround to bring tensile files to the same fs layout as expected in the lib + # hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocrand", + srcs = glob(["%{rocm_root}/lib/librocrand*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocrand/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocprofiler_register", + srcs = glob([ + "%{rocm_root}/lib/librocprofiler-register.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "amd_comgr", + srcs = glob([ + "%{rocm_root}/lib/libamd_comgr.so*", + ]), + hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_smi", + srcs = glob([ + "%{rocm_root}/lib/librocm_smi64.so*", + "%{rocm_root}/lib/libroam.so*", + ]), + hdrs = glob([ + "%{rocm_root}/include/oam/**", + "%{rocm_root}/include/rocm_smi/**", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "system_libs", + srcs = glob([ + "rocm_dist/usr/lib/**/libelf.so*", + "rocm_dist/usr/lib/**/libdrm.so*", + "rocm_dist/usr/lib/**/libnuma.so*", + "rocm_dist/usr/lib/**/libdrm_amdgpu.so*", + ]), + data = glob([ + "rocm_dist/usr/**", + ]), ) filegroup( name = "rocm_root", srcs = [ - "rocm/bin/clang-offload-bundler", + "%{rocm_root}/bin/clang-offload-bundler", ], + visibility = ["//visibility:public"], ) -%{copy_rules} +filegroup( + name = "all_files", + srcs = glob(["%{rocm_root}/**"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 83a7e9dababf38..d327083e4dc8ea 100644 --- a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -11,6 +11,8 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false }) +def select_threshold(value, above_or_eq, threshold, below): + return below if value < threshold else above_or_eq def rocm_default_copts(): """Default options for all ROCm compilations.""" diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl new file mode 100644 index 00000000000000..c1cc501e1a2ded --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl @@ -0,0 +1,18 @@ +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_20_04.bzl", + "rocm_redist_ubuntu_20_04", +) +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_22_04.bzl", + "rocm_redist_ubuntu_22_04", +) +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_24_04.bzl", + "rocm_redist_ubuntu_24_04", +) + +rocm_redist = { + "ubuntu_20.04": rocm_redist_ubuntu_20_04, + "ubuntu_22.04": rocm_redist_ubuntu_22_04, + "ubuntu_24.04": rocm_redist_ubuntu_24_04, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl new file mode 100644 index 00000000000000..ecae2197563b33 --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_20_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~20.04_amd64.deb", + sha256 = "fabf4a831f21b5248932e08654149bc215da2a816613ad8d05b805d4e226171a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "215fae8759742bc048699feaacd6256a3ac2138771b69731dab7779325bb1b41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "e901d66275b3b520ee73250caa4a1836be142823083528b4db6cc31a18bfb94d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "f8a20128b5c26198bd9ecec894f8a4c74fa28ee668e4ef1bf73d0c3edff8c144", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "ab3ee54b33eba013fbf3d9aefe64b54e1918b9fb72790ca0b57fb391cb662cf0", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~20.04_amd64.deb", + sha256 = "a68123c046b8c913705262014463a8a30768167a1b68a78d8455deaf85a802d7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "c71fab59f62ad9d4b60aa4217f4db42c6996d83d5ad7ba29e127cc13bda59afc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "25887526ea2e955d4c0afa4749f8db55a49e399a349d43ccf66e0ad99ff78b2a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "3cfec840c79c6bce4e83bf6e056e241cc13ff572352b040a952c7642b61d45aa", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "cb56dd79ff52eaddfed379831023484d9ec32b9538bc3d02ee34c328457cd20e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "1e968f9405c8b90fbb58dff09d8bab08cf31c8386880fff95e1cb8932320bc37", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "f08ba25b6b950754b5a2bb64c125a01b9f44280f227ff19eeb78e188f0b17320", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "e9464369619bbea7299ac83e17b3cbbabdeb16e6d4da116400532e7737332b65", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "2efed49be9413e08e91b3fb67736644bb0e8809fc673d310a0abab65b69eacad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "19564fb2f9616860234aa8bd69cca324a1a3ec33476581ec57200a1dac1d4dcb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~20.04_amd64.deb", + sha256 = "e4940a5d47e9e39d603f18936e7921c603fd8dde0e359e0be796f9c1cdacd431", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "638a28c5407c3af7d16e1b0179b7494b0aeb36c314114af148b1bcd52e883db1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "77c9d26c4f0053b71fb86f7a6b489655e27053f9605efca3a16344ccf286e313", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "2b3ce1ca2e58e891963f26d4bd31ae45894480483f691d371f269e698f75f8eb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "0dedbffa5bb272d656086a9586e3705551345945f35f4f6be6dc8a27b63127a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "6e5b3caeadf592367f8638db67a70b8dd9231a8257dc2012a9c46e2c5974fff5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "eaefe5a7d75ef61314b83af5bb85d8e652a730deaa58e1d600b1e9c2e673673c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "b2bfe29ab688781bad5bc067ee682658085e22caaf09b18278f2f4b9905081d3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "e94d50fd6f24d70649ce046dbfe4dda2587d1d82892d4c126a4c3e91d1570071", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "0e16c9fc58fc904542be4dad63bb2ff34268b5c13957c432e91ec0e4fd149c82", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "14f47d79b508eb259bfe4e0e5f360edb5721b908caf3bb981a4eee4181783be9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "97e6e77eaea56de6cc4ea2c525dd8b9a587546eb99c782c7af46cdc5363b99bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "ae055b579d319e1a779783ba774f119fb0e1a731d058a03b36dc5c15214d210a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "3bcf3dc22dbede7da70299cde1484776827808b967d371441f6cf6d3fe8af30d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "ce17d2b85407b9539e0feda513fd360a48ebfd971c19af122dda21d60448c9fc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "322ca8425c3a8f2ec17c551bad606b96d957b0c1eea07196dd66ac9f15460ed5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~20.04_amd64.deb", + sha256 = "1bbdb32d21dbc12bf9a736f6ca8726df9673e4401465d2b9b537c47b358b67f1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "e74e1907eb90a692344626e881cb88eeed5565ac3b487eb94ad4ac02ffd838ed", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~20.04_amd64.deb", + sha256 = "4be88c5010c2cf0223c1dd7dc9d4a430fc54ee401ca093de2dcca60dabea763a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~20.04_amd64.deb", + sha256 = "ddd0ac44b08470dfc128d6f6d2598a9728879f5a78bc5290645baebf22433b63", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "b94cdf230b372ebcaf97085cf67f01ef7977f814280fdaf1886797f39899ef41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "9a85b57eea3790432eae06421081b3e59d3c9841d59646364ecd174f9ed4821a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "87dcd34a9b50f46161ecdb7781ab03c2b311fb7e13aa167c4a9c5e3bcf24b473", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "21e4aa1957e7bc5d293a418a983d9b3c3917fb78eb79d3d4d55a253b9bae7743", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "dacc13278f2be1cd847fca30ce409dcf95749df5f1a27635bc6dbd61be488d14", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.101-2_amd64.deb", + sha256 = "4cd2e10f9486456a2782487f8bfd39f330f35a4d5bd6d693412b9e4ca2a6acbd", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.101-2_amd64.deb", + sha256 = "d4567a30f7d68b4dcf794f8677b96e89083693c94e88279fecf577ceba8b9774", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.176-1.1build1_amd64.deb", + sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.12-1_amd64.deb", + sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl new file mode 100644 index 00000000000000..88dca226f795b7 --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_22_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~22.04_amd64.deb", + sha256 = "bc5d620e4e0db3746fc6b2279e463f618681f1f95ba973e40b687cef50ca2489", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "38e9670bedc7bbdc0b9f38c7a0fe90f73ef80f161cbf63c98d30e422438ce2c5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "c66cc8c19b57cab740710811457f02a16e24cff761e5c99c3640f63ceefe8281", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "fbd647e1b13e7aa2c14c9581f9102c069ddab9ecb47a4b226d433ec37b19e92d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "885cf3f3a52ebde9caadf6348a6cda28fd15e3bc52bab0c90b587d72b29ff7ef", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~22.04_amd64.deb", + sha256 = "468026fa8eb70121f0c545557a926ddc41228cef9457b4a00d8fc3a36b04310f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "c2c7d2ec5a8a31837c0addfc619ee67a374ea967cc6d43900472005489f62722", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "6e649430cc5e247bbd052dff2d681b6bf0ef09d0bc3446a4911f4ab4cd317140", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "389b0c83a39adbeeec442adde3fedba2820ed948179a4a0df03d67560501cd97", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "adf9aad1fc062445e34cdddbeca80db9c02f4c5f258e01c45e2a6222d15cb66d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "cb46dfbff3943a3167f6173fc381d744eb966a3451bcff49458c696888ec452c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "8c7a216aeef6ceeb3881d3e443a89a0f5c15a17deb5926cba4b787554c8fab87", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "501cad72df5f09572f99c11eebbb1eff49afb6ca8c91bcf4966f81068177a95d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "b20c86be57698a944f91048699d0fbde5253bea28ba9d4035ce1de1d3c20f9ac", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "9dab6f44b92b6020e183777f6f07219d68de5d10cad7538c7ddcae0192aa3e33", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~22.04_amd64.deb", + sha256 = "62d280204d8ff642b464dab03fc344442df6dc5f04e152da20604e8050303c41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "6c2aa042067e51d5b70a264ca83c92ffaa6e81d00d08b55986917da860e66d85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "f3452b2bd9c2869c550c7f963cca65fb35a37183ad4a56d96e05c69adb2f1d04", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "f3205c0a7d736f457ee2262988260e8dc4c495fa74a394ff73a9dfe002aff335", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "953a248cd44f403e5423185918166bfa29a009519c3d7b5b5a8e067fdf672602", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "c306ca3e59b851ebb35872e09e5598adf2e2ebb736c1b200ff4ee204fe262f7e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "115d0e9ec1b93bf7cba5fa1e3de1428f0d999d931c2dd495e4cdad22b5078936", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "0d40fc9aa1da617cd8864258cd1259a0e7444ea0da446297d154b5b3422393af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "8c1e72cf1c165e20960b0c2f3c499900a809d59340d14a0acff95c543c7087f2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "22c80c1a704f4ce7d6a49a8b41acd64f3ed0513cd7f5570a0664a10df5858334", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "9c2ff1dc100e342969bd51a7cd4918048c8b25579de709efde56425d969cd50f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "1101f3edb9dbc9f4914d7f26b5569ec9bde076d52d4125c98d22a99dd730ab51", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "d5b660df350130e0ab04ddf3e36dd442bde27ae9cbb8e5f12c047b0d3cb05463", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "0d06a84ac53d388089b7b8c80133f60c1eea5bfd85155ecc113efb206a747c25", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "4a29539480a7e4b27991ccf533a35526dd3994a457fa84e4c960192c2fa05b46", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "febb8614cedd98f13ba0624072ffdd13b9a6dc3431380a17a0eaf87583627890", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "3d859bb735ff8bf1962ce680e9257dcc574ab36224f50069f833fa19c6d7e69d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~22.04_amd64.deb", + sha256 = "ffd4e064e8a1d52b9e72114e8a1d51c78004a960f1d923448af8ed07a1b6f30b", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~22.04_amd64.deb", + sha256 = "66df78d8c5e2d1a0ae43cd4a5e41cf75ec120c870a0bbd7da18a2ba4dec42f9c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~22.04_amd64.deb", + sha256 = "317c16a6e0b0b456153437406dd92225e17dbd454fc1304b0c3fef5fbfc69bc2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9ddf8835f1e94d5004b4c466091c8110cb72e11eda545d0de395395832076c0a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9a9ed0c66d3a9d9ff50f1fc3a9e9105bb8b1a6d93c1f856682625dfb68ab627f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "5b86bf7b33a3ffa7098878f27d1b119aada69ebb02bd121b47209559c32703be", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "4573f99191fbe3a2afab84fdf5a05e024bd230ca7866d7eba71a5f2560a3a0bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "4fbc91db9085ecd80a5e051bba56863ae33b22516d727ab3fef15fb500187222", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.110-1ubuntu1_amd64.deb", + sha256 = "e5ea68db36b31aab442c790e1c78ecdf53646c16b0cd83db15966632ba04152c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.110-1ubuntu1_amd64.deb", + sha256 = "ae1f0d77668d7275d085ba820206ba91e90833dd1a02b8e251af0c73aa119ba3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.186-1build1_amd64.deb", + sha256 = "8effc4d7a0cc341bcf6cb11af0134f3defa6292376ecfdfc697a9b228606345c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.14-3ubuntu2_amd64.deb", + sha256 = "0721c89001fbbd1ada23e89da5d60e762763c1a7b3dc814a2e9a518480a8043d", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl new file mode 100644 index 00000000000000..da9ef00998f936 --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl @@ -0,0 +1,187 @@ +rocm_redist_ubuntu_24_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~24.04_amd64.deb", + sha256 = "7e1ff2d9f2435f5b9db9aa952bb57d1a878a8aa7d96bda61361c107b7e1428e3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "5e6601ada30432ee0dab0473585bdf1fa7c398f0c655538d48eba9c44e6dc77a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "7ff8f6308c744c71008959b17ab6338de1c6fd3e4581dd94271e6eca9fdc4c13", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "e9f71e71db600d72dcb2b61e64b965b6c60d47bd4bb699e8abec85edb260b819", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt6.2.0/hipblaslt6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "e5dfd8ba9e49f919a96c102d3a652e8ef0c4d1a63b3f3909c856d40b1745e2a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt-dev6.2.0/hipblaslt-dev6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "639bd47010035ee6719425510be33d2f54483004a909dfa4c64f853d7394a22f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~24.04_amd64.deb", + sha256 = "c2782a98633e4400f46ba732605e56b2821366db60ec06d88db0615e4d1acf3c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "48fec4d06aef3159db4117125b728242a1eeb480ea3d55d3901d945d4b883694", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "8dd73cdbd4f0563f4a0481304771e4cbcac5905eea1f2d8ef41f922cdf9aba85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "e3c0a4ebda8d3aacd44b19c6872f23222513be0a5c04f793605088d9183f1be4", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "adbba9ffcf8b5e4202efbe45924d87520bf4100ec5464bd0ba3beb61cb535c6c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "01d3dd6195111808b40a5837d3e51d8c27c4700b4bd8bb2d901e39d0474fd98a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "2ba33a96388cd3edd7b5b8b261fe99cbd569894f4d7db291fc0dd0ff5d7c67ce", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "6a767f493a722e2d4260a9bc23cf9db66fd275a094b395c768e305f60d6b4fe9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "82f182134b415080ba4a12fd7993b6099ee9b9e549c72bfebee24c8486704078", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "011d5c28f45cd9d756e0cf6ea6a3d37eabd98a3381ffd961c772ab92a37e4ee8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~24.04_amd64.deb", + sha256 = "fa04f707debb75087ea2bf5e327602034eaa3a6900421f2cf32ad5f5f1c887b9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "2dbf6d126d0de6930e0cd94d0e525e07d3019d90bd7256f3151a7f1fbc2250af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "df5fdd2218e4d380b133ba402f3734fbe0589d9cdd8618a101b71b968909b4ba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4d7efa4ee6aa2bf69b0aab449cc1d01c25ca65814e1b3cb07f6b59fa8b1608b8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4ab4f880344e04d61b6fa746be5c4bdc2841409fb6987ee61e39c6420b4eca42", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "521c87ce396c6ce10076cc641b6035451fd68ddb36a684c5a9c9538dfc831ade", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "00f135ce2ae47c35085ef06248ff7d5ce8c12fd0d5b82e7bd77b1dbc0ce7058e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "40c936452e84bfec87236f08de5a9d3f232c397a3305b6143c26697ed56ceda1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "eb3904263b396d46799eeea1081d8e8d1a551a890432a803364db2d013849f92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "af5fcbe8dc2b6cbec30e2d39d30736e8a47a0b9d0ca2be7f179f2947f9c98245", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "228f07a3caefc41f6efd5345eb9d3630f1db769f9b4abd1313cbcb32d077ce53", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "cda72054d2011dbb062e75386766d928fd8905c15c88685c3ef87fc963bd88ad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "298544f717dfb236b9257b19a0ab81abaaa770128976d4abfdea546cd32d8b02", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "8e78ed8e480b55a496153b150acb22bab39c3bb8cf1e62f9aff7eaf75a3a3a92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "72c388eae7c0f54151b46fbd8fa6e26f1ca81e2b8b415c43411a156b3f25b6e7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "3e85a859c5dafa82a9a57dda096d566b821217bacfac995f7cc45ed460b68999", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~24.04_amd64.deb", + sha256 = "c094e3022c73fca2aa6c8bb435f93550109531de37fe8de5fbf6cfe1f047b645", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "6c832e2feb0885fbe481245825c76a466921b294f530eb0d0da70a44cfe6e608", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~24.04_amd64.deb", + sha256 = "d198d010fedfbe51d3fd19444e2848d430e08f91d19a5b2661b94ac6d1135863", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~24.04_amd64.deb", + sha256 = "2a2a95185ce0e54df226474b2f5cfcdc9e5ede5a6d88a8a70c2635ea2237abba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "2f2fb6f8d06ace89131934c833b0ea359335a4b45aeec1559b293d7bc14b1d1d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "c6c781ee87c459aed32e943b389137f98ecd402fb83a3d1c98de9a76abadc3a3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5e4b3e38556f0826e5322971635a49a72283d60862ccc4d28efd11c8fb955b47", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5bb6ae92a25f33488f2ee5f123ac4f67ad130e18e4949161715451509be3b89d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "1867833a569fbf3f87b82c81bc47f5d62085ea40f12d1cb33475c1f2dec89bc4", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.120-2build1_amd64.deb", + sha256 = "f5fb4e7ce17921cc466fb7911abf91495ffb181b36772f68e2e82cb621703112", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.120-2build1_amd64.deb", + sha256 = "e149d4daea33f58853b8013fd6c24888429ce7716a4b26d1a1f45181b5a4e73e", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1t64_0.190-1.1build4_amd64.deb", + sha256 = "b277e52769302778bd052376ac6687b52954b6605dd5f781bff8631e3504d58f", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.18-1build1_amd64.deb", + sha256 = "508daa855e99959acaa945e6a89d218e0be6b5727fd28773580942ff37cf5805", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/tsl/third_party/gpus/rocm_configure.bzl index b980f10448ad88..5e2ba436b3710b 100644 --- a/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -12,6 +12,10 @@ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ +load( + "//third_party/gpus/rocm:rocm_redist.bzl", + "rocm_redist", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -33,8 +37,6 @@ load( load( ":cuda_configure.bzl", "enable_cuda", - "make_copy_dir_rule", - "make_copy_files_rule", ) load( ":sycl_configure.bzl", @@ -48,6 +50,9 @@ _TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" +_DISTRIBUTION_PATH = "rocm/rocm_dist" +_OS = "OS" +_ROCM_VERSION = "ROCM_VERSION" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" @@ -203,20 +208,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): """ inc_dirs = [] - # Add HSA headers (needs to match $HSA_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") - - # Add HIP headers (needs to match $HIP_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") - if int(rocm_config.rocm_version_number) >= 50200: - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocblas") - - # Add HIP-Clang headers (realpath relative to compiler binary) - rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) + # Add full paths + rocm_toolkit_path = str(repository_ctx.path(rocm_config.rocm_toolkit_path)) inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") @@ -367,7 +360,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin): +def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin): """Returns the ROCm libraries on the system. Args: @@ -383,7 +376,6 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ for name, path in [ ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), - (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), @@ -401,17 +393,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True)) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) -def find_rocm_config(repository_ctx): +def find_rocm_config(repository_ctx, rocm_path): """Returns ROCm config dictionary from running find_rocm_config.py""" python_bin = get_python_bin(repository_ctx) - exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config]) + exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config], env_vars = {"ROCM_PATH": rocm_path}) if exec_result.return_code: auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) # Parse the dict from stdout. return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) -def _get_rocm_config(repository_ctx, bash_bin): +def _get_rocm_config(repository_ctx, bash_bin, rocm_path, install_path): """Detects and returns information about the ROCm installation on the system. Args: @@ -426,7 +418,7 @@ def _get_rocm_config(repository_ctx, bash_bin): miopen_version_number: The version of MIOpen on the system. hipruntime_version_number: The version of HIP Runtime on the system. """ - config = find_rocm_config(repository_ctx) + config = find_rocm_config(repository_ctx, rocm_path) rocm_toolkit_path = config["rocm_toolkit_path"] rocm_version_number = config["rocm_version_number"] miopen_version_number = config["miopen_version_number"] @@ -437,6 +429,7 @@ def _get_rocm_config(repository_ctx, bash_bin): rocm_version_number = rocm_version_number, miopen_version_number = miopen_version_number, hipruntime_version_number = hipruntime_version_number, + install_path = install_path, ) def _tpl_path(repository_ctx, labelname): @@ -500,15 +493,12 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": "hipfft", - "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), "%{roctracer_lib}": _lib_name("roctracer64"), "%{rocsolver_lib}": _lib_name("rocsolver"), "%{hipsolver_lib}": _lib_name("hipsolver"), "%{hipblaslt_lib}": _lib_name("hipblaslt"), - "%{copy_rules}": "", "%{rocm_headers}": "", }, ) @@ -526,7 +516,7 @@ def _create_dummy_repository(repository_ctx): "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, "%{hipblaslt_flag}": "0", }, - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", ) # If rocm_configure is not configured to build with GPU support, and the user @@ -578,6 +568,53 @@ def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def _download_package(repository_ctx, archive): + file_name = _get_file_name(archive.url) + tmp_dir = "tmp" + repository_ctx.file(tmp_dir + "/.idx") # create tmp dir + + repository_ctx.report_progress("Downloading and extracting {}, expected hash is {}".format(archive.url, archive.sha256)) # buildifier: disable=print + repository_ctx.download_and_extract( + url = archive.url, + output = tmp_dir if archive.url.endswith(".deb") else _DISTRIBUTION_PATH, + sha256 = archive.sha256, + ) + + all_files = repository_ctx.path(tmp_dir).readdir() + + matched_files = [f for f in all_files if _get_file_name(str(f)).startswith("data.")] + for f in matched_files: + repository_ctx.extract(f, _DISTRIBUTION_PATH) + + repository_ctx.delete(tmp_dir) + repository_ctx.delete(file_name) + +def _remove_root_dir(path, root_dir): + if path.startswith(root_dir + "/"): + return path[len(root_dir) + 1:] + return path + +def _setup_rocm_distro_dir(repository_ctx): + """Sets up the rocm hermetic installation directory to be used in hermetic build""" + bash_bin = get_bash_bin(repository_ctx) + os = repository_ctx.os.environ.get(_OS) + rocm_version = repository_ctx.os.environ.get(_ROCM_VERSION) + if os and rocm_version: + redist = rocm_redist[os][rocm_version] + repository_ctx.file("rocm/.index") + for archive in redist["archives"]: + _download_package(repository_ctx, archive) + return _get_rocm_config(repository_ctx, bash_bin, "{}/{}".format(_DISTRIBUTION_PATH, redist["rocm_root"]), "/{}".format(redist["rocm_root"])) + else: + rocm_path = repository_ctx.os.environ.get(_ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + repository_ctx.report_progress("Using local rocm installation {}".format(rocm_path)) # buildifier: disable=print + repository_ctx.symlink(rocm_path, _DISTRIBUTION_PATH) + return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + def _create_local_rocm_repository(repository_ctx): """Creates the repository containing files set up to build with ROCm.""" @@ -590,12 +627,8 @@ def _create_local_rocm_repository(repository_ctx): "rocm:rocm_config.h", ]} - bash_bin = get_bash_bin(repository_ctx) - rocm_config = _get_rocm_config(repository_ctx, bash_bin) - - # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft + rocm_config = _setup_rocm_distro_dir(repository_ctx) rocm_version_number = int(rocm_config.rocm_version_number) - hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path @@ -603,75 +636,19 @@ def _create_local_rocm_repository(repository_ctx): # Copy header and library files to execroot. # rocm_toolkit_path - rocm_toolkit_path = rocm_config.rocm_toolkit_path - copy_rules = [ - make_copy_dir_rule( - repository_ctx, - name = "rocm-include", - src_dir = rocm_toolkit_path + "/include", - out_dir = "rocm/include", - ), - ] - - # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and - # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include - # dir has been removed. This removal will happen in a near-future ROCm release. - hiprand_include = "" - hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand" - softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin) - if not softlink_exists[0]: - hiprand_include = '":hiprand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "hiprand-include", - src_dir = rocm_toolkit_path + "/hiprand/include", - out_dir = "rocm/include/hiprand", - ), - ) - - rocrand_include = "" - rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand" - softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin) - if not softlink_exists[0]: - rocrand_include = '":rocrand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "rocrand-include", - src_dir = rocm_toolkit_path + "/rocrand/include", - out_dir = "rocm/include/rocrand", - ), - ) + rocm_toolkit_path = _remove_root_dir(rocm_config.rocm_toolkit_path, "rocm") - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin) + bash_bin = get_bash_bin(repository_ctx) + rocm_libs = _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): if lib: rocm_lib_srcs.append(lib.path) rocm_lib_outs.append("rocm/lib/" + lib.file_name) - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-lib", - srcs = rocm_lib_srcs, - outs = rocm_lib_outs, - )) clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler" - # copy files mentioned in third_party/gpus/rocm/BUILD - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-bin", - srcs = [ - clang_offload_bundler_path, - ], - outs = [ - "rocm/bin/" + "clang-offload-bundler", - ], - )) - have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0" # Set up BUILD file for rocm/ @@ -693,20 +670,8 @@ def _create_local_rocm_repository(repository_ctx): ) repository_dict = { - "%{hip_lib}": rocm_libs["amdhip64"].file_name, - "%{rocblas_lib}": rocm_libs["rocblas"].file_name, - "%{hipfft_or_rocfft}": hipfft_or_rocfft, - "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name, - "%{hiprand_lib}": rocm_libs["hiprand"].file_name, - "%{miopen_lib}": rocm_libs["MIOpen"].file_name, - "%{rccl_lib}": rocm_libs["rccl"].file_name, - "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, - "%{roctracer_lib}": rocm_libs["roctracer64"].file_name, - "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, - "%{copy_rules}": "\n".join(copy_rules), - "%{rocm_headers}": ('":rocm-include",\n' + - hiprand_include + - rocrand_include), + "%{rocm_root}": rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), } is_rocm_clang = _use_rocm_clang(repository_ctx) @@ -726,7 +691,6 @@ def _create_local_rocm_repository(repository_ctx): ) # Set up crosstool/ - cc = find_cc(repository_ctx, is_rocm_clang) host_compiler_includes = get_cxx_inc_directories( repository_ctx, @@ -785,6 +749,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx.template( "crosstool/cc_toolchain_config.bzl", tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"], + rocm_defines, ) repository_ctx.template( @@ -792,11 +757,13 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"], { "%{cpu_compiler}": str(cc), - "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc", + "%{compiler}": rocm_defines["%{compiler}"], + "%{hipcc_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/bin/hipcc")), "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{rocm_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), + "%{rocr_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{hip_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), @@ -806,13 +773,32 @@ def _create_local_rocm_repository(repository_ctx): # Set up rocm_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. repository_ctx.template( - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", + tpl_paths["rocm:rocm_config.h"], + { + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + "%{rocm_toolkit_path}": rocm_config.install_path, + "%{rocm_version_number}": rocm_config.rocm_version_number, + "%{miopen_version_number}": rocm_config.miopen_version_number, + "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, + "%{hipblaslt_flag}": have_hipblaslt, + "%{hip_soversion_number}": "6" if int(rocm_config.rocm_version_number) >= 60000 else "5", + "%{rocblas_soversion_number}": "4" if int(rocm_config.rocm_version_number) >= 60000 else "3", + }, + ) + + # Set up rocm_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "rocm/rocm_config_hermetic/rocm_config.h", tpl_paths["rocm:rocm_config.h"], { "%{rocm_amdgpu_targets}": ",".join( ["\"%s\"" % c for c in rocm_config.amdgpu_targets], ), - "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), "%{rocm_version_number}": rocm_config.rocm_version_number, "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, @@ -888,6 +874,8 @@ _ENVIRONS = [ "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, + _OS, + _ROCM_VERSION, ] remote_rocm_configure = repository_rule( diff --git a/third_party/tsl/third_party/remote_config/common.bzl b/third_party/tsl/third_party/remote_config/common.bzl index 57fb6fcf7aca9a..c70c0ba5b51db6 100644 --- a/third_party/tsl/third_party/remote_config/common.bzl +++ b/third_party/tsl/third_party/remote_config/common.bzl @@ -212,7 +212,8 @@ def execute( cmdline, error_msg = None, error_details = None, - allow_failure = False): + allow_failure = False, + env_vars = {}): """Executes an arbitrary shell command. Args: @@ -222,10 +223,11 @@ def execute( error_details: string, details about the error or steps to fix it allow_failure: bool, if True, an empty stdout result or output to stderr is fine, otherwise either of these is an error + env_vars: environment variables Returns: The result of repository_ctx.execute(cmdline) """ - result = raw_exec(repository_ctx, cmdline) + result = raw_exec(repository_ctx, cmdline, env_vars) if (result.stderr or not result.stdout) and not allow_failure: fail( "\n".join([ @@ -236,7 +238,7 @@ def execute( ) return result -def raw_exec(repository_ctx, cmdline): +def raw_exec(repository_ctx, cmdline, env_vars = {}): """Executes a command via repository_ctx.execute() and returns the result. This method is useful for debugging purposes. For example, to print all @@ -245,11 +247,12 @@ def raw_exec(repository_ctx, cmdline): Args: repository_ctx: the repository_ctx cmdline: the list of args + env_vars: environment variables Returns: The 'exec_result' of repository_ctx.execute(). """ - return repository_ctx.execute(cmdline) + return repository_ctx.execute(cmdline, environment = env_vars) def files_exist(repository_ctx, paths, bash_bin = None): """Checks which files in paths exists. diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 76133d52b65b9c..605b8f62f84507 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -2319,6 +2319,7 @@ gpu_kernel_library( "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rocm_config", ]), ) @@ -2479,6 +2480,10 @@ cc_library( "@tsl//tsl/platform:logging", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", + ]) + if_rocm_is_configured([ + # keep sorted + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -2489,7 +2494,9 @@ gpu_kernel_library( deps = [ "//xla:shape_util", "//xla:types", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_test( diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 76342518614502..6b7d44a829faa0 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -820,15 +820,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "rocm_rpath", - linkopts = select({ - "//conditions:default": [ - "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", - ], - }), -) - cc_library( name = "stream_executor_rocm", tags = [ @@ -837,12 +828,12 @@ cc_library( ], deps = [ ":rocm_platform_id", - ":rocm_rpath", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", + "@local_config_rocm//rocm:rocm_rpath", ] + if_static( [":all_runtime"], ), diff --git a/xla/tsl/platform/default/BUILD b/xla/tsl/platform/default/BUILD index f95ba7897dde37..9f8dc1d79cb598 100644 --- a/xla/tsl/platform/default/BUILD +++ b/xla/tsl/platform/default/BUILD @@ -1,5 +1,6 @@ # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "//xla/tsl:tsl.bzl", "if_cuda_tools", @@ -103,12 +104,16 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", - "@local_config_rocm//rocm:rocm_headers", "@local_config_tensorrt//:tensorrt_headers", "@tsl//tsl/platform:load_library", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", - ] + if_oss(["@local_config_nccl//:nccl_config"]), + ] + if_oss([ + "@local_config_nccl//:nccl_config", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( @@ -264,6 +269,7 @@ cc_library( name = "load_library", srcs = ["load_library.cc"], hdrs = ["@tsl//tsl/platform:load_library.h"], + linkstatic = True, tags = [ "manual", "no_oss", @@ -271,7 +277,9 @@ cc_library( ], deps = [ "@com_google_absl//absl/status", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_rpath", + ]), ) cc_library( @@ -393,6 +401,7 @@ cc_library( "nobuilder", ], deps = [ + "@local_config_rocm//rocm:rocm_config", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", From 1c175363743c52b2b1feee93036dacfe365da97c Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 8 Jan 2025 06:04:47 -0800 Subject: [PATCH 33/45] Fix undefined behavior of mismatch in coordination service. `std::mismatch` should be called with an end iterator as the second argument if there is no guarantee on element count in the second range. PiperOrigin-RevId: 713264159 --- .../coordination/coordination_service.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/xla/tsl/distributed_runtime/coordination/coordination_service.cc index d6175c1c1d5488..9efc66bdac7a31 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -1350,8 +1350,9 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( for (it = begin; it != kv_store_.end(); ++it) { // Stop once the next key does not have the directory prefix. Since keys are // ordered, none of the other keys would have a matching prefix. - if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first != - dir.end()) { + if (std::mismatch(dir.begin(), dir.end(), it->first.begin(), + it->first.end()) + .first != dir.end()) { break; } KeyValueEntry kv; @@ -1373,8 +1374,9 @@ absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( auto begin = kv_store_.lower_bound(dir); std::map::iterator end; for (end = begin; end != kv_store_.end(); end++) { - if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first != - dir.end()) + if (std::mismatch(dir.begin(), dir.end(), end->first.begin(), + end->first.end()) + .first != dir.end()) break; } kv_store_.erase(begin, end); From 122cf084581c28b04bbf3d96af826e84e1e3d38a Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Wed, 8 Jan 2025 06:37:21 -0800 Subject: [PATCH 34/45] [xla:gpu] fix bug in counting good autotuner configs Move comparison of executable != nullptr _before_ calling std::move(executable). This is really only used for logging, but definitely adds confusion to the logs when it's always 0 :). PiperOrigin-RevId: 713272260 --- xla/service/gpu/autotuning/gemm_fusion_autotuner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index ba6743ee4801a4..39f31c0dcf0b5f 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -1016,11 +1016,11 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, << " with config '" << ConfigToString(config) << "'\nFused HLO computation:\n" << fusion->fused_instructions_computation()->ToString(); + log(*executable != nullptr); if (*executable != nullptr) { absl::MutexLock lock(&results_mu); results[fusion].push_back({config, std::move(*executable)}); } - log(*executable != nullptr); counter.DecrementCount(); }); } @@ -1047,10 +1047,10 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, TF_ASSIGN_OR_RETURN( std::unique_ptr executable, compile(fusion, config, gemm_config_set.size() > 1)); + log(executable != nullptr); if (executable != nullptr) { results[fusion].push_back({config, std::move(executable)}); } - log(executable != nullptr); } } } From b89b28f2da640a429cd405ef116d847c90c74ee1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 8 Jan 2025 06:55:38 -0800 Subject: [PATCH 35/45] [pjrt] Removed unused CreateDeviceToHostChannelHandle, CreateChannelHandle and SupportsSendRecvCallbacks PiperOrigin-RevId: 713276521 --- xla/pjrt/cpu/cpu_client.h | 7 ------- xla/pjrt/pjrt_c_api_client.h | 16 ---------------- xla/pjrt/pjrt_client.h | 13 ------------- xla/pjrt/pjrt_stream_executor_client.h | 7 ------- xla/pjrt/tf_pjrt_client.h | 6 ------ 5 files changed, 49 deletions(-) diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index 2a1517a1b53fc4..e325e15e291373 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -202,13 +202,6 @@ class TfrtCpuClient final : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return Unimplemented("CreateChannelHandle not implemented."); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); - } - absl::Status Defragment() override { return Unimplemented("Defragment not implemented."); } diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 03e41ec3985903..fe98aa5ecce399 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -401,28 +401,12 @@ class PjRtCApiClient : public PjRtClient { "this feature."); } - absl::StatusOr CreateChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateChannelHandle. Please report an " - "issue at https://github.com/google/jax/issues if you need this " - "feature."); - } - - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateDeviceToHostChannelHandle. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } - absl::Status Defragment() override { return Unimplemented( "PJRT C API does not support Defragment. Please report an issue at " "https://github.com/google/jax/issues if you need this feature."); } - bool SupportsSendRecvCallbacks() const override { return true; } - const PJRT_Api* pjrt_c_api() const; PJRT_Client* pjrt_c_client() { return c_client_.get(); } diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 0b1da9ef4660a1..c0a07ae66d4e51 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -1070,25 +1070,12 @@ class PjRtClient { "MakeCrossHostReceiveBuffersForGather is not implemented."); } - // Create ChannelHandles for XLA send/recv. - virtual absl::StatusOr CreateChannelHandle() { - return Unimplemented("CreateChannelHandle is not implemented."); - } - virtual absl::StatusOr CreateDeviceToHostChannelHandle() { - return Unimplemented("CreateDeviceToHostChannelHandle is not implemented."); - } - // TODO(zhangqiaorjc): Experimental API to be removed. // Defragment device memory. virtual absl::Status Defragment() { return Unimplemented("Defragment is not implemented."); } - // If false, this client does not support send/recv host callbacks, and - // callers should not set the `send_callbacks` and `recv_callbacks` arguments - // in ExecuteOptions. - virtual bool SupportsSendRecvCallbacks() const { return false; } - // Return the PjRtHostMemoryForDeviceManager for this client. It can be // nullptr if the implementation does not provide one. virtual PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index 394777b07ff477..f753df6d6fcc29 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -394,13 +394,6 @@ class PjRtStreamExecutorClient : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return client()->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return client()->CreateDeviceToHostChannelHandle(); - } - // TODO(zhangqiaorjc): Experimental. Will be removed. absl::Status Defragment() override { return Unimplemented("Defragment not implemented"); diff --git a/xla/pjrt/tf_pjrt_client.h b/xla/pjrt/tf_pjrt_client.h index 8933a2482c8683..49b8d5db5e92ec 100644 --- a/xla/pjrt/tf_pjrt_client.h +++ b/xla/pjrt/tf_pjrt_client.h @@ -340,12 +340,6 @@ class TfPjRtClient : public PjRtClient { return wrapped_->MakeCrossHostReceiveBuffersForGather( shapes, std::move(gather_details), device, std::move(notifier)); } - absl::StatusOr CreateChannelHandle() override { - return wrapped_->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return wrapped_->CreateDeviceToHostChannelHandle(); - } absl::StatusOr GetTopologyDescription() const override { return wrapped_->GetTopologyDescription(); From 0ffff6cd4fdfe6bc10a7e5bd8f46899349a210da Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 8 Jan 2025 06:57:59 -0800 Subject: [PATCH 36/45] [pjrt] Removed unused prefer_to_retain_reference argument from RecordUsage It was always set to false by the callers. PiperOrigin-RevId: 713277020 --- xla/pjrt/pjrt_stream_executor_client.cc | 67 +++---------------------- 1 file changed, 8 insertions(+), 59 deletions(-) diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 39b0d9740afc99..35a8267ae14868 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -347,32 +347,11 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { // after the usage of device_buffer was enqueued. // usage_stream: the stream the operation using device_buffer // was enqueued on. -// prefer_to_retain_reference: relevant only for the compute synchronous -// allocation model. If true, retain a reference -// to device_buffer until after the operation -// completes. If false then the compute stream -// will have to be synchronized past event before -// device_buffer can be freed. -// -// prefer_to_retain_reference encodes a heuristic set by the caller for the -// compute synchronous model: -// -// Generally when a buffer is the destination of a copy to a device, it will -// subsequently be used on the device's compute stream before being freed. In -// that case, there is no need to retain a reference to the buffer. If the -// buffer is freed before being used on the compute stream, the free will be -// delayed until the host knows that event has completed, but this is expected -// to be uncommon. -// -// When a buffer is the source of a copy from a device, we need to either retain -// a reference to the buffer until the copy completes or serialize the compute -// stream behind the copy. It is often better to retain a reference since while -// that keeps memory alive longer, it avoids stalling the compute stream. void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, LocalDeviceState* buffer_local_device, LocalDeviceState* stream_local_device, std::shared_ptr event, - se::Stream* usage_stream, bool prefer_to_retain_reference, + se::Stream* usage_stream, std::vector>* buffers_to_release = nullptr) { tsl::profiler::TraceMe traceme("RecordUsage"); @@ -382,11 +361,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, (stream_local_device != buffer_local_device) || // In the synchronous allocation model, always retain a reference. (stream_local_device->allocation_model() == - LocalDeviceState::kSynchronous) || - // In the compute synchronous model, use the caller's heuristic. - (stream_local_device->allocation_model() == - LocalDeviceState::kComputeSynchronized && - prefer_to_retain_reference); + LocalDeviceState::kSynchronous); if (retain_buffer_until_completion) { if (buffers_to_release) { buffers_to_release->push_back(device_buffer.buffer()); @@ -415,15 +390,8 @@ absl::Status AddDestinationBufferSynchronization( } definition_event->SetSequencingEvent(std::move(event_or).value(), copy_stream); - // prefer_to_retain_reference=false means don't retain a memory reference - // until the transfer is complete when using the ComputeSynchronized - // allocation model. This is a heuristic because in the common case - // destination buffers will be used on the compute stream and therefore don't - // require any synchronization before being freed. If the buffer is allocated - // and never used, the free will take longer and this is assumed to be ok. RecordUsage(std::move(device_buffer), local_device, local_device, - definition_event, copy_stream, - /*prefer_to_retain_reference=*/false); + definition_event, copy_stream); return absl::OkStatus(); } @@ -583,16 +551,9 @@ AllocateDestinationBuffer( if (on_device_shape.IsTuple()) { // Add a usage hold for the tuple table write and immediately convert it to - // the appropriate form of synchronization. prefer_to_retain_reference=false - // means don't retain a memory reference until the transfer is complete when - // using the ComputeSynchronized allocation model. This is a heuristic - // because in the common case destination buffers will be used on the - // compute stream and therefore don't require any synchronization before - // being freed. If the buffer is allocated and never used, the free will - // take longer and this is assumed to be ok. + // the appropriate form of synchronization. RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, - definition_events.back(), tuple_table_stream, - /*prefer_to_retain_reference=*/false); + definition_events.back(), tuple_table_stream); } return py_buffer; @@ -1954,8 +1915,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( std::move(async_copy_to_device)); RecordUsage(std::move(dst_device_buffer), transfer_local_device, - transfer_local_device, copy_event, transfer_stream, - /*prefer_to_retain_reference=*/false); + transfer_local_device, copy_event, transfer_stream); return std::pair, std::shared_ptr>( @@ -2039,12 +1999,6 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( std::unique_ptr& buffer = buffer_and_event.first; std::shared_ptr& event = buffer_and_event.second; - // prefer_to_retain_reference=*/true means that, when using the - // ComputeSynchronized allocation model, retain a reference to the - // src_device_buffer until the copy completes. This is a heuristic; the - // alternative is to ensure, before freeing the buffer, that the compute - // stream is synchronized past the transfer, but it seems better to hold onto - // the buffer too long than to stall the compute stream. src_device_buffer.ConvertUsageHold(transfer_stream, event, /*reference_held=*/true); @@ -2340,7 +2294,7 @@ absl::StatusOr> OutputBufferHelper( memory_space); RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), - /*prefer_to_retain_reference=*/false, &buffers_to_release); + &buffers_to_release); return std::unique_ptr(std::move(pjrt_buffer)); } @@ -3118,14 +3072,9 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( buffers_to_release)); for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) { - // prefer_to_retain_reference=false because when using the - // ComputeSynchronized allocation model we don't need to retain a reference - // to the device_buffer during execution because by definition the compute - // stream is synchronized past the execution. if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) { RecordUsage(std::move(b), device_state, device_state, definition_event, - stream, - /*prefer_to_retain_reference=*/false, &buffers_to_release); + stream, &buffers_to_release); } else { CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation); b.ConfirmDonation(); From 9e1504a532194c0f94186b3c257289e4e82b4de4 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 8 Jan 2025 07:01:12 -0800 Subject: [PATCH 37/45] #sdy use `applyPatternsGreedily` with `config.fold=false` and `config.cseConstants=false` to avoid constant folding and CSE which is expensive. PiperOrigin-RevId: 713277781 --- xla/service/spmd/shardy/round_trip_common/BUILD | 1 + .../spmd/shardy/round_trip_common/pipeline_passes.cc | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/xla/service/spmd/shardy/round_trip_common/BUILD b/xla/service/spmd/shardy/round_trip_common/BUILD index 48fb0862daa5ff..b3ab4176a0be73 100644 --- a/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/xla/service/spmd/shardy/round_trip_common/BUILD @@ -110,6 +110,7 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index c4d7a13a55bb99..1438d40cf61fc8 100644 --- a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" @@ -48,7 +49,13 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { // We need to canonicalize redundant mhlo::GetTupleElementOp and // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before // `createOpenWhileFreeVarsShardingPass`. - pm.addPass(mlir::createCanonicalizerPass()); + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; + config.fold = false; + config.cseConstants = false; + // TODO(tomnatan): consider only enabling the specific passes we need. + pm.addPass(mlir::createCanonicalizerPass(config)); // Shardy is currently operating on stablehlo, since this is what JAX // emits. Long term shardy will be fully dialect agnostic, and both mhlo // and stablehlo can register their ops for sdy propagation. From 9eb2755078626330beac74ac5d2d8aa099f2c1e1 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Wed, 8 Jan 2025 07:17:32 -0800 Subject: [PATCH 38/45] Moving AtomicRMW utilities out of lower_tensors. These are going to also be used in vectorizing AtomicRMW in follow-up changes. PiperOrigin-RevId: 713281944 --- xla/backends/gpu/codegen/transforms/BUILD | 1 + .../codegen/transforms/atomic_rmw_utils.cc | 120 ++++++++++++++++++ .../gpu/codegen/transforms/lower_tensors.cc | 65 ---------- xla/backends/gpu/codegen/transforms/passes.h | 8 +- 4 files changed, 128 insertions(+), 66 deletions(-) create mode 100644 xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc diff --git a/xla/backends/gpu/codegen/transforms/BUILD b/xla/backends/gpu/codegen/transforms/BUILD index 3894a53825be86..090cf3d26325ab 100644 --- a/xla/backends/gpu/codegen/transforms/BUILD +++ b/xla/backends/gpu/codegen/transforms/BUILD @@ -38,6 +38,7 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "atomic_rmw_utils.cc", "convert_float_nvidia.cc", "convert_xla_gpu_pure_call_ops.cc", "erase_dead_functions.cc", diff --git a/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc b/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc new file mode 100644 index 00000000000000..ad1c769447e012 --- /dev/null +++ b/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc @@ -0,0 +1,120 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/ilist.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" + +namespace xla { +namespace gpu { + +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" + +using mlir::Operation; +using mlir::Type; +using mlir::Value; + +namespace ml = ::mlir::LLVM; +namespace arith = ::mlir::arith; + +bool IsAtomicIntegral(Type element_type) { + if (!element_type.isInteger()) { + return false; + } + unsigned element_bitwidth = element_type.getIntOrFloatBitWidth(); + return element_bitwidth == 32 || element_bitwidth == 64; +} + +std::optional GetAtomicBinOp(Operation* modifier_op, + Type element_type) { + return llvm::TypeSwitch>( + modifier_op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) + .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) + .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) + // Integer operations. + .Case([&](arith::AddIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::add) + : std::nullopt; + }) + .Case([&](arith::MaxUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umax) + : std::nullopt; + }) + .Case([&](arith::MinUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umin) + : std::nullopt; + }) + .Case([&](arith::MaxSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::max) + : std::nullopt; + }) + .Case([&](arith::MinSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::min) + : std::nullopt; + }) + .Default([](Operation* op) { return std::nullopt; }); +} + +// Returns atomic op modifier and the atomic bin op kind. +std::optional> GetAtomicModifierParameters( + AtomicRMWOp op) { + Type element_type = op.getInput().getType().getElementType(); + auto& operations = op.getBody()->getOperations(); + auto terminator = op.getBody()->getTerminator(); + if (operations.size() > 2) { + return std::nullopt; + } + // If the body contains only the terminator, then it is an atomic store. + if (operations.size() == 1) { + // TODO(b/336367145): Support complex atomic store. + if (element_type.isF32() || IsAtomicIntegral(element_type)) { + return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); + } + return std::nullopt; + } + // Match the kind of the atomic op. + mlir::Operation* modifier_op = &operations.front(); + auto kind = GetAtomicBinOp(modifier_op, element_type); + if (!kind.has_value()) { + return std::nullopt; + } + // Find the modifier arg that does not match the argument of `atomic_rmw` + // body. + Value block_arg = op.getBody()->getArgument(0); + Value modifier_arg = modifier_op->getOperand(0) == block_arg + ? modifier_op->getOperand(1) + : modifier_op->getOperand(0); + return std::make_pair(modifier_arg, *kind); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 822ba8498800eb..0fff3bc811bbca 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -755,71 +755,6 @@ class RewriteAtomicRMW : public OpRewritePattern { } private: - // Returns atomic op modifier and the atomic bin op kind. - std::optional> GetAtomicModifierParameters( - AtomicRMWOp op) const { - Type element_type = op.getInput().getType().getElementType(); - auto& operations = op.getBody()->getOperations(); - auto terminator = op.getBody()->getTerminator(); - if (operations.size() > 2) { - return std::nullopt; - } - // If the body contains only the terminator, then it is an atomic store. - if (operations.size() == 1) { - // TODO(b/336367145): Support complex atomic store. - if (element_type.isF32() || IsAtomicIntegral(element_type)) { - return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); - } - return std::nullopt; - } - // Match the kind of the atomic op. - mlir::Operation* modifier_op = &operations.front(); - std::optional kind = - llvm::TypeSwitch>( - modifier_op) - // Floating-point operations. - .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) - .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) - .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) - // Integer operations. - .Case([&](arith::AddIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::add) - : std::nullopt; - }) - .Case([&](arith::MaxUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umax) - : std::nullopt; - }) - .Case([&](arith::MinUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umin) - : std::nullopt; - }) - .Case([&](arith::MaxSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::max) - : std::nullopt; - }) - .Case([&](arith::MinSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::min) - : std::nullopt; - }) - .Default([](Operation* op) { return std::nullopt; }); - if (!kind.has_value()) { - return std::nullopt; - } - // Find the modifier arg that does not match the argument of `atomic_rmw` - // body. - Value block_arg = op.getBody()->getArgument(0); - Value modifier_arg = modifier_op->getOperand(0) == block_arg - ? modifier_op->getOperand(1) - : modifier_op->getOperand(0); - return std::make_pair(modifier_arg, *kind); - } - // Certain computations, such as floating-point addition and integer // maximization, can be simply implemented using an LLVM atomic instruction. // If "computation" is one of this kind, emits code to do that and returns diff --git a/xla/backends/gpu/codegen/transforms/passes.h b/xla/backends/gpu/codegen/transforms/passes.h index db6f75779b93b1..98b6963a18148c 100644 --- a/xla/backends/gpu/codegen/transforms/passes.h +++ b/xla/backends/gpu/codegen/transforms/passes.h @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "xla/hlo/analysis/indexing_map.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -31,6 +33,10 @@ namespace gpu { #define GEN_PASS_DECL #include "xla/backends/gpu/codegen/transforms/passes.h.inc" +// Returns atomic op modifier and the atomic bin op kind. +std::optional> +GetAtomicModifierParameters(AtomicRMWOp op); + std::unique_ptr CreateConvertFloatNvidiaPass(); std::optional> MaybeCreateConvertFloatNvidiaPass( const se::DeviceDescription& device_description); From 22aad7facf5331b70a696808f9cc587acd047058 Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 8 Jan 2025 07:18:49 -0800 Subject: [PATCH 39/45] [XLA:CPU] Remove no thunks tests for exhaustive_binary_test PiperOrigin-RevId: 713282226 --- xla/tests/exhaustive/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index 0b0c52554e9a5f..9b60d841dd4808 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -250,7 +250,6 @@ exhaustive_xla_test( shard_count = 50, tags = [ "optonly", - "test_xla_cpu_no_thunks", # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", ], From cee7d8ee256186dcd12dbafc7e596d2f1012e4f3 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jan 2025 08:12:48 -0800 Subject: [PATCH 40/45] [XLA:GPU] Fix sorted scatter with imperfectly tiled indices. The algorithm was checking whether to write to the output or not by comparing the current slice index with the number of indices per warp. It works only when we have perfectly tiled indices, e.g. 50 indices per warp with a total of 2000 indices. As soon as we have 2001 indices, the last warp processes 1 update slice, but never writes it down. Also simplified the logic for the update loop that accumulates elements in registers. Instead of having scf.if inside of xla.loop, now we have two different xla.loops in different cases of scf.if, that either overwrite the accumulator or combine it with the new data. PiperOrigin-RevId: 713296321 --- xla/service/gpu/fusions/scatter_mlir.cc | 160 ++++++++++-------- xla/service/gpu/fusions/scatter_mlir.h | 40 +++-- .../fusions/tests/scatter/sorted_indices.hlo | 8 +- 3 files changed, 118 insertions(+), 90 deletions(-) diff --git a/xla/service/gpu/fusions/scatter_mlir.cc b/xla/service/gpu/fusions/scatter_mlir.cc index 5163375e38cdb0..4f98d4bfd61dcd 100644 --- a/xla/service/gpu/fusions/scatter_mlir.cc +++ b/xla/service/gpu/fusions/scatter_mlir.cc @@ -301,8 +301,8 @@ class EmitterHelper { Value write_to_output_required, ValueRange thread_and_block_ids, Value iv, const IndexingMap& slice_indexing, - Value offsets_changed, ValueRange offsets, - Value accumulator, Value output_tensor) const; + ValueRange offsets, Value accumulator, + Value output_tensor) const; private: Value GetElement(ImplicitLocOpBuilder& b, int operand_index, @@ -371,8 +371,8 @@ SmallVector EmitterHelper::WriteAccumulatedElementToOutput( Value EmitterHelper::WriteAccumulatorToOutput( ImplicitLocOpBuilder& b, Value write_to_output_required, ValueRange thread_and_block_ids, Value iv, - const IndexingMap& slice_indexing, Value offsets_changed, - ValueRange offsets, Value accumulator, Value output_tensor) const { + const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator, + Value output_tensor) const { SmallVector dims = Pack({thread_and_block_ids, iv}); return EmitUpdateIf( b, write_to_output_required, output_tensor, @@ -721,11 +721,15 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( // Prepare loop initial values. Inits are packed as // [index_changed, is_inbounds, index_0, ..., accumulator]. Value is_inbounds_init = b.create(0, b.getI1Type()); + Value slice_id_init = b.create(0); std::vector indices_init(description_.index_vector_length, b.create(-1)); Value accumulator_init = InitializeAccumulator(b); SmallVector inits = - Pack({indices_init, is_inbounds_init, accumulator_init, output_tensor}); + Pack({slice_id_init, indices_init, is_inbounds_init, accumulator_init, + output_tensor}); + + int64_t output_rank = description_.output_shape.size(); auto loop_over_indices_fn = [&](ImplicitLocOpBuilder& nested_b, ValueRange ivs, @@ -733,14 +737,13 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( ValueRange outer_iter_args) -> SmallVector { // Unpack the iter_args. SmallVector iter_args_unpack = - Unpack(outer_iter_args, {description_.index_vector_length, 1, 1, 1}); - ValueRange trimmed_offsets = iter_args_unpack[0]; - Value iter_is_inbounds = iter_args_unpack[1].front(); - Value iter_acc = iter_args_unpack[2].front(); - Value iter_output = iter_args_unpack[3].front(); + Unpack(outer_iter_args, {1, description_.index_vector_length, 1, 1, 1}); + ValueRange trimmed_offsets = iter_args_unpack[1]; + Value iter_is_inbounds = iter_args_unpack[2].front(); + Value iter_acc = iter_args_unpack[3].front(); + Value iter_output = iter_args_unpack[4].front(); Value iter_slice_id = ivs.front(); - int64_t output_rank = description_.output_shape.size(); SmallVector offsets = PadWithZeros(trimmed_offsets, output_rank, nested_b); @@ -767,78 +770,95 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( b.create(offsets_changed, iter_is_inbounds)); iter_output = helper.WriteAccumulatorToOutput( b, write_to_output_required, thread_and_block_ids, iter_slice_id, - slice_indexing, offsets_changed, offsets, iter_acc, iter_output); + slice_indexing, offsets, iter_acc, iter_output); // Update `is_inbounds` if the offsets changed. Value new_is_inbounds = UpdateIsInbounds( nested_b, iter_is_inbounds, offsets_changed, new_offsets, description_.slice_shape, description_.output_shape); - // Update accumulator and/or output. - auto is_last_iteration = nested_b.create( - arith::CmpIPredicate::eq, iter_slice_id, - b.create(num_indices_per_warp_ - 1)); - - SmallVector acc_and_output = {iter_acc, iter_output}; - auto loop_over_slices_fn = - [&](ImplicitLocOpBuilder& update_loop_b, ValueRange accumulator_indices, - ValueRange slice_indices, - ValueRange inner_iter_args) -> SmallVector { - Value acc_arg = inner_iter_args.front(); - Value output_arg = inner_iter_args.back(); - auto update_elem = helper.GetUpdateElement(update_loop_b, slice_indices); - auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); - // If the index changed, overwrite the accumulator element, otherwise - // apply the scatter computation to reduce with the accumulator element. - auto updated_accumulator = - update_loop_b - .create( - offsets_changed, - [&](OpBuilder& then_b, Location then_loc) -> void { - Value updated_accumulator = then_b.create( - then_loc, update_elem, acc_arg, acc_ind_opfold); - then_b.create(then_loc, updated_accumulator); - }, - [&](OpBuilder& else_b, Location else_loc) -> void { - ImplicitLocOpBuilder implicit_else_b(else_loc, else_b); - Value accumulator_elem = - implicit_else_b.create( - acc_arg, acc_ind_opfold); - auto reduced_val = mlir_converter::InlineBlock( - implicit_else_b, helper.GetReducer().getBody().front(), - {accumulator_elem, update_elem})[0]; - Value updated_ac = implicit_else_b.create( - reduced_val, acc_arg, acc_ind_opfold); - implicit_else_b.create(updated_ac); - }) - .getResult(0); - // If this is the last index, that this warp has to process, then we write - // to the output. - auto updated_output = - EmitUpdateIf(update_loop_b, is_last_iteration, output_arg, - [&](ImplicitLocOpBuilder& nested_b) { - return helper.WriteAccumulatedElementToOutput( - nested_b, updated_accumulator, accumulator_indices, - slice_indices, new_offsets, output_arg); - }) - .front(); - return {updated_accumulator, updated_output}; + // Emits a loop that overwrites the accumulator with the new update elements + // if the offsets changed. + auto emit_overwrite_accumulator_fn = [&](OpBuilder& then_b, + Location then_loc) -> void { + ImplicitLocOpBuilder implicit_then_b(then_loc, then_b); + auto then_results = EmitXlaLoopOp( + implicit_then_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + return update_loop_b + .create(then_loc, update_elem, acc_arg, + acc_ind_opfold) + ->getResults(); + }); + implicit_then_b.create(then_loc, then_results); + }; + // Emits a loop that combines the accumulator with the new update elements + // if the offsets did not change. + auto emit_combine_accumulator_fn = [&](OpBuilder& else_b, + Location else_loc) -> void { + ImplicitLocOpBuilder implicit_else_b(else_loc, else_b); + auto else_results = EmitXlaLoopOp( + implicit_else_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + Value accumulator_elem = update_loop_b.create( + acc_arg, acc_ind_opfold); + auto reduced_val = mlir_converter::InlineBlock( + update_loop_b, helper.GetReducer().getBody().front(), + {accumulator_elem, update_elem})[0]; + return update_loop_b + .create(reduced_val, acc_arg, acc_ind_opfold) + ->getResults(); + }); + implicit_else_b.create(else_results); }; - auto updated_accumulator_and_output = - EmitUpdateIf(nested_b, new_is_inbounds, acc_and_output, + auto updated_accumulator = + EmitUpdateIf(nested_b, new_is_inbounds, {iter_acc}, [&](ImplicitLocOpBuilder& if_b) { - return EmitXlaLoopOp( - if_b, Pack({thread_and_block_ids, iter_slice_id}), - acc_and_output, slice_indexing, loop_over_slices_fn); - }); - SmallVector updated_if_loop_results = Pack( - {new_trimmed_offsets, new_is_inbounds, updated_accumulator_and_output}); + return nested_b + .create(offsets_changed, + emit_overwrite_accumulator_fn, + emit_combine_accumulator_fn) + .getResults(); + }) + .front(); + SmallVector updated_if_loop_results = + Pack({iter_slice_id, new_trimmed_offsets, new_is_inbounds, + updated_accumulator, iter_output}); return updated_if_loop_results; }; auto loop_over_indices_results = EmitXlaLoopOp(b, thread_and_block_ids, inits, thread_id_to_update_id_map, loop_over_indices_fn); - b.create(loop_over_indices_results.back()); + + // Write the accumulator to the output tensor. + SmallVector loop_over_indices_results_unpacked = + Unpack(loop_over_indices_results, + {1, description_.index_vector_length, 1, 1, 1}); + Value result_slice_id = loop_over_indices_results_unpacked[0].front(); + auto result_offsets = + PadWithZeros(loop_over_indices_results_unpacked[1], output_rank, b); + Value result_is_inbounds = loop_over_indices_results_unpacked[2].front(); + Value result_acc = loop_over_indices_results_unpacked[3].front(); + Value result_output = loop_over_indices_results_unpacked[4].front(); + result_output = helper.WriteAccumulatorToOutput( + b, result_is_inbounds, thread_and_block_ids, result_slice_id, + slice_indexing, result_offsets, result_acc, result_output); + + b.create(result_output); return absl::OkStatus(); } diff --git a/xla/service/gpu/fusions/scatter_mlir.h b/xla/service/gpu/fusions/scatter_mlir.h index 6b555c17c0490c..676123d74b11a2 100644 --- a/xla/service/gpu/fusions/scatter_mlir.h +++ b/xla/service/gpu/fusions/scatter_mlir.h @@ -147,28 +147,36 @@ class ScatterWithDistributedUpdates : public MlirScatterFusion { %acc = vector // #indices_map - for %i = 0 to %num_indices_per_warp_ step 1 { - %new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i)) - %indices_changed = EmitInequalityCheck(%new_indices, %indices) - if (%indices_changed && %i != 0) { - %output_tensor = WriteAccumulatorToTheOutput(%acc, %output_tensor); - } - if (%indices_changed) { - %inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape) - } - if (%inbounds) { + %updated_accumulator, %updated_out = for %i = 0 to %num_indices_per_warp_ { + %new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i)) + %indices_changed = EmitInequalityCheck(%new_indices, %indices) + if (%indices_changed && %i != 0) { + %output_tensor = WriteAccumulatorToOutput(%current_acc, %current_out); + } + if (%indices_changed) { + %inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape) + } + if (%inbounds) { + if (%indices_changed) { // updates_map(%i) for %j = 0 to %num_slice_iterations_per_warp step 1 { for %k = 0 to %vector_size step 1 { %update_elem = GetUpdateElement - %acc = %indices_changed ? %update_elem : Reduce(%update_elem, %acc) - if (%i = %num_indices_per_warp - 1) { - %output_tensor = WriteAccumulatorToTheOutput(%acc, %output_tensor); - } + %acc = %update_elem } } - } - } + } else { + // updates_map(%i) + for %j = 0 to %num_slice_iterations_per_warp step 1 { + for %k = 0 to %vector_size step 1 { + %update_elem = GetUpdateElement + %acc = Reduce(%update_elem, %acc) + } + } + } + } +} +%final_out = WriteAccumulatorToOutput(%updated_accumulator, %updated_out); */ class ScatterWithDistributedIndices : public MlirScatterFusion { public: diff --git a/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo b/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo index 69fdf05c86cd3e..332eb543af61b0 100644 --- a/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo +++ b/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo @@ -9,13 +9,13 @@ add { } scatter { %operand = f32[100] parameter(0) - %indices = s32[2000,1] parameter(1) - %update = f32[2000,32] parameter(2) + %indices = s32[2001,1] parameter(1) + %update = f32[2001,32] parameter(2) ROOT %scatter = f32[100] scatter( f32[100] %operand, - s32[2000,1] %indices, - f32[2000,32] %update + s32[2001,1] %indices, + f32[2001,32] %update ), update_window_dims={1}, inserted_window_dims={}, From d83a3159a630163ada28891c2b51c3c042024979 Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Wed, 8 Jan 2025 08:17:25 -0800 Subject: [PATCH 41/45] Passing device information to Vectorization pass. This will be needed when adding vectorization for AtomicRMW which will only be available for Hopper. PiperOrigin-RevId: 713297711 --- xla/backends/gpu/codegen/transforms/passes.h | 5 ++- xla/backends/gpu/codegen/transforms/passes.td | 5 +++ .../tests/vectorize_loads_stores.mlir | 3 +- .../transforms/vectorize_loads_stores.cc | 43 +++++++++++++++---- .../gpu/fusions/mlir/mlir_fusion_emitter.cc | 2 +- 5 files changed, 47 insertions(+), 11 deletions(-) diff --git a/xla/backends/gpu/codegen/transforms/passes.h b/xla/backends/gpu/codegen/transforms/passes.h index 98b6963a18148c..de12227f94c0cf 100644 --- a/xla/backends/gpu/codegen/transforms/passes.h +++ b/xla/backends/gpu/codegen/transforms/passes.h @@ -62,7 +62,10 @@ std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); -std::unique_ptr CreateVectorizeLoadsAndStoresPass(); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info = ""); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description); #define GEN_PASS_REGISTRATION #include "xla/backends/gpu/codegen/transforms/passes.h.inc" diff --git a/xla/backends/gpu/codegen/transforms/passes.td b/xla/backends/gpu/codegen/transforms/passes.td index 1b5ffbdb24636e..53b20387c62aad 100644 --- a/xla/backends/gpu/codegen/transforms/passes.td +++ b/xla/backends/gpu/codegen/transforms/passes.td @@ -256,6 +256,11 @@ def VectorizeLoadsAndStoresPass : "mlir::vector::VectorDialect", ]; + let options = [ + Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", + "Serialized stream_executor::GPUDeviceInfo proto.">, + ]; + let constructor = "CreateVectorizeLoadsAndStoresPass()"; } diff --git a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir index a3b7e816bb05fb..3f04219d0eeb17 100644 --- a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir @@ -1,5 +1,6 @@ // RUN: emitters_opt -allow-unregistered-dialect %s -split-input-file \ -// RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s +// RUN: -xla-gpu-vectorize-loads-stores="gpu_device_info='cuda_compute_capability {major: 6}'" -cse -canonicalize \ +// RUN: | FileCheck %s #map = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> diff --git a/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc b/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc index 8202ae05e8d076..19e6b7faf5e36a 100644 --- a/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc +++ b/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" @@ -40,7 +41,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -326,21 +329,45 @@ class VectorizeLoadsAndStoresPass : public impl::VectorizeLoadsAndStoresPassBase< VectorizeLoadsAndStoresPass> { public: + explicit VectorizeLoadsAndStoresPass( + const VectorizeLoadsAndStoresPassOptions& options) + : VectorizeLoadsAndStoresPassBase(options) {} + + explicit VectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} + void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (!gpu_device_info_.empty()) { + se::GpuDeviceInfoProto device_info; + CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_, + &device_info)); + device_description_ = se::DeviceDescription(device_info); + } + mlir::MLIRContext* mlir_context = &getContext(); + mlir::RewritePatternSet patterns(mlir_context); + patterns.add(mlir_context); + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } + + se::DeviceDescription device_description_; }; } // namespace -std::unique_ptr> -CreateVectorizeLoadsAndStoresPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info) { + VectorizeLoadsAndStoresPassOptions options; + options.gpu_device_info_ = gpu_device_info; + return std::make_unique(options); +} + +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) { + return std::make_unique(device_description); } } // namespace gpu diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index f859c70af94053..17d79786b802b9 100644 --- a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -608,7 +608,7 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addNestedPass(CreateVectorizeLoadsAndStoresPass()); + pm.addNestedPass(CreateVectorizeLoadsAndStoresPass(device)); pm.addNestedPass(CreateOptimizeLoopsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); From 61fcf4f4e224364604a8524f0fee32295400f091 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 8 Jan 2025 08:47:16 -0800 Subject: [PATCH 42/45] [xla:cpu] Add CpuClique to XLA:CPU collectives and use generic collectives APIs to acquire communicator in CollectiveThunk Implement Cliques support for XLA:CPU collectives for consistency with XLA:GPU. Further unification will be in followup CLs. PiperOrigin-RevId: 713305764 --- xla/backends/cpu/collectives/BUILD | 56 ++++++++ xla/backends/cpu/collectives/cpu_clique.cc | 59 +++++++++ xla/backends/cpu/collectives/cpu_clique.h | 42 ++++++ .../cpu/collectives/cpu_clique_key.cc | 59 +++++++++ xla/backends/cpu/collectives/cpu_clique_key.h | 44 +++++++ xla/backends/cpu/collectives/cpu_cliques.cc | 122 ++++++++++++++++++ xla/backends/cpu/collectives/cpu_cliques.h | 33 +++++ .../cpu/collectives/cpu_collectives.h | 19 +++ xla/backends/cpu/runtime/BUILD | 18 ++- xla/backends/cpu/runtime/collective_thunk.cc | 20 ++- xla/backends/cpu/runtime/thunk.cc | 7 +- xla/backends/cpu/runtime/thunk.h | 13 +- xla/core/collectives/clique.cc | 11 ++ xla/core/collectives/clique.h | 7 +- xla/core/collectives/clique_key.cc | 3 + xla/core/collectives/clique_key.h | 2 + xla/service/cpu/BUILD | 5 + xla/service/cpu/collectives_interface.h | 109 +++++++++++++++- 18 files changed, 603 insertions(+), 26 deletions(-) create mode 100644 xla/backends/cpu/collectives/cpu_clique.cc create mode 100644 xla/backends/cpu/collectives/cpu_clique.h create mode 100644 xla/backends/cpu/collectives/cpu_clique_key.cc create mode 100644 xla/backends/cpu/collectives/cpu_clique_key.h create mode 100644 xla/backends/cpu/collectives/cpu_cliques.cc create mode 100644 xla/backends/cpu/collectives/cpu_cliques.h diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index be10b5cafa1250..4363e15a7e13f1 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -14,6 +14,59 @@ package_group( ], ) +cc_library( + name = "cpu_clique_key", + srcs = ["cpu_clique_key.cc"], + hdrs = ["cpu_clique_key.h"], + deps = [ + "//xla/core/collectives:clique_key", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:casts", + ], +) + +cc_library( + name = "cpu_clique", + srcs = ["cpu_clique.cc"], + hdrs = ["cpu_clique.h"], + deps = [ + ":cpu_clique_key", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "cpu_cliques", + srcs = ["cpu_cliques.cc"], + hdrs = ["cpu_cliques.h"], + deps = [ + ":cpu_clique", + ":cpu_clique_key", + ":cpu_collectives", + "//xla:util", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "cpu_collectives", srcs = ["cpu_collectives.cc"], @@ -23,14 +76,17 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives", + "//xla/core/collectives:clique_id", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ], ) diff --git a/xla/backends/cpu/collectives/cpu_clique.cc b/xla/backends/cpu/collectives/cpu_clique.cc new file mode 100644 index 00000000000000..a81dd80392f9f1 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/logging.h" + +namespace xla::cpu { + +CpuClique::CpuClique(CpuCliqueKey key) : Clique({}), key_(std::move(key)) {} + +std::string CpuClique::DebugString() const { + std::string out = + absl::StrFormat("key: %s; size: %d; communicators: ", key_.ToString(), + num_communicators()); + int32_t cnt = 0; + ForEachComm([&](RankId rank, Communicator* comm) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%s]", rank.value(), + comm->ToString()); + }); + return out; +} + +absl::Status CpuClique::HealthCheck() const { + absl::Status health_check = absl::OkStatus(); + ForEachComm([&health_check](RankId rank, Communicator* comm) { + if (auto s = comm->HealthCheck(); !s.ok()) { + LOG(ERROR) << "CPU communicator error (rank " << rank << "): " << s; + if (health_check.ok()) health_check = std::move(s); // return first error + } + }); + return health_check; +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique.h b/xla/backends/cpu/collectives/cpu_clique.h new file mode 100644 index 00000000000000..e1ff3025a955b0 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ + +#include + +#include "absl/status/status.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" + +namespace xla::cpu { + +// A group of CPU communicators making up a clique. +class CpuClique final : public Clique { + public: + explicit CpuClique(CpuCliqueKey key); + + absl::Status HealthCheck() const final; + + std::string DebugString() const final; + + private: + CpuCliqueKey key_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ diff --git a/xla/backends/cpu/collectives/cpu_clique_key.cc b/xla/backends/cpu/collectives/cpu_clique_key.cc new file mode 100644 index 00000000000000..b66c844d4983ed --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique_key.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/casts.h" + +namespace xla::cpu { + +bool CpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { + auto* other_cpu = tsl::down_cast(&other); + if (other_cpu == nullptr) return false; + + return absl::c_all_of(devices(), [&](GlobalDeviceId id) { + return absl::c_linear_search(other_cpu->devices(), id); + }); +} + +std::string CpuCliqueKey::ToString() const { + return absl::StrFormat("devices=[%s]", GlobalDeviceIdsToString(devices())); +} + +void CpuCliqueKey::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), devices()); +} + +bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() == b.devices(); +} + +bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() < b.devices(); +} + +bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() > b.devices(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique_key.h b/xla/backends/cpu/collectives/cpu_clique_key.h new file mode 100644 index 00000000000000..30b257c1a0d0c0 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ + +#include + +#include "absl/hash/hash.h" +#include "xla/core/collectives/clique_key.h" + +namespace xla::cpu { + +// Clique key for identifying a particular CPU collectives clique. +class CpuCliqueKey final : public CliqueKey { + public: + using CliqueKey::CliqueKey; + + bool IsSubsetOf(const CliqueKey& other) const final; + std::string ToString() const final; + + friend bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b); + + private: + void HashValue(absl::HashState state) const final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ diff --git a/xla/backends/cpu/collectives/cpu_cliques.cc b/xla/backends/cpu/collectives/cpu_cliques.cc new file mode 100644 index 00000000000000..6e6c437256ad12 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.cc @@ -0,0 +1,122 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_cliques.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/collectives/cpu_clique.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::cpu { + +//===----------------------------------------------------------------------===// +// ProcessCpuCliques +//===----------------------------------------------------------------------===// + +namespace { + +// CpuClique is not thread-safe, so we wrap it in a thread-safe container as we +// create new communicators lazily and potentially from multiple threads. +struct ThreadSafeClique { + explicit ThreadSafeClique(CpuCliqueKey key) : clique(key) {} + + absl::Mutex mu; + CpuClique clique ABSL_GUARDED_BY(mu); +}; + +// Container for initialized and ready to use CPU cliques. In contrast to GPU +// cliques, CPU cliques are not lockable, and we create communicators lazily +// when needed. +struct ProcessCpuCliques { + absl::Mutex mu; + absl::node_hash_map map ABSL_GUARDED_BY(mu); +}; +} // namespace + +// Returns process-local CPU cliques. +static ProcessCpuCliques& GetProcessCpuCliques() { + static auto* cliques = new ProcessCpuCliques; + return *cliques; +} + +//===----------------------------------------------------------------------===// + +// TODO(b/380457503): Consider switching to a lockable CPU clique model similar +// to GPU cliques, and creating all communicators upfront. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank) { + VLOG(3) << "Acquire communicator for clique key " << clique_key.ToString() + << " and rank " << rank; + + ProcessCpuCliques& cliques = GetProcessCpuCliques(); + + // Synchronize access to the process cliques. + ThreadSafeClique& thread_safe_clique = [&]() -> ThreadSafeClique& { + absl::MutexLock lock(&cliques.mu); + auto [it, emplaced] = cliques.map.try_emplace(clique_key, clique_key); + return it->second; + }(); + + // Check if we already have a communicator for this rank. + std::optional comm = [&]() -> std::optional { + absl::MutexLock lock(&thread_safe_clique.mu); + return thread_safe_clique.clique.comm(rank); + }(); + + if (comm.has_value()) return *comm; + + VLOG(3) << "Create a new communicator for clique key " + << clique_key.ToString() << " and rank " << rank; + + // Create a new communicator and add it to the clique. + CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank); + CpuCollectives::Config config; + + TF_ASSIGN_OR_RETURN( + std::vector> communicators, + collectives->CreateCommunicators(clique_key.num_devices(), clique_key, + std::nullopt, {device_rank}, config)); + + // We expect to create communicators lazily on at a time. + if (communicators.size() != 1) { + return Internal( + "Expected to create a single communicator for a clique key %s and rank " + "%d, but got %d", + clique_key.ToString(), rank.value(), communicators.size()); + } + + absl::MutexLock lock(&thread_safe_clique.mu); + TF_RETURN_IF_ERROR(thread_safe_clique.clique.AddComm( + rank, std::move(communicators.front()))); + + return *thread_safe_clique.clique.comm(rank); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_cliques.h b/xla/backends/cpu/collectives/cpu_cliques.h new file mode 100644 index 00000000000000..b42774619fe4b2 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.h @@ -0,0 +1,33 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::cpu { + +// Returns a communicator for a given clique key and rank. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ diff --git a/xla/backends/cpu/collectives/cpu_collectives.h b/xla/backends/cpu/collectives/cpu_collectives.h index a728e7cd3a399d..330b35f52146d1 100644 --- a/xla/backends/cpu/collectives/cpu_collectives.h +++ b/xla/backends/cpu/collectives/cpu_collectives.h @@ -16,11 +16,19 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ #define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ +#include +#include +#include + #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { @@ -50,6 +58,17 @@ class CpuCollectives : public Collectives { absl::Duration timeout_; }; + absl::StatusOr CreateUniqueCliqueId() const final { + return Unimplemented("CPU collectives do not support clique ids"); + } + + absl::StatusOr>> SplitCommunicators( + absl::Span comms, int32_t color, + absl::Span keys, const Config& config) final { + return Unimplemented( + "CPU collectives do not support communicator splitting"); + } + // Tries to cast a Collectives::Device to a CpuCollectives::Device. static absl::StatusOr TryCast( const Collectives::Device* device); diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index a83a5e51dca28d..cd1e7b89e9c1ae 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -145,6 +145,8 @@ cc_library( ":resource_use", "//xla:executable_run_options", "//xla:util", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives", "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", @@ -155,11 +157,12 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", ], @@ -593,6 +596,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_cliques", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", @@ -601,6 +609,9 @@ cc_library( "//xla/service/cpu:collectives_interface", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -610,9 +621,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/backends/cpu/runtime/collective_thunk.cc b/xla/backends/cpu/runtime/collective_thunk.cc index f838fb0e49acd1..35a6f72fb9671d 100644 --- a/xla/backends/cpu/runtime/collective_thunk.cc +++ b/xla/backends/cpu/runtime/collective_thunk.cc @@ -32,23 +32,27 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_cliques.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -172,7 +176,7 @@ CollectiveThunk::ExecuteWithCommunicator( TF_RET_CHECK(params) << "Collective parameters are not set for collective operation"; - CollectivesInterface* collectives = params->collectives; + CpuCollectives* collectives = params->collectives; TF_RET_CHECK(collectives) << "Collectives interface is not set for collective operation"; @@ -183,8 +187,10 @@ CollectiveThunk::ExecuteWithCommunicator( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, - collectives->GetCommunicator(key.global_devices, rank)); + CpuCliqueKey clique_key(key.global_devices); + TF_ASSIGN_OR_RETURN( + Communicator * communicator, + AcquireCommunicator(collectives, clique_key, RankId(rank))); TF_RETURN_IF_ERROR(callback(key, *communicator)); diff --git a/xla/backends/cpu/runtime/thunk.cc b/xla/backends/cpu/runtime/thunk.cc index 8dab085b47fb6b..a17de11724bda3 100644 --- a/xla/backends/cpu/runtime/thunk.cc +++ b/xla/backends/cpu/runtime/thunk.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" @@ -30,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -121,8 +123,7 @@ Thunk::CollectiveExecuteParams::Create( Thunk::CollectiveExecuteParams::CollectiveExecuteParams( RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment, - CollectivesInterface* collectives) + const DeviceAssignment* device_assignment, CpuCollectives* collectives) : run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), diff --git a/xla/backends/cpu/runtime/thunk.h b/xla/backends/cpu/runtime/thunk.h index 38d3f41d6a75b3..2c86db92517745 100644 --- a/xla/backends/cpu/runtime/thunk.h +++ b/xla/backends/cpu/runtime/thunk.h @@ -28,21 +28,20 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" -#include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" namespace Eigen { struct ThreadPoolDevice; @@ -164,13 +163,13 @@ class Thunk { GlobalDeviceId global_device_id; const DeviceAssignment* device_assignment = nullptr; - CollectivesInterface* collectives = nullptr; + CpuCollectives* collectives = nullptr; private: CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, const DeviceAssignment* device_assignment, - CollectivesInterface* collectives); + CpuCollectives* collectives); }; //===--------------------------------------------------------------------===// diff --git a/xla/core/collectives/clique.cc b/xla/core/collectives/clique.cc index 6eb73c1ea91cba..1a0a5d659aecba 100644 --- a/xla/core/collectives/clique.cc +++ b/xla/core/collectives/clique.cc @@ -21,8 +21,10 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/util.h" namespace xla { @@ -44,4 +46,13 @@ void Clique::ForEachComm( } } +absl::Status Clique::AddComm(RankId rank, + std::unique_ptr communicator) { + auto emplaced = communicators_.emplace(rank, std::move(communicator)); + if (!emplaced.second) { + return InvalidArgument("Rank %d already exists in clique", rank.value()); + } + return absl::OkStatus(); +} + } // namespace xla diff --git a/xla/core/collectives/clique.h b/xla/core/collectives/clique.h index 69705ccfa524c5..24f80a3f1682c9 100644 --- a/xla/core/collectives/clique.h +++ b/xla/core/collectives/clique.h @@ -49,6 +49,9 @@ class Clique { // Returns a communicator for a given rank if it's in a clique. std::optional comm(RankId rank) const; + // Adds a communicator to the clique. + absl::Status AddComm(RankId rank, std::unique_ptr communicator); + // Calls `fn` for each communicator in the clique. void ForEachComm(absl::FunctionRef fn) const; @@ -61,8 +64,8 @@ class Clique { size_t num_communicators() const { return communicators_.size(); } private: - // We keep communicators in a sorted order by rank to guarantee deterministic - // traversal order in `ForEachComm`. + // We keep communicators in a sorted order by rank to guarantee + // deterministic traversal order in `ForEachComm`. absl::btree_map> communicators_; }; diff --git a/xla/core/collectives/clique_key.cc b/xla/core/collectives/clique_key.cc index 2da8d6651c3548..92749633bb91ad 100644 --- a/xla/core/collectives/clique_key.cc +++ b/xla/core/collectives/clique_key.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/core/collectives/clique_key.h" +#include #include #include #include @@ -31,6 +32,8 @@ CliqueKey::CliqueKey(std::vector devices) absl::Span CliqueKey::devices() const { return devices_; } +size_t CliqueKey::num_devices() const { return devices_.size(); } + std::optional CliqueKey::rank(GlobalDeviceId id) const { if (auto it = absl::c_find(devices_, id); it != devices_.end()) { return RankId(it - devices_.begin()); diff --git a/xla/core/collectives/clique_key.h b/xla/core/collectives/clique_key.h index 05411773431507..37e16d5fb774ae 100644 --- a/xla/core/collectives/clique_key.h +++ b/xla/core/collectives/clique_key.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ #define XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ +#include #include #include #include @@ -52,6 +53,7 @@ class CliqueKey { std::optional rank(GlobalDeviceId id) const; absl::Span devices() const; + size_t num_devices() const; // Returns true if this clique is a subset of `other`. virtual bool IsSubsetOf(const CliqueKey& other) const = 0; diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 4ea228a0c63300..012112662640a2 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -1961,12 +1961,17 @@ cc_library( name = "collectives_interface", hdrs = ["collectives_interface.h"], deps = [ + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index cfa3b11f36513a..77e159e1535bc4 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -17,22 +17,108 @@ limitations under the License. #define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ #include +#include #include #include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class CollectivesInterface { +namespace internal { + +// An adapter from a shared_ptr to a Communicator. +class CommunicatorWrapper final : public Communicator { + public: + explicit CommunicatorWrapper(std::shared_ptr comm) + : comm_(std::move(comm)) {} + + absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer, + stream_executor::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->AllReduce(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId root, + const Executor& executor) final { + return comm_->Broadcast(send_buffer, recv_buffer, dtype, count, root, + executor); + } + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->ReduceScatter(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) final { + return comm_->AllGather(send_buffer, recv_buffer, dtype, count, executor); + } + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) final { + return comm_->CollectivePermute(send_buffer, recv_buffer, dtype, count, + source_rank, target_ranks, executor); + } + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) final { + return comm_->AllToAll(send_buffers, recv_buffers, dtype, count, executor); + } + + absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Send(send_buffer, dtype, count, peer, executor); + } + + absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Recv(recv_buffer, dtype, count, peer, executor); + } + + absl::StatusOr NumRanks() const final { return comm_->NumRanks(); } + + std::string ToString() const final { return comm_->ToString(); } + + private: + std::shared_ptr comm_; +}; + +} // namespace internal + +class CollectivesInterface : public CpuCollectives { public: virtual ~CollectivesInterface() = default; @@ -42,6 +128,25 @@ class CollectivesInterface { // rank: the rank of this process. virtual absl::StatusOr> GetCommunicator( absl::Span devices, int rank) = 0; + + absl::StatusOr>> + CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final { + // We expect to create CPU communicators lazily one at a time. + if (ranks.size() != 1) { + return InvalidArgument("Expected 1 rank, got %d", ranks.size()); + } + + TF_ASSIGN_OR_RETURN(auto comm, GetCommunicator(clique_key.devices(), + ranks[0].rank.value())); + + std::vector> comms; + comms.reserve(1); + comms.push_back(std::make_unique(comm)); + return comms; + } }; } // namespace xla::cpu From 1b9969348628c9a435fc1d6973acab2e5d55849a Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 8 Jan 2025 09:23:05 -0800 Subject: [PATCH 43/45] Remove experimental TOSA convert python API In preparation for larger changes, this entry point is being disabled here for now. PiperOrigin-RevId: 713316210 --- .bazelrc | 44 ++++++++++++++++++++-------------------- third_party/tsl/.bazelrc | 44 ++++++++++++++++++++-------------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/.bazelrc b/.bazelrc index 142ed60871ce3f..f8a9ef174f7eca 100644 --- a/.bazelrc +++ b/.bazelrc @@ -738,42 +738,42 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cpu_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cuda_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_x86_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... @@ -786,8 +786,8 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test @@ -796,15 +796,15 @@ build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP @@ -813,8 +813,8 @@ build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_co # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index 086e35096080ab..5998529e822a8b 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -738,42 +738,42 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cpu_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cuda_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_x86_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... @@ -786,8 +786,8 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test @@ -796,15 +796,15 @@ build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP @@ -813,8 +813,8 @@ build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_co # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... From cbcde508791510b6bd0593d3ac198378eb95e478 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jan 2025 09:29:09 -0800 Subject: [PATCH 44/45] [XLA:GPU][Emitters] Fix a typo in vectorize_loads_stores.mlir PiperOrigin-RevId: 713318085 --- .../gpu/codegen/transforms/tests/vectorize_loads_stores.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir index 3f04219d0eeb17..d5d3d0a74fe4a2 100644 --- a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir @@ -252,7 +252,7 @@ func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f3 func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index + %c4 = arith.constant 4 : index %cst = arith.constant 0.0 : f32 %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> @@ -264,6 +264,7 @@ func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[V:.*]] = scf.for +// CHECK-SAME: (vector<4xf32>) // CHECK-NEXT: vector.insert // CHECK-NEXT: scf.yield // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[C0]]] From 56e1f4325e3f58eaf0411264c76e59ed3cfceed4 Mon Sep 17 00:00:00 2001 From: xla authors Date: Wed, 8 Jan 2025 09:50:18 -0800 Subject: [PATCH 45/45] IFRT proxy asan fix: Do not call `promise.Set()` twice in error-handling path. PiperOrigin-RevId: 713323821 --- xla/python/ifrt_proxy/client/grpc_host_buffer.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index ab36e6c0f17f6f..2c8d52e7e7cff2 100644 --- a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -105,8 +105,10 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); promise.Set( absl::InternalError("Failed to write all host buffer chunks")); + return; } } @@ -150,6 +152,7 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); return Future<>( absl::InternalError("Failed to write all host buffer chunks")); }