Skip to content

Commit

Permalink
Fix more LINT errors
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Sep 11, 2023
1 parent cf471c6 commit f847182
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions python/tvm/relax/transform/optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
"""Relax Optimize Layout Transform pass."""
import tvm
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.relax import Expr, Function
from tvm.relax.dpl import *
from tvm.relax.dpl import is_op, rewrite_call, wildcard
from . import function_pass


Expand All @@ -29,18 +28,6 @@ class OptimizeLayoutTransform:
"""
Pass to remove redundant transform layout operators
introduced by AlterOpImpl pass.
Parameters
----------
func: Expr
The relax function to be optimized
mod: IRModule
The ir module
ctx: PassContext
Relax pass context
"""

def __init__(self):
Expand All @@ -51,9 +38,23 @@ def __init__(self):
self.pattern = pattern_1

def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule:
"""
Tranformation function to pattern match layout_transform -> layout_transform
pattern
Parameters
----------
func: Expr
The relax function to be optimized
mod: IRModule
The ir module
ctx: PassContext
Relax pass context
"""
updated_func = func
for global_var, func in mod.functions.items():
for _, func in mod.functions.items():
# Skip non-relax functions
if not isinstance(func, Function):
continue
Expand All @@ -66,6 +67,7 @@ def rewriter(expr, matches):
arg2 = matches[self.input]
if list(arg1.struct_info.shape) == list(arg2.struct_info.shape):
return arg2
return expr

updated_func = rewrite_call(self.pattern, rewriter, func)

Expand Down

0 comments on commit f847182

Please sign in to comment.