Skip to content

Commit

Permalink
add dim param for Tensor::Squeeze (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Jun 17, 2022
1 parent a822ba7 commit ac0b52f
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 8 deletions.
4 changes: 3 additions & 1 deletion csrc/mmdeploy/codebase/mmocr/dbnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class DBHead : public MMOCR {
return Status(eNotSupported);
}

conf.Squeeze();
// drop batch dimension
conf.Squeeze(0);

conf = conf.Slice(0);

std::vector<std::vector<cv::Point>> contours;
Expand Down
4 changes: 3 additions & 1 deletion csrc/mmdeploy/codebase/mmocr/panet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class PANHead : public MMOCR {
(int)pred.data_type());
return Status(eNotSupported);
}
pred.Squeeze();

// drop batch dimension
pred.Squeeze(0);

auto text_pred = pred.Slice(0);
auto kernel_pred = pred.Slice(1);
Expand Down
2 changes: 1 addition & 1 deletion csrc/mmdeploy/codebase/mmocr/psenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PSEHead : public MMOCR {
}

// drop batch dimension
_preds.Squeeze();
_preds.Squeeze(0);

cv::Mat_<uint8_t> masks;
cv::Mat_<int> kernel_labels;
Expand Down
12 changes: 7 additions & 5 deletions csrc/mmdeploy/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) {
}

void Tensor::Squeeze() {
TensorShape new_shape;
new_shape.reserve(shape().size());
std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape),
[](int64_t dim) { return dim != 1; });
Reshape(new_shape);
desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end());
}

void Tensor::Squeeze(int dim) {
if (shape(dim) == 1) {
desc_.shape.erase(desc_.shape.begin() + dim);
}
}

Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {
Expand Down
1 change: 1 addition & 0 deletions csrc/mmdeploy/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor {
void Reshape(const TensorShape& shape);

void Squeeze();
void Squeeze(int dim);

Tensor Slice(int start, int end);
Tensor Slice(int index) { return Slice(index, index + 1); }
Expand Down

0 comments on commit ac0b52f

Please sign in to comment.