Skip to content

Commit

Permalink
[PJRT] Add support of passing per-compilation compile options (#19438)
Browse files Browse the repository at this point in the history
As discussed in
#19418 (comment),
#19418 (review)
and #19418 (comment),
here we support to read `env_option_overrides` as IREE compile flags
from `compile_options` passed by frontends like JAX in a per-compilation
basis.

Most of these code already exists but has been commented due to some
problems: `compile_options` was not yet available in that time, but it's
now introduced by #19369.

A simple use case is shown below, also as a test case:

https://github.com/iree-org/iree/blob/c37a80212dd4a541762fc9fdaaa615b6d0a62829/integrations/pjrt/test/test_compile_options.py#L9-L15

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
  • Loading branch information
PragmaTwice and ScottTodd authored Dec 12, 2024
1 parent 6b686c7 commit 9b96886
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pkgci_test_pjrt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
source ${VENV_DIR}/bin/activate
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
# install
python -m pip install jax==0.4.35
python -m pip install jax==0.4.36
- name: Run tests
run: |
source ${VENV_DIR}/bin/activate
Expand Down
7 changes: 7 additions & 0 deletions build_tools/testing/run_jax_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ diff_jax_test test/test_add.py
diff_jax_test test/test_degenerate.py
diff_jax_test test/test_simple.py

# here we test if the compile options is passed to IREE PJRT plugin successfully.
# we pass --iree-scheduling-dump-statistics-format=csv via jax.jit,
# and see if there's statistics in the output
compile_options_test_tmp_out=$(mktemp /tmp/jax_test_result_compile_options.XXXXXX)
JAX_PLATFORMS=$actual_jax_platform python test/test_compile_options.py 2>&1 | tee $compile_options_test_tmp_out
cat $compile_options_test_tmp_out | grep '@main_dispatch'


# FIXME: we can also utilize the native test cases from JAX,
# e.g. `tests/nn_test.py` from the JAX repo, as below,
Expand Down
1 change: 1 addition & 0 deletions integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_cc_library(
iree::compiler::bindings::c::loader
iree_pjrt::partitioner_api
iree_pjrt::partitioner_api::loader
iree_pjrt_deps::protos
PUBLIC
)

Expand Down
8 changes: 4 additions & 4 deletions integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1487,8 +1487,8 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
}

// Set flags.
// TODO: Plumb CompileOptions through.
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
if (!job->SetFlags(options)) return MakeCompilerError(*job);

if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"partitioner_flags", /*extension=*/"txt", /*index=*/-1,
Expand Down Expand Up @@ -1538,8 +1538,8 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
if (!SetDefaultCompilerFlags(job.get())) {
return MakeCompilerError(*job);
}
// TODO: Plumb CompileOptions through.
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
if (!job->SetFlags(options)) return MakeCompilerError(*job);

if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"flags", /*extension=*/"txt", /*index=*/-1,
Expand Down
6 changes: 2 additions & 4 deletions integrations/pjrt/src/iree_pjrt/common/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
#include <string>

#include "iree_pjrt/common/debugging.h"
// TODO: Excise.
// #include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/compile_options.pb.h"

namespace iree::pjrt {

Expand All @@ -37,8 +36,7 @@ class CompilerJob {
// setup of a job (or if the underlying session will not be re-used).
// Returns false on failure.
virtual bool SetFlag(const char* flag) = 0;
// TODO: Excise.
// virtual bool SetFlags(xla::CompileOptions options) = 0;
virtual bool SetFlags(xla::CompileOptionsProto options) = 0;

// Gets all flags as a string. This is intended for debug printing a plausible
// command line to reproduce compilation.
Expand Down
73 changes: 38 additions & 35 deletions integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,44 @@ class OpenXLAPartitionerJob : public CompilerJob {
return true;
}

// TODO: Find another way to deal with this.
// bool SetFlags(xla::CompileOptions options) override {
// int num_partitions = options.executable_build_options.num_partitions();
// int num_replicas = options.executable_build_options.num_replicas();
// bool use_spmd_partitioning =
// options.executable_build_options.use_spmd_partitioning();
// auto allow_spmd_sharding_propagation_to_output =
// options.executable_build_options
// .allow_spmd_sharding_propagation_to_output();
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
// num_partitions)
// .c_str())) {
// return false;
// }
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
// num_replicas)
// .c_str())) {
// return false;
// }
// if (!SetFlag(
// absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
// use_spmd_partitioning)
// .c_str())) {
// return false;
// }
// if (!SetFlag(
// absl::StrCat(
// "--openxla-partitioner-gspmd-allow-spmd-"
// "sharding-propagation-to-output=",
// absl::StrJoin(allow_spmd_sharding_propagation_to_output,
// ",")) .c_str())) {
// return false;
// }
// return true;
// }
bool SetFlags(xla::CompileOptionsProto options) override {
int num_partitions = options.executable_build_options().num_partitions();
int num_replicas = options.executable_build_options().num_replicas();
bool use_spmd_partitioning =
options.executable_build_options().use_spmd_partitioning();
auto allow_spmd_sharding_propagation_to_output =
options.executable_build_options()
.allow_spmd_sharding_propagation_to_output();
if (!SetFlag(("--openxla-partitioner-gspmd-num-partitions=" +
std::to_string(num_partitions))
.c_str())) {
return false;
}
if (!SetFlag(("--openxla-partitioner-gspmd-replica-count=" +
std::to_string(num_replicas))
.c_str())) {
return false;
}
if (!SetFlag(("--openxla-partitioner-gspmd-use-spmd-partitioning=" +
std::to_string(use_spmd_partitioning))
.c_str())) {
return false;
}
std::string allow_spmd_sharding_propagation_to_output_str;
for (size_t i = 0; i < allow_spmd_sharding_propagation_to_output.size();
++i) {
if (i != 0) allow_spmd_sharding_propagation_to_output_str += ",";
allow_spmd_sharding_propagation_to_output_str +=
std::to_string(allow_spmd_sharding_propagation_to_output[i]);
}
if (!SetFlag(("--openxla-partitioner-gspmd-allow-spmd-"
"sharding-propagation-to-output=" +
allow_spmd_sharding_propagation_to_output_str)
.c_str())) {
return false;
}
return true;
}

std::string GetFlags() override {
std::string flags;
Expand Down
45 changes: 22 additions & 23 deletions integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,28 @@ class IREECompilerJob : public CompilerJob {
return true;
}

// TODO: Excise: Cannot dep on an internal XLA structure.
// bool SetFlags(xla::CompileOptions options) override {
// // Set extra options, overriding env variables if appropriate.
// for (auto [option, option_override] : options.env_option_overrides) {
// std::string override_string;
// if (auto override_val = std::get_if<std::string>(&option_override)) {
// override_string = *override_val;
// } else if (auto override_val = std::get_if<bool>(&option_override)) {
// override_string = *override_val ? "true" : "false";
// } else if (auto override_val = std::get_if<int64_t>(&option_override))
// {
// override_string = std::to_string(*override_val);
// } else {
// assert(false &&
// "option value should be of type string, bool, or int64");
// }
// if (!SetFlag(absl::StrCat("--", option, "=", override_string).c_str()))
// {
// return false;
// }
// }
// return true;
// }
bool SetFlags(xla::CompileOptionsProto options) override {
// Set extra options, overriding env variables if appropriate.
for (auto [option, option_override] : options.env_option_overrides()) {
std::string override_string;
if (option_override.has_string_field()) {
override_string = option_override.string_field();
} else if (option_override.has_bool_field()) {
override_string = option_override.bool_field() ? "true" : "false";
} else if (option_override.has_int_field()) {
override_string = std::to_string(option_override.int_field());
} else if (option_override.has_double_field()) {
override_string = std::to_string(option_override.double_field());
} else {
assert(false &&
"option value should be of type string, bool, int, or double");
}
if (!SetFlag(("--" + option + "=" + override_string).c_str())) {
return false;
}
}
return true;
}

std::string GetFlags() override {
std::string flags;
Expand Down
19 changes: 19 additions & 0 deletions integrations/pjrt/test/test_compile_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from functools import partial
import jax.numpy as jnp
from jax import jit

a = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])


@partial(jit, compiler_options={"iree-scheduling-dump-statistics-format": "csv"})
def f(a, b):
return a + b


print(f(a, a))

0 comments on commit 9b96886

Please sign in to comment.