As part of adding support for a Torch operator in Torch-MLIR, it is usually necessary to define a shape and dtype function so that the compiler can infer the shapes and dtypes of result tensors for the operator. We use the abstract interpretation library for this process.
We will use the example of adding support for the torch.aten.tanh
op.
-
First, you need to find the shape and dtype function signatures for the operator you are implementing a functions for. This can be found in
include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt
generated by thebuild_tools/update_torch_ods.sh
script. That file is the "rosetta stone" that allows translating between e.g.torch.aten.tanh
,AtenTanhOp
, and the shape and dtype function signatures are:def aten〇tanh〡shape(self: List[int]) -> List[int]:
def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
Note the use of
〇
as a separator since.
or::
aren't legal in a Python identifier. -
Paste the function signature into
abstract_interp_lib_gen.py
in an appropriate place (ideally near other functions with a similar functions). Note thatabstract_interp_lib_gen.py
will check that these signatures are verbatim identical with the ones given inJITOperatorRegistryDump.txt
-- this ensures that the functions don't get outdated if Torch changes an operator signature. -
Fill in the body of the function. Ideally this will just be a call into a helper function from
torch/jit/_shape_functions.py
. But in general, you will need to write the function and test it (see the comments about "Shape, dtype, and decomposition function testing infrastructure" intesting_framework.py
). New shape functions should be added upstream following the example of this PR, though it can be useful to iterate locally inabstract_interp_lib_gen.py
first.Similarly, dtype functions should ideally just be a call to the helper
promote_dtypes
defined inlibrary_generator.py
. However, some ops will require some extra logic to calculate the right result types. While dtypes are expressed asint
s in the arguments of the dtype function, using PyTorch dtypes, such astorch.int
andtorch.float32
, in the body of the dtype function is fully supported. Dtype functions are also expected to be fully tested. -
Re-run the
build_tools/update_abstract_interp_lib.sh
script to update the library. After this step happens, ideally everything "just works" and the functions are now correctly inferred for the operator.
It is possible that the refinement pipeline (see Shape and Dtype Refinement Pipeline Architecture)
is not able to infer the shape or dtype of a tensor with a given
abstract interpretation function. This usually means that there is something
about the function which the optimizations in
torch-simplify-shape-functions
and torch-simplify-dtype-functions
(lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
,
lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
)
cannot handle.
To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the function and the IR dumps. The best IR dump to look at
varies, but frequently the IR dump right before DropAbstractInterpCalculations
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
-
You might find that there is a loop with a non-constant trip count, but based on your understanding of the function, you would expect it to be simplified to a constant trip count -- you can then look at the trip count calculation and see if there is a missing fold or canonicalization.
-
You might find that there is a list operation that is not currently understood by the optimizations. You can then teach the optimizations about that operation.
-
You might find that there is an
Optional
value that you would expect to be resolved to either a concrete value orNone
. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
See this video for general guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the function using constructs that
torch-simplify-shape-functions
and torch-simplify-dtype-functions
can handle
(look at other functions for examples, sometimes it requires writing things a
little awkwardly).