Skip to content

Commit

Permalink
Support for tt.bitcast and tt.int_to_ptr in pointer analysis and lowe…
Browse files Browse the repository at this point in the history
…ring (#232)

Add support for tt.bitcast and tt.int_to_ptr in pointer analysis and
lowering

## Summary
This diff adds support for tt.bitcast and tt.int_to_ptr in pointer
analysis and lowering. Kernels that does indirect memory accesses
(horizontal fusion, jagged tensor, etc.) could potentially use this
feature.

## Testing

Added Lit test

Co-authored-by: Jan Szczepaniec <[email protected]>
  • Loading branch information
Myrthan and Jan Szczepaniec authored Feb 25, 2025
1 parent b8ac4f5 commit ed25ab4
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ class PtrAnalysis {
PtrState &state, const Location loc,
OpBuilder &builder);

// Operand is the result of tt.int_to_ptr.
// Expected result:
// Directly grab op result
LogicalResult visitOperandIntToPtr(triton::IntToPtrOp intToPtrOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Operand is the result of tt.bitcast.
// Expected result:
// Directly grab op result
LogicalResult visitOperandBitcast(triton::BitcastOp bitcastOp, PtrState &state,
const Location loc, OpBuilder &builder);

// Get the computed PtrState for the forOp's init-arg at the provided index.
FailureOr<PtrState> getLoopInitArgPtrState(scf::ForOp forOp, size_t index);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,11 @@ struct BitcastConverter : public OpConversionPattern<triton::BitcastOp> {
LogicalResult
matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// arith::bitcast does not support casting pointers
if (isa<triton::PointerType>(op.getSrc().getType())) {
return failure();
}

auto arithBitcast = rewriter.create<arith::BitcastOp>(
op.getLoc(), op.getType(), op.getOperand());

Expand Down
24 changes: 24 additions & 0 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,26 @@ LogicalResult PtrAnalysis::visitOperandForOp(scf::ForOp forOp, Value operand,
return success();
}

LogicalResult PtrAnalysis::visitOperandIntToPtr(triton::IntToPtrOp op,
PtrState &state,
const Location loc,
OpBuilder &builder) {
state.source = op.getResult();
return success();
}

LogicalResult PtrAnalysis::visitOperandBitcast(triton::BitcastOp op,
PtrState &state,
const Location loc,
OpBuilder &builder) {
auto resType = op.getResult().getType();
if (isa<ShapedType>(resType)) {
return visitOperand(op.getSrc(), state, loc, builder);
}
state.source = op.getResult();
return success();
}

LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state,
const Location loc,
OpBuilder &builder) {
Expand Down Expand Up @@ -684,6 +704,10 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state,
if (auto addPtrOp = dyn_cast<triton::AddPtrOp>(op)) {
return visitOperandAddptr(cast<triton::AddPtrOp>(op), state, loc,
builder);
} else if (auto castOp = dyn_cast<triton::BitcastOp>(op)) {
return visitOperandBitcast(castOp, state, loc, builder);
} else if (auto intToPtrOp = dyn_cast<triton::IntToPtrOp>(op)) {
return visitOperandIntToPtr(intToPtrOp, state, loc, builder);
} else if (auto makeTensorOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
llvm_unreachable("Unexpected operand defining operation tts.make_tptr");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class TritonArithToLinalgPass

target.addLegalOp<triton::FuncOp, triton::ReturnOp>();

target.addDynamicallyLegalOp<triton::BitcastOp>([](triton::BitcastOp op) {
return isa<triton::PointerType>(op.getSrc().getType());
});

target.addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect>(
[](Operation *op) {
// Lower dense constant to linalg.fill
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TritonToStructured/addptr_bitcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: triton-shared-opt --triton-to-structured %s | FileCheck %s

module {
tt.func @test(%arg0: !tt.ptr<i64>, %arg1: !tt.ptr<f32>) {
%0 = tt.load %arg0 : !tt.ptr<i64>
%1 = tt.int_to_ptr %0 : i64 -> !tt.ptr<f16>
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%3 = tt.splat %1 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>>
%4 = tt.addptr %3, %2 : tensor<32x!tt.ptr<f16>>, tensor<32xi32>
%5 = tt.load %4 : tensor<32x!tt.ptr<f16>>
%6 = tt.bitcast %arg1 : !tt.ptr<f32> -> !tt.ptr<f16>
%7 = tt.splat %6 : !tt.ptr<f16> -> tensor<32x!tt.ptr<f16>>
%8 = tt.addptr %7, %2 : tensor<32x!tt.ptr<f16>>, tensor<32xi32>
tt.store %8, %5 : tensor<32x!tt.ptr<f16>>
tt.return
}
}

// CHECK: [[IN_SRC:%.+]] = tt.int_to_ptr
// CHECK: [[IN_PTR:%.+]] = tts.make_tptr [[IN_SRC]]
// CHECK: [[OUT_SRC:%.+]] = tt.bitcast
// CHECK: [[OUT_PTR:%.+]] = tts.make_tptr [[OUT_SRC]]

0 comments on commit ed25ab4

Please sign in to comment.