Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Add vector transfer + slice foldings in GenericVectorization #17613

Merged

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 7, 2024

Vectorizing a linalg.copy op can result in a sequence of

%extract = tensor.extract_slice %source
%read = vector.transfer_read %extract
%write = vector.transfer_read %dest
%insert = tensor.insert_slice %write into %dest

This sequence is folded by the transfer_write folder into

%extract = tensor.extract_slice %source
%insert = tensor.insert_slice %extract into %dest

In order to conserve the vector transfers, this PR adds folding patterns for vector transfer ops acting on insert/extract slice ops. This will fold the insert_slice into the transfer_write and the extract_slice into the transfer_read, and the vector transfers will not be folded.

This is turned off for the vector distribution pipeline because it causes distribution to fail in some cases.

Also removes Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir, since it completes a TODO to remove the test after eliminating some undesired extra buffers.

@Max191 Max191 requested a review from hanhanW as a code owner June 7, 2024 20:11
@Max191 Max191 added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 7, 2024
@Max191
Copy link
Contributor Author

Max191 commented Jun 7, 2024

needs tests, but I wanted to run benchmarks

EDIT: I see this is breaking several tests. I'll need to look into this more

@Max191 Max191 marked this pull request as draft June 7, 2024 20:31
@Max191 Max191 force-pushed the generic-vectorization-transfer-slice-folding branch from 4179cc3 to 810abaa Compare June 7, 2024 20:36
Copy link

github-actions bot commented Jun 7, 2024

Abbreviated Benchmark Summary

