Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to bubble-up extract_slice operations. #18332

Merged
merged 10 commits into from
Aug 28, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar commented Aug 22, 2024

This adds pass to replace a tensor.extract_slice operation with a
slice of the producer. In general there might be more opportunities to
use this pass more aggressively (like when an operation has a single
use which is a slice), but for now this is being done only for
bit-extend operations.

Fixes #18254

Signed-off-by: MaheshRavishankar [email protected]

IanWood1 and others added 3 commits August 22, 2024 18:15
This adds pass to replace a `tensor.extract_slice` operation with a
slice of the producer. In general there might be more opportunities to
use this pass more aggressively (like when an operation has a single
use which is a slice), but for now this is being done only for
bit-extend operations.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
@IanWood1
Copy link
Contributor

IanWood1 commented Aug 23, 2024

I tried https://gist.github.com/monorimet/3a0a4310c1ed09265353ce747599d502 but it seems like there is a collapse_shape in the way:

  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<24x4608x128xf16>) outs(%3 : tensor<24x4608x128xf32>) {
  ^bb0(%in: f16, %out: f32):
    %21 = arith.extf %in : f16 to f32
    linalg.yield %21 : f32
  } -> tensor<24x4608x128xf32>
  %expanded = tensor.expand_shape %4 [[0, 1], [2], [3, 4, 5]] output_shape [1, 24, 4608, 64, 1, 2] : tensor<24x4608x128xf32> into tensor<1x24x4608x64x1x2xf32>
  %extracted_slice_7 = tensor.extract_slice %expanded[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
  %expanded_8 = tensor.expand_shape %extracted_slice_7 [[0, 1], [2], [3, 4]] output_shape [1, 24, 4608, 64, 1] : tensor<24x4608x64xf32> into tensor<1x24x4608x64x1xf32>

also, iree/compiler/Dialect/Flow/Transforms/test/bubble_up_extract_slice.mlir just needs to be updated to use --iree-flow-bubble-up-extract-slices

Edit:

It seems like the expand/extracts should be foldable

@MaheshRavishankar
Copy link
Contributor Author

I tried https://gist.github.com/monorimet/3a0a4310c1ed09265353ce747599d502 but it seems like there is a collapse_shape in the way:

  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<24x4608x128xf16>) outs(%3 : tensor<24x4608x128xf32>) {
  ^bb0(%in: f16, %out: f32):
    %21 = arith.extf %in : f16 to f32
    linalg.yield %21 : f32
  } -> tensor<24x4608x128xf32>
  %expanded = tensor.expand_shape %4 [[0, 1], [2], [3, 4, 5]] output_shape [1, 24, 4608, 64, 1, 2] : tensor<24x4608x128xf32> into tensor<1x24x4608x64x1x2xf32>
  %extracted_slice_7 = tensor.extract_slice %expanded[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
  %expanded_8 = tensor.expand_shape %extracted_slice_7 [[0, 1], [2], [3, 4]] output_shape [1, 24, 4608, 64, 1] : tensor<24x4608x64xf32> into tensor<1x24x4608x64x1xf32>

also, iree/compiler/Dialect/Flow/Transforms/test/bubble_up_extract_slice.mlir just needs to be updated to use --iree-flow-bubble-up-extract-slices

That makes sense. I should just run this after bubble up expand shapes.

@MaheshRavishankar
Copy link
Contributor Author

Probably needs more tests. I added a rank reduced slice test, we want without rank reduction

@MaheshRavishankar
Copy link
Contributor Author

Someone else should review this since I co-authored this commit. Probably needs more tests. I added a rank reduced slice test, we want without rank reduction slice test always

Signed-off-by: Ian Wood <[email protected]>
@IanWood1 IanWood1 force-pushed the extract_slice_prop branch from 9e1d446 to 5748e9f Compare August 27, 2024 16:17
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, I have some concerns about the ordering of some of the failure cases within the pattern though.

if (tilingResult->tiledOps.size() != 1 ||
!isa<linalg::GenericOp>(tilingResult->tiledOps[0])) {
return rewriter.notifyMatchFailure(
linalgOp, "expected extract_slice to generate a `linalg.generic`");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failure after generating IR in a pattern is problematic. At a minimum I would restrict to generics like I said above, but also would be worth adding checks for isProjectedPermutation(indexingMaps) and !hasIndexSemantics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I removed this check and added the checks you suggested before IR mutation. I also changed the failed checks so asserts since there is no graceful way to exit at that point

Signed-off-by: Ian Wood <[email protected]>
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this looks good to me!

@IanWood1
Copy link
Contributor

@ScottTodd test_einsum_inner_prod is timing out in onnx regression tests https://github.com/iree-org/iree/actions/runs/10585615045/job/29333076426?pr=18332 the message is:

FAILED iree-test-suites/onnx_ops/onnx/node/generated/test_einsum_inner_prod/run_module_io_flags.txt::model.mlir::gpu_rocm_rdna3 - Failed: Timeout >30.0s

I don't think it is related to this pr because I was getting a segfault during torch conversion, Should the rdna3 config be updated similar to what was done in #18357?

@MaheshRavishankar
Copy link
Contributor Author

@ScottTodd test_einsum_inner_prod is timing out in onnx regression tests https://github.com/iree-org/iree/actions/runs/10585615045/job/29333076426?pr=18332 the message is:

FAILED iree-test-suites/onnx_ops/onnx/node/generated/test_einsum_inner_prod/run_module_io_flags.txt::model.mlir::gpu_rocm_rdna3 - Failed: Timeout >30.0s

I don't think it is related to this pr because I was getting a segfault during torch conversion, Should the rdna3 config be updated similar to what was done in #18357?

THis is existing error. We've been hitting this all week

@MaheshRavishankar
Copy link
Contributor Author

Oh, you can merge even with the failure

Copy link
Contributor Author

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @IanWood1 !

@IanWood1 IanWood1 merged commit d6762d4 into iree-org:main Aug 28, 2024
36 of 37 checks passed
@IanWood1 IanWood1 deleted the extract_slice_prop branch August 28, 2024 01:29
@IanWood1
Copy link
Contributor

Oh, you can merge even with the failure

Great, I wasn't sure!

josemonsalve2 pushed a commit to josemonsalve2/iree that referenced this pull request Sep 14, 2024
This adds pass to replace a `tensor.extract_slice` operation with a
slice of the producer. In general there might be more opportunities to
use this pass more aggressively (like when an operation has a single
use which is a slice), but for now this is being done only for
bit-extend operations.

Co-authored-by: Ian Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ROCM][gfx942] shared memory limit exceeded on elemwise broadcast (bf16) (flux-dev)
3 participants