Skip to content

Commit

Permalink
Fix LINT errors
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Sep 6, 2023
1 parent c5d5d2b commit cfe9eb8
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 55 deletions.
17 changes: 5 additions & 12 deletions python/tvm/relax/transform/optimize_layout_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -24,9 +23,10 @@
from tvm.relax.expr_functor import mutator, PyExprMutator
from typing import Union


@mutator
class OptimizeLayoutTranformMutator(PyExprMutator):
'''
"""
Mutator to iterate over relax functions to
remove redundant transform layout operators
introduced by AlterOpImpl pass.
Expand All @@ -36,18 +36,12 @@ class OptimizeLayoutTranformMutator(PyExprMutator):
mod: IRModule
The ir module
'''
"""

def __init__(self, mod: IRModule) -> None:
super().__init__(mod)
self.mod_ = mod
self.patterns = [
[
"relax.layout_transform",
"relax.layout_transform"
]
]

self.patterns = [["relax.layout_transform", "relax.layout_transform"]]

# Matches the call_node against the pattern layout_transform -> layout_transform.
# Based on the pattern matching, returns the updated arguments for call_node.
Expand All @@ -62,8 +56,7 @@ def check_op_type(call_node: relax.Call, op_name: str) -> bool:
return True

new_call_args = []

# Update args of call_node be checking the pattern
# Update args of call_node
for arg in call_node.args[1]:
is_pattern_match = False
if not isinstance(arg, relax.expr.Var):
Expand Down
195 changes: 152 additions & 43 deletions tests/python/relax/test_optimize_layout_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down Expand Up @@ -39,11 +38,14 @@ def _run_pass_compare_output(Before, Expected):


def test_optimize_transform_layout_pass_one_arg():

@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
Expand All @@ -54,24 +56,50 @@ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4,
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
def main(
x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv2 = R.call_tir(Before.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv5 = R.call_tir(Before.relax_add_replacement, (lv4, lv3), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2 = R.call_tir(
Before.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv5 = R.call_tir(
Before.relax_add_replacement,
(lv4, lv3),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
Expand All @@ -82,14 +110,32 @@ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4,
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
def main(
x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv5 = R.call_tir(Expected.relax_add_replacement, (lv4, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv5 = R.call_tir(
Expected.relax_add_replacement,
(lv4, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv
Expand All @@ -98,11 +144,14 @@ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32"


def test_optimize_transform_layout_pass_two_args():

@I.ir_module
class Before:
@T.prim_func(private=True)
def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
Expand All @@ -113,27 +162,63 @@ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4,
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32"), z: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
def main(
x: R.Tensor((16,), dtype="float32"),
y: R.Tensor((16,), dtype="float32"),
z: R.Tensor((16,), dtype="float32"),
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(z, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv3 = R.call_tir(Before.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv4 = R.call_tir(Before.relax_add_replacement, (lv, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv9 = R.call_tir(Before.relax_add_replacement, (lv7, lv8), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
z, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv3 = R.call_tir(
Before.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv4 = R.call_tir(
Before.relax_add_replacement,
(lv, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv9 = R.call_tir(
Before.relax_add_replacement,
(lv7, lv8),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv10
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
def relax_add_replacement(
arg0: T.Buffer((4, 4), "float32"),
arg1: T.Buffer((4, 4), "float32"),
output: T.Buffer((4, 4), "float32"),
):
T.func_attr({"operator_name": "relax.add"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
Expand All @@ -144,15 +229,39 @@ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4,
output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]

@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32"), z: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
def main(
x: R.Tensor((16,), dtype="float32"),
y: R.Tensor((16,), dtype="float32"),
z: R.Tensor((16,), dtype="float32"),
) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(z, index_map=lambda i: (i // 4, i % 4), pad_value=None)
lv3 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv4 = R.call_tir(Expected.relax_add_replacement, (lv, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None)
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
x, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
y, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
z, index_map=lambda i: (i // 4, i % 4), pad_value=None
)
lv3 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv4 = R.call_tir(
Expected.relax_add_replacement,
(lv, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv5 = R.call_tir(
Expected.relax_add_replacement,
(lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv6
R.output(gv)
return gv
Expand Down

0 comments on commit cfe9eb8

Please sign in to comment.