Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen committed Dec 16, 2024
1 parent cb06f93 commit 24f5bf9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
6 changes: 3 additions & 3 deletions dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def ReduceSum(name, x, dim):
param = infer_param.ReduceParam()
param.name = name
param.reduceType = infer_param.ReduceType.REDUCE_SUM
param.axis = dim if isinstance(dim, list) else [dim]
param.axis = dim

op.set_input([x])
op.set_param(param)
Expand All @@ -739,7 +739,7 @@ def ReduceMax(name, x, dim):
param = infer_param.ReduceParam()
param.name = name
param.reduceType = infer_param.ReduceType.REDUCE_MAX
param.axis = dim if isinstance(dim, list) else [dim]
param.axis = dim

op.set_input([x])
op.set_param(param)
Expand All @@ -751,7 +751,7 @@ def ReduceMin(name, x, dim):
param = infer_param.ReduceParam()
param.name = name
param.reduceType = infer_param.ReduceType.REDUCE_MIN
param.axis = dim if isinstance(dim, list) else [dim]
param.axis = dim

op.set_input([x])
op.set_param(param)
Expand Down
7 changes: 7 additions & 0 deletions dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,12 @@ def get_ascend_dtype(dtype: torch.dtype) -> str:
raise RuntimeError(f"unknow torch data type ({dtype}) in get_ascend_dtype!")


def get_reduce_dim(x, dim):
x_rank = len(x.node.meta["val"].shape)
dim = dim if isinstance(dim, list) else [dim]
dim = [(i + x_rank) % x_rank for i in dim]
return dim


def remove_duplicates(lst):
return list(OrderedDict.fromkeys(lst))
8 changes: 7 additions & 1 deletion dlinfer/graph/dicp/vendor/AtbGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from dlinfer.graph.dicp.dynamo_bridge.conversion import register_conversion_impl
from dlinfer.graph.dicp.dynamo_bridge.op_transformer import SingleOpTransformer
from dlinfer.graph.dicp.vendor.AtbGraph import ext_ops
from dlinfer.graph.dicp.vendor.AtbGraph.codegen.utils import get_ascend_dtype
from dlinfer.graph.dicp.vendor.AtbGraph.codegen.utils import (
get_ascend_dtype,
get_reduce_dim,
)


aten = torch.ops.aten
Expand Down Expand Up @@ -652,14 +655,17 @@ def aten_scalar_tensor(self, x, dtype, layout, device):

@register_conversion(torch.ops.aten.sum.dim_IntList)
def aten_reduce_sum(self, x, dim):
dim = get_reduce_dim(x, dim)
return self.get_proxy(atb_op.ReduceSum, (x, dim))

@register_conversion(torch.ops.aten.amax.default)
def aten_reduce_sum(self, x, dim):
dim = get_reduce_dim(x, dim)
return self.get_proxy(atb_op.ReduceMax, (x, dim))

@register_conversion(torch.ops.aten.amin.default)
def aten_reduce_sum(self, x, dim):
dim = get_reduce_dim(x, dim)
return self.get_proxy(atb_op.ReduceMin, (x, dim))


Expand Down

0 comments on commit 24f5bf9

Please sign in to comment.