From 4147676d82384c54bda1816682d671dafb36c6f3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 6 Oct 2022 17:11:32 +0200 Subject: [PATCH] Support integers of arbitrary width and signedness as ntensor.getitem/setitem indices --- .../imex/Dialect/imex_util/ImexUtilOps.td | 4 +- .../ntensor/Transforms/ResolveArrayOps.cpp | 24 ++++++- .../Dialect/ntensor/resolve-array-ops.mlir | 62 +++++++++++++++++++ 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/mlir/include/imex/Dialect/imex_util/ImexUtilOps.td b/mlir/include/imex/Dialect/imex_util/ImexUtilOps.td index 7229ad90c..6724173eb 100644 --- a/mlir/include/imex/Dialect/imex_util/ImexUtilOps.td +++ b/mlir/include/imex/Dialect/imex_util/ImexUtilOps.td @@ -125,9 +125,11 @@ def ChangeLayoutOp : ImexUtil_Op<"change_layout", [ViewLikeOpInterface, NoSideEf def SignCastOp : ImexUtil_Op<"sign_cast", [NoSideEffect]> { let arguments = (ins AnyType : $value); - let results = (outs AnyType); + let results = (outs AnyType:$dest); let hasFolder = 1; let hasCanonicalizer = 1; + + let assemblyFormat = "$value attr-dict `:` type($value) `to` type($dest)"; } def ExtractMemrefMetadataOp diff --git a/mlir/lib/Dialect/ntensor/Transforms/ResolveArrayOps.cpp b/mlir/lib/Dialect/ntensor/Transforms/ResolveArrayOps.cpp index 93ebbd867..c2c83b348 100644 --- a/mlir/lib/Dialect/ntensor/Transforms/ResolveArrayOps.cpp +++ b/mlir/lib/Dialect/ntensor/Transforms/ResolveArrayOps.cpp @@ -22,7 +22,8 @@ #include static bool isIndexOrSlice(mlir::Type type) { - return type.isa(); + return type + .isa(); } static bool isValidGetitemIndex(mlir::Type type) { @@ -35,6 +36,23 @@ static bool isValidGetitemIndex(mlir::Type type) { return false; } +static mlir::Value convertIndex(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value value) { + auto intType = value.getType().dyn_cast(); + if (intType) { + if (intType.getSignedness() != mlir::IntegerType::Signless) { + auto signlessType = + mlir::IntegerType::get(builder.getContext(), intType.getWidth()); + value = builder.create(loc, signlessType, value); + } + + auto indexType = builder.getIndexType(); + value = builder.create(loc, indexType, value); + } + + return value; +} + static mlir::LogicalResult computeIndices(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value, mlir::Value index, @@ -70,8 +88,8 @@ computeIndices(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value, auto size = resolved.getCount(); return {foldConst(begin), foldConst(size), foldConst(step), true}; } else { - mlir::Value index = - builder.create(loc, indexVal, len); + mlir::Value index = convertIndex(builder, loc, indexVal); + index = builder.create(loc, index, len); return {index, builder.getIndexAttr(1), builder.getIndexAttr(1), false}; } }; diff --git a/mlir/test/Dialect/ntensor/resolve-array-ops.mlir b/mlir/test/Dialect/ntensor/resolve-array-ops.mlir index d24b8f55d..412badc4b 100644 --- a/mlir/test/Dialect/ntensor/resolve-array-ops.mlir +++ b/mlir/test/Dialect/ntensor/resolve-array-ops.mlir @@ -14,6 +14,37 @@ func.func @test(%arg1: !ntensor.ntensor, %arg2: index) -> f32 { // ----- +func.func @test(%arg1: !ntensor.ntensor, %arg2: i32) -> f32 { + %0 = ntensor.getitem(%arg1 : !ntensor.ntensor) [%arg2 : i32] -> f32 + return %0 : f32 +} +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: !ntensor.ntensor, %[[ARG2:.*]]: i32) +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[DIM:.*]] = ntensor.dim %[[ARG1]], %[[C0]] : !ntensor.ntensor +// CHECK-NEXT: %[[IDIM:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK-NEXT: %[[IND:.*]] = ntensor.resolve_index %[[IDIM]], %[[DIM]] +// CHECK-NEXT: %[[RES:.*]] = ntensor.load %[[ARG1]][%[[IND]]] : !ntensor.ntensor +// CHECK-NEXT: return %[[RES]] : f32 + +// ----- + +func.func @test(%arg1: !ntensor.ntensor, %arg2: si32) -> f32 { + %0 = ntensor.getitem(%arg1 : !ntensor.ntensor) [%arg2 : si32] -> f32 + return %0 : f32 +} +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: !ntensor.ntensor, %[[ARG2:.*]]: si32) +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[DIM:.*]] = ntensor.dim %[[ARG1]], %[[C0]] : !ntensor.ntensor +// CHECK-NEXT: %[[SDIM:.*]] = imex_util.sign_cast %[[ARG2]] : si32 to i32 +// CHECK-NEXT: %[[IDIM:.*]] = arith.index_cast %[[SDIM]] : i32 to index +// CHECK-NEXT: %[[IND:.*]] = ntensor.resolve_index %[[IDIM]], %[[DIM]] +// CHECK-NEXT: %[[RES:.*]] = ntensor.load %[[ARG1]][%[[IND]]] : !ntensor.ntensor +// CHECK-NEXT: return %[[RES]] : f32 + +// ----- + func.func @test(%arg1: !ntensor.ntensor, %arg2: index, %arg3: f32) { ntensor.setitem(%arg1 : !ntensor.ntensor) [%arg2 : index] = (%arg3 : f32) return @@ -28,6 +59,37 @@ func.func @test(%arg1: !ntensor.ntensor, %arg2: index, %arg3: f32) { // ----- +func.func @test(%arg1: !ntensor.ntensor, %arg2: i32, %arg3: f32) { + ntensor.setitem(%arg1 : !ntensor.ntensor) [%arg2 : i32] = (%arg3 : f32) + return +} +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: !ntensor.ntensor, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: f32) +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[DIM:.*]] = ntensor.dim %[[ARG1]], %[[C0]] : !ntensor.ntensor +// CHECK-NEXT: %[[IDIM:.*]] = arith.index_cast %[[ARG2]] : i32 to index +// CHECK-NEXT: %[[IND:.*]] = ntensor.resolve_index %[[IDIM]], %[[DIM]] +// CHECK-NEXT: ntensor.store %[[ARG3]], %[[ARG1]][%[[IND]]] : !ntensor.ntensor +// CHECK-NEXT: return + +// ----- + +func.func @test(%arg1: !ntensor.ntensor, %arg2: si32, %arg3: f32) { + ntensor.setitem(%arg1 : !ntensor.ntensor) [%arg2 : si32] = (%arg3 : f32) + return +} +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[ARG1:.*]]: !ntensor.ntensor, %[[ARG2:.*]]: si32, %[[ARG3:.*]]: f32) +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[DIM:.*]] = ntensor.dim %[[ARG1]], %[[C0]] : !ntensor.ntensor +// CHECK-NEXT: %[[SDIM:.*]] = imex_util.sign_cast %[[ARG2]] : si32 to i32 +// CHECK-NEXT: %[[IDIM:.*]] = arith.index_cast %[[SDIM]] : i32 to index +// CHECK-NEXT: %[[IND:.*]] = ntensor.resolve_index %[[IDIM]], %[[DIM]] +// CHECK-NEXT: ntensor.store %[[ARG3]], %[[ARG1]][%[[IND]]] : !ntensor.ntensor +// CHECK-NEXT: return + +// ----- + func.func @test(%arg1: !ntensor.ntensor, %arg2: !ntensor.slice) -> !ntensor.ntensor { %0 = ntensor.getitem(%arg1 : !ntensor.ntensor) [%arg2 : !ntensor.slice] -> !ntensor.ntensor return %0 : !ntensor.ntensor