diff --git a/dicp/dicp/dynamo_bridge/operator.py b/dicp/dicp/dynamo_bridge/operator.py index da5326bb9..213b8bc31 100644 --- a/dicp/dicp/dynamo_bridge/operator.py +++ b/dicp/dicp/dynamo_bridge/operator.py @@ -89,19 +89,14 @@ def make_cpu(x): with fake_mode: try: if hasattr(self, "infer_result"): - info: TensorInfo = self.infer_result(*new_args, **kwargs) - return torch.empty( - info.shape, dtype=info.dtype, memory_format=info.memory_format - ) + return self.infer_result(*new_args, **kwargs) elif hasattr(self, "torch_op"): return self.torch_op(*new_args, **kwargs) except Exception as e: - print("torch_op: ", self.torch_op if hasattr(self, "torch_op") else "") - print(new_args, kwargs) - print(e.args) - print(traceback.format_exc()) log = logging.getLogger(__name__) if hasattr(self, "infer_result"): - log.warning("infer shape and dtype failed") + log.warning( + str(self.__name__) + ": infer shape and dtype failed,ignore" + ) elif hasattr(self, "torch_op"): - log.warning("torch_op error") + log.warning("torch_op error: " + str(self.torch_op.__name__)) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 0514b8e36..36a18ad3a 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1,12 +1,9 @@ -import typing import torch from typing import Tuple from dicp.dynamo_bridge.operator import Operator -import numpy as np -from collections.abc import Sequence from dicp.vendor.AscendGraph.infer_res_utils import * -from dicp.dynamo_bridge.utils import TensorInfo, get_memory_format +from dicp.dynamo_bridge.utils import get_memory_format aten = torch.ops.aten @@ -30,12 +27,7 @@ def __init__(self): super().__init__("adds") def infer_result(self, x1, x2): - x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1, True) - x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2, True) - memory_format = get_memory_format(x1) - dtype = get_cast_dtype(x1_dtype, x2_dtype) - out_shape = get_broadcast_res_two_shape(x1_shape, x2_shape) - return TensorInfo(shape=out_shape, dtype=dtype, memory_format=memory_format) + return common_binary_op_infer(x1, x2) class Add(Operator): @@ -43,12 +35,7 @@ def __init__(self): super().__init__("add") def infer_result(self, x1, x2): - x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1, True) - x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2, True) - memory_format = get_memory_format(x1) - dtype = get_cast_dtype(x1_dtype, x2_dtype) - out_shape = get_broadcast_res_two_shape(x1_shape, x2_shape) - return TensorInfo(shape=out_shape, dtype=dtype, memory_format=memory_format) + return common_binary_op_infer(x1, x2) class BroadcastTo(Operator): @@ -75,11 +62,41 @@ class BatchMatMul(Operator): def __init__(self): super().__init__("BatchMatMul") + def infer_result(self, x1, x2, adj_x1=False, adj_x2=False): + x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1) + x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2) + + assert x1_dim == 3 and x2_dim == 3, ( + self.__class__.__name__ + ": bmm's inputs must be 3D tensor!" + ) # no broadcast + assert x1_dtype == x2_dtype, ( + self.__class__.__name__ + ": expect same input type!" + ) # no dtype cast + + adj_x1_shape = ( + [x1.shape[0]] + list(reversed(x1.shape[1:])) if adj_x1 else list(x1.shape) + ) + adj_x2_shape = ( + [x2.shape[0]] + list(reversed(x2.shape[1:])) if adj_x2 else list(x2.shape) + ) + + assert adj_x1_shape[2] == adj_x2_shape[1], ( + self.__class__.__name__ + ": shape mismatch!" + ) + out_shape = adj_x1_shape[0:2] + [adj_x2_shape[2]] + + return torch.empty( + out_shape, dtype=x1_dtype, memory_format=get_memory_format(x1) + ) + class Sub(Operator): def __init__(self): super().__init__("Sub") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2) + class Mul(Operator): def __init__(self): @@ -87,64 +104,80 @@ def __init__(self): self.torch_op = aten.mul def infer_result(self, x1, x2): - x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1, True) - x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2, True) - out_shape = get_broadcast_res_two_shape(x1_shape, x2_shape) - dtype = get_cast_dtype(x1_dtype, x2_dtype) - memory_format = get_memory_format(x1) - return TensorInfo(shape=out_shape, dtype=dtype, memory_format=memory_format) + return common_binary_op_infer(x1, x2) class Div(Operator): def __init__(self): super().__init__("Div") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2) + class DivNoNan(Operator): def __init__(self): super().__init__("DivNoNan") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2) + class Maximum(Operator): def __init__(self): super().__init__("Maximum") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2) + class Rsqrt(Operator): def __init__(self): super().__init__("Rsqrt") + def infer_result(self, x): + return common_unary_op_infer(x) + class Sqrt(Operator): def __init__(self): super().__init__("Sqrt") + def infer_result(self, x): + return common_unary_op_infer(x) + class Log(Operator): def __init__(self): super().__init__("Log") def infer_result(self, x): - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - return TensorInfo( - list(x_shape), dtype=x_dtype, memory_format=get_memory_format(x) - ) + return common_unary_op_infer(x) class Exp(Operator): def __init__(self): super().__init__("Exp") + def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0): + return common_unary_op_infer(x) + class Neg(Operator): def __init__(self): super().__init__("Neg") + def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0): + return common_unary_op_infer(x) + class Relu(Operator): def __init__(self): super().__init__("Relu") + def infer_result(self, x, base=-1.0, scale=1.0, shift=0.0): + return common_unary_op_infer(x) + class Swish(Operator): def __init__(self): @@ -161,26 +194,54 @@ def __init__(self): super().__init__("SoftmaxV2") def infer_result(self, x, axes=None): - x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x, True) - return TensorInfo( - list(x_shape), dtype=x_dtype, memory_format=get_memory_format(x) - ) + return common_unary_op_infer(x) class ReduceSumD(Operator): def __init__(self): super().__init__("ReduceSumD") + def infer_result(self, x, dims, keepdim): + return reduce_op_infer(x, dims, keepdim) + class Unsqueeze(Operator): def __init__(self): super().__init__("Unsqueeze") + def infer_result(self, x, dim=None): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + assert dim is not None, ( + self.__class__.__name__ + ": doesn't specify axis to unsqueeze!" + ) + x_shape = list(x_shape) + for d in sorted(dim, reverse=True): + x_shape.insert(d + x_dim + 1 if d < 0 else d, 1) + return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + class Squeeze(Operator): def __init__(self): super().__init__("Squeeze") + def infer_result(self, x, dim=None): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + if dim is None: + shape = [i for i in x_shape if i != 1] + else: + shape = list(x_shape) + for i in dim: + assert x_shape[i] == 1, ( + self.__class__.__name__ + + ": can only squeeze a dimension that is 1!" + ) + shape.pop(i) + + x_memory_format = get_memory_format(x) + if len(shape) < 4: + x_memory_format = torch.contiguous_format + return torch.empty(shape, dtype=x_dtype, memory_format=x_memory_format) + class Pack(Operator): def __init__(self): @@ -232,50 +293,85 @@ def __init__(self): super().__init__("ReduceMaxD") def infer_result(self, x, dims, keepdim): - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - out_shape = reduce_ops_output_size(x_shape, x_dim, dims, keepdim) - return TensorInfo( - shape=out_shape, dtype=x_dtype, memory_format=get_memory_format(x) - ) + return reduce_op_infer(x, dims, keepdim) class Const(Operator): def __init__(self): super().__init__("Const") - def infer_result(self, x, dtype, x_dim): - return TensorInfo(x_dim, dtype=dtype, memory_format=torch.contiguous_format) + def infer_result(self, new_args, kwargs): + return new_args, kwargs class Sigmoid(Operator): def __init__(self): super().__init__("Sigmoid") + def infer_result(self, x): + return common_unary_op_infer(x) + class Pow(Operator): def __init__(self): super().__init__("Pow") + def infer_result(self, base, expo): + base, base_shape, base_dim, base_dtype = get_fake_tensor_meta_val(base) + + if isinstance(expo, Tuple): # Const + expo, expo_shape = get_op_const_arg_kwarg(expo) + expo_dtype = type(expo[0]) if len(expo) > 0 else base_dtype + else: # fake Tensor + expo, expo_shape, expo_dim, expo_dtype = get_fake_tensor_meta_val( + expo + ) + + out_shape = get_broadcast_res_two_shape(base_shape, expo_shape) + dtype = get_cast_dtype(base_dtype, expo_dtype) + memory_format = get_memory_format(base) + return torch.empty(out_shape, dtype=dtype, memory_format=memory_format) + class Select(Operator): def __init__(self): super().__init__("Select") + def infer_result(self, x1, x2, condition): + x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1) + x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2) + _, c_shape, _, _ = get_fake_tensor_meta_val(condition) + out_shape = get_broadcast_res_two_shape( + get_broadcast_res_two_shape(x1_shape, c_shape), x2_shape + ) + dtype = get_cast_dtype(x1_dtype, x2_dtype) + memory_format = get_memory_format(x1) + return torch.empty(out_shape, dtype=dtype, memory_format=memory_format) + class LessEqual(Operator): def __init__(self): super().__init__("LessEqual") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2, torch.bool) + class Less(Operator): def __init__(self): super().__init__("Less") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2, torch.bool) + class Equal(Operator): def __init__(self): super().__init__("Equal") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2, torch.bool) + class Conv2D(Operator): def __init__(self): @@ -286,6 +382,9 @@ class GreaterEqual(Operator): def __init__(self): super().__init__("GreaterEqual") + def infer_result(self, x1, x2): + return common_binary_op_infer(x1, x2, torch.bool) + class InAdd(Operator): def __init__(self): @@ -297,10 +396,7 @@ def __init__(self): super().__init__("Cast") def infer_result(self, x, dtype): - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - return TensorInfo( - list(x_shape), dtype=dtype, memory_format=get_memory_format(x) - ) + return common_unary_op_infer(x, ascend_type_to_torch(dtype)) class CastToCpu(Operator): @@ -312,16 +408,32 @@ class Identity(Operator): def __init__(self): super().__init__("Identity") + def infer_result(self, x, idx=None): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + out_shape = list(x_shape[idx]) if idx is not None else list(x_shape) + return torch.empty(out_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + class IdentityInp(Operator): def __init__(self): super().__init__("IdentityInp") + def infer_result(self, src, dst): + src, src_shape, src_dim, src_dtype = get_fake_tensor_meta_val(src) + dst, dst_shape, dst_dim, dst_dtype = get_fake_tensor_meta_val(dst) + out_shape = get_broadcast_res_two_shape(src_shape, dst_shape) + return torch.empty( + out_shape, dtype=dst_dtype, memory_format=get_memory_format(dst) + ) + class IdentityN(Operator): def __init__(self): super().__init__("IdentityN") + def infer_result(self, x): + return common_unary_op_infer(x) + class Empty(Operator): def __init__(self): @@ -332,6 +444,13 @@ class GatherV2(Operator): def __init__(self): super().__init__("GatherV2") + def infer_result(self, x, index, axis): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index) + idx_shape = list(idx_shape) + idx_shape.append(x_shape[-1]) + return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + class OnesLike(Operator): def __init__(self): @@ -358,10 +477,7 @@ def __init__(self): super().__init__("LogSoftmaxV2") def infer_result(self, x, dim): - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - return TensorInfo( - list(x_shape), dtype=x_dtype, memory_format=get_memory_format(x) - ) + return common_unary_op_infer(x) class LogSoftmaxGrad(Operator): @@ -418,6 +534,9 @@ class ZerosLike(Operator): def __init__(self, x): super().__init__("ZerosLike") + def infer_result(self, x): + return common_unary_op_infer(x) + class SplitD(Operator): def __init__(self): @@ -433,6 +552,19 @@ class ConcatD(Operator): def __init__(self): super().__init__("ConcatD") + # TODO:memory_format? + def infer_result(self, x, dim=0): + x0, x0_shape, x0_dim, x0_dtype = get_fake_tensor_meta_val(x[0]) + dim = (dim + x0_dim) % x0_dim + out_shape = list(x0_shape) + out_shape[dim] = 0 + for t in x: + _, t, _, _ = get_fake_tensor_meta_val(t) + out_shape[dim] += t[dim] + return torch.empty( + out_shape, dtype=x0_dtype, memory_format=get_memory_format(x0) + ) + class MaskedFill(Operator): def __init__(self): @@ -443,6 +575,30 @@ class Reshape(Operator): def __init__(self): super().__init__("Reshape") + # TODO:conflict in solving stride between "view" and "select" + def infer_result(self, x, shape_const_op): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + re_shape, re_dim = get_op_const_arg_kwarg(shape_const_op) + # check whether stride and storage_offset are manually specified + # if so, x is from operators like "Slice", and the stride and storage_offset still need to modify here + x_stride = list(x.stride()) + x_shape = list(x_shape) + + for i in range(len(x_stride) - 2, -1, -1): + if x_stride[i + 1] * x_shape[i + 1] != x_stride[i]: + del x_stride[i + 1] + del x_shape[i + 1] + break + else: + if len(x_shape) != len(re_shape): + del x_stride[0] + del x_shape[0] + + x_storage_offset = x.storage_offset() + res = torch.empty(re_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + res = torch.as_strided(res, re_shape, x_stride, x_storage_offset) + return res + class Pad(Operator): def __init__(self): @@ -453,6 +609,9 @@ class Fills(Operator): def __init__(self): super().__init__("Fills") + def infer_result(self, x, value): + return common_unary_op_infer(x) + class SoftmaxGrad(Operator): def __init__(self): @@ -463,23 +622,27 @@ class StatelessBernoulli(Operator): def __init__(self): super().__init__("StatelessBernoulli") + def infer_result(self, target, prob, seed, offset, dtype): + return common_unary_op_infer( + target, spec_dtype=dtype, spec_format=torch.contiguous_format + ) + class Shape(Operator): def __init__(self): super().__init__("Shape") + def infer_result(self, x): + # like Const, we won't use this function, but it should exist as a flag for triggering inference of resinfo + return common_unary_op_infer(x, spec_format=torch.contiguous_format) + class AddV2(Operator): def __init__(self): super().__init__("AddV2") def infer_result(self, x1, x2): - x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1, True) - x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2, True) - memory_format = get_memory_format(x1) - dtype = get_cast_dtype(x1_dtype, x2_dtype) - out_shape = get_broadcast_res_two_shape(x1_shape, x2_shape) - return TensorInfo(shape=out_shape, dtype=dtype, memory_format=memory_format) + return common_binary_op_infer(x1, x2) class StatelessRandomUniformV2(Operator): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 8bc5fd5a3..635fb5161 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -346,6 +346,10 @@ def gen_call_func(self): out_stride_str += '[1],' out_storage_offset_str += '0,' continue + if elem.dim()==0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) + out_stride_str += '[1],' + out_storage_offset_str += '0,' + continue stride = list(elem.stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") diff --git a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py index 7357ad614..f2b909d24 100644 --- a/dicp/dicp/vendor/AscendGraph/infer_res_utils.py +++ b/dicp/dicp/vendor/AscendGraph/infer_res_utils.py @@ -1,32 +1,70 @@ from collections.abc import Sequence from typing import Optional, Tuple, Union +from dicp.dynamo_bridge.utils import get_memory_format import torch +"""parse and get val""" + + +# in conversion.py, some ops' ("cast") inputs are ascend_type like 'FLOAT',but infer needs torch type +def ascend_type_to_torch(ascend_type: str) -> torch.dtype: + ascend_type_map = { + "BOOL": torch.bool, + "INT64": torch.int64, + "FLOAT": torch.float32, + "FLOAT16": torch.float16, + "INT32": torch.int32, + "COMPLEX64": torch.complex64, + } + + assert ( + ascend_type in ascend_type_map + ), "unknow ascend_dtype in ascend_type_to_torch!" + + return ascend_type_map[ascend_type] + def get_fake_tensor_meta_val( x, req_dim=True, req_dtype=True -) -> Tuple[any, list, Union[int, None], Union[torch.dtype, type, None]]: +) -> Tuple[torch.Tensor, Union[torch.Size, list], int, Union[torch.dtype, type, None]]: x_shape = x.size() if hasattr(x, "size") else [1] x_dim = len(x_shape) x_dtype = x.dtype if hasattr(x, "dtype") else None return x, x_shape, x_dim, x_dtype -def get_broadcast_res_two_shape(shape1, shape2) -> Optional[list]: - len1 = len(shape1) - len2 = len(shape2) - max_len = max(len1, len2) - result_shape = [] - for i in range(-1, -max_len - 1, -1): - dim1 = shape1[i] if i >= -len1 else 1 - dim2 = shape2[i] if i >= -len2 else 1 - if dim1 == dim2 or dim1 == 1 or dim2 == 1: - result_shape.insert(0, max(dim1, dim2)) - else: - print(torch.randn(shape1).shape, " ", torch.randn(shape2).shape, end=" ") - assert False, "input shapes must be broadcastable!" - return result_shape +def get_op_const_arg_kwarg(const_arg): + """ + if some operator uses Const as an input, call this func to get the input (args and kwargs) of the input op. + Some operators like "reshape" need a tensor's value(shape), so for operators like "Const" we directly pass its input + (including value and shape) instead of constructing a fakeTensor, which will neglect a tensor's value. + input: + - const_arg: Tuple (new_args,kwargs) + - new_args: Tuple, identical to input-"new_args" of operator Const + - kwargs: dict, identical to input-"kwargs" of operator Const + + output: + - arg0: list, value of "Const"'s input + - arg2: list, shape of "Const"'s input + """ + new_args = const_arg[0] + arg0 = new_args[0] + arg2 = new_args[2] + return arg0, arg2 + + +def get_op_const_arg_kwarg(const_arg): + """ + similar to get_op_const_arg_kwarg() + """ + new_args = const_arg[0] + shape = new_args[0] + dim = new_args[2] + return shape, dim + + +"""analyze dtype,format""" def get_cast_dtype( @@ -86,6 +124,25 @@ def analyze_memory_format(tensor: torch.Tensor, operation: str) -> torch.memory_ return tensor.memory_format if tensor.is_contiguous() else original_format +"""calculate size,stride,storage_offset""" + + +def get_broadcast_res_two_shape(shape1, shape2) -> Optional[list]: + len1 = len(shape1) + len2 = len(shape2) + max_len = max(len1, len2) + result_shape = [] + for i in range(-1, -max_len - 1, -1): + dim1 = shape1[i] if i >= -len1 else 1 + dim2 = shape2[i] if i >= -len2 else 1 + if dim1 == dim2 or dim1 == 1 or dim2 == 1: + result_shape.insert(0, max(dim1, dim2)) + else: + print(torch.randn(shape1).shape, " ", torch.randn(shape2).shape, end=" ") + assert False, "input shapes must be broadcastable!" + return result_shape + + def reduce_ops_output_size( x_shape, x_dim, dim: Union[None, Sequence, int], keepdim=False ): @@ -93,7 +150,7 @@ def reduce_ops_output_size( if keepdim is True: shape = [1] * x_dim else: - shape = [] + shape = [] # sum(all) need a scalar as ouput (no shape no stride) else: dim = [dim] if not isinstance(dim, Sequence) else dim dim = [(d + x_dim) % x_dim for d in dim] @@ -106,3 +163,40 @@ def reduce_ops_output_size( if r not in dim and r - x_dim not in dim ] return shape + + +def cal_stride_offset(new_shape: list, offset: list, res: torch.Tensor): + stride = list(res.stride()) + ori_shape = list(res.size()) + new_offset = 0 + for s, off in zip(stride, offset): + new_offset += s * off + stride = [k for k, i, j in zip(stride, ori_shape, new_shape) if i != j] + return stride, new_offset + + +"""binary&unary operators""" + + +def common_binary_op_infer(x1, x2, spec_dtype=None, spec_format=None) -> torch.Tensor: + x1, x1_shape, x1_dim, x1_dtype = get_fake_tensor_meta_val(x1) + x2, x2_shape, x2_dim, x2_dtype = get_fake_tensor_meta_val(x2) + out_shape = get_broadcast_res_two_shape(x1_shape, x2_shape) + dtype = get_cast_dtype(x1_dtype, x2_dtype) if not spec_dtype else spec_dtype + memory_format = get_memory_format(x1) if not spec_format else spec_format + return torch.empty(out_shape, dtype=dtype, memory_format=memory_format) + + +def common_unary_op_infer(x, spec_dtype=None, spec_format=None) -> torch.Tensor: + _, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) + return torch.empty( + x_shape, + dtype=x_dtype if not spec_dtype else spec_dtype, + memory_format=get_memory_format(x) if not spec_format else spec_format, + ) + + +def reduce_op_infer(x, dims, keepdim) -> torch.tensor: + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + out_shape = reduce_ops_output_size(x_shape, x_dim, dims, keepdim) + return torch.empty(out_shape, dtype=x_dtype, memory_format=get_memory_format(x))