From 6a19ed89ec00d97ccb0c3006230b60d3adc3e45e Mon Sep 17 00:00:00 2001 From: chentong319 Date: Fri, 14 Jan 2022 10:39:13 -0500 Subject: [PATCH] add check (#1098) Signed-off-by: Tong Chen Co-authored-by: Alexandre Eichenberger --- src/Dialect/ONNX/ONNXOps.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 7b87fce859eee..00ce5fed1974e 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -884,6 +884,11 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( onnxmlir::SeqType seqType = input_sequence().getType().dyn_cast(); ShapedType tensorType = tensor().getType().dyn_cast(); + ShapedType seqTensorType = seqType.getElementType().cast(); + + // Merge the tensor type for the seq and the inserted tensor + // Pick the weaker attr: known dim > unknown dim > unranked + // If inference gets an unranked tensor, no need to update the result // When the input seq is empty, inherit the tensor type if (seqType.getLength() == 0) { @@ -891,11 +896,21 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( return success(); } - // Merge the tensor type for the seq and the inserted tensor - // Pick the weaker attr: known dim > unknown dim > unranked tensor - // If inference gets an unranked tensor, no need to update the result - auto seqShape = seqType.getElementType().cast().getShape(); - auto seqRank = seqType.getElementType().cast().getRank(); + auto newLength = seqType.getLength() == -1 ? -1 : seqType.getLength() + 1; + + // When one of the tensor is unranked + if (!tensorType.hasRank()) { + getResult().setType(onnxmlir::SeqType::get(tensorType, newLength)); + return success(); + } + if (!seqTensorType.hasRank()) { + getResult().setType(onnxmlir::SeqType::get(seqTensorType, newLength)); + return success(); + } + + // Merge when both are ranked + auto seqShape = seqTensorType.getShape(); + auto seqRank = seqTensorType.getRank(); if (seqRank == -1) return success(); @@ -909,7 +924,7 @@ LogicalResult ONNXSequenceInsertOp::inferShapes( } getResult().setType(onnxmlir::SeqType::get( mlir::RankedTensorType::get(dims, tensorType.getElementType()), - seqType.getLength() == -1 ? -1 : seqType.getLength() + 1)); + newLength)); return success(); }