Skip to content

Commit

Permalink
QuadricCustomOp: Handle multiple outputs when shape inferencing (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrego authored Oct 3, 2023
1 parent d0776ea commit 126cce8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
1 change: 0 additions & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2930,7 +2930,6 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t
.Attr("element_wise", "True (1) if only element-wise ops, False (0) otherwise", AttributeProto::INT, true)
.TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(),
"Allow inputs and outputs to be any kind of tensor.");
// FIXME: Add a type/shape inference function

#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
Expand Down
30 changes: 19 additions & 11 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"QLinearAveragePool": self._infer_qlinear_unary_op,
# Quadric custom operators
"QuadricCustomOp": self._infer_custom_op,
"QuadricCustomOpElementWise": self._infer_custom_op
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
Expand Down Expand Up @@ -455,6 +454,7 @@ def _onnx_infer_single_node(self, node):
"If",
"Loop",
"Scan",
"QuadricCustomOp",
"SplitToSequence",
"ZipMap", # contrib ops
"Attention",
Expand Down Expand Up @@ -974,17 +974,21 @@ def _infer_qgemm(self, node):
def _infer_custom_op(self, node):
# For the CCL custom operators the shape and dtype of the output are present in
# the attributes and can be used to directly create the value info
attr_map = {n.name:n for n in list(node.attribute)}
assert "shape" in attr_map and "elem_type" in attr_map,\
"Custom op output type not found"
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
attr_map["elem_type"].i,
attr_map["shape"].ints,
attr_map = {n.name: n for n in list(node.attribute)}
assert "shape" in attr_map and "elem_type" in attr_map, "Custom op output type not found"
if len(node.output) > 1:
for i, out in enumerate(node.output):
vi = self.known_vi_[out]
vi.CopyFrom(
helper.make_tensor_value_info(
out,
attr_map["elem_type"].ints[i],
attr_map["shape"].tensors[i].int32_data,
)
)
)
else:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], attr_map["elem_type"].i, attr_map["shape"].ints))

def _infer_ConcatFromSequence(self, node):
seq_shape = self._get_shape(node, 0)
Expand Down Expand Up @@ -2582,6 +2586,10 @@ def get_prereq(node):
get_attribute(node, "then_branch"),
get_attribute(node, "else_branch"),
]
elif node.op_type == "QuadricCustomOp":
# Should have a subgraph, but allow for cases where it's not there
subgraph = get_attribute(node, "sub_graph")
subgraphs = [subgraph] if subgraph else []
elif node.op_type in ["Loop", "Scan"]:
subgraphs = [get_attribute(node, "body")]
for g in subgraphs:
Expand Down

0 comments on commit 126cce8

Please sign in to comment.