Skip to content

Commit

Permalink
[Fix] Update shape utility functions to expect Sequence instead of Li…
Browse files Browse the repository at this point in the history
…st (#86)

update shape utility functions to expect Sequence instead of List
  • Loading branch information
yaoyaoding authored Feb 12, 2023
1 parent 0cd78c7 commit 80a35d6
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/hidet/graph/ops/definitions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def normalize_index(index: Optional[int], dim_size, default) -> int:
return dim_size


def resolve_out_dtype(input_dtypes: List[Union[DataType, str]]) -> str:
def resolve_out_dtype(input_dtypes: Sequence[Union[DataType, str]]) -> str:
from hidet.ir.utils.type_utils import numeric_promotion

if len(input_dtypes) == 0:
Expand All @@ -114,29 +114,31 @@ def resolve_out_dtype(input_dtypes: List[Union[DataType, str]]) -> str:
return out_dtype.name


def can_broadcast(src_shape: List[int], dst_shape: List[int]) -> bool:
def can_broadcast(src_shape: Sequence[int], dst_shape: Sequence[int]) -> bool:
if len(dst_shape) < len(src_shape):
return False
src_shape = [1 for _ in range(len(dst_shape) - len(src_shape))] + src_shape
src_shape = [1 for _ in range(len(dst_shape) - len(src_shape))] + list(src_shape)
for a, b in zip(src_shape, dst_shape):
if a not in [1, b]:
return False
return True


def can_mutually_broadcast(x_shape: List[int], y_shape: List[int]) -> bool:
def can_mutually_broadcast(x_shape: Sequence[int], y_shape: Sequence[int]) -> bool:
x_shape, y_shape = list(x_shape), list(y_shape)
while len(x_shape) < len(y_shape):
x_shape = [1] + x_shape
while len(y_shape) < len(x_shape):
y_shape = [1] + y_shape
return all(p == q or p == 1 or q == 1 for p, q in zip(x_shape, y_shape))


def broadcast_shape(x_shape: List[int], y_shape: List[int]) -> List[int]:
def broadcast_shape(x_shape: Sequence[int], y_shape: Sequence[int]) -> List[int]:
"""
Broadcast two shapes with the same rule as numpy.
Please refer to https://numpy.org/doc/stable/user/basics.broadcasting.html for details.
"""
x_shape, y_shape = list(x_shape), list(y_shape)
orig_shapes = x_shape, y_shape
while len(x_shape) < len(y_shape):
x_shape = [1] + x_shape
Expand All @@ -150,9 +152,9 @@ def broadcast_shape(x_shape: List[int], y_shape: List[int]) -> List[int]:
return result_shape


def broadcast_shapes(shapes: List[List[int]]) -> List[int]:
def broadcast_shapes(shapes: Sequence[Sequence[int]]) -> List[int]:
assert len(shapes) >= 1
expanded_shape = shapes[0]
expanded_shape = list(shapes[0])
for shape in shapes:
expanded_shape = broadcast_shape(expanded_shape, shape)
return expanded_shape
Expand Down

0 comments on commit 80a35d6

Please sign in to comment.