Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Enable softmax-like fusion under aggressive fusion. #17747

Merged
merged 1 commit into from
Jun 28, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

Under aggressive fusion, drop the restriction of consumer iteration space being same dimensionality as the producer iteration space. Typically this can lead to large vectors if not handled properly. So this is guarded under
--iree-flow-enable-aggressive-fusion flag.

Fixes nod-ai/SHARK-ModelDev#749

@MaheshRavishankar MaheshRavishankar added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 27, 2024
@MaheshRavishankar
Copy link
Contributor Author

Oops. Sorry push the branch to upstream repo instead of my fork and broke the naming convention. Will delete branch after it lands.

Under aggressive fusion, drop the restriction of consumer iteration
space being same dimensionality as the producer iteration
space. Typically this can lead to large vectors if not handled
properly. So this is guarded under
`--iree-flow-enable-aggressive-fusion` flag.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_enable_softmax_fusion branch from 10c3216 to a62dd5a Compare June 28, 2024 05:02
Copy link

Abbreviated Benchmark Summary

@ commit b8a2701f5d91366e7318dcfdb76cdba464bab8d3 (vs. base 4294a5b0ebaec6dcca483bf16f5918108b09ea0a)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 786.945 (1.0X) N/A 222.780 (3.5X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.951 (1.0X) N/A 8.560 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.963 (1.0X) N/A 34.866 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.792 (1.0X) N/A 5.021 (1.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.214 (1.0X) N/A 8.432 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.011 (1.0X) N/A 8.844 (1.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.783 (1.0X) N/A 13.734 (0.9X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.322 (1.0X) N/A 61.578 (0.5X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.393 (1.0X) N/A 62.150 (0.5X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 68.291 (1.0X) N/A 64.402 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.460 (1.0X) N/A 4.534 (1.0X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.674 (1.0X) N/A 4.867 (0.8X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.882 (1.0X) N/A 5.430 (1.1X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.868 (1.0X) N/A 2.781 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.470 (1.0X) N/A 9.904 (0.9X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.770 (1.0X) N/A 0.586 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.137 (1.0X) N/A 5.252 (0.8X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.626 (1.0X) N/A 7.587 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.656 (1.0X) N/A 1.799 (3.7X)
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 214.525 (1.0X) N/A 108.507 (2.0X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.251 (1.0X) N/A 30.344 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 275.118 (1.0X) N/A 230.055 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 27.059 (1.0X) N/A 13.342 (2.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.599 (1.0X) N/A 39.758 (1.8X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 88.059 (1.0X) N/A 41.453 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 79.640 (1.0X) N/A 56.497 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.606 (1.0X) N/A 186.313 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.087 (1.0X) N/A 190.674 (0.9X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 516.559 (1.0X) N/A 240.696 (2.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.129 (1.0X) N/A 17.898 (1.5X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.118 (1.0X) N/A 12.177 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.686 (1.0X) N/A 12.605 (1.7X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.782 (1.0X) N/A 2.792 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.862 (1.0X) N/A 31.684 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.698 (1.0X) N/A 0.521 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.596 (1.0X) N/A 19.587 (0.9X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.054 (1.0X) N/A 0.054 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.042 (1.0X) N/A 0.021 (2.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 47.961 (1.0X) N/A 42.587 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 50.179 (1.0X) N/A 43.990 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 29.953 (1.0X) N/A 27.232 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.595 (1.0X) N/A 20.784 (4.5X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 93.644 (1.0X) N/A 21.661 (4.3X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 52.749 (1.0X) N/A 21.878 (2.4X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 139.061 (1.0X) N/A 27.060 (5.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 129.569 (1.0X) N/A 28.537 (4.5X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 72.141 (1.0X) N/A 26.653 (2.7X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 699.527 (1.0X) N/A 347.031 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 695.695 (1.0X) N/A 354.636 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 390.395 (1.0X) N/A 218.419 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1060.607 (1.0X) N/A 278.886 (3.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1061.074 (1.0X) N/A 273.321 (3.9X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 551.029 (1.0X) N/A 161.357 (3.4X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2062.117 (1.0X) N/A 299.939 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2061.354 (1.0X) N/A 299.823 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1094.145 (1.0X) N/A 178.230 (6.1X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.081 (1.0X) N/A 0.016 (5.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.074 (1.0X) N/A 0.017 (4.4X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.095 (1.0X) N/A 1.314 (9.2X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 16.530 (1.0X) N/A 1.077 (15.3X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 108.340 (vs. 96.330, 12.47%↑) 108.971 1.817

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_int8(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][default-flags] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 104.999 (vs. 112.742, 6.87%↓) 106.298 5.082
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 28.537 (vs. 30.566, 6.64%↓) 28.628 0.687
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 129.569 (vs. 138.647, 6.55%↓) 129.426 0.447

[Top 3 out of 7 results showed]

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

@MaheshRavishankar MaheshRavishankar enabled auto-merge (squash) June 28, 2024 07:02
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, just one question

@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to miss the tests for default path? Do we add a new mlir file to test aggressive-fusion=true? I don't suggest to add an additional FileCheck in this file because some of tests seem redundant to the aggressive fusion.

@MaheshRavishankar MaheshRavishankar merged commit 7090f64 into main Jun 28, 2024
63 checks passed
@MaheshRavishankar MaheshRavishankar deleted the sdxl_quantized_enable_softmax_fusion branch June 28, 2024 17:05
@hanhanW
Copy link
Contributor

hanhanW commented Jun 28, 2024

oh, my bad.. I did not notice that there is auto-merge..

LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…17747)

Under aggressive fusion, drop the restriction of consumer iteration
space being same dimensionality as the producer iteration space.
Typically this can lead to large vectors if not handled properly. So
this is guarded under
`--iree-flow-enable-aggressive-fusion` flag.

Fixes nod-ai/SHARK-ModelDev#749

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:cuda Run default CUDA benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[fusion] Missing fusion of truncation operation with reduction-like operations
2 participants