Skip to content

Commit

Permalink
[PJRT] Allow to pass extra compile options via env variables (#19418)
Browse files Browse the repository at this point in the history
Sometime it's useful to pass some extra IREE compiler options to the
PJRT plugin by environment variables to debug/do some
experiment/performance tuning without recompilation.

This is a rewrite to the following code which was commented as a TODO.


https://github.com/iree-org/iree/blob/b68c535ece28e139492606f391493f3e95242420/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc#L231-L245

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
  • Loading branch information
PragmaTwice and ScottTodd authored Dec 13, 2024
1 parent 0cafee9 commit ffa0f42
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 15 deletions.
13 changes: 13 additions & 0 deletions integrations/pjrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ pip install -v --no-deps -e python_packages/iree_cpu_plugin
JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
```

## Advanced settings

To pass additional compile options to IREE during JIT compilation, you can use
the `IREE_PJRT_IREE_COMPILER_OPTIONS` environment variable. This variable can
be set to a space-delimited list of flags that would be passed to the
`iree-compile` command-line tool.

For example:
```shell
export IREE_PJRT_IREE_COMPILER_OPTIONS=--iree-scheduling-dump-statistics-format=csv
JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
```

## Incrementally developing

If you did an editable install (`-e`) above, then you should be able to incrementally
Expand Down
13 changes: 13 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ iree_cc_library(
common
HDRS
"api_impl.h"
"command_line_utils.h"
"dylib_entry_point.cc.inc"
"iree_helpers.h"
"layout_utils.h"
"platform.h"
"tensor_utils.h"
SRCS
"api_impl.cc"
"command_line_utils.cc"
"layout_utils.cc"
"platform.cc"
"tensor_utils.cc"
Expand Down Expand Up @@ -60,6 +62,17 @@ iree_cc_library(
PUBLIC
)

iree_cc_test(
NAME
command_line_utils_test
SRCS
"command_line_utils_test.cc"
DEPS
::common
iree::testing::gtest
iree::testing::gtest_main
)

iree_cc_library(
NAME
debugging
Expand Down
54 changes: 54 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "command_line_utils.h"

namespace iree {
namespace pjrt {

// TODO: currently this function doesn't handle escape sequences,
// it just ensure that single/double quotes are interpreted corrently.
std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
std::string_view options_str) {
std::vector<std::string> options;
std::string current;

enum { NORMAL, SINGLE_QUOTE, DOUBLE_QUOTE } state = NORMAL;
for (auto it = options_str.begin(); it != options_str.end(); ++it) {
if (std::isspace(*it) && state == NORMAL) {
if (!current.empty()) {
options.push_back(std::move(current));
current.clear();
}
} else if (*it == '"' && state != SINGLE_QUOTE) {
if (state == NORMAL)
state = DOUBLE_QUOTE;
else if (state == DOUBLE_QUOTE)
state = NORMAL;
} else if (*it == '\'' && state != DOUBLE_QUOTE) {
if (state == NORMAL)
state = SINGLE_QUOTE;
else if (state == SINGLE_QUOTE)
state = NORMAL;
} else {
current.push_back(*it);
}
}

if (!current.empty()) {
options.push_back(std::move(current));
}

// if it's still in a quote, then return nullopt
if (state != NORMAL) {
return std::nullopt;
}

return options;
}

} // namespace pjrt
} // namespace iree
26 changes: 26 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/command_line_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_
#define IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_

#include <optional>
#include <string>
#include <string_view>
#include <vector>

namespace iree {
namespace pjrt {

// parse command line options (maybe with quotes) to an array of options
// e.g. `a b "c d"` -> {"a", "b", "c d"}
std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
std::string_view options_str);

} // namespace pjrt
} // namespace iree

#endif
24 changes: 24 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree_pjrt/common/command_line_utils.h"

#include <gtest/gtest.h>

using namespace iree::pjrt;

TEST(CommandLineUtils, ParseOptionsFromCommandLine) {
EXPECT_EQ(ParseOptionsFromCommandLine("--help --verbose"),
(std::vector<std::string>{"--help", "--verbose"}));
EXPECT_EQ(ParseOptionsFromCommandLine("-a='x y' -b \"n m\""),
(std::vector<std::string>{"-a=x y", "-b", "n m"}));
EXPECT_EQ(ParseOptionsFromCommandLine("'\"' \"'\""),
(std::vector<std::string>{"\"", "'"}));
EXPECT_EQ(ParseOptionsFromCommandLine("ab abc d 'e f g' h "),
(std::vector<std::string>{"ab", "abc", "d", "e f g", "h"}));
EXPECT_EQ(ParseOptionsFromCommandLine("a 'b"), std::nullopt);
EXPECT_EQ(ParseOptionsFromCommandLine("x\"y"), std::nullopt);
}
4 changes: 4 additions & 0 deletions integrations/pjrt/src/iree_pjrt/common/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ class AbstractCompiler {
// An AbstractCompiler based on IREE.
class IREECompiler : public AbstractCompiler {
public:
IREECompiler(std::vector<std::string> extra_options = {})
: extra_options_(std::move(extra_options)) {}

std::unique_ptr<CompilerJob> StartJob() override;
std::string GetRevision() override;
std::string GetErrorMessage() override { return error_message_; }

private:
std::string error_message_;
std::vector<std::string> extra_options_;
};

// An AbstractCompiler based on the HLO partitioner.
Expand Down
11 changes: 10 additions & 1 deletion integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/base/internal/path.h"
#include "iree/compiler/embedding_api.h"
#include "iree/compiler/loader.h"
#include "iree_pjrt/common/command_line_utils.h"
#include "iree_pjrt/partitioner_api/embedding_api.h"
#include "iree_pjrt/partitioner_api/loader.h"

Expand Down Expand Up @@ -98,7 +99,15 @@ iree_status_t DylibPlatform::SubclassInitialize() {
message.append(*loaded_compiler);
logger().debug(message);
}
compiler_ = std::make_unique<IREECompiler>();

std::vector<std::string> extra_compiler_options;
if (auto options_str = config_vars().Lookup("IREE_COMPILER_OPTIONS")) {
if (auto options = ParseOptionsFromCommandLine(*options_str)) {
extra_compiler_options = std::move(*options);
logger().debug("Extra compile options: " + *options_str);
}
}
compiler_ = std::make_unique<IREECompiler>(std::move(extra_compiler_options));
{
std::string message("Compiler Version: ");
message.append(compiler_->GetRevision());
Expand Down
20 changes: 6 additions & 14 deletions integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,12 @@ std::unique_ptr<CompilerJob> IREECompiler::StartJob() {
}

// Propagate all options set via environment variable.
// TODO: Excise/translate to something that doesn't rely on LLVM.
// if (std::optional<std::string> env_value = llvm::sys::Process::GetEnv(
// llvm::StringRef("IREE_COMPILER_OPTIONS"))) {
// llvm::SmallVector<const char*, 20> new_argv;
// llvm::BumpPtrAllocator a;
// llvm::StringSaver saver(a);

// llvm::cl::TokenizeGNUCommandLine(*env_value, saver, new_argv);
// for (auto arg : new_argv)
// if (!job->SetFlag(arg)) {
// error_message_ = job->GetErrorMessage();
// return nullptr;
// }
// }
for (auto arg : extra_options_) {
if (!job->SetFlag(arg.c_str())) {
error_message_ = job->GetErrorMessage();
return nullptr;
}
}

return job;
}
Expand Down

0 comments on commit ffa0f42

Please sign in to comment.