From ac0b52f12ac897867b9bf3dae74b4192493b3d5f Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 17 Jun 2022 14:06:35 +0800 Subject: [PATCH] add dim param for `Tensor::Squeeze` (#603) --- csrc/mmdeploy/codebase/mmocr/dbnet.cpp | 4 +++- csrc/mmdeploy/codebase/mmocr/panet.cpp | 4 +++- csrc/mmdeploy/codebase/mmocr/psenet.cpp | 2 +- csrc/mmdeploy/core/tensor.cpp | 12 +++++++----- csrc/mmdeploy/core/tensor.h | 1 + 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmocr/dbnet.cpp b/csrc/mmdeploy/codebase/mmocr/dbnet.cpp index c4a84006db..5b9d2f0e20 100644 --- a/csrc/mmdeploy/codebase/mmocr/dbnet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/dbnet.cpp @@ -50,7 +50,9 @@ class DBHead : public MMOCR { return Status(eNotSupported); } - conf.Squeeze(); + // drop batch dimension + conf.Squeeze(0); + conf = conf.Slice(0); std::vector> contours; diff --git a/csrc/mmdeploy/codebase/mmocr/panet.cpp b/csrc/mmdeploy/codebase/mmocr/panet.cpp index 9c5ecb2812..042d088be2 100644 --- a/csrc/mmdeploy/codebase/mmocr/panet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/panet.cpp @@ -53,7 +53,9 @@ class PANHead : public MMOCR { (int)pred.data_type()); return Status(eNotSupported); } - pred.Squeeze(); + + // drop batch dimension + pred.Squeeze(0); auto text_pred = pred.Slice(0); auto kernel_pred = pred.Slice(1); diff --git a/csrc/mmdeploy/codebase/mmocr/psenet.cpp b/csrc/mmdeploy/codebase/mmocr/psenet.cpp index 3b0f2bd5fb..19ab318178 100644 --- a/csrc/mmdeploy/codebase/mmocr/psenet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/psenet.cpp @@ -51,7 +51,7 @@ class PSEHead : public MMOCR { } // drop batch dimension - _preds.Squeeze(); + _preds.Squeeze(0); cv::Mat_ masks; cv::Mat_ kernel_labels; diff --git a/csrc/mmdeploy/core/tensor.cpp b/csrc/mmdeploy/core/tensor.cpp index ddbb08252c..07fac1ae7d 100644 --- a/csrc/mmdeploy/core/tensor.cpp +++ b/csrc/mmdeploy/core/tensor.cpp @@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) { } void Tensor::Squeeze() { - TensorShape new_shape; - new_shape.reserve(shape().size()); - std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape), - [](int64_t dim) { return dim != 1; }); - Reshape(new_shape); + desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end()); +} + +void Tensor::Squeeze(int dim) { + if (shape(dim) == 1) { + desc_.shape.erase(desc_.shape.begin() + dim); + } } Result Tensor::CopyFrom(const Tensor& tensor, Stream stream) { diff --git a/csrc/mmdeploy/core/tensor.h b/csrc/mmdeploy/core/tensor.h index 92403fe38e..ef967af9ea 100644 --- a/csrc/mmdeploy/core/tensor.h +++ b/csrc/mmdeploy/core/tensor.h @@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor { void Reshape(const TensorShape& shape); void Squeeze(); + void Squeeze(int dim); Tensor Slice(int start, int end); Tensor Slice(int index) { return Slice(index, index + 1); }