diff --git a/python/hidet/graph/ops/definitions/utils.py b/python/hidet/graph/ops/definitions/utils.py index b08bd1fde..a9979b71a 100644 --- a/python/hidet/graph/ops/definitions/utils.py +++ b/python/hidet/graph/ops/definitions/utils.py @@ -103,7 +103,7 @@ def normalize_index(index: Optional[int], dim_size, default) -> int: return dim_size -def resolve_out_dtype(input_dtypes: List[Union[DataType, str]]) -> str: +def resolve_out_dtype(input_dtypes: Sequence[Union[DataType, str]]) -> str: from hidet.ir.utils.type_utils import numeric_promotion if len(input_dtypes) == 0: @@ -114,17 +114,18 @@ def resolve_out_dtype(input_dtypes: List[Union[DataType, str]]) -> str: return out_dtype.name -def can_broadcast(src_shape: List[int], dst_shape: List[int]) -> bool: +def can_broadcast(src_shape: Sequence[int], dst_shape: Sequence[int]) -> bool: if len(dst_shape) < len(src_shape): return False - src_shape = [1 for _ in range(len(dst_shape) - len(src_shape))] + src_shape + src_shape = [1 for _ in range(len(dst_shape) - len(src_shape))] + list(src_shape) for a, b in zip(src_shape, dst_shape): if a not in [1, b]: return False return True -def can_mutually_broadcast(x_shape: List[int], y_shape: List[int]) -> bool: +def can_mutually_broadcast(x_shape: Sequence[int], y_shape: Sequence[int]) -> bool: + x_shape, y_shape = list(x_shape), list(y_shape) while len(x_shape) < len(y_shape): x_shape = [1] + x_shape while len(y_shape) < len(x_shape): @@ -132,11 +133,12 @@ def can_mutually_broadcast(x_shape: List[int], y_shape: List[int]) -> bool: return all(p == q or p == 1 or q == 1 for p, q in zip(x_shape, y_shape)) -def broadcast_shape(x_shape: List[int], y_shape: List[int]) -> List[int]: +def broadcast_shape(x_shape: Sequence[int], y_shape: Sequence[int]) -> List[int]: """ Broadcast two shapes with the same rule as numpy. Please refer to https://numpy.org/doc/stable/user/basics.broadcasting.html for details. """ + x_shape, y_shape = list(x_shape), list(y_shape) orig_shapes = x_shape, y_shape while len(x_shape) < len(y_shape): x_shape = [1] + x_shape @@ -150,9 +152,9 @@ def broadcast_shape(x_shape: List[int], y_shape: List[int]) -> List[int]: return result_shape -def broadcast_shapes(shapes: List[List[int]]) -> List[int]: +def broadcast_shapes(shapes: Sequence[Sequence[int]]) -> List[int]: assert len(shapes) >= 1 - expanded_shape = shapes[0] + expanded_shape = list(shapes[0]) for shape in shapes: expanded_shape = broadcast_shape(expanded_shape, shape) return expanded_shape