Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPUVectorDistribute] Add general support for statically tiled codegen on dynamic shapes #19992

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

static llvm::cl::opt<bool> clEnableBlockedMatmuls(
"iree-codegen-block-dynamic-dimensions-of-contractions",
llvm::cl::desc("developer flag to gaurd blocking dynamic dimensions of "
llvm::cl::desc("developer flag to guard blocking dynamic dimensions of "
"contraction-like ops"),
llvm::cl::Hidden, llvm::cl::init(true));

Expand Down Expand Up @@ -315,6 +315,13 @@ void BlockDynamicDimensionsPass::runOnOperation() {
IRRewriter rewriter(context);
auto walkResult = operation->walk([&](Operation *op) -> WalkResult {
rewriter.setInsertionPoint(op);
// If lowering config is set, changing the dimensionality of
// of the op will break the mapping. Therefore, skip operations
// that has lowering config set.
if (op->hasAttrOfType<IREE::Codegen::LoweringConfigAttrInterface>(
"lowering_config")) {
return success();
}
Comment on lines +318 to +324
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this was used while debugging. We should remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ...
BlockDynamicDimension pass deletes the lowering config if it changes the linalg op.
Also the lowering config does not make sense after the dimensionality change.

return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op);
});
if (walkResult.wasInterrupted()) {
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ static void tileAndBufferize(OpPassManager &funcPassManager) {
}

static void addGPUVectorizationPasses(OpPassManager &funcPassManager,
bool vectorizeCopies = true) {
bool vectorizeCopies = true,
bool enableMasking = false) {
funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeIm2colPass());
funcPassManager.addPass(
Expand All @@ -239,8 +240,10 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager,
options.vectorizeGatherAccesses = true;
options.enableCleanup = false;
options.foldCastIntoContract = true;
options.enableVectorMasking = enableMasking;
funcPassManager.addPass(createGenericVectorizationPass(options));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(memref::createResolveShapedTypeResultDimsPass());
funcPassManager.addPass(createCSEPass());
// Run subset hoisting to convert iter_args to vectors.
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
Expand Down Expand Up @@ -866,7 +869,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());

// Linalg -> Vector
addGPUVectorizationPasses(funcPassManager);
addGPUVectorizationPasses(funcPassManager, /*vectorizeCopies=*/true,
/*enableMasking=*/true);

// Allocate tensors for copies to shared memory.
funcPassManager.addPass(createGPUVectorAllocPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,17 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// CHECK-DAG: %[[LHS_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<64x968x1281xf16, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[RHS_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : memref<64x1281x1281xf16, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[OUT_GLOBAL:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : memref<64x968x1281xf16, #hal.descriptor_type<storage_buffer>>
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[MASK0:.+]] = vector.create_mask %c1, %[[M0UB0:.+]], %{{.+}} : vector<1x1x4xi1>
// CHECK-DAG: %[[MASK1:.+]] = vector.create_mask %c1, %{{.+}}, %[[M1UB1:.+]] : vector<1x1x4xi1>
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL]]{{.+}} %[[MASK0]] {in_bounds = [true, true, true]}
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL]]{{.+}} %[[MASK1]] {in_bounds = [true, true, true]}
// CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
// CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK-DAG: %[[MASK0:.+]] = vector.create_mask %c1, %[[M0UB0]], %{{.+}} : vector<1x1x4xi1>
// CHECK-DAG: %[[MASK1:.+]] = vector.create_mask %c1, %{{.+}}, %[[M1UB1]] : vector<1x1x4xi1>
// CHECK-DAG: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL]]{{.+}} %[[MASK0]] {in_bounds = [true, true, true]}
// CHECK-DAG: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL]]{{.+}} %[[MASK1]] {in_bounds = [true, true, true]}
// CHECK: gpu.barrier
// CHECK-DAG: %{{.+}} = vector.transfer_read %[[LHS_SHARED]]
// CHECK-DAG: %{{.+}} = vector.transfer_read %[[RHS_SHARED]]
Expand Down Expand Up @@ -629,9 +629,11 @@ hal.executable public @pad_batch_matmul {
// RHS
// The dynamic dimension should be removed after:
// https://github.com/llvm/llvm-project/pull/112236
// CHECK: %[[MASK:.+]] = vector.create_mask %c1, %{{.+}}, %{{.+}} : vector<1x1x2xi1>
// CHECK: vector.transfer_read
// CHECK-SAME: in_bounds = [true, false, false]
// CHECK-SAME: memref<1x?x24xf32
// CHECK-SAME: %[[MASK]]
// CHECK-SAME: in_bounds = [true, true, true]
// CHECK-SAME: memref<196x24x24xf32
Comment on lines +632 to +636
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. So instead of in_bounds attr, we are relying on masking. Does it produce the same code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they both produce conditional code.

// CHECK-SAME: vector<1x1x2xf32>
// CHECK: scf.yield
// OUTPUT
Expand Down Expand Up @@ -1312,3 +1314,127 @@ module {
// MEMORY-LABEL: func.func @attention_gather_k
// MEMORY-COUNT-3: memref.alloc
// MEMORY-NOT: memref.alloc

// -----

#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64>
#lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 0, 512], workgroup = [1, 1, 1, 32, 0, 0]}>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use partial_reduction instead of reduction here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should be tiling the outer K2 dimension to number of warps?

Copy link
Contributor Author

@manupak manupak Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an example.
Would you be able to spell out the config that you d like tested here ?


// {indexing_maps = [
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>,
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)
// ],
// iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]
// }
Comment on lines +1324 to +1330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented out code

Copy link
Contributor Author

@manupak manupak Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I thought its useful to understand the lowering config dimensionality of
QK and PV matmul generics.
Otherwise it appears as a set of magic numbers that is not represented in linalg_ext.attention op.

#qk_config = {
lowering_config = #iree_gpu.lowering_config<{
subgroup_basis = [[1, 1, 1, 1, 1], [0, 1, 2, 3, 4]],
thread_basis = [[1, 1, 1, 1, 64], [0, 1, 2, 3, 4]],
thread = [0, 0, 0, 8, 8]
}>
}

