Skip to content

Commit

Permalink
PR tensorflow#19571: PJRT: assign process index and count for compila…
Browse files Browse the repository at this point in the history
…tion using device assignment.

Imported from GitHub PR openxla/xla#19571

Only a subset of processes may be participating in the compilation of a module.
Copybara import of the project:

--
15250fc203482cdb17e60db263657df9a192b699 by Ilia Sergachev <[email protected]>:

PJRT: assign process index and count for compilation using device assignment.

Only a subset of processes may be participating in the compilation of a
module.

--
8620919d35f1160d993cb69e5b32c32b77cbba7d by Ilia Sergachev <[email protected]>:

fix functional_hlo_runner_test

Merging this change closes tensorflow#19571

PiperOrigin-RevId: 702231769
  • Loading branch information
sergachev authored and tensorflower-gardener committed Dec 3, 2024
1 parent 43f7231 commit f13c441
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 16 deletions.
30 changes: 25 additions & 5 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1840,20 +1840,27 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
}

class ShardedAutotuningTest : public ::testing::TestWithParam<bool> {
class ShardedAutotuningTest
: public ::testing::TestWithParam<std::tuple<bool, int>> {
public:
static constexpr int kNumNodes = 2;
};

static const char* test_binary_name;

TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
bool use_xla_computation;
int num_active_nodes;
std::tie(use_xla_computation, num_active_nodes) = GetParam();

tsl::SubProcess child[ShardedAutotuningTest::kNumNodes];
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
std::vector<std::string> argv;
argv.push_back(test_binary_name);
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
argv.push_back(absl::StrFormat("--use_xla_computation=%d", GetParam()));
argv.push_back(
absl::StrFormat("--use_xla_computation=%d", use_xla_computation));
argv.push_back(absl::StrFormat("--num_active_nodes=%d", num_active_nodes));
child[node_id].SetProgram(test_binary_name, argv);
child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
Expand All @@ -1876,6 +1883,7 @@ TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
}