@ commit 780cfc42654569975d94f85536a0f58c0de7742a (no previous benchmark results to compare)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 750.134 (1.0X) N/A 226.458 (3.3X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.950 (1.0X) N/A 8.513 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.611 (1.0X) N/A 34.462 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.781 (1.0X) N/A 4.996 (1.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.196 (1.0X) N/A 8.529 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.118 (1.0X) N/A 8.969 (1.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.971 (1.0X) N/A 13.694 (0.9X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.412 (1.0X) N/A 61.637 (0.5X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.315 (1.0X) N/A 61.791 (0.6X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 68.462 (1.0X) N/A 64.878 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.442 (1.0X) N/A 4.616 (1.0X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.719 (1.0X) N/A 4.910 (0.8X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.827 (1.0X) N/A 5.407 (1.1X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.840 (1.0X) N/A 2.857 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.374 (1.0X) N/A 9.909 (0.8X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.774 (1.0X) N/A 0.613 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.060 (1.0X) N/A 5.234 (0.8X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.599 (1.0X) N/A 7.580 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.726 (1.0X) N/A 1.804 (3.7X)
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 222.711 (1.0X) N/A 108.573 (2.1X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.269 (1.0X) N/A 30.053 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 263.358 (1.0X) N/A 231.222 (1.1X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.704 (1.0X) N/A 13.185 (2.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.590 (1.0X) N/A 40.528 (1.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 89.278 (1.0X) N/A 42.134 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 71.440 (1.0X) N/A 57.371 (1.2X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 173.751 (1.0X) N/A 186.396 (0.9X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 182.898 (1.0X) N/A 190.931 (1.0X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 516.111 (1.0X) N/A 241.190 (2.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 23.917 (1.0X) N/A 18.286 (1.3X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.779 (1.0X) N/A 11.647 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.543 (1.0X) N/A 11.914 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.763 (1.0X) N/A 2.724 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.375 (1.0X) N/A 31.989 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.707 (1.0X) N/A 0.548 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.128 (1.0X) N/A 19.107 (0.9X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.054 (1.0X) N/A 0.054 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.043 (1.0X) N/A 0.021 (2.0X)

Raw Latencies

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
BertLargeTF(stablehlo) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 226.458 226.129 5.275
BertLargeTF(stablehlo) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 750.134 740.843 40.480
DeepLabV3\_fp32(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.513 8.523 0.042

[Top 3 out of 92 results showed]

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

@Max191
Copy link
Contributor Author

Max191 commented Jun 10, 2024

This needs to wait on a bug fix in llvm/llvm-project#95020.

@Max191 Max191 force-pushed the generic-vectorization-transfer-slice-folding branch from 2ad77bf to 03ab8a2 Compare July 8, 2024 20:37
@Max191 Max191 force-pushed the generic-vectorization-transfer-slice-folding branch from df56e9e to 5681df6 Compare July 17, 2024 18:56
@Max191 Max191 marked this pull request as ready for review July 17, 2024 18:59
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We have to be careful of folding tensor.extract_slice and tensor.insert_slice ops away because some transformations heavily rely on these ops. It looks good on CPU side because there are no further tiling after vectorization. The rest is vector lowering.

I'm not sure how it works on GPU side wrt scf.forall lowering. I think it is okay because you provide an option to turn it off.

(cc @banach-space because of the changes of vectorize_with_masking_and_hoist.mlir.)

@Max191 Max191 removed benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks labels Jul 19, 2024
@Max191 Max191 force-pushed the generic-vectorization-transfer-slice-folding branch from 93f2925 to 12e5c75 Compare July 19, 2024 13:48
Signed-off-by: Max Dawkins <[email protected]>
@Max191
Copy link
Contributor Author

Max191 commented Jul 19, 2024

cc @benmxwl-arm since I just had to update a generic_vectorizaton.mlir test that you just added. It seems like a harmless change, but pinging just in case

@Max191
Copy link
Contributor Author

Max191 commented Jul 19, 2024

Sorry, wrong account. @MacDue

@MacDue
Copy link
Member

MacDue commented Jul 19, 2024

Looks harmless (it's still infers the right vector size :))

@Max191 Max191 merged commit 8b83425 into iree-org:main Jul 19, 2024
60 checks passed
@ScottTodd
Copy link
Member

Heads up - I see test failures in these modified tests after merge:

The following tests FAILED:
        141 - iree/compiler/Codegen/Common/test/generic_vectorization.mlir.test (Failed)
        217 - iree/compiler/Codegen/LLVMCPU/test/pipeline_pad_tests.mlir.test (Failed)
        242 - iree/compiler/Codegen/LLVMCPU/test/vectorize_with_masking_and_hoist.mlir.test (Failed)

Postsubmit CI jobs are running behind a 12h queue right now, so I kicked off a presubmit run at https://github.com/iree-org/iree/actions/runs/10014109365/job/27683126524?pr=17971 on #17971 to see if the CI machines can repro.

Local logs: https://gist.github.com/ScottTodd/cfceb22a41ca80257918d1d468b05ddb

@hanhanW
Copy link
Contributor

hanhanW commented Jul 19, 2024

Thanks for the heads-up! I'm taking a look and will provide a fix.

@ScottTodd
Copy link
Member

Linux CI seemed to pass... I can try my local Windows at an earlier commit 🤔

Do the logs give any clues?

@hanhanW
Copy link
Contributor

hanhanW commented Jul 19, 2024

Linux CI seemed to pass... I can try my local Windows at an earlier commit 🤔

That's weird. I think it should generate deterministic IRs... compiling IREE on linux now

Do the logs give any clues?

I don't see issues, I need to repro it on my local machine.

@ScottTodd
Copy link
Member

On Windows/MSVC, those tests passed at 4a13331

so something in this range caused them to fail: 4a13331...5b112cb

Ben suggests

Usually is an issue with relying on undefined collection behavior

@ScottTodd
Copy link
Member

Weeeeeird. I can't repro now, after trying to bisect through that history. Friday afternoon doing Friday afternoon things...

Sorry for the noise :P

@MacDue
Copy link
Member

MacDue commented Jul 22, 2024

I've yet to track down exactly what's wrong, but this change breaks one of our internal tests, resulting in an output that is all NaNs.
It also seems to prevent the hosting of reads/writes (which for SME tiles is very costly). I'm trying to find a small reproducer now...

@MacDue
Copy link
Member

MacDue commented Jul 22, 2024

Update:

  1. The bad hosting of reads/writes is real, so we may want to consider disabling (for the CPU backend?) this or at least having it off by default. However, I'm not quite sure I understand the motivation for this change (and where it needs to be enabled). Why do we want to "conserve the vector transfers"?

  2. The incorrect results don't seem to come directly from this change, but instead come from some stack alignment addressing mode/calculation issues exposed by this (which was quite the pain to track down 😅) See [LSR] Fix matching vscale immediates llvm/llvm-project#100080.

@hanhanW
Copy link
Contributor

hanhanW commented Jul 22, 2024

Why do we want to "conserve the vector transfers"?

Say that the original input is:

%extract = tensor.extract_slice %source
%read = vector.transfer_read %extract
%write = vector.transfer_read %dest
%insert = tensor.insert_slice %write into %dest

Without the change, it becomes a (tensor.extract_slice, tensor.insert_slice) pair. There are no compute (or load/store) ops at tensor level. Then we bufferize the dispatch which gets

%src = memref.subview (...)
%dest = memref.subview (...)
memref.copy (...)

There are no vector codes, which leads to scalar load/store. While we don't rely on LLVM vectorizer, they are scalar loads/stores without additional vectorization at buffer level.

The trend is moving the vectorization to tensor's world, so we want to get:

vector.transfer_read %src
vector.transfer_write %read, %dest

This preserves the vector code after bufferization:

%src = memref.subview (...)
%read = vector.transfer_read %src ...
%dest = memref.subview (...)
vector.transfer_write %read, %dest ...

Does it make sense?

I don't remember the hoisting issues that we hit on ARM path. Could you share the expected IR before and after the hoisting?

@MacDue
Copy link
Member

MacDue commented Jul 23, 2024

So before this change we'd get:

%15 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %14) -> (tensor) {
	%extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg4, 0] [1, %10] [1, 1] : tensor<352x?xf32> to tensor<1x?xf32>
	%extracted_slice_8 = tensor.extract_slice %extracted_slice_2[%arg4, 0] [1, %11] [1, 1] : tensor<352x?xf32> to tensor<1x?xf32>
	%extracted_slice_9 = tensor.extract_slice %arg5[0, 0] [%10, %11] [1, 1] : tensor to tensor
	%17 = vector.create_mask %c1, %10 : vector<1x[8]xi1>
	%18 = vector.transfer_read %extracted_slice_7[%c0, %c0], %cst_1, %17 {in_bounds = [true, true]} : tensor<1x?xf32>, vector<1x[8]xf32>
	%19 = vector.create_mask %c1, %11 : vector<1x[8]xi1>
	%20 = vector.transfer_read %extracted_slice_8[%c0, %c0], %cst_1, %19 {in_bounds = [true, true]} : tensor<1x?xf32>, vector<1x[8]xf32>
	%21 = vector.create_mask %10, %11 : vector<[8]x[8]xi1>
	%22 = vector.transfer_read %extracted_slice_9[%c0, %c0], %cst_1, %21 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
	%23 = vector.create_mask %10, %11, %c1 : vector<[8]x[8]x1xi1>
	%24 = vector.mask %23 { vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %18, %20, %22 : vector<1x[8]xf32>, vector<1x[8]xf32> into vector<[8]x[8]xf32> } : vector<[8]x[8]x1xi1> -> vector<[8]x[8]xf32>
	%25 = vector.transfer_write %24, %extracted_slice_9[%c0, %c0], %21 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor
	%inserted_slice_10 = tensor.insert_slice %25 into %arg5[0, 0] [%10, %11] [1, 1] : tensor into tensor
	scf.yield %inserted_slice_10 : tensor
}

And iree-codegen-optimize-tensor-insert-extract-slices would transform that to:

%21 = vector.transfer_read %extracted_slice_5[%c0, %c0], %cst_1, %19 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
%22 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %21) -> (vector<[8]x[8]xf32>) {
	%extracted_slice_10 = tensor.extract_slice %extracted_slice[%arg4, 0] [1, %14] [1, 1] : tensor<352x?xf32> to tensor<1x?xf32>
	%25 = vector.transfer_read %4[%arg4, %arg0], %cst_1 {in_bounds = [true, true]} : tensor<352x128xf32>, vector<1x[8]xf32>
	%26 = vector.transfer_read %extracted_slice_10[%c0, %c0], %cst_1, %18 {in_bounds = [true, true]} : tensor<1x?xf32>, vector<1x[8]xf32>
	%27 = vector.mask %20 { vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %25, %26, %arg5 : vector<1x[8]xf32>, vector<1x[8]xf32> into vector<[8]x[8]xf32> } : vector<[8]x[8]x1xi1> -> vector<[8]x[8]xf32>
	scf.yield %27 : vector<[8]x[8]xf32>
}
%23 = vector.transfer_write %22, %extracted_slice_5[%c0, %c0], %19 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor 

Which is super nice, and eventually means the initial transfer_read will become just a zero constant.

It would hoist the pair:

%extracted_slice_9 = tensor.extract_slice %arg5[0, 0] [%10, %11] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%inserted_slice_10 = tensor.insert_slice %25 into %arg5[0, 0] [%10, %11] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>

Then the pair:

%22 = vector.transfer_read %extracted_slice_9[%c0, %c0], %cst_1, %21 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[8]x[8]xf32>
%25 = vector.transfer_write %24, %extracted_slice_9[%c0, %c0], %21 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>

Note: These are both easy to hoist as they're matching pairs to the same value (%arg5 and %extracted_slice_9 respectively).


Now we get:

%16 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %15) -> (tensor) {
	%extracted_slice_5 = tensor.extract_slice %arg5[0, 0] [%c8_vscale, %12] [1, 1] : tensor to tensor
	%18 = vector.transfer_read %4[%arg4, %arg0], %cst_1 {in_bounds = [true, true]} : tensor<352x128xf32>, vector<1x[8]xf32>
	%19 = vector.create_mask %c1, %12 : vector<1x[8]xi1>
	%20 = vector.transfer_read %5[%arg4, %arg2], %cst_1, %19 {in_bounds = [true, true]} : tensor<352x1xf32>, vector<1x[8]xf32>
	%21 = vector.create_mask %c8_vscale, %12 : vector<[8]x[8]xi1>
	%22 = vector.transfer_read %arg5[%c0, %c0], %cst_1, %21 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
	%23 = vector.create_mask %c8_vscale, %12, %c1 : vector<[8]x[8]x1xi1>
	%24 = vector.mask %23 { vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %18, %20, %22 : vector<1x[8]xf32>, vector<1x[8]xf32> into vector<[8]x[8]xf32> } : vector<[8]x[8]x1xi1> -> vector<[8]x[8]xf32>
	%25 = vector.transfer_write %24, %extracted_slice_5[%c0, %c0], %21 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor
	%inserted_slice_6 = tensor.insert_slice %25 into %arg5[0, 0] [%c8_vscale, %12] [1, 1] : tensor into tensor
	scf.yield %inserted_slice_6 : tensor
}

Here's no pairs hoistSubsetAtIterArg() knows how to hoist here. The sources no longer match up and there's uses in the loop blocking even hoisting the tensor.extract_slice/insert_slice pair).

So iree-codegen-optimize-tensor-insert-extract-slices fails to hoist anything (that matters at runtime):

%21 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %17) -> (tensor) {
	%extracted_slice_7 = tensor.extract_slice %arg5[0, 0] [%c8_vscale, %14] [1, 1] : tensor to tensor
	%23 = vector.transfer_read %4[%arg4, %arg0], %cst_1 {in_bounds = [true, true]} : tensor<352x128xf32>, vector<1x[8]xf32>
	%24 = vector.transfer_read %5[%arg4, %arg2], %cst_1, %18 {in_bounds = [true, true]} : tensor<352x1xf32>, vector<1x[8]xf32>
	%25 = vector.transfer_read %arg5[%c0, %c0], %cst_1, %19 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
	%26 = vector.mask %20 { vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %23, %24, %25 : vector<1x[8]xf32>, vector<1x[8]xf32> into vector<[8]x[8]xf32> } : vector<[8]x[8]x1xi1> -> vector<[8]x[8]xf32>
	%27 = vector.transfer_write %26, %extracted_slice_7[%c0, %c0], %19 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor
	%inserted_slice_8 = tensor.insert_slice %27 into %arg5[0, 0] [%c8_vscale, %14] [1, 1] : tensor into tensor
	scf.yield %inserted_slice_8 : tensor
}

This leaves very costly full-ZA tile reads/writes within in inner most loop of a matmul, which is really bad.

@banach-space
Copy link
Collaborator

Thanks for all the impressive digging, @MacDue ! Sadly this is hurting SME pretty bad :( (Sorry I didn't check earlier when @hanhanW pinged me).

@Max191, is @tiled_linalg_copy a good representation of what you wanted to achieve? I see 3 options here:

  1. Find a way to preserve the behaviour of @tiled_linalg_copy while keeping nice hoisting that we used to have for SME by moving populateFoldTensorSubsetIntoVectorTransferPatterns further down the compilation path,
  2. Expose earlySubsetTransferFolding so that there's a user-facing flag that we could use in our compilation flow (I know that these are not popular),
  3. Find a way to preserve the behaviour of @tiled_linalg_copy while keeping nice hoisting that we used to have for SME by means other than populateFoldTensorSubsetIntoVectorTransferPatterns. It's a bit counter-intuitive to me that these patterns are needed here TBH - what patterns eliminate xfer_read/xfer_write Ops that you want to preserve? (i.e., what's the root cause of the issue being addressed here?)

Thanks!

@qedawkins
Copy link
Contributor

The root of the issue is that linalg.copy essentially vectorizes to a no-op, transfer_write(transfer_read) of the same indices. My opinion is that preventing such foldings shouldn't be a requirement for any lowering flow, i.e. if such a pair of transfers really do cancel, we should try to cancel them. Bufferization with produce a memref.copy later if the copy is material (i.e. between different memory spaces) and we can vectorize such copies post bufferization.

With that said, it looks like there is just a missing pattern that is blocking hoisting in the above example. I think we should do both of the following:

  1. Revert this PR, it ideally shouldn't be needed (if someone gives a compelling reason why such a folding is a problem, we can think about relanding).
  2. Fix or add the insert_slice(transfer_write) composition pattern because that's all that looks to be missing in the above example
%16 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %15) -> (tensor) {
	%extracted_slice_5 = tensor.extract_slice %arg5[0, 0] [%c8_vscale, %12] [1, 1] : tensor to tensor
	...
	%22 = vector.transfer_read %arg5[%c0, %c0], %cst_1, %21 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
        ...
	%25 = vector.transfer_write %24, %extracted_slice_5[%c0, %c0], %21 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor
	%inserted_slice_6 = tensor.insert_slice %25 into %arg5[0, 0] [%c8_vscale, %12] [1, 1] : tensor into tensor
	scf.yield %inserted_slice_6 : tensor
}

becomes the easily hoistable (and potentially cleaner IR)

%16 = scf.for %arg4 = %c0 to %c352 step %c1 iter_args(%arg5 = %15) -> (tensor) {
	...
	%22 = vector.transfer_read %arg5[%c0, %c0], %cst_1, %21 {in_bounds = [true, true]} : tensor, vector<[8]x[8]xf32>
        ...
	%25 = vector.transfer_write %24, %arg5[%c0, %c0], %21 {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor
	scf.yield %25 : tensor
}

@MacDue
Copy link
Member

MacDue commented Jul 24, 2024

Had a quick look a 2.:

The insert_slice(transfer_write) does not apply because the transfer_write is masked. So just looking at those two ops it may not be a legal replacement. I think you'd need to match insert_slice(transfer_write(extract_slice)) and check both the insert + extract are from the same tensor.

@qedawkins
Copy link
Contributor

I see, that makes sense. In that case I think the insert_slice(transfer_write(extract_slice)) pattern would be worth writing then because the IR examples you showed above look reasonable to me. AFAICT this was a situation where we were just changing the order in which two different sets of folders were applied (transfer_write(transfer_read) vs transfer(slice)) and it would be good for all backends to handle both orders gracefully. That said if you want to send a revert, or disable it for CPU backends first, I would stamp.

However, I'm not quite sure I understand the motivation for this change (and where it needs to be enabled). Why do we want to "conserve the vector transfers"?

To give a little more context, for GPU at some point (post bufferization) we really want to end up with these transfers, in particular because any time there is a linalg.copy on GPU it is being used to approximate data movement between memory spaces and/or layouts. Thus the transfers are there to vectorize the copy. With that said, linalg.copy is itself a no-op in the absence of the bufferization dialect, so this PR might have been a case of trying to preserve linalg.copy "nofold" semantics past the lifetime of the op, hence why I think reverting or toggling per backend makes sense.

banach-space added a commit to banach-space/iree that referenced this pull request Jul 24, 2024
…orization (iree-org#17613)"

This reverts commit 8b83425.

This change is hurting SVE+SME performance pretty badly. See
iree-org#17613 for context.
@banach-space
Copy link
Collaborator

That said if you want to send a revert, or disable it for CPU backends first, I would stamp.

#17997 :) I know that Ben is having a quick look at the pattern. I'm mostly keen to get our CI to a happier state.

From what you are saying, this PR was mostly to keep GPU and CPU paths consistent rather than fixing any specific CPU issue?

@qedawkins
Copy link
Contributor

IIUC this PR was to change the behavior on GPU and opted to change the behavior on CPU also. The red flag for me is that this is disabled for LLVMGPUVectorDistribute, meaning we aren't even being consistent about it on the GPU side.

banach-space added a commit to banach-space/iree that referenced this pull request Jul 24, 2024
…orization (iree-org#17613)"

This reverts commit 8b83425.

This change is hurting SVE+SME performance pretty badly. See
iree-org#17613 for context.

Signed-off-by: Andrzej Warzynski <[email protected]>
banach-space added a commit that referenced this pull request Jul 24, 2024
…orization (#17613)" (#17997)

This reverts commit 8b83425.

This change is hurting SVE+SME performance pretty badly. See
#17613 for context.

Signed-off-by: Andrzej Warzynski <[email protected]>
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
iree-org#17613)

Vectorizing a `linalg.copy` op can result in a sequence of
```
%extract = tensor.extract_slice %source
%read = vector.transfer_read %extract
%write = vector.transfer_read %dest
%insert = tensor.insert_slice %write into %dest
```
This sequence is folded by the transfer_write folder into
```
%extract = tensor.extract_slice %source
%insert = tensor.insert_slice %extract into %dest
```
In order to conserve the vector transfers, this PR adds folding patterns
for vector transfer ops acting on insert/extract slice ops. This will
fold the insert_slice into the transfer_write and the extract_slice into
the transfer_read, and the vector transfers will not be folded.

This is turned off for the vector distribution pipeline because it
causes distribution to fail in some cases.

Also removes `Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir`, since
it completes a TODO to remove the test after eliminating some undesired
extra buffers.

---------

Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…orization (iree-org#17613)" (iree-org#17997)

This reverts commit 8b83425.

This change is hurting SVE+SME performance pretty badly. See
iree-org#17613 for context.

Signed-off-by: Andrzej Warzynski <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
@MacDue
Copy link
Member

MacDue commented Aug 6, 2024

Btw, I forgot to mention but when I took a look at the folds I spotted at least one upstream bug, which I reported here: llvm/llvm-project#101708

Max191 added a commit to Max191/iree that referenced this pull request Aug 13, 2024
Max191 added a commit to Max191/iree that referenced this pull request Aug 23, 2024
Max191 added a commit to Max191/iree that referenced this pull request Aug 26, 2024
Max191 added a commit to Max191/iree that referenced this pull request Aug 27, 2024
@Max191 Max191 deleted the generic-vectorization-transfer-slice-folding branch October 25, 2024 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:cuda Run default CUDA benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants