Skip to content

Commit

Permalink
Fix device list of loaded executable in PJRT plugin for multiple GPUs
Browse files Browse the repository at this point in the history
Signed-off-by: PragmaTwice <[email protected]>
  • Loading branch information
PragmaTwice committed Dec 4, 2024
1 parent 939984c commit adf9e55
Show file tree
Hide file tree
Showing 8 changed files with 1,376 additions and 22 deletions.
1 change: 1 addition & 0 deletions integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_cc_library(
iree::vm
iree::vm::bytecode::module
iree_pjrt_deps::headers
iree_pjrt_deps::protos
PUBLIC
)

Expand Down
48 changes: 29 additions & 19 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree_pjrt/common/api_impl.h"

#include <iterator>
#include <optional>
#include <sstream>
#include <utility>
Expand Down Expand Up @@ -1002,7 +1003,7 @@ iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer(

// Compile program and check for errors:
LoadedExecutableInstance* executable;
auto* error = this->client().Compile(&program, &executable);
auto* error = this->client().Compile(&program, {}, &executable);
if (error) {
auto errinst = ErrorInstance::FromError(error);
auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
Expand Down Expand Up @@ -1351,22 +1352,14 @@ void ClientInstance::BindApi(PJRT_Api* api) {
LoadedExecutableInstance* executable;

// Read compilation options.
// TODO: Port CompileOptionsProto into the project or leave ommitted.
// xla::CompileOptionsProto options_proto;
// if (!options_proto.ParseFromArray(args->compile_options,
// args->compile_options_size)) {
// return MakeError(iree_make_status(IREE_STATUS_INTERNAL,
// "could not parse compilation
// options"));
// }
// auto options = xla::CompileOptions::FromProto(options_proto);
// if (!options.ok()) {
// return MakeError(
// iree_make_status(IREE_STATUS_INTERNAL,
// std::string(options.status().message()).c_str()));
// }

auto* error = client->Compile(args->program, /**options,*/ &executable);
xla::CompileOptionsProto options_proto;
if (!options_proto.ParseFromArray(args->compile_options,
args->compile_options_size)) {
return MakeError(iree_make_status(IREE_STATUS_INTERNAL,
"could not parse compilation options"));
}

auto* error = client->Compile(args->program, options_proto, &executable);
if (error) return error;
args->executable = *executable;
return nullptr;
Expand Down Expand Up @@ -1451,7 +1444,7 @@ iree_status_t ClientInstance::PopulateDevices() {
}

PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
/*xla::CompileOptions options,*/
xla::CompileOptionsProto options,
LoadedExecutableInstance** out_executable) {
std::unique_ptr<ArtifactDumper::Transaction> artifact_tx;
if (platform().artifact_dumper().enabled()) {
Expand Down Expand Up @@ -1570,11 +1563,28 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
output->GetDataSize()));
}

// calculate devices for this computation from device assignment
std::vector<DeviceInstance*> devices;

const auto& build_options = options.executable_build_options();
if (build_options.has_device_assignment()) {
const auto& device_assignment = build_options.device_assignment();
for (auto id :
device_assignment.computation_devices(0).replica_device_ids()) {
if (id < addressable_devices_.size())
devices.push_back(addressable_devices_[id]);
}
}

if (devices.empty()) {
devices = addressable_devices_;
}

auto executable = std::make_unique<LoadedExecutableInstance>(
*this,
new ExecutableImage(std::move(output),
std::string(program->code, program->code_size)),
addressable_devices_);
devices);
status = executable->LoadAll();
if (!iree_status_is_ok(status)) {
return MakeError(status);
Expand Down
3 changes: 2 additions & 1 deletion integrations/pjrt/src/iree_pjrt/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "iree_pjrt/common/layout_utils.h"
#include "iree_pjrt/common/platform.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "compile_options.pb.h"

namespace iree::pjrt {

Expand Down Expand Up @@ -452,7 +453,7 @@ class ClientInstance {
// Compiles.
// See TODOs in PJRT_Client_Compile.
PJRT_Error* Compile(
const PJRT_Program* program, /*xla::CompileOptions options, */
const PJRT_Program* program, xla::CompileOptionsProto options,
LoadedExecutableInstance** executable);

// ---------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion integrations/pjrt/src/iree_pjrt/vulkan/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ iree_status_t VulkanClientInstance::CreateDriver(
}

bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-backends=vulkan");
return compiler_job->SetFlag("--iree-hal-target-backends=vulkan-spirv");
}

} // namespace iree::pjrt::vulkan
23 changes: 23 additions & 0 deletions integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,26 @@ iree_cc_library(
"xla/pjrt/c/pjrt_c_api.h"
PUBLIC
)

find_package(Protobuf REQUIRED)

protobuf_generate_cpp(
PJRT_PROTO_SRCS
PJRT_PROTO_HDRS
"xla/pjrt/compile_options.proto"
"xla/xla_data.proto"
)

iree_cc_library(
NAME
protos
INCLUDES
${Protobuf_INCLUDE_DIRS}
HDRS
${PJRT_PROTO_HDRS}
SRCS
${PJRT_PROTO_SRCS}
DEPS
${Protobuf_LIBRARIES}
PUBLIC
)
3 changes: 2 additions & 1 deletion integrations/pjrt/third_party/pjrt_c_api/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pjrt_c_api

This directory contains a fork of C headers needed to build a PJRT plugin.
This directory contains a fork of C headers and .proto files
needed to build a PJRT plugin.

It is intended to be sync'd with upstream for major/breaking changes and
releases.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
syntax = "proto3";

package xla;

// TODO: to avoid introducing too many source files in XLA to IREE PJRT,
// currently we remove some fields in this file which is not in use
// so that their message definitions are not required.
// If we want to uncomment these removed fields, we should also
// add the corresponding schema files, like below.

// import "xla/stream_executor/device_description.proto";
// import "xla/xla.proto";
import "xla_data.proto";

// A serialization of xla::ExecutableBuildOptions.
// Next id: 24.
message ExecutableBuildOptionsProto {
// If set, this is the device to build the computation for. Valid
// device_ordinal values are: 0 to # of devices - 1. These values are
// identical to the device ordinal values used by StreamExecutor. The built
// executable will be executable on any device equivalent to the specified
// device as determined by Backend::devices_equivalent(). A value of -1
// indicates this option has not been set.
int64 device_ordinal = 1;

// If set, this specifies the layout of the result of the computation. If not
// set, the service will chose the layout of the result. A Shape is used to
// store the layout to accommodate tuple result shapes. A value of nullptr
// indicates the option has not been set.
xla.ShapeProto result_layout = 2;

// Expose access to the XLA compilation environments, which will be passed to
// the compilation process.
// xla.CompilationEnvironmentsProto comp_envs = 13;

// Expose access to the XLA debug options which will be passed to the
// compilation process.
// xla.DebugOptions debug_options = 3;

// The number of replicas of this computation that are to be executed.
// Defaults to 1.
int64 num_replicas = 4;

// The number of partitions in this computation. Defaults to 1.
int64 num_partitions = 5;

// Indicates whether to use SPMD (true) or MPMD (false) partitioning when
// num_partitions > 1 and XLA is requested to partition the input program.
bool use_spmd_partitioning = 6;

// Whether to automatically generate XLA shardings for SPMD partitioner.
bool use_auto_spmd_partitioning = 7;

// The amount of effort to spend on optimizing for minimizing program
// execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which
// strongly prioritizes execution time at the cost of longer compile times,
// suitable for production workloads. A value of -0.5 would be appropriate for
// research use cases that prefer faster compilations to iterate more quickly.
// Positive values, on the other hand, might enable costly optimizations that
// are off by default.
float exec_time_optimization_effort = 20;

// The amount of effort to spend on making the program fit in memory (where
// "fit in memory" here has a backend-dependent meaning), as a value in
// [-1.0,+1.0]. The baseline is 0.0, which expends significant effort on
// attempting to make the program fit. A value of -1.0 would be appropriate
// for use cases that wish to spend minimal effort here and fail as quickly as
// possible instead. Positive values, on the other hand, might enable costly
// algorithms to reduce memory usage that are off by default.
float memory_fitting_effort = 21;

// Whether HLOs should be deduplicated.
bool deduplicate_hlo = 8;

// If set, this specifies a static device assignment for the computation.
// Otherwise, the computation will be compiled generically and can be run with
// any device assignment compatible with the computation's replica and
// partition counts.
xla.DeviceAssignmentProto device_assignment = 9;

// Whether input and output buffers are aliased if the associated parameter is
// passed-through XLA modules without being changed.
bool alias_passthrough_params = 10;

// By default, XLA builds an executable by invoking standard compilation, i.e.
// running Compiler::Compile, or both Compiler::RunHloPasses and
// Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
// executable by invoking only RunBackend and skip invoking RunHloPasses,
// which can be used to compile post-optimizations HLO modules.
bool run_backend_only = 11;

// Allows sharding propagation to propagate to the parameters. This changes
// the input shape of the computation (which is undesirable), but it can be
// used to allow to run partial compilation to determine what would be the
// input sharding of a computation if XLA would be allowed to propagate the
// sharding which can be used by higher level framework as a way to query
// intermediate sharding of operations when multiple computation would be
// chained and merged together.
// This is a vector of bool, because the user can control which parameters can
// have the sharding substituted. If only one boolean value is passed in the
// vector that is interpreted as the value to be applied for every parameter.
repeated bool allow_spmd_sharding_propagation_to_parameters = 18;

// Allows sharding propagation to propagate to the outputs. This changes the
// output shape of the computation (which is undesirable), but it can be used
// to allow to run partial compilation to determine what would be the output
// sharding of a computation if XLA would be allowed to propagate the sharding
// which can be used by higher level framework as a way to query intermediate
// sharding of operations when multiple computation would be chained and
// merged together.
// This is a vector of bool, because the user can control (if the output of
// the computation is a tuple) which elements of the tuple can have the
// sharding substituted and which don't. If only one boolean value is passed
// in the vector that's interpreted as the value to be applied for every
// single element of the output tuple. One value per element of the tuple
// means that each value is attached to one of the output elements.
repeated bool allow_spmd_sharding_propagation_to_output = 12;

// Opaque profile data for any feedback directed optimizations.
bytes fdo_profile = 14;

int64 device_memory_size = 15;

// Mesh shape in auto sharding options.
repeated int64 auto_spmd_partitioning_mesh_shape = 16;

// Mesh ids in auto sharding options.
repeated int64 auto_spmd_partitioning_mesh_ids = 17;

// Use Shardy, a new partitioner, to replace the existing
// ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for
// details.
bool use_shardy_partitioner = 19;

int64 process_index = 22;
int64 process_count = 23;
}

message OptionOverrideProto {
oneof value {
string string_field = 1;
bool bool_field = 2;
int64 int_field = 3;
double double_field = 4;
}
}

message CompileOptionsProto {
// Refer CompileOptions for documentation of fields.
// repeated ShapeProto argument_layouts = 1;
bool parameter_is_tupled_arguments = 2;
ExecutableBuildOptionsProto executable_build_options = 3;
bool compile_portable_executable = 4;
int64 profile_version = 5;
bytes serialized_multi_slice_config = 6;
map<string, OptionOverrideProto> env_option_overrides = 7;

// stream_executor.GpuTargetConfigProto target_config = 8;
}

// Helper for serializing opaque executables alongside CompileOptions.
message ExecutableAndOptionsProto {
bytes serialized_executable = 1;
CompileOptionsProto compile_options = 2;
}
Loading

0 comments on commit adf9e55

Please sign in to comment.