Skip to content

Commit

Permalink
[dicp][ascend] infer op resinfo (part 2) (DeepLink-org#491)
Browse files Browse the repository at this point in the history
* fix a bug in get_cast_dtype: type(int+bool) should be int

* clean code format

* finish res_op_infer for more simple operators

* Update operator.py

delete some unnecessary print()

* Update operator.py

clean code

* finish operators' info inference except for those having trouble testing solely without inference and operators involving Reshape still have problems

* clean code format

* Update warning message output in operator.py

* extract common function for general binary and unary operator ,add op bmm's inference

* Update ascend_op.py

delete unuse param
  • Loading branch information
KevinfromTJ authored and ustclight-sls committed Dec 8, 2023
1 parent 1327fa2 commit fdfe229
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 77 deletions.
15 changes: 5 additions & 10 deletions dicp/dicp/dynamo_bridge/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,14 @@ def make_cpu(x):
with fake_mode:
try:
if hasattr(self, "infer_result"):
info: TensorInfo = self.infer_result(*new_args, **kwargs)
return torch.empty(
info.shape, dtype=info.dtype, memory_format=info.memory_format
)
return self.infer_result(*new_args, **kwargs)
elif hasattr(self, "torch_op"):
return self.torch_op(*new_args, **kwargs)
except Exception as e:
print("torch_op: ", self.torch_op if hasattr(self, "torch_op") else "")
print(new_args, kwargs)
print(e.args)
print(traceback.format_exc())
log = logging.getLogger(__name__)
if hasattr(self, "infer_result"):
log.warning("infer shape and dtype failed")
log.warning(
str(self.__name__) + ": infer shape and dtype failed,ignore"
)
elif hasattr(self, "torch_op"):
log.warning("torch_op error")
log.warning("torch_op error: " + str(self.torch_op.__name__))
Loading

0 comments on commit fdfe229

Please sign in to comment.