Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mlir-hlo (replace with stablehlo). #2460

Merged
merged 4 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "externals/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "externals/mlir-hlo"]
path = externals/mlir-hlo
url = https://github.com/tensorflow/mlir-hlo.git
[submodule "externals/stablehlo"]
path = externals/stablehlo
url = https://github.com/openxla/stablehlo.git
9 changes: 3 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,10 @@ endif()

if (TORCH_MLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo
${CMAKE_CURRENT_BINARY_DIR}/stablehlo
EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo)
endif()

set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
Expand Down
7 changes: 6 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,18 @@ Torch-MLIR by default builds with the latest nightly PyTorch version. This can b
# Updating the LLVM and MLIR-HLO submodules

Torch-MLIR depends on `llvm-project` (which contains, among other things,
upstream MLIR) and `mlir-hlo`, both of which are submodules in the `externals/`
upstream MLIR) and `stablehlo`, both of which are submodules in the `externals/`
directory. We aim to update these at least weekly to bring in the latest
features and spread out over time the effort of updating our code for MLIR API
breakages.

## Which LLVM commit should I pick?

NOTE: This section is in flux. Specifically, the `mlir-hlo` dep has been
dropped and the project is running off of a `stablehlo` fork which can be
patched for certain OS combinations. As of 2023-09-12, stellaraccident@
is massaging this situation. Please reach out for advice updating.

Since downstream projects may want to build Torch-MLIR (and thus LLVM and
MLIR-HLO) in various configurations (Release versus Debug builds; on Linux,
Windows, or macOS; possibly with Clang, LLD, and LLDB enabled), it is crucial to
Expand Down
8 changes: 1 addition & 7 deletions e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
ashay marked this conversation as resolved.
Show resolved Hide resolved

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import (
Expand All @@ -44,15 +43,14 @@
register_all_tests()

def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
default="linalg",
help=f"""
Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"stablehlo": run through torch-mlir"s default StableHLO backend.
"tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -100,10 +98,6 @@ def main():
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
crashing_set = set()
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
crashing_set = STABLEHLO_CRASHING_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = set()
Expand Down
1 change: 0 additions & 1 deletion externals/mlir-hlo
Submodule mlir-hlo deleted from 16886a
1 change: 1 addition & 0 deletions externals/stablehlo
Submodule stablehlo added at 77a598
8 changes: 0 additions & 8 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ set(LinkedLibs
TorchMLIRRefBackend
)

if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
MhloPasses
MhloToLinalg
StablehloToMhlo
)
endif()

add_mlir_library(TorchMLIRInitAll
InitAll.cpp

Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_STABLEHLO

#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
Expand Down
51 changes: 39 additions & 12 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
Expand All @@ -25,7 +26,6 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "utils/hlo_utils.h"
#include <iostream>
#include <numeric>

Expand All @@ -34,6 +34,34 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;

namespace {

template <typename T>
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (ty.isa<mlir::IntegerType>())
return b.getIntegerAttr(ty, constant);
if (ty.isa<mlir::FloatType>())
return b.getFloatAttr(ty, constant);
if (auto complexTy = ty.dyn_cast<mlir::ComplexType>())
return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
val);
}

Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
val);
}

} // namespace

LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
size_t dimSizeIndexBits) {
Expand Down Expand Up @@ -836,7 +864,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
"for AtenReciprocalOp");
}

Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
Expand Down Expand Up @@ -945,7 +973,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
}

Value zeroTensor;
zeroTensor = chlo::getConstantLike(
zeroTensor = getConstantLike(
rewriter, op->getLoc(),
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
Expand All @@ -967,9 +995,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return op.emitError("only ranked tensor type is supported.");
}

Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
Value one = getConstantLike(rewriter, loc, 1.0, input);
Value two = getConstantLike(rewriter, loc, 2.0, input);
Value half = getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
Expand Down Expand Up @@ -1485,13 +1513,12 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
}
// Create constant value
Value kAlpha =
chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value cstAlpha0 =
chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = chlo::getConstantLike(rewriter, loc, .5, input);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = getConstantLike(rewriter, loc, .5, input);
Value one = getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = getConstantLike(rewriter, loc, -0.5, input);

// Compute
Value kBeta0 =
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchToStablehlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRBufferTransforms
MLIRComplexDialect
ChloOps
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TorchConversion/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ set(LinkedLibs
)

if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs ChloPasses)
list(APPEND LinkedLibs
StablehloOps
)
endif()

add_mlir_library(TorchMLIRTorchConversionPasses
Expand Down
12 changes: 0 additions & 12 deletions lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h"

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h"
#endif

void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
Expand All @@ -40,12 +36,4 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses();

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
mlir::mhlo::registerTestUnfuseBatchNormPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
}
50 changes: 0 additions & 50 deletions python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py

This file was deleted.

8 changes: 0 additions & 8 deletions tools/torch-mlir-opt/torch-mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "torch-mlir/InitAll.h"

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "stablehlo/dialect/Register.h"
#endif

Expand All @@ -32,12 +30,6 @@ int main(int argc, char **argv) {

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
mlir::mhlo::registerTestUnfuseBatchNormPass();
#endif
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "MLIR modular optimizer driver\n", registry));
Expand Down
Loading