//{indexing_maps = [
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>,
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
// affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// ],
// iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
// }
#pv_config = {
lowering_config = #iree_gpu.lowering_config<{
subgroup_basis = [[1, 1, 1, 4, 1], [0, 1, 2, 3, 4]],
thread_basis = [[1, 1, 1, 1, 64], [0, 1, 2, 3, 4]],
thread = [0, 0, 0, 8, 8]
}>
}

module {
hal.executable public @decode_attn_dispatch_0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should be in pipeline_vector_distribute_gfx942_reduction.mlir (or whever the reduction test file is called)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? all other attention tests are here...

hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @attention_dynamic_masked ordinal(0) layout(#hal.pipeline.layout<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We don't need pipeline binding flags like "ReadOnly|Indirect" for tests. Check how other tests do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @attention_dynamic_masked() attributes {translation_info = #translation} {
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 8.837890e-02 : f16
%0 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
%5 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
%6 = arith.extui %0 : i32 to i64
%7 = arith.extui %1 : i32 to i64
%8 = arith.shli %7, %c32_i64 : i64
%9 = arith.ori %6, %8 : i64
%10 = arith.index_castui %9 : i64 to index
%11 = arith.extui %2 : i32 to i64
%12 = arith.extui %3 : i32 to i64
%13 = arith.shli %12, %c32_i64 : i64
%14 = arith.ori %11, %13 : i64
%15 = arith.index_castui %14 : i64 to index
%16 = arith.extui %4 : i32 to i64
%17 = arith.extui %5 : i32 to i64
%18 = arith.shli %17, %c32_i64 : i64
%19 = arith.ori %16, %18 : i64
%20 = arith.index_castui %19 : i64 to index
%21:3 = util.assume.int
%10<umin = 0, umax = 9007199254740991>,
%15<umin = 0, umax = 9007199254740991>,
%20<umin = 0, umax = 9007199254740991>
: index, index, index
%22 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x32x1x128xf16>>
%23 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(4) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<4x32x1x128xf16>>
%24 = flow.dispatch.workload.ordinal %21#0, 0 : index
%25 = flow.dispatch.workload.ordinal %21#1, 1 : index
%26 = flow.dispatch.workload.ordinal %21#2, 2 : index
%27 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x32x?x128xf16>>{%24}
%28 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x32x128x?xf16>>{%25}
%29 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x32x1x?xf16>>{%26}
%30 = flow.dispatch.tensor.load %22, offsets = [0, 0, 0, 0], sizes = [4, 32, 1, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x32x1x128xf16>> -> tensor<4x32x1x128xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a simpler test? I don't think we need all these pipeline.binding flags.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try; for some reason I thought its needed for the test as every other test in the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you mean remove the flags but keep the hal ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok Ive removed the flags and made it look more similiar to other tests.

%31 = flow.dispatch.tensor.load %27, offsets = [0, 0, 0, 0], sizes = [4, 32, %24, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x32x?x128xf16>>{%24} -> tensor<4x32x?x128xf16>
%32 = flow.dispatch.tensor.load %28, offsets = [0, 0, 0, 0], sizes = [4, 32, 128, %25], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x32x128x?xf16>>{%25} -> tensor<4x32x128x?xf16>
%33 = flow.dispatch.tensor.load %29, offsets = [0, 0, 0, 0], sizes = [4, 32, 1, %26], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x32x1x?xf16>>{%26} -> tensor<4x32x1x?xf16>
%34 = tensor.empty() : tensor<4x32x1x128xf16>
%35 = iree_linalg_ext.attention {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
]
,lowering_config = #lowering_config
,decomposition_config = { qk_attrs = #qk_config, pv_attrs = #pv_config }
}
ins(%30, %31, %32, %cst, %33 : tensor<4x32x1x128xf16>, tensor<4x32x?x128xf16>, tensor<4x32x128x?xf16>, f16, tensor<4x32x1x?xf16>)
outs(%34 : tensor<4x32x1x128xf16>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<4x32x1x128xf16>
flow.dispatch.tensor.store %35, %23, offsets = [0, 0, 0, 0], sizes = [4, 32, 1, 128], strides = [1, 1, 1, 1] : tensor<4x32x1x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<4x32x1x128xf16>>
return
}
}
}
}
}

// CHECK-LABEL: func.func @attention_dynamic_masked
// CHECK: scf.for %[[INDVAR:.+]] = %c0 to %24 step %c512
// CHECK: affine.min
// CHECK: %[[MASK0:.+]] = vector.create_mask {{.*}} : vector<8xi1>
// CHECK: %[[MASK1:.+]] = vector.create_mask {{.*}} : vector<8x128xi1>
// CHECK: %[[MASK1_SLICE:.+]] = vector.extract_strided_slice %[[MASK1]] {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x128xi1> to vector<8x8xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK1_SLICE]]
// CHECK: vector.transfer_read {{.*}} %[[MASK0]]

// MEMORY-LABEL: func.func @attention_dynamic_masked
// MEMORY-NOT: memref.alloc
Loading