Skip to content

Commit

Permalink
fix const tensor align bug in AlignPermuteVectorForElementWise
Browse files Browse the repository at this point in the history
Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 7, 2023
1 parent 720f0a4 commit aef6b72
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/tim/transform/ops/op_layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() {
auto src_inputs = op_->impl()->InputsTensor();
std::shared_ptr<IPermuteVector> required_pv = nullptr;
std::shared_ptr<vx::Tensor> ref_input;
int32_t ref_rank = 0;
for (const auto& in : src_inputs) {
if (!in->IsConstTensor()) {
int32_t rank = in->GetShape().size();
if (!in->IsConstTensor() && rank > ref_rank) {
required_pv = context_->GetPermuteVector(in);
ref_input = in;
break;
ref_rank = rank;
}
}

Expand Down Expand Up @@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() {
std::vector<uint32_t> OpLayoutInfer::GetExpandedShape(
const std::vector<uint32_t>& ref_shape,
const std::vector<uint32_t>& origin_shape) {
std::vector<uint32_t> expanded_shape;
for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) {
if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) {
expanded_shape.push_back(origin_shape[j]);
++j;
} else {
expanded_shape.push_back(1);
}
std::vector<uint32_t> expanded_shape(origin_shape);
int32_t ref_rank = ref_shape.size();
int32_t origin_rank = origin_shape.size();
for (int32_t i = 0; i < ref_rank; ++i) {
if (i >= origin_rank) expanded_shape.push_back(1);
}
return expanded_shape;
}
Expand Down

0 comments on commit aef6b72

Please sign in to comment.