From aef6b72a807f88df19cc2149ed8f24cc21695f75 Mon Sep 17 00:00:00 2001 From: Chen Date: Thu, 7 Dec 2023 06:27:14 +0000 Subject: [PATCH] fix const tensor align bug in AlignPermuteVectorForElementWise Signed-off-by: Chen --- src/tim/transform/ops/op_layout_inference.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/tim/transform/ops/op_layout_inference.cc b/src/tim/transform/ops/op_layout_inference.cc index 1d1f0e83a..7275a2873 100644 --- a/src/tim/transform/ops/op_layout_inference.cc +++ b/src/tim/transform/ops/op_layout_inference.cc @@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() { auto src_inputs = op_->impl()->InputsTensor(); std::shared_ptr required_pv = nullptr; std::shared_ptr ref_input; + int32_t ref_rank = 0; for (const auto& in : src_inputs) { - if (!in->IsConstTensor()) { + int32_t rank = in->GetShape().size(); + if (!in->IsConstTensor() && rank > ref_rank) { required_pv = context_->GetPermuteVector(in); ref_input = in; - break; + ref_rank = rank; } } @@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() { std::vector OpLayoutInfer::GetExpandedShape( const std::vector& ref_shape, const std::vector& origin_shape) { - std::vector expanded_shape; - for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) { - if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) { - expanded_shape.push_back(origin_shape[j]); - ++j; - } else { - expanded_shape.push_back(1); - } + std::vector expanded_shape(origin_shape); + int32_t ref_rank = ref_shape.size(); + int32_t origin_rank = origin_shape.size(); + for (int32_t i = 0; i < ref_rank; ++i) { + if (i >= origin_rank) expanded_shape.push_back(1); } return expanded_shape; }