Skip to content

Commit

Permalink
Fix inappropriate comments for reduce layoutinfer
Browse files Browse the repository at this point in the history
Type: Code refine

Signed-off-by: Chen <[email protected]>
  • Loading branch information
Chen committed Dec 11, 2023
1 parent ce971cf commit dd79505
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/tim/transform/ops/reduce_layout_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ class ReduceLayoutInfer : public OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
std::set<int32_t> axis_set; // Save unique axis values
std::vector<int32_t> new_axis, pv_reduced;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
++i) {
int32_t axis_num = op_->impl()->node()->nn_param.reduce.axis_num;
for (uint32_t i = 0; i < axis_num; ++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) {
axis += pv->Rank();
}
axis = MapAxis(pv->AsStdVec(), axis);
// Save unique axis values for calculating pv length
axis_set.insert(axis);
new_axis.push_back(axis);
}
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
(*reduce).BindInput(context_->GetMapedTensor(t_src));

if (op_->impl()->node()->nn_param.reduce.keep_dim) {
auto otensor_infer = CreateOutputsTensor(pv);
(*reduce).BindOutput(otensor_infer[0]);
Expand Down

0 comments on commit dd79505

Please sign in to comment.