Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Layout] Add layout transformation analysis for PrimFunc #14066

Merged
merged 9 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,19 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
*/
TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);

/*!
* \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks
* and buffers for the PrimFunc.
*
* \param fn The PrimFunc to be analyzed.
* \param write_buffer_transformations Array of IndexMap transformations on PrimFunc outputs.
* \return Suggested transforms per block in `fn`. For each block the returned value is a map
* from the object (block or buffer) to it's index map transformation.
*/

TVM_DLL Map<tir::Block, Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
const Function& fn, Array<tir::IndexMap> write_buffer_transformations);

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 31 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
configuring the passes and scripting them in Python.
"""

from typing import Dict, List
from typing import Dict, List, Union, Callable
from enum import IntEnum

from tvm import tir
from tvm import IRModule
from tvm.relax.ty import Type
from tvm.relax.struct_info import StructInfo, FuncStructInfo
from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding
from tvm.tir import IndexMap, PrimFunc, Block, Buffer
from . import _ffi_api


Expand Down Expand Up @@ -289,3 +290,32 @@ def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
will be well tested and will not be blocked by not having structure info.
"""
return _ffi_api.well_formed(mod, check_struct_info) # type: ignore


def suggest_layout_transforms(
func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]]
) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]:
"""Suggest Layout transformations of blocks and buffers in a PrimFunc.
Parameters
----------
func: PrimFunc
PrimFunc on which analysis will be performed and transformations suggested.
write_buffer_transforms: List[Union[IndexMap, Callable]
List of layout transformations on the output buffers. The number of layout
transformations must match the number of outputs of the PrimFunc.
Returns
-------
ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]]
Suggested transforms per block in `func`. For each block the returned value is a map
from the object (block or buffer) to it's index map transformation.
"""
write_buffer_index_maps = []
for transform in write_buffer_transforms:
if callable(transform):
transform = IndexMap.from_func(transform)
assert isinstance(transform, IndexMap)
write_buffer_index_maps.append(transform)
return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps) # type: ignore
Loading