Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DICP]fuse transpose/mm in ascendgraph #523

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dicp/dicp/dynamo_bridge/op_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def get_proxy(self, target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] =
'call_function', target.get_singleton(), args, kwargs)
return proxy

def get_proxy_from_node(self, node):
return self.tracer.proxy(node)

def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
if target in self._conversions:
converted_target = self._conversions[target]
Expand Down
6 changes: 6 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2)


class Muls(Operator):
def __init__(self):
super().__init__("Muls")
self.torch_op = aten.mul


class Div(Operator):
def __init__(self):
super().__init__("Div")
Expand Down
7 changes: 7 additions & 0 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,13 @@ def Mul(name, x, y):
op.set_input("x2", y)
return op.to_node()

@staticmethod
def Muls(name, x, y):
op = OP(name, "Muls")
op.set_input("x", x)
op.set_attr_float("value", float(y))
return op.to_node()

@staticmethod
def IdentityN(name, *args, **kwargs):
input_names = []
Expand Down
2 changes: 1 addition & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def release_memory(self):


memory_pool = MemoryPool()
zero_tensor = torch.randn(1).to(dipu_device_str)


class AscendExecutor(object):
Expand Down Expand Up @@ -254,7 +255,6 @@ def init_resource(self):

def _prepare_input(self, images, dims):
assert self.num_inputs == len(images)
zero_tensor = torch.randn(1).to(dipu_device_str)
for i in range(self.num_inputs):
buffer_size = self.input_size[i]
if dims is not None and i in dims.keys():
Expand Down
27 changes: 17 additions & 10 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def get_param_proxy(self, param, type, target_shape):

def mul_scalar(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
y_shape = list(x.node.meta['val'].shape)
y_op = self.get_param_proxy(y, const_dtype, y_shape)
if out_dtype == torch.float16:
y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"))
return self.get_proxy(ascend_op.Mul, (x, y_op))
# Muls support bfloat16, int32, int16, float16, float32, complex32, complex64.
if out_dtype not in [torch.float, torch.float16, torch.int32]:
y_shape = list(x.node.meta['val'].shape)
y_op = self.get_param_proxy(y, out_dtype, y_shape)
return self.get_proxy(ascend_op.Mul, (x, y_op))
return self.get_proxy(ascend_op.Muls, (x, y))

def mul_complex64(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
Expand Down Expand Up @@ -855,15 +855,22 @@ def symsize(self, x, dim):
def mm(self, x, y):
# TODO! MatMul not support fp32 input
# for higher precision in some cases
out_dtype = fx_traceback.get_current_meta()['val'].dtype
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
x = self.get_proxy(ascend_op.Unsqueeze, (x, [0]))
y = self.get_proxy(ascend_op.Unsqueeze, (y, [0]))
mm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False))
return self.get_proxy(ascend_op.Squeeze, (mm, [0]))
else:
mm = self.get_proxy(ascend_op.MatMul, (x, y, False, False))
return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype)))
out_dtype = fx_traceback.get_current_meta()['val'].dtype
trans_x = False
trans_y = False
if isinstance(x.node.target, ascend_op.Permute) and x.node.args[1] == [1, 0]:
x = self.get_proxy_from_node(x.node.args[0])
trans_x = True
if isinstance(y.node.target, ascend_op.Permute) and y.node.args[1] == [1, 0]:
y = self.get_proxy_from_node(y.node.args[0])
trans_y = True
mm = self.get_proxy(ascend_op.MatMul, (x, y, trans_x, trans_y))
return self.get_proxy(ascend_op.Cast, (mm, get_ascend_dtype(out_dtype)))

@register_conversion(aten.bmm.default)
def bmm(self, x, y):
Expand Down