-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example Torch MLIR LTC Backend (#725)
- Loading branch information
1 parent
a44fb17
commit eee5cb7
Showing
10 changed files
with
434 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,3 +192,4 @@ else() | |
endif() | ||
|
||
add_subdirectory(test) | ||
add_subdirectory(examples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(ltc_backend) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
########################################################################### | ||
# Setup PyTorch | ||
########################################################################### | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") | ||
include(TorchMLIRPyTorch) | ||
TorchMLIRProbeForPyTorchInstall() | ||
find_package(Torch 1.11 REQUIRED) | ||
|
||
TorchMLIRConfigurePyTorch() | ||
|
||
########################################################################### | ||
# Setup Python development | ||
########################################################################### | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules") | ||
include(MLIRDetectPythonEnv) | ||
mlir_configure_python_dev_packages() | ||
|
||
########################################################################### | ||
# Library definition | ||
########################################################################### | ||
|
||
include_directories(BEFORE | ||
${TORCH_INCLUDE_DIRS} | ||
${CMAKE_CURRENT_SOURCE_DIR} | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
${Python3_INCLUDE_DIRS} | ||
${PYTHON_H_DIR} | ||
${PROJECT_SOURCE_DIR}/python | ||
) | ||
link_directories("${TORCH_INSTALL_PREFIX}/lib") | ||
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib) | ||
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib) | ||
|
||
file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS | ||
"ltc_backend/csrc/*.h" | ||
"ltc_backend/csrc/*.cc" | ||
"ltc_backend/csrc/*.cpp" | ||
"ltc_backend/csrc/*/*.h" | ||
"ltc_backend/csrc/*/*.cc" | ||
"ltc_backend/csrc/*/*.cpp" | ||
) | ||
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC}) | ||
add_dependencies(example_mlir_ltc_backend | ||
torch_mlir_ltc_backend | ||
) | ||
target_link_libraries(example_mlir_ltc_backend | ||
${TORCH_LIBRARIES} | ||
${Python3_LIBRARIES} | ||
torch_python | ||
torch_mlir_ltc_backend | ||
) | ||
|
||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") | ||
set_target_properties(example_mlir_ltc_backend PROPERTIES | ||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/" | ||
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND | ||
PREFIX "${PYTHON_MODULE_PREFIX}" | ||
SUFFIX "${PYTHON_MODULE_EXTENSION}" | ||
CXX_VISIBILITY_PRESET "hidden" | ||
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" | ||
) |
Empty file.
140 changes: 140 additions & 0 deletions
140
examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
//===- backend_impl.cpp ---------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <torch/csrc/lazy/backend/backend_data.h> | ||
#include <torch/csrc/lazy/backend/backend_device.h> | ||
#include <torch/csrc/lazy/backend/lowering_context.h> | ||
#include <torch/csrc/lazy/core/shape.h> | ||
|
||
#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h> | ||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h> | ||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h> | ||
#include <torch_mlir/csrc/utils/debug.h> | ||
#include <torch_mlir/csrc/utils/exception.h> | ||
|
||
#include "backend_impl.h" | ||
|
||
using namespace torch::lazy; | ||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
struct ExampleMlirBackendDeviceType : public BackendDeviceType { | ||
ExampleMlirBackendDeviceType(std::string device_type) | ||
: device_type_(device_type) {} | ||
|
||
std::string toString() const override { return device_type_; } | ||
|
||
std::string device_type_; | ||
}; | ||
|
||
class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { | ||
public: | ||
ExampleMlirBackendImpl() : default_device_type_("Magic") {} | ||
|
||
/** | ||
* Configuration | ||
* */ | ||
void SetRngSeed(size_t seed) const override { | ||
std::cout << "RNG Seed Set to: " << seed << std::endl; | ||
} | ||
|
||
/** | ||
* Lowering, Compilation, Execution | ||
* */ | ||
std::vector<std::string> | ||
GetCompilationDevices(const std::string &device, | ||
c10::ArrayRef<std::string> devices) const override { | ||
return std::vector<std::string>(devices.begin(), devices.end()); | ||
}; | ||
|
||
std::vector<ComputationPtr> | ||
Compile(std::vector<ComputationPtr> instances) const override { | ||
PRINT_FUNCTION(); | ||
|
||
// Vendor backend specific lowering can be exec here before returning. | ||
for (const auto &instance : instances) { | ||
std::cout << "Instance received at Compile: \n" | ||
<< GetComputationBackendText(instance) << std::endl; | ||
} | ||
|
||
return instances; | ||
} | ||
|
||
std::vector<BackendDataPtr> | ||
ExecuteComputation(Computation &computation, | ||
c10::ArrayRef<BackendDataPtr> arguments, | ||
const BackendDevice &device) const override { | ||
PRINT_FUNCTION(); | ||
|
||
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this | ||
// function, we will generate a list of BackendData that corresponds to the | ||
// return values in the MLIR. | ||
std::vector<torch::lazy::BackendDataPtr> results; | ||
|
||
// "Borrow" some tensor data from arguments to reuse in return. This ensures | ||
// that the tensor device is correctly configured. | ||
TORCH_CHECK(arguments.size() > 0, | ||
"Need at least one argument for example execution."); | ||
const TorchMlirBackendData *torch_mlir_data = | ||
dynamic_cast<const TorchMlirBackendData *>(arguments[0].get()); | ||
TORCH_CHECK(torch_mlir_data, | ||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData."); | ||
|
||
// For this demo we aren't performing a legitimate execution, so we generate | ||
// some dummy data to return based on the expected number of return values. | ||
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation); | ||
for (unsigned i = 0; i < mlir_computation->num_results(); i++) { | ||
results.push_back(std::make_shared<TorchMlirBackendData>( | ||
torch_mlir_data->mlir_info()->tensor, device, | ||
torch_mlir_data->shape())); | ||
} | ||
|
||
return results; | ||
} | ||
|
||
/** | ||
* Device Configuration | ||
* */ | ||
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const { | ||
return std::make_shared<BackendDeviceType>(default_device_type_); | ||
} | ||
|
||
void SetDefaultDeviceType(std::string device_type) { | ||
default_device_type_ = ExampleMlirBackendDeviceType(device_type); | ||
} | ||
|
||
/** | ||
* Debug/Metrics | ||
* */ | ||
std::string | ||
GetComputationBackendText(const ComputationPtr computation) const override { | ||
auto mlir_computation = | ||
static_cast<TorchMlirComputation *>(computation.get()); | ||
return mlir_computation->to_string(); | ||
} | ||
|
||
private: | ||
ExampleMlirBackendDeviceType default_device_type_; | ||
}; | ||
|
||
BackendImplInterface *GetExampleMlirBackendImpl() { | ||
static ExampleMlirBackendImpl *example_mlir_backend_impl = | ||
new ExampleMlirBackendImpl(); | ||
return example_mlir_backend_impl; | ||
} | ||
|
||
void InitExampleMlirBackend() { | ||
at::RegisterTorchMlirLazyNativeFunctions(); | ||
static std::unique_ptr<BackendRegistrar> g_registrar; | ||
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl())); | ||
} | ||
|
||
} // namespace lazy | ||
} // namespace torch |
27 changes: 27 additions & 0 deletions
27
examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===- backend_impl.h -----------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <torch/csrc/lazy/backend/backend_interface.h> | ||
|
||
namespace at { | ||
// This function is defined in the codegenerated RegisterLazy.cpp file. | ||
TORCH_API void RegisterTorchMlirLazyNativeFunctions(); | ||
} // namespace at | ||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl(); | ||
|
||
void InitExampleMlirBackend(); | ||
|
||
} // namespace lazy | ||
} // namespace torch |
73 changes: 73 additions & 0 deletions
73
examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
//===- example_mlir_backend_pybind.cpp ------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch/csrc/jit/python/pybind.h" | ||
#include "torch/csrc/lazy/backend/backend_interface.h" | ||
|
||
#include <exception> | ||
#include <iostream> | ||
#include <string> | ||
|
||
#include "backend/backend_impl.h" | ||
#include "utils/sys_utils.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace { | ||
bool verbose = sys_util::GetEnv("VERBOSE", false); | ||
|
||
struct NoGilSection { | ||
NoGilSection() : state(PyEval_SaveThread()) {} | ||
~NoGilSection() { PyEval_RestoreThread(state); } | ||
PyThreadState *state = nullptr; | ||
}; | ||
|
||
/** | ||
* @brief Install the plugin | ||
*/ | ||
void Initialize() { | ||
// Initialize the Example MLIR LTC Backend | ||
torch::lazy::InitExampleMlirBackend(); | ||
|
||
// sanity check | ||
const torch::lazy::BackendImplInterface *mlir_backend = | ||
torch::lazy::GetExampleMlirBackendImpl(); | ||
const torch::lazy::BackendImplInterface *lazy_backend = | ||
torch::lazy::getBackend(); | ||
if (lazy_backend != mlir_backend) { | ||
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl; | ||
throw std::runtime_error("Failed to initialize MLIR Lazy Backend"); | ||
} | ||
|
||
if (verbose) { | ||
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; | ||
} | ||
} | ||
|
||
/** | ||
* @brief Uninstall the plugin | ||
*/ | ||
void Shutdown() { | ||
if (verbose) { | ||
std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl; | ||
} | ||
} | ||
} // anonymous namespace | ||
|
||
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) { | ||
m.doc() = ("pybind11 for example MLIR LTC backend."); | ||
m.def("_initialize", []() { | ||
NoGilSection gil; | ||
Initialize(); | ||
}); | ||
m.def("_shutdown", []() { | ||
NoGilSection gil; | ||
Shutdown(); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//===- sys_utils.h --------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <cstdlib> | ||
#include <string> | ||
|
||
namespace sys_util { | ||
|
||
template <typename T> | ||
T GetEnv(const std::string &name, const T &default_value = T(0)) { | ||
const char *env = std::getenv(name.c_str()); | ||
if (!env) { | ||
return default_value; | ||
} | ||
return T(std::atoi(env)); | ||
} | ||
|
||
} // namespace sys_util |
Oops, something went wrong.