Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[custom op] Generalize shape library logic to work with dtypes #1594

Merged
merged 10 commits into from
Dec 13, 2022

Conversation

ramiro050
Copy link
Collaborator

This commit generalizes the shape library logic, so that dtype rules
for ops can also be expressed using the same mechanism. In other
words, each op can now have a shape function and a dtype function
specified in Python that is imported during lowering to calculate the
shapes and dtypes throught a program. For more information about how
to specify a dtype function, see the updated
docs/adding_a_shape_and_dtype_function.md.

For those not familiar with how the shape library works, the file
docs/calculations_lib.md provides an overview.

To make the reviewing a bit easier, I suggest the following review
order:

  1. Get familiar with the overall architecture by reading
    docs/calculations_lib.md

  2. New op declarations

    • include/torch-mlir/Dialect/Torch/IR/TorchOps.td
    • lib/Dialect/Torch/IR/TorchOps.cpp
  3. New passes

    • include/torch-mlir/Dialect/Torch/Transforms/Passes.td
    • include/torch-mlir/Dialect/Torch/Transforms/Passes.h
    • lib/Dialect/Torch/Transforms/Passes.cpp
    • lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp
    • lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
    • lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.cpp
    • lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.h
    • lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.cpp
    • lib/Dialect/Torch/Transforms/SimplifyCalculationsUtils.h

    The *Utils.* files include logic that is shared by dtype and
    shape passes.

  4. Tests

    • test/Dialect/Torch/ops.mlir
    • test/Dialect/Torch/reify-dtype-calculations.mlir
    • test/Dialect/Torch/simplify-dtype-calculations.mlir
  5. Introduce torch_mlir_promote_dtypes

    • python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp
    • python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py
  6. Simple refactoring/generalizing

    • include/torch-mlir/Dialect/Torch/Utils/Utils.h
    • lib/Dialect/Torch/Utils/Utils.cpp
    • lib/Dialect/Torch/Transforms/DropCalculations.cpp
    • lib/Dialect/Torch/Transforms/RefineTypes.cpp
    • lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp
    • lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
    • python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py
    • python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py
    • python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/calculations_lib_gen.py
    • lib/Dialect/Torch/Transforms/CalculationsLibrary.cpp
  7. The rest of the files include minor changes

    • Replace shape with calculations
    • Replace m_TorchConstantIntList with
      m_TorchListOfConstantInts (needed to avoid ambiguity with new
      pattern m_TorchListOfOptionalConstantInts)

@ramiro050 ramiro050 requested a review from silvasean November 16, 2022 03:06
Copy link
Contributor

@silvasean silvasean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of comments. Overall looks good. Will probably have a few more nits after these are addressed but not much more than that.

docs/adding_a_shape_and_dtype_function.md Outdated Show resolved Hide resolved
.github/workflows/RollPyTorch.yml Outdated Show resolved Hide resolved
docs/adding_a_shape_and_dtype_function.md Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/ReifyCalculationsUtils.cpp Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp Outdated Show resolved Hide resolved
test/Dialect/Torch/simplify-dtype-calculations.mlir Outdated Show resolved Hide resolved
lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp Outdated Show resolved Hide resolved
@ramiro050
Copy link
Collaborator Author

@silvasean, I've addressed the comments in separate commits. Also, I had to make this change to make all tests pass: a24b7c6. The commit message (pasted below) has the explanation:

This commit adds back the ops that use the new dtype refinement
pass. This is needed to avoid the catch-22 that results from ops in
the new dtype refinement pass needing dtype information from ops in
the RefineTypes pass, and ops in the RefineTypes pass needing dtype
information from ops in the new dtype refinement pass.

The reason this catch-22 problem is not handled by the iterative
application of passes when lowering to the backend contract is because
the DecomposeComplexOps pass is not currently designed to run more
than once, since every op with a decomposition gets marked illegal
after the pass is done. Marking ops as legal results in no
decomposition patterns being applied because the graph is already in a
legal state.

