Skip to content

Commit

Permalink
[Hexagon][TOPI] Use IndexMap axis separator instead of TE
Browse files Browse the repository at this point in the history
- This is needed to reuse index_map functions passed to transform_layout in TIR
  • Loading branch information
abhikran-quic committed Apr 3, 2023
1 parent 2c052b2 commit 5087562
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import struct
from typing import Tuple
from tvm import te

from tvm.tir import IndexMap

def n11c_1024c_2d(n, h, w, c):
"""Return index map for n11c_1024 2d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
return [n, h, w, c // 1024, IndexMap.AXIS_SEPARATOR, c % 1024]


def n11c_1024c_1d(n, h, w, c):
Expand All @@ -37,7 +37,7 @@ def n11c_1024c_1d(n, h, w, c):

def nhwc_8h2w32c2w_2d(n, h, w, c):
"""Return index map for nhwc_8h2w32c2w 2d layout"""
return [n, h // 8, w // 4, c // 32, te.AXIS_SEPARATOR, h % 8, (w % 4) // 2, c % 32, w % 2]
return [n, h // 8, w // 4, c // 32, IndexMap.AXIS_SEPARATOR, h % 8, (w % 4) // 2, c % 32, w % 2]


def nhwc_8h2w32c2w_1d(n, h, w, c):
Expand All @@ -47,7 +47,7 @@ def nhwc_8h2w32c2w_1d(n, h, w, c):

def nhw_32h16w_2d(n, h, w):
"""Return index map for nhw_32h16w 2d layout"""
return [n, h // 32, w // 16, te.AXIS_SEPARATOR, h % 32, w % 16]
return [n, h // 32, w // 16, IndexMap.AXIS_SEPARATOR, h % 32, w % 16]


def nhwc_4h4w32c_1d(n, h, w, c):
Expand All @@ -57,7 +57,7 @@ def nhwc_4h4w32c_1d(n, h, w, c):

def nhwc_4h4w32c_2d(n, h, w, c):
"""Return index map for nhwc_4h4w32c 2d layout"""
return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, w % 4, c % 32]
return [n, h // 4, w // 4, c // 32, IndexMap.AXIS_SEPARATOR, h % 4, w % 4, c % 32]


def nc_512c_1d(n, c):
Expand All @@ -67,12 +67,12 @@ def nc_512c_1d(n, c):

def nc_512c_2d(n, c):
"""Return index map for nc_512c 2d layout"""
return [n, c // 512, te.AXIS_SEPARATOR, c % 512]
return [n, c // 512, IndexMap.AXIS_SEPARATOR, c % 512]


def nc_1024c_2d(n, c):
"""Return index map for nc_1024c 2d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]
return [n, c // 1024, IndexMap.AXIS_SEPARATOR, c % 1024]


def nc_2048c_1d(n, c):
Expand All @@ -82,7 +82,7 @@ def nc_2048c_1d(n, c):

def nc_2048c_2d(n, c):
"""Return index map for nc_2024c 2d layout"""
return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]
return [n, c // 2048, IndexMap.AXIS_SEPARATOR, c % 2048]


def nc_1024c_1d(n, c):
Expand All @@ -92,37 +92,37 @@ def nc_1024c_1d(n, c):

def nhwc_4h2w32c2w_2d(n, h, w, c):
"""Return index map for nhwc_4h2w32c2w 2d layout"""
return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2]
return [n, h // 4, w // 4, c // 32, IndexMap.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2]


def nhwc_1024c_2d(n, h, w, c):
"""Return index map for nhwc_1024 2d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
return [n, h, w, c // 1024, IndexMap.AXIS_SEPARATOR, c % 1024]


def nc_1024_2d(n, c):
"""Return index map for nc_1024 2d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]
return [n, c // 1024, IndexMap.AXIS_SEPARATOR, c % 1024]


def nhwc_2048c_2d(n, h, w, c):
"""Return index map for nhwc_2048 2d layout"""
return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048]
return [n, h, w, c // 2048, IndexMap.AXIS_SEPARATOR, c % 2048]


def nc_2048_2d(n, c):
"""Return index map for nc_2048 2d layout"""
return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]
return [n, c // 2048, IndexMap.AXIS_SEPARATOR, c % 2048]


def nhwc_8h8w32c_2d(n, h, w, c):
"""Return index map for nhwc_8h8w32c 2d layout"""
return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32]
return [n, h // 8, w // 8, c // 32, IndexMap.AXIS_SEPARATOR, h % 8, w % 8, c % 32]


def n11c_2048c_2d(n, h, w, c):
"""Return index map for n11c_2048c 2d layout"""
return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048]
return [n, h, w, c // 2048, IndexMap.AXIS_SEPARATOR, c % 2048]


def iohw_16i32o2i_1d(height, width, in_channel, out_channel):
Expand All @@ -143,15 +143,15 @@ def ohwi32o_1d(height, width, in_channel, out_channel):

def ncw_32c64w_2d(n, c, w):
"""Return index map for ncw_32c64w 2d layout"""
return [n, c // 32, w // 64, te.AXIS_SEPARATOR, c % 32, w % 64]
return [n, c // 32, w // 64, IndexMap.AXIS_SEPARATOR, c % 32, w % 64]


def nchw_32c8h8w_2d(n, c, h, w):
return [n, c // 32, h // 8, w // 8, te.AXIS_SEPARATOR, c % 32, h % 8, w % 8]
return [n, c // 32, h // 8, w // 8, IndexMap.AXIS_SEPARATOR, c % 32, h % 8, w % 8]


def nchw_32c8h4w_2d(n, c, h, w):
return [n, c // 32, h // 8, w // 4, te.AXIS_SEPARATOR, c % 32, h % 8, w % 4]
return [n, c // 32, h // 8, w // 4, IndexMap.AXIS_SEPARATOR, c % 32, h % 8, w % 4]


def get_layout_transform_fn(layout):
Expand Down

0 comments on commit 5087562

Please sign in to comment.