Skip to content

Commit

Permalink
[Unity][Layout] Add layout transformation analysis for PrimFunc (#14066)
Browse files Browse the repository at this point in the history
* [Layout] Add layout transformation analysis for PrimFunc.

This change adds a PrimFunc level analysis to suggest layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs.

* Add support for multiple blocks such as split op.

* Add negative tests and increase coverage.

* fix warning message

* fix lint

* remove unused header

* Address comments.
Moved some utility functions to support/array.h
improve doc

* fix deprecation warn T.var("int64") to T.int64()

* address comments
  • Loading branch information
psrivas2 authored and tqchen committed Mar 5, 2023
1 parent 74f3007 commit 98d0a01
Show file tree
Hide file tree
Showing 5 changed files with 1,522 additions and 2 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,19 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
*/
TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);

/*!
* \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks
* and buffers for the PrimFunc.
*
* \param fn The PrimFunc to be analyzed.
* \param write_buffer_transformations Array of IndexMap transformations on PrimFunc outputs.
* \return Suggested transforms per block in `fn`. For each block the returned value is a map
* from the object (block or buffer) to it's index map transformation.
*/

TVM_DLL Map<tir::Block, Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
const Function& fn, Array<tir::IndexMap> write_buffer_transformations);

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 31 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
configuring the passes and scripting them in Python.
"""

from typing import Dict, List
from typing import Dict, List, Union, Callable
from enum import IntEnum

from tvm import tir
from tvm import IRModule
from tvm.relax.ty import Type
from tvm.relax.struct_info import StructInfo, FuncStructInfo
from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding
from tvm.tir import IndexMap, PrimFunc, Block, Buffer
from . import _ffi_api


Expand Down Expand Up @@ -289,3 +290,32 @@ def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
will be well tested and will not be blocked by not having structure info.
"""
return _ffi_api.well_formed(mod, check_struct_info) # type: ignore


def suggest_layout_transforms(
func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]]
) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]:
"""Suggest Layout transformations of blocks and buffers in a PrimFunc.
Parameters
----------
func: PrimFunc
PrimFunc on which analysis will be performed and transformations suggested.
write_buffer_transforms: List[Union[IndexMap, Callable]
List of layout transformations on the output buffers. The number of layout
transformations must match the number of outputs of the PrimFunc.
Returns
-------
ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]]
Suggested transforms per block in `func`. For each block the returned value is a map
from the object (block or buffer) to it's index map transformation.
"""
write_buffer_index_maps = []
for transform in write_buffer_transforms:
if callable(transform):
transform = IndexMap.from_func(transform)
assert isinstance(transform, IndexMap)
write_buffer_index_maps.append(transform)
return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps) # type: ignore
Loading

0 comments on commit 98d0a01

Please sign in to comment.