diff --git a/python/tvm/relax/transform/optimize_layout_transform.py b/python/tvm/relax/transform/optimize_layout_transform.py index 36280fd45c671..ce8b39dfb15d6 100644 --- a/python/tvm/relax/transform/optimize_layout_transform.py +++ b/python/tvm/relax/transform/optimize_layout_transform.py @@ -14,13 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local """Relax Optimize Layout Transform pass.""" -import tvm from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.relax import Expr, Function -from tvm.relax.dpl import * +from tvm.relax.dpl import is_op, rewrite_call, wildcard from . import function_pass @@ -29,18 +28,6 @@ class OptimizeLayoutTransform: """ Pass to remove redundant transform layout operators introduced by AlterOpImpl pass. - - Parameters - ---------- - func: Expr - The relax function to be optimized - - mod: IRModule - The ir module - - ctx: PassContext - Relax pass context - """ def __init__(self): @@ -51,9 +38,23 @@ def __init__(self): self.pattern = pattern_1 def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule: + """ + Tranformation function to pattern match layout_transform -> layout_transform + pattern + + Parameters + ---------- + func: Expr + The relax function to be optimized + + mod: IRModule + The ir module + ctx: PassContext + Relax pass context + """ updated_func = func - for global_var, func in mod.functions.items(): + for _, func in mod.functions.items(): # Skip non-relax functions if not isinstance(func, Function): continue @@ -66,6 +67,7 @@ def rewriter(expr, matches): arg2 = matches[self.input] if list(arg1.struct_info.shape) == list(arg2.struct_info.shape): return arg2 + return expr updated_func = rewrite_call(self.pattern, rewriter, func)