Skip to content

Commit

Permalink
fix crash when eletwise inputs are different rank
Browse files Browse the repository at this point in the history
when two INPUT are different rank, AlignPermuteVectorForElementWise()
will force align them and crash

Type: Bug fix

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 4, 2023
1 parent 5173979 commit f10351e
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/tim/transform/ops/elementwise_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ class ElementWiseLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ == tim::vx::INPUT &&
in_1->GetSpec().attr_ == tim::vx::INPUT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto size_long = pv_long->Rank();
auto size_short = pv_short->Rank();
auto expand_pv = MakeShared(size_long);
// if different size, expand short pv as long pv
for (uint32_t i = 0; i < size_short; ++i) {
expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}

auto short_shape = short_tensor->GetShape();
for (uint32_t i = 0; i < size_long;
++i) { // expand shape and set new tensor shape
if (i >= size_short) short_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(short_shape);

context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto elementwise = context_->infer_graph_->CreateOperation<OpType>();
for (const auto& i_src : op_->impl()->InputsTensor()) {
Expand All @@ -63,6 +91,33 @@ class MultiplyLayoutInfer : public OpLayoutInfer {

void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto in_0 = op_->impl()->InputsTensor()[0];
auto in_1 = op_->impl()->InputsTensor()[1];
std::shared_ptr<tim::vx::Tensor> short_tensor =
in_0->GetShape().size() > in_1->GetShape().size() ? in_1 : in_0;
std::shared_ptr<tim::vx::Tensor> long_tensor =
in_0->GetShape().size() < in_1->GetShape().size() ? in_1 : in_0;
if (in_0->GetSpec().attr_ == tim::vx::INPUT &&
in_1->GetSpec().attr_ == tim::vx::INPUT &&
in_0->GetShape().size() != in_1->GetShape().size()) {
auto pv_long = context_->GetPermuteVector(long_tensor);
auto pv_short = context_->GetPermuteVector(short_tensor);
auto size_long = pv_long->Rank();
auto size_short = pv_short->Rank();
auto expand_pv = MakeShared(size_long);
// if different size, expand short pv as long pv
for (uint32_t i = 0; i < size_short; ++i) {
expand_pv->At(i) = pv_short->At(i); // replace low dims with short pv
}

auto short_shape = short_tensor->GetShape();
for (uint32_t i = 0; i < size_long; ++i) { // expand and set new shape
if (i >= size_short) short_shape.push_back(1);
}
short_tensor->GetSpec().SetShape(short_shape);

context_->SetPermuteVector(short_tensor, expand_pv); // set new expand pv
}
auto required_pv = AlignPermuteVectorForElementWise();
auto multiply =
context_->infer_graph_->CreateOperation<tim::vx::ops::Multiply>(
Expand Down

0 comments on commit f10351e

Please sign in to comment.