Skip to content

Commit

Permalink
Integrate llvm-project and mlir-hlo. (#2454)
Browse files Browse the repository at this point in the history
Corresponding commits:

* mlir-hlo: 16886a108eff5197f816ca0f1950cc5ff1b078d9
* stablehlo: 77a59815a82b34f7b08ed2d42a711d9920682d0e
* llvm-project: 4acc3ff

* Adapt to ByteCodeOpInterface changes.
* Adapt to RegionBranchPoint changes: https://reviews.llvm.org/D159116
* Adapt inferReturnTypes to get the value from properties.
* Adapt invalid.mlir to properties syntax
* [TOSA] Align with custom assembly format change.
* [TOSA] handle change of axis to int32 type
* [TOSA] Restore improper convert to i32

Landing with Windows broken (it cannot be fixed because of the way the mlir-hlo dep is inserted). Will followup with an untangling.
---------

Co-authored-by: TatWai Chong <[email protected]>
Co-authored-by: Eric Kunze <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent 106b585 commit a00a0d4
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 192 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 9061 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 3994 files
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,7 +2346,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op.getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape),
inputType.getElementType()),
sumDiv, rewriter.getI64IntegerAttr(i));
sumDiv, rewriter.getI32IntegerAttr(i));
}

return rewriter.create<tosa::ReshapeOp>(
Expand Down Expand Up @@ -3214,7 +3214,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
prunedShape.push_back(en.value());
}

auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
Expand Down Expand Up @@ -4787,7 +4787,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);

auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim));
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
rewriter.replaceOp(op, result.getResult());
return success();
}
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));

// And reshape to [N, W]
// %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) ->
Expand Down Expand Up @@ -648,7 +648,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));

// And reshape to [N, W]
// [[1],[2],[3]] -> [[1,2,3]]
Expand Down Expand Up @@ -717,7 +717,7 @@ std::optional<Value> convertReduceOpCommon(
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);

shape_vec[axis_val] = 1;
RankedTensorType reduce_type = RankedTensorType::get(
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ add_mlir_library(TorchMLIRTorchDialect
Core

LINK_LIBS PUBLIC
MLIRBytecodeOpInterface
MLIRBytecodeReader
MLIRBytecodeWriter
MLIRFuncDialect
MLIRIR
MLIRSupport
Expand Down
53 changes: 28 additions & 25 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
// PrimLoopOp
//===----------------------------------------------------------------------===//

OperandRange
PrimLoopOp::getEntrySuccessorOperands(std::optional<unsigned int> index) {
assert(index.has_value() && index.value() == 0);
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getRegion());
return getIterArgsInit();
}

void PrimLoopOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {

if (!index.has_value()) {
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region &region = getRegion();
if (!point.getRegionOrNull()) {
regions.emplace_back(&region, region.getArguments().slice(1));
return;
}
assert(*index == 0);
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
assert(point == region);
regions.emplace_back(&region, region.getArguments().slice(1));
regions.emplace_back(getResults());
}

Expand All @@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
// PrimLoopConditionOp
//===----------------------------------------------------------------------===//

MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
MutableOperandRange
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
// Pass all operands except the condition to the successor which is the
// parent loop op.
return getIterArgsMutable();
Expand Down Expand Up @@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}

void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index.has_value()) {
if (point.getRegionOrNull()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
Expand Down Expand Up @@ -1595,7 +1594,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -1635,7 +1636,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -2768,43 +2771,43 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {

template <typename CalculateOp>
static void
getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional<unsigned> index,
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (!index.has_value()) {
if (!point.getRegionOrNull()) {
// First thing the op does is branch into the calculation.
regions.emplace_back(&op.getCalculation());
return;
}
if (*index == 0) {
if (point == op.getBody()) {
// Body returns control to the outer op, passing through results.
regions.emplace_back(op.getResults());
return;
}
assert(*index == 1);
assert(point == op.getCalculation());
// Calculation branches to the body.
regions.emplace_back(&op.getBody());
}

void ShapeCalculateOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, regions);
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

//===----------------------------------------------------------------------===//
// DtypeCalculateOp
//===----------------------------------------------------------------------===//

void DtypeCalculateOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, regions);
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

//===----------------------------------------------------------------------===//
// ShapeCalculateYieldShapesOp
//===----------------------------------------------------------------------===//

MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
RegionBranchPoint point) {
// The shape operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand All @@ -2823,7 +2826,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
//===----------------------------------------------------------------------===//

MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
RegionBranchPoint point) {
// The dtype operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand Down
Loading

0 comments on commit a00a0d4

Please sign in to comment.