Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@e2402615a5a7
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Jan 21, 2025
1 parent 23d7f60 commit d6682a0
Show file tree
Hide file tree
Showing 12 changed files with 594 additions and 89 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "b270525f730be6e7196667925f5a9bfa153262e9"
LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d"

LLVM_SHA256 = "fcf77da395cd5097eac5951471b04aa887f565c3447545239421a0eb7089da7c"
LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b270525f730be6e7196667925f5a9bfa153262e9
e2402615a5a76d46a433dfcc1de10b38a1263c9d
2 changes: 1 addition & 1 deletion examples/c++/ExampleAdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main() {
/** create function **/
// create function argument and result types.
auto tensorType =
mlir::RankedTensorType::get({3, 4}, mlir::FloatType::getF32(&context));
mlir::RankedTensorType::get({3, 4}, mlir::Float32Type::get(&context));
auto func_type =
mlir::FunctionType::get(&context, {tensorType, tensorType}, {tensorType});

Expand Down
10 changes: 6 additions & 4 deletions stablehlo/conversions/tosa/tests/nullary.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@ func.func @constant_f64() -> tensor<10xf64> {
// CHECK-LABEL: @iota_dimension_0
func.func @iota_dimension_0() -> tensor<4x8xf32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"()
// CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}>
// CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array<i64: 1, 8>}
// CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> : () -> tensor<4x1xf32>
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 8]> : vector<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]]
%0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>)
return %0 : tensor<4x8xf32>
}

// CHECK-LABEL: @iota_dimension_1
func.func @iota_dimension_1() -> tensor<4x8xi32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"()
// CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}>
// CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array<i64: 4, 1>}
// CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> : () -> tensor<1x8xi32>
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[4, 1]> : vector<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]]
%0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>)
return %0 : tensor<4x8xi32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,13 @@ struct ConvertStablehloIotaOp : public OpRewritePattern<stablehlo::IotaOp> {
}
}

auto shapeType = rewriter.getType<tosa::shapeType>(tileMultiples.size());
auto shapedMultiples = rewriter.create<tosa::ConstShapeOp>(
op.getLoc(), shapeType, rewriter.getIndexVectorAttr(tileMultiples));

// Tile the const array to the result shape of the iota op.
rewriter.replaceOpWithNewOp<tosa::TileOp>(
op, resultType, constOp, rewriter.getDenseI64ArrayAttr(tileMultiples));
rewriter.replaceOpWithNewOp<tosa::TileOp>(op, resultType, constOp,
shapedMultiples);
return success();
}
};
Expand Down
65 changes: 65 additions & 0 deletions stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3927,3 +3927,68 @@ func.func @square_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.square"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}

// -----

// CHECK-LABEL: @ragged_dot_mode_1
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<2x11x5xf32>, %[[ARG_1:.*]]: tensor<3x2x5x7xf32>, %[[ARG_2:.*]]: tensor<3xi64>) -> tensor<2x11x7xf32> {
// CHECK: %[[VAL_0:.*]] = stablehlo.iota dim = 1 : tensor<1x11x1xi64>
// CHECK: %[[VAL_C:.*]] = stablehlo.constant dense<0> : tensor<1xi64>
// CHECK: %[[VAL_CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x11x7xf32>
// CHECK: %[[VAL_CST_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x11x7xf32>
// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[ARG_2]] [0:1] : (tensor<3xi64>) -> tensor<1xi64>
// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_C]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_3:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_C]], %[[VAL_1]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = stablehlo.broadcast_in_dim %[[VAL_4]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_6:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_5]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_7:.*]] = stablehlo.and %[[VAL_3]], %[[VAL_6]] : tensor<1x11x1xi1>
// CHECK: %[[VAL_8:.*]] = stablehlo.broadcast_in_dim %[[VAL_7]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1>
// CHECK: %[[VAL_9:.*]] = stablehlo.slice %[[ARG_1]] [0:1, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32>
// CHECK: %[[VAL_10:.*]] = stablehlo.reshape %[[VAL_9]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32>
// CHECK: %[[VAL_11:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_10]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32>
// CHECK: %[[VAL_12:.*]] = stablehlo.select %[[VAL_8]], %[[VAL_11]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32>
// CHECK: %[[VAL_13:.*]] = stablehlo.add %[[VAL_CST]], %[[VAL_12]] : tensor<2x11x7xf32>
// CHECK: %[[VAL_14:.*]] = stablehlo.add %[[VAL_C]], %[[VAL_1]] : tensor<1xi64>
// CHECK: %[[VAL_15:.*]] = stablehlo.slice %[[ARG_2]] [1:2] : (tensor<3xi64>) -> tensor<1xi64>
// CHECK: %[[VAL_16:.*]] = stablehlo.broadcast_in_dim %[[VAL_14]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.compare LE, %[[VAL_16]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_18:.*]] = stablehlo.add %[[VAL_14]], %[[VAL_15]] : tensor<1xi64>
// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_18]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_20:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_19]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_21:.*]] = stablehlo.and %[[VAL_17]], %[[VAL_20]] : tensor<1x11x1xi1>
// CHECK: %[[VAL_22:.*]] = stablehlo.broadcast_in_dim %[[VAL_21]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1>
// CHECK: %[[VAL_23:.*]] = stablehlo.slice %[[ARG_1]] [1:2, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32>
// CHECK: %[[VAL_24:.*]] = stablehlo.reshape %[[VAL_23]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32>
// CHECK: %[[VAL_25:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_24]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32>
// CHECK: %[[VAL_26:.*]] = stablehlo.select %[[VAL_22]], %[[VAL_25]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32>
// CHECK: %[[VAL_27:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_26]] : tensor<2x11x7xf32>
// CHECK: %[[VAL_28:.*]] = stablehlo.add %[[VAL_14]], %[[VAL_15]] : tensor<1xi64>
// CHECK: %[[VAL_29:.*]] = stablehlo.slice %[[ARG_2]] [2:3] : (tensor<3xi64>) -> tensor<1xi64>
// CHECK: %[[VAL_30:.*]] = stablehlo.broadcast_in_dim %[[VAL_28]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_31:.*]] = stablehlo.compare LE, %[[VAL_30]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_32:.*]] = stablehlo.add %[[VAL_28]], %[[VAL_29]] : tensor<1xi64>
// CHECK: %[[VAL_33:.*]] = stablehlo.broadcast_in_dim %[[VAL_32]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64>
// CHECK: %[[VAL_34:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_33]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1>
// CHECK: %[[VAL_35:.*]] = stablehlo.and %[[VAL_31]], %[[VAL_34]] : tensor<1x11x1xi1>
// CHECK: %[[VAL_36:.*]] = stablehlo.broadcast_in_dim %[[VAL_35]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1>
// CHECK: %[[VAL_37:.*]] = stablehlo.slice %[[ARG_1]] [2:3, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32>
// CHECK: %[[VAL_38:.*]] = stablehlo.reshape %[[VAL_37]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32>
// CHECK: %[[VAL_39:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_38]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32>
// CHECK: %[[VAL_40:.*]] = stablehlo.select %[[VAL_36]], %[[VAL_39]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32>
// CHECK: %[[VAL_41:.*]] = stablehlo.add %[[VAL_27]], %[[VAL_40]] : tensor<2x11x7xf32>
// CHECK: %[[VAL_42:.*]] = stablehlo.add %[[VAL_28]], %[[VAL_29]] : tensor<1xi64>
func.func @ragged_dot_mode_1(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> {
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [1],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [1],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32>
func.return %0 : tensor<2x11x7xf32>
}
Loading

0 comments on commit d6682a0

Please sign in to comment.