absl::Status ShardedAutotuningWorksTestBody(const int node_id,
const int num_active_nodes,
bool use_xla_computation) {
std::unique_ptr<xla::DistributedRuntimeService> service;
if (node_id == 0) {
Expand Down Expand Up @@ -1911,6 +1919,11 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
TF_RET_CHECK(client->addressable_device_count() == 1);
TF_RET_CHECK(client->device_count() == ShardedAutotuningTest::kNumNodes);

if (node_id >= num_active_nodes) {
// Inactive nodes connect to the coordination service but don't compile.
return absl::OkStatus();
}

CompileOptions compile_options;
DebugOptions* debug_options =
compile_options.executable_build_options.mutable_debug_options();
Expand Down Expand Up @@ -1951,8 +1964,11 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
return absl::OkStatus();
}

INSTANTIATE_TEST_SUITE_P(ShardedAutotuningTest, ShardedAutotuningTest,
::testing::Values(false, true));
INSTANTIATE_TEST_SUITE_P(
ShardedAutotuningTest, ShardedAutotuningTest,
::testing::Combine(::testing::Bool(),
::testing::Range(1,
ShardedAutotuningTest::kNumNodes + 1)));

} // namespace
} // namespace xla
Expand All @@ -1961,10 +1977,13 @@ int main(int argc, char* argv[]) {
// Save name of binary so that it may invoke itself.
xla::test_binary_name = argv[0];
int node_id = -1;
int num_active_nodes = -1;
bool use_xla_computation = false;
std::vector<tsl::Flag> flag_list = {
tsl::Flag("node_id", &node_id,
"Node ID for ShardedAutotuningWorks test."),
tsl::Flag("num_active_nodes", &num_active_nodes,
"Test parameter for ShardedAutotuningWorks."),
tsl::Flag("use_xla_computation", &use_xla_computation,
"Test parameter for ShardedAutotuningWorks."),
};
Expand All @@ -1973,7 +1992,8 @@ int main(int argc, char* argv[]) {
tsl::Flags::Parse(&argc, argv, flag_list);
testing::InitGoogleTest(&argc, argv);
if (node_id >= 0) {
return xla::ShardedAutotuningWorksTestBody(node_id, use_xla_computation)
return xla::ShardedAutotuningWorksTestBody(node_id, num_active_nodes,
use_xla_computation)
.raw_code();
}
return RUN_ALL_TESTS();
Expand Down
14 changes: 9 additions & 5 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3481,17 +3481,23 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
if (device_assignment != nullptr) {
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
addressable_devices.reserve(num_replicas * num_partitions);
absl::flat_hash_set<int> all_process_indices;
std::optional<int> this_process_index;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int64_t device_id = (*device_assignment)(replica, partition);
PjRtGlobalDeviceId global_device_id(device_id);

TF_ASSIGN_OR_RETURN(PjRtDevice * device,
LookupDevice(global_device_id));
all_process_indices.insert(device->process_index());
if (device->process_index() != process_index()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
if (!this_process_index.has_value()) {
this_process_index = all_process_indices.size() - 1;
}
PjRtLoadedExecutable::LogicalDeviceIds logica_device_ids;
logica_device_ids.replica = replica;
logica_device_ids.partition = partition;
Expand All @@ -3509,6 +3515,9 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
build_options.set_device_ordinal(
addressable_devices.front()->local_hardware_id().value());
}

build_options.set_process_index(*this_process_index);
build_options.set_process_count(all_process_indices.size());
}
return extras;
}
Expand All @@ -3525,11 +3534,6 @@ PjRtStreamExecutorClient::CompileInternal(
!options.executable_build_options.key_value_store()) {
options.executable_build_options.set_key_value_store(*key_value_store());
}
options.executable_build_options.set_process_index(process_index());
TF_RET_CHECK(device_count() % addressable_device_count() == 0)
<< "Each process is expected to have the same number of devices";
options.executable_build_options.set_process_count(
device_count() / addressable_device_count());
auto input_options = options;

TF_RETURN_IF_ERROR(options.ApplyAllOptionOverrides());
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ xla_test(
":create_client",
":functional_hlo_runner",
"//xla:debug_options_flags",
"//xla:status_macros",
"//xla:xla_proto_cc",
"//xla/hlo/testlib:filecheck",
"//xla/pjrt:pjrt_client",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
const PreprocessingOptions& preproc_options,
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
InputFormat input_format, int task_id, int num_nodes,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
std::shared_ptr<xla::KeyValueStoreInterface> kv_store,
bool use_gpu_count_workaround) {
TF_ASSIGN_OR_RETURN(
CompileOptions compile_options,
FunctionalHloRunner::CreateCompileOptions(client, raw_compile_options,
Expand All @@ -554,7 +555,8 @@ absl::Status FunctionalHloRunner::LoadAndCompile(
int num_partitions =
compile_options.executable_build_options.num_partitions();
int needed_devices = num_replicas * num_partitions;
if (client.addressable_device_count() < needed_devices) {
if (client.addressable_device_count() < needed_devices &&
use_gpu_count_workaround) {
LOG(INFO) << "Applying a workaround to allow compiling multi-device HLOs "
"on machines with fewer devices.";
DeviceAssignment assignment(num_replicas, num_partitions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ class FunctionalHloRunner {
const PreprocessingOptions& preproc_options,
const RawCompileOptions& raw_compile_options, std::string_view hlo_file,
InputFormat input_format, int task_id = 0, int num_nodes = 1,
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr);
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr,
bool use_gpu_count_workaround = true);

// Compiles and runs the given HLO module with the given arguments for each
// device. The given arguments is a map from device ID to a list of arguments.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "xla/hlo/testlib/filecheck.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/status_macros.h"
#include "xla/tools/multihost_hlo_runner/create_client.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/util/command_line_flags.h"
Expand Down Expand Up @@ -296,7 +297,9 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) {
PjRtEnvironment env,
xla::GetPjRtEnvironmentForGpu("127.0.0.1:12345", gpu_options,
/*init_timeout=*/absl::Seconds(120)));
CHECK(env.kv_store != nullptr);
TF_RET_CHECK(env.kv_store != nullptr);
TF_RET_CHECK(env.client->device_count() == kNumNodes);
TF_RET_CHECK(env.client->addressable_device_count() == 1);
// Make HLO module IDs of multiple_gemm_fusions.hlo differ: the autotuner
// should not rely on them.
if (node_id == 0) {
Expand All @@ -310,9 +313,10 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) {
TF_RETURN_IF_ERROR(FunctionalHloRunner::LoadAndCompile(
*env.client, GetDebugOptionsFromFlags(),
FunctionalHloRunner::PreprocessingOptions{},
FunctionalHloRunner::RawCompileOptions{},
FunctionalHloRunner::RawCompileOptions{.num_replicas = kNumNodes},
GetHloPath(absl::StrFormat("multiple_gemm_fusions_%d.hlo", node_id + 1)),
InputFormat::kText));
InputFormat::kText, node_id, kNumNodes, /*kv_store=*/nullptr,
/*use_gpu_count_workaround=*/false));
if (node_id == 0) {
TF_ASSIGN_OR_RETURN(
std::string results0,
Expand Down

0 comments on commit f13c441

Please sign in to comment.