Adding back the ops to RefineTypes seems like the simplest solution to
this problem while the rest of the ops get relocated to use the new
dtype refinement pass.

@ramiro050 ramiro050 requested a review from silvasean November 20, 2022 22:15
@silvasean
Copy link
Contributor

@silvasean, I've addressed the comments in separate commits. Also, I had to make this change to make all tests pass: a24b7c6. The commit message (pasted below) has the explanation:

This commit adds back the ops that use the new dtype refinement
pass. This is needed to avoid the catch-22 that results from ops in
the new dtype refinement pass needing dtype information from ops in
the RefineTypes pass, and ops in the RefineTypes pass needing dtype
information from ops in the new dtype refinement pass.
The reason this catch-22 problem is not handled by the iterative
application of passes when lowering to the backend contract is because
the DecomposeComplexOps pass is not currently designed to run more
than once, since every op with a decomposition gets marked illegal
after the pass is done. Marking ops as legal results in no
decomposition patterns being applied because the graph is already in a
legal state.
Adding back the ops to RefineTypes seems like the simplest solution to
this problem while the rest of the ops get relocated to use the new
dtype refinement pass.

I think that before we land this we should fix this so that the iterative lowering behaves as intended here (perhaps as easy as switching DecomposeComplexOps to use the greedy rewriter? and then perhaps move the final legality check into satisfiesBackendContract somehow?). That seems like an independently useful improvement as well as the right thing to do here.

I would generally bias pretty heavily against keeping the old path in RefineTypes alive unless absolutely necessary -- when we migrate an op, we should delete the old support. That guarantees monotonic progress towards the new system by making the new system load-bearing. If the old system still works for these ops, it is easy to add support for ops in the new system, but the e2e test is still ends up using the old system for some reason and so weren't even testing the new code -- thus we accumulate buggy, untested code in the new system.

This commit generalizes the shape library logic, so that dtype rules
for ops can also be expressed using the same mechanism. In other
words, each op can now have a shape function and a dtype function
specified in Python that is imported during lowering to calculate the
shapes and dtypes throught a program. For more information about how
to specify a dtype function, see the updated
`docs/adding_a_shape_and_dtype_function.md`.

For those not familiar with how the shape library works, the file
`docs/calculations_lib.md` provides an overview.
@ramiro050
Copy link
Collaborator Author

I've replaced two of the ops I was using (AtenAddScalarOp and AtenAddTensorOp) for e2e testing this patch to ops that are less commonly encountered in large workloads: 634dd9a 11ec64d. The new ops maintain the same level of test coverage that the previous ops were doing.

While running decomposeComplexOps multiple times did fix the catch-22 I was encountering before, I would've still required increasing the number of iterations run in lowerToBackendContractPass to get tests like Resnet18 and mobilenet to pass, since they use a lot of adds, noticeably affecting the time it took torch-mlir to lower to a particular backend.

The best approach for moving those common ops into the abstract_lib_gen.py file is to do a good chunk of them in a single go, which should be straight forward.

@ramiro050 ramiro050 requested a review from silvasean December 13, 2022 00:05
Copy link
Contributor

@silvasean silvasean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :) Let's do this!

@ramiro050 ramiro050 merged commit a710237 into llvm:main Dec 13, 2022
@ramiro050 ramiro050 deleted the custom-op-dtypes branch December 13, 2022 16:25
PriyaBSavithiri pushed a commit to PriyaBSavithiri/mcw-torch-mlir that referenced this pull request Dec 15, 2022
…1594)

* [custom op] Generalize shape library logic to work with dtypes

This commit generalizes the shape library logic, so that dtype rules
for ops can also be expressed using the same mechanism. In other
words, each op can now have a shape function and a dtype function
specified in Python that is imported during lowering to calculate the
shapes and dtypes throught a program. For more information about how
to specify a dtype function, see the updated
`docs/adding_a_shape_and_dtype_function.md`.

For those not familiar with how the shape library works, the file
`docs/calculations_lib.md` provides an overview.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants