From 56bc1242b2a59c210b6f70816175581c8247a659 Mon Sep 17 00:00:00 2001 From: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Date: Fri, 3 Sep 2021 15:19:54 +0800 Subject: [PATCH] [Feature] Add NCNN on MMSegmentation (#55) * fix custom ops support, fix multiple mark bug, add name mapping * check if the value_info need to be added * remove unnecessary print * add nms implement * two stage split wip * add two stage split * add split retinanet visualize * add two stage split (wip) * finish two stage split * fix lint * move parse string to mmdeploy.utils * add calib data generator * create calib dataset * finish end2end int8 * add split two stage tensorrt visualize * first * fix0 * fix1 * dirty work * wip * add allocator * finally done! * lint * fix lint * better gather * better onnx2ncnn * fix tensorslice bugs * fix lint * fix clang-format * remove comments * fix expand * int param * fix lint * [Fix] NCNN TensorSlice op bugs (#42) * fix custom ops support, fix multiple mark bug, add name mapping * check if the value_info need to be added * remove unnecessary print * add nms implement * two stage split wip * add two stage split * add split retinanet visualize * add two stage split (wip) * finish two stage split * fix lint * move parse string to mmdeploy.utils * add calib data generator * create calib dataset * finish end2end int8 * add split two stage tensorrt visualize * fix tensorslice bugs * fix lint * fix clang-format * remove comments * int param * fix lint Co-authored-by: grimoire * add two stage ncnn support * remove unused ops * git unused config * remove no_grad, should add in refactor * add ncnn wrapper * fix lint * size return tuple * Resolve grammar error * Fix lint * Trim Trailing Whitespace * fix trim * add argmax to topk * add ArgMax parse * add ncnn mmseg deploy cfg * utils add ncnn mmseg * add ncnn * fix lint * fix yapf * fix clang-format-9 * remove debugging code Co-authored-by: grimoire Co-authored-by: grimoire Co-authored-by: maningsheng --- backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 7 + backend_ops/ncnn/ops/topk/topk.cpp | 1106 +++++++++++++++------- backend_ops/ncnn/ops/topk/topk.h | 1 + configs/mmseg/ncnn.py | 1 + mmdeploy/apis/ncnn/ncnn_utils.py | 11 +- mmdeploy/mmseg/apis/inference.py | 33 +- 6 files changed, 806 insertions(+), 353 deletions(-) mode change 100755 => 100644 backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp mode change 100755 => 100644 backend_ops/ncnn/ops/topk/topk.cpp create mode 100644 configs/mmseg/ncnn.py diff --git a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp old mode 100755 new mode 100644 index 6a350fa73f..c504f5564b --- a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -3296,6 +3296,8 @@ int main(int argc, char** argv) { fprintf(pp, "%-16s", "UnaryOp"); } else if (op == "Add") { fprintf(pp, "%-16s", "BinaryOp"); + } else if (op == "ArgMax") { + fprintf(pp, "%-16s", "TopK"); } else if (op == "Asin") { fprintf(pp, "%-16s", "UnaryOp"); } else if (op == "Atan") { @@ -3604,6 +3606,11 @@ int main(int argc, char** argv) { fprintf(pp, " 1=%d", with_scalar); fprintf(pp, " 2=%e", b); } + } else if (op == "ArgMax") { + int axis = get_node_attr_i(node, "axis"); + int keepdims = get_node_attr_i(node, "keepdims"); + fprintf(pp, " 0=%d", axis - 1); + fprintf(pp, " 3=%d", keepdims); } else if (op == "Asin") { int op_type = 12; fprintf(pp, " 0=%d", op_type); diff --git a/backend_ops/ncnn/ops/topk/topk.cpp b/backend_ops/ncnn/ops/topk/topk.cpp old mode 100755 new mode 100644 index 2fd9fafb75..3a1583a9ed --- a/backend_ops/ncnn/ops/topk/topk.cpp +++ b/backend_ops/ncnn/ops/topk/topk.cpp @@ -18,6 +18,7 @@ int TopK::load_param(const ParamDict& pd) { axis = pd.get(0, -1); largest = pd.get(1, 1); sorted = pd.get(2, 1); + keep_dims = pd.get(3, 1); return 0; } @@ -25,191 +26,54 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { int dims = bottom_blobs[0].dims; int positive_axis = axis < 0 ? dims + axis : axis; + int topk; + if (bottom_blobs.size() == 2) { + const Mat& topk_blob = bottom_blobs[1]; + topk = (int)(topk_blob[0] + 0.5); + } else if (bottom_blobs.size() == 1) { + topk = 1; + } else { + fprintf(stderr, "topk input blobs should be 1 or 2, but not %d\n", + bottom_blobs.size()); + return -103; + } - const Mat& topk_blob = bottom_blobs[1]; // To do: Cut the top_val_blob after unit test. And we should change them in // param files. Mat& top_val_blob = top_blobs[0]; - Mat& top_ind_blob = top_blobs[1]; - - int topk = (int)(topk_blob[0] + 0.5); - if (dims == 1 && positive_axis == 0) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - const float* ptr = bottom_blobs[0]; - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", - largest); - return -100; - } - float* valptr = top_val_blob; - float* indptr = top_ind_blob; - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - valptr[i] = vec[i].first; - indptr[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - - // pair comparison - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] > valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] < valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } - } - } - - if (dims == 2 && positive_axis == 0) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; + Mat& top_ind_blob = top_val_blob; // fill the reference + if (top_blobs.size() == 2) top_ind_blob = top_blobs[1]; + + if (topk > 1) { + // real topk + if (keep_dims == 0) { + fprintf(stderr, "real topk should not reduce dims!\n"); + return -102; } - top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].h); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", - largest); - return -100; - } - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.row(i)[col] = vec[i].first; - top_ind_blob.row(i)[col] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] > valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] < valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", - sorted); + if (dims == 1 && positive_axis == 0) { + if (topk > bottom_blobs[0].w) { + fprintf(stderr, "topk should not greater than total items!\n"); return -100; } - } - } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; - if (dims == 2 && positive_axis == 1) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; - for (int r = 0; r < bottom_blobs[0].h; r++) { + const float* ptr = bottom_blobs[0]; std::vector > vec; vec.resize(bottom_blobs[0].w); if (largest == 1) { for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); + vec[i] = std::make_pair(ptr[i], -i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater >()); } else if (largest == 0) { for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); + vec[i] = std::make_pair(ptr[i], i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less >()); @@ -218,81 +82,72 @@ int TopK::forward(const std::vector& bottom_blobs, largest); return -100; } - + float* valptr = top_val_blob; + float* indptr = top_ind_blob; if (sorted == 1) { for (int i = 0; i < topk; i++) { - top_val_blob.row(r)[i] = vec[i].first; - top_ind_blob.row(r)[i] = abs(vec[i].second); + valptr[i] = vec[i].first; + indptr[i] = abs(vec[i].second); } } else if (sorted == 0) { int cur = 0; float valtarget = vec[topk - 1].first; int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + + // pair comparison if (largest == 1) { for (int i = 0; i < bottom_blobs[0].w; i++) { if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] > valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; + if (bottom_blobs[0][i] > valtarget) { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && - i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; + } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; cur++; } } } else { for (int i = 0; i < bottom_blobs[0].w; i++) { if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] < valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; + if (bottom_blobs[0][i] < valtarget) { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && - i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; + } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; cur++; } } } - - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", - sorted); - return -100; } } - } - - if (dims == 3 && positive_axis == 0) { - if (topk > bottom_blobs[0].c) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, - opt.blob_allocator); - if (top_val_blob.empty()) return -100; + if (dims == 2 && positive_axis == 0) { + if (topk > bottom_blobs[0].h) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, - opt.blob_allocator); - if (top_ind_blob.empty()) return -100; + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; - for (int r = 0; r < bottom_blobs[0].h; r++) { for (int col = 0; col < bottom_blobs[0].w; col++) { std::vector > vec; - vec.resize(bottom_blobs[0].c); + vec.resize(bottom_blobs[0].h); if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater >()); } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less >()); @@ -301,50 +156,44 @@ int TopK::forward(const std::vector& bottom_blobs, largest); return -100; } - if (sorted == 1) { for (int i = 0; i < topk; i++) { - top_val_blob.channel(i).row(r)[col] = vec[i].first; - top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + top_val_blob.row(i)[col] = vec[i].first; + top_ind_blob.row(i)[col] = abs(vec[i].second); } } else if (sorted == 0) { int cur = 0; float valtarget = vec[topk - 1].first; int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) { - top_val_blob.channel(cur).row(r)[col] = - bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; + if (bottom_blobs[0].row(i)[col] > valtarget) { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && + } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = - bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; cur++; } } } else { - for (int i = 0; i < bottom_blobs[0].c; i++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) { - top_val_blob.channel(cur).row(r)[col] = - bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; + if (bottom_blobs[0].row(i)[col] < valtarget) { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && + } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = - bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; cur++; } } } - } else { fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); @@ -352,37 +201,30 @@ int TopK::forward(const std::vector& bottom_blobs, } } } - } - - if (dims == 3 && positive_axis == 1) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, - opt.blob_allocator); - if (top_val_blob.empty()) return -100; + if (dims == 2 && positive_axis == 1) { + if (topk > bottom_blobs[0].w) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; - top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, - opt.blob_allocator); - if (top_ind_blob.empty()) return -100; + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; - for (int page = 0; page < bottom_blobs[0].c; page++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { + for (int r = 0; r < bottom_blobs[0].h; r++) { std::vector > vec; - vec.resize(bottom_blobs[0].h); + vec.resize(bottom_blobs[0].w); if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = - std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater >()); } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = - std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); } std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less >()); @@ -394,45 +236,43 @@ int TopK::forward(const std::vector& bottom_blobs, if (sorted == 1) { for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(i)[col] = vec[i].first; - top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + top_val_blob.row(r)[i] = vec[i].first; + top_ind_blob.row(r)[i] = abs(vec[i].second); } } else if (sorted == 0) { int cur = 0; float valtarget = vec[topk - 1].first; int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (largest == 1) { - if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; + if (largest == 1) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] > valtarget) { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == - valtarget && + } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; cur++; } - } else { - if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; + } + } else { + for (int i = 0; i < bottom_blobs[0].w; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] < valtarget) { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == - valtarget && + } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; cur++; } } } + } else { fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); @@ -440,97 +280,679 @@ int TopK::forward(const std::vector& bottom_blobs, } } } - } - - if (dims == 3 && positive_axis == 2) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, - opt.blob_allocator); - if (top_val_blob.empty()) return -100; + if (dims == 3 && positive_axis == 0) { + if (topk > bottom_blobs[0].c) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; - top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, - opt.blob_allocator); - if (top_ind_blob.empty()) return -100; + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; - for (int page = 0; page < bottom_blobs[0].c; page++) { for (int r = 0; r < bottom_blobs[0].h; r++) { - std::vector > vec; - vec.resize(bottom_blobs[0].w); + for (int col = 0; col < bottom_blobs[0].w; col++) { + std::vector > vec; + vec.resize(bottom_blobs[0].c); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = - std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + if (largest == 1) { + for (int i = 0; i < bottom_blobs[0].c; i++) { + vec[i] = + std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::greater >()); + } else if (largest == 0) { + for (int i = 0; i < bottom_blobs[0].c; i++) { + vec[i] = + std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::less >()); + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + + if (sorted == 1) { + for (int i = 0; i < topk; i++) { + top_val_blob.channel(i).row(r)[col] = vec[i].first; + top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + } + } else if (sorted == 0) { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) { + for (int i = 0; i < bottom_blobs[0].c; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) { + top_val_blob.channel(cur).row(r)[col] = + bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } else if (bottom_blobs[0].channel(i).row(r)[col] == + valtarget && + i <= indtarget) { + top_val_blob.channel(cur).row(r)[col] = + bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } else { + for (int i = 0; i < bottom_blobs[0].c; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) { + top_val_blob.channel(cur).row(r)[col] = + bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } else if (bottom_blobs[0].channel(i).row(r)[col] == + valtarget && + i <= indtarget) { + top_val_blob.channel(cur).row(r)[col] = + bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } + + } else { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", + sorted); + return -100; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", - largest); - return -100; } + } + } + if (dims == 3 && positive_axis == 1) { + if (topk > bottom_blobs[0].h) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) { + for (int col = 0; col < bottom_blobs[0].w; col++) { + std::vector > vec; + vec.resize(bottom_blobs[0].h); - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(r)[i] = vec[i].first; - top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = + std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::greater >()); + } else if (largest == 0) { + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = + std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::less >()); + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + + if (sorted == 1) { + for (int i = 0; i < topk; i++) { + top_val_blob.channel(page).row(i)[col] = vec[i].first; + top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + } + } else if (sorted == 0) { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + for (int i = 0; i < bottom_blobs[0].h; i++) { if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) { - top_val_blob.channel(page).row(r)[cur] = - bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = - bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; + if (largest == 1) { + if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } else if (bottom_blobs[0].channel(page).row(i)[col] == + valtarget && + i <= indtarget) { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + } else { + if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } else if (bottom_blobs[0].channel(page).row(i)[col] == + valtarget && + i <= indtarget) { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } } } } else { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", + sorted); + return -100; + } + } + } + } + if (dims == 3 && positive_axis == 2) { + if (topk > bottom_blobs[0].w) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) { + for (int r = 0; r < bottom_blobs[0].h; r++) { + std::vector > vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) { for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) { - top_val_blob.channel(page).row(r)[cur] = - bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = - bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; + vec[i] = + std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::greater >()); + } else if (largest == 0) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = + std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::less >()); + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + + if (sorted == 1) { + for (int i = 0; i < topk; i++) { + top_val_blob.channel(page).row(r)[i] = vec[i].first; + top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); + } + } else if (sorted == 0) { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) { + top_val_blob.channel(page).row(r)[cur] = + bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } else if (bottom_blobs[0].channel(page).row(r)[i] == + valtarget && + i <= indtarget) { + top_val_blob.channel(page).row(r)[cur] = + bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + } + } else { + for (int i = 0; i < bottom_blobs[0].w; i++) { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) { + top_val_blob.channel(page).row(r)[cur] = + bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } else if (bottom_blobs[0].channel(page).row(r)[i] == + valtarget && + i <= indtarget) { + top_val_blob.channel(page).row(r)[cur] = + bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } } } + + } else { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", + sorted); + return -100; } + } + } + } + } else { + if (topk <= 0) { + fprintf(stderr, "topk should not <= 0!\n"); + return -102; + } + if (dims == 1 && positive_axis == 0) { + if (topk > bottom_blobs[0].w) { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) { + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = ptr[i]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = + std::distance(vec.begin(), index_iter); // replace with index + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = + std::distance(vec.begin(), index_iter); // replace with index + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + if (dims == 2 && positive_axis == 0) { + if (keep_dims == 1) { + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + } else { + top_val_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + for (int col = 0; col < bottom_blobs[0].w; col++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", - sorted); + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); return -100; } } } - } + if (dims == 2 && positive_axis == 1) { + if (keep_dims == 1) { + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + } else { + top_val_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int r = 0; r < bottom_blobs[0].h; r++) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + if (dims == 3 && positive_axis == 0) { + if (keep_dims == 1) { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + } else { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].c); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int r = 0; r < bottom_blobs[0].h; r++) { + for (int col = 0; col < bottom_blobs[0].w; col++) { + for (int i = 0; i < bottom_blobs[0].c; i++) { + ptr = bottom_blobs[0].channel(i); + vec[i] = ptr[r * bottom_blobs[0].w + col]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; + + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + } + if (dims == 3 && positive_axis == 1) { + if (keep_dims == 1) { + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + std::vector vec; + vec.resize(bottom_blobs[0].h); + + for (int page = 0; page < bottom_blobs[0].c; page++) { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, + "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + } else { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int page = 0; page < bottom_blobs[0].c; page++) { + const float* ptr = bottom_blobs[0].channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) { + for (int i = 0; i < bottom_blobs[0].h; i++) { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = + std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, + "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + } + } + if (dims == 3 && positive_axis == 2) { + if (keep_dims == 1) { + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + std::vector vec; + vec.resize(bottom_blobs[0].w); + + for (int page = 0; page < bottom_blobs[0].c; page++) { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, + "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + } else { + top_val_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) { + top_ind_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, + opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + for (int page = 0; page < bottom_blobs[0].c; page++) { + const float* ptr = bottom_blobs[0].channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) { + for (int i = 0; i < bottom_blobs[0].w; i++) { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + r] = + std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = + std::distance(vec.begin(), index_iter); + } else if (largest == 0) { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_val_blob.w + r] = + std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = + std::distance(vec.begin(), index_iter); + } else { + fprintf(stderr, + "largest attribute should be 0 or 1, but not %d\n", + largest); + return -100; + } + } + } + } + } + } return 0; } diff --git a/backend_ops/ncnn/ops/topk/topk.h b/backend_ops/ncnn/ops/topk/topk.h index 31e1e35e44..6887180bc0 100755 --- a/backend_ops/ncnn/ops/topk/topk.h +++ b/backend_ops/ncnn/ops/topk/topk.h @@ -17,6 +17,7 @@ class TopK : public ncnn::Layer { int axis; int largest; int sorted; + int keep_dims; }; } // namespace mmlab diff --git a/configs/mmseg/ncnn.py b/configs/mmseg/ncnn.py new file mode 100644 index 0000000000..896944830e --- /dev/null +++ b/configs/mmseg/ncnn.py @@ -0,0 +1 @@ +_base_ = ['./base.py', '../_base_/backends/ncnn.py'] diff --git a/mmdeploy/apis/ncnn/ncnn_utils.py b/mmdeploy/apis/ncnn/ncnn_utils.py index be4d5561b5..020d70b9f5 100644 --- a/mmdeploy/apis/ncnn/ncnn_utils.py +++ b/mmdeploy/apis/ncnn/ncnn_utils.py @@ -20,7 +20,7 @@ def __init__(self, bin_file: str, output_names: Optional[Iterable[str]] = None, **kwargs): - super().__init__() + super(NCNNWrapper, self).__init__() net = ncnn.Net() ncnn_ext.register_mm_custom_layers(net) @@ -41,11 +41,12 @@ def get_output_names(self): return self._net.output_names() def forward(self, inputs: Dict[str, torch.Tensor]): - batch_size = next(iter(inputs.values())).size(0) - for k, v in inputs.items(): - assert v.size( + input_list = list(inputs.values()) + batch_size = input_list[0].size(0) + for tensor in input_list[1:]: + assert tensor.size( 0) == batch_size, 'All tensor should have same batch size' - assert v.device.type == 'cpu', 'NCNN only support cpu device' + assert tensor.device.type == 'cpu', 'NCNN only support cpu device' # set output names output_names = self.get_output_names() diff --git a/mmdeploy/mmseg/apis/inference.py b/mmdeploy/mmseg/apis/inference.py index 0355c700e2..43ffbc3cd1 100644 --- a/mmdeploy/mmseg/apis/inference.py +++ b/mmdeploy/mmseg/apis/inference.py @@ -35,6 +35,8 @@ def aug_test(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') def forward(self, img, img_metas, **kwargs): + if isinstance(img, (list, tuple)): + img = img[0] seg_pred = self.forward_test(img, img_metas, **kwargs) # whole mode supports dynamic shape ori_shape = img_metas[0][0]['ori_shape'] @@ -60,8 +62,6 @@ def __init__(self, model_file: str, class_names: Sequence[str], self.model = ORTWrapper(model_file, device_id) def forward_test(self, imgs, img_metas, **kwargs): - if isinstance(imgs, (list, tuple)): - imgs = imgs[0] seg_pred = self.model({'input': imgs})[0] return seg_pred @@ -79,20 +79,41 @@ def __init__(self, model_file: str, class_names: Sequence[str], self.output_name = self.model.output_names[0] def forward_test(self, imgs, img_metas, **kwargs): - input_data = imgs[0].contiguous() + input_data = imgs.contiguous() with torch.cuda.device(self.device_id), torch.no_grad(): seg_pred = self.model({'input': input_data})[self.output_name] seg_pred = seg_pred.detach().cpu().numpy() return seg_pred +class NCNNSegmentor(DeployBaseSegmentor): + + def __init__(self, model_file: Sequence[str], class_names: Sequence[str], + palette: np.ndarray, device_id: int): + super(NCNNSegmentor, self).__init__(class_names, palette, device_id) + from mmdeploy.apis.ncnn import NCNNWrapper + assert len(model_file) == 2 + ncnn_param_file = model_file[0] + ncnn_bin_file = model_file[1] + self.model = NCNNWrapper( + ncnn_param_file, ncnn_bin_file, output_names=['output']) + + def forward_test(self, imgs, *args, **kwargs): + results = self.model({'input': imgs})['output'] + results = results.detach().cpu().numpy() + return results + + ONNXRUNTIME_SEGMENTOR_MAP = dict(end2end=ONNXRuntimeSegmentor) TENSORRT_SEGMENTOR_MAP = dict(end2end=TensorRTSegmentor) +NCNN_SEGMENTOR_MAP = dict(end2end=NCNNSegmentor) + BACKEND_SEGMENTOR_MAP = { Backend.ONNXRUNTIME: ONNXRUNTIME_SEGMENTOR_MAP, - Backend.TENSORRT: TENSORRT_SEGMENTOR_MAP + Backend.TENSORRT: TENSORRT_SEGMENTOR_MAP, + Backend.NCNN: NCNN_SEGMENTOR_MAP } @@ -130,9 +151,9 @@ def build_segmentor(model_files, model_cfg, deploy_cfg, device_id): model_type = 'end2end' assert model_type in segmentor_map, f'Unsupported model type: {model_type}' backend_segmentor_class = segmentor_map[model_type] - + model_files = model_files[0] if len(model_files) == 1 else model_files backend_segmentor = backend_segmentor_class( - *model_files, + model_files, class_names=class_names, device_id=device_id, palette=palette)