Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate llvm-project and mlir-hlo. #2454

Merged
merged 9 commits into from
Sep 12, 2023
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 9061 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 3994 files
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,7 +2346,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op.getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape),
inputType.getElementType()),
sumDiv, rewriter.getI64IntegerAttr(i));
sumDiv, rewriter.getI32IntegerAttr(i));
}

return rewriter.create<tosa::ReshapeOp>(
Expand Down Expand Up @@ -3214,7 +3214,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
prunedShape.push_back(en.value());
}

auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
Expand Down Expand Up @@ -4787,7 +4787,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);

auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim));
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
rewriter.replaceOp(op, result.getResult());
return success();
}
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));

// And reshape to [N, W]
// %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) ->
Expand Down Expand Up @@ -648,7 +648,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));

// And reshape to [N, W]
// [[1],[2],[3]] -> [[1,2,3]]
Expand Down Expand Up @@ -717,7 +717,7 @@ std::optional<Value> convertReduceOpCommon(
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);

shape_vec[axis_val] = 1;
RankedTensorType reduce_type = RankedTensorType::get(
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ add_mlir_library(TorchMLIRTorchDialect
Core

LINK_LIBS PUBLIC
MLIRBytecodeOpInterface
MLIRBytecodeReader
MLIRBytecodeWriter
MLIRFuncDialect
MLIRIR
MLIRSupport
Expand Down
53 changes: 28 additions & 25 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
// PrimLoopOp
//===----------------------------------------------------------------------===//

OperandRange
PrimLoopOp::getEntrySuccessorOperands(std::optional<unsigned int> index) {
assert(index.has_value() && index.value() == 0);
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getRegion());
return getIterArgsInit();
}

void PrimLoopOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {

if (!index.has_value()) {
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region &region = getRegion();
if (!point.getRegionOrNull()) {
regions.emplace_back(&region, region.getArguments().slice(1));
return;
}
assert(*index == 0);
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
assert(point == region);
regions.emplace_back(&region, region.getArguments().slice(1));
regions.emplace_back(getResults());
}

Expand All @@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
// PrimLoopConditionOp
//===----------------------------------------------------------------------===//

MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
MutableOperandRange
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
// Pass all operands except the condition to the successor which is the
// parent loop op.
return getIterArgsMutable();
Expand Down Expand Up @@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}

void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (index.has_value()) {
if (point.getRegionOrNull()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
Expand Down Expand Up @@ -1595,7 +1594,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -1635,7 +1636,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
auto attr = properties.as<Properties *>()
->getValue()
.dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -2768,43 +2771,43 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {

template <typename CalculateOp>
static void
getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional<unsigned> index,
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (!index.has_value()) {
if (!point.getRegionOrNull()) {
// First thing the op does is branch into the calculation.
regions.emplace_back(&op.getCalculation());
return;
}
if (*index == 0) {
if (point == op.getBody()) {
// Body returns control to the outer op, passing through results.
regions.emplace_back(op.getResults());
return;
}
assert(*index == 1);
assert(point == op.getCalculation());
// Calculation branches to the body.
regions.emplace_back(&op.getBody());
}

void ShapeCalculateOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, regions);
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

//===----------------------------------------------------------------------===//
// DtypeCalculateOp
//===----------------------------------------------------------------------===//

void DtypeCalculateOp::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, regions);
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

//===----------------------------------------------------------------------===//
// ShapeCalculateYieldShapesOp
//===----------------------------------------------------------------------===//

MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
RegionBranchPoint point) {
// The shape operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand All @@ -2823,7 +2826,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
//===----------------------------------------------------------------------===//

MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
std::optional<unsigned> index) {
RegionBranchPoint point) {
// The dtype operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
Expand Down
Loading
Loading