diff --git a/csrc/mmdeploy/preprocess/elena/collect_impl.cpp b/csrc/mmdeploy/preprocess/elena/collect_impl.cpp index 379f661efe..1c6a8b7742 100644 --- a/csrc/mmdeploy/preprocess/elena/collect_impl.cpp +++ b/csrc/mmdeploy/preprocess/elena/collect_impl.cpp @@ -117,7 +117,8 @@ class CollectImpl : public ::mmdeploy::CollectImpl { visitor.crop_hw[0], visitor.crop_hw[1], visitor.mean[0], visitor.mean[1], visitor.mean[2], visitor.std[0], visitor.std[1], visitor.std[2], visitor.pad_tlbr[0], visitor.pad_tlbr[1], visitor.pad_tlbr[2], visitor.pad_tlbr[3], visitor.pad_hw[0], - visitor.pad_hw[1], visitor.pad_val, dst_tensor.data(), dst_tensor.size()); + visitor.pad_hw[1], visitor.pad_val, dst_tensor.data(), dst_tensor.shape(2), + dst_tensor.shape(3)); output[key] = std::move(dst_tensor); } return ::mmdeploy::CollectImpl::Process(output); diff --git a/csrc/mmdeploy/preprocess/elena/elena_registry.h b/csrc/mmdeploy/preprocess/elena/elena_registry.h index ae71061685..6eda61d8cd 100644 --- a/csrc/mmdeploy/preprocess/elena/elena_registry.h +++ b/csrc/mmdeploy/preprocess/elena/elena_registry.h @@ -16,7 +16,7 @@ using FuseFunc = void (*)(void* stream, uint8_t* data_in, int src_h, int src_w, int crop_left, int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1, float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h, int pad_w, - float pad_value, float* data_out, int data_out_num); + float pad_value, float* data_out, int dst_h, int dst_w); class MMDEPLOY_API FuseKernel { public: diff --git a/third_party/CVFusion b/third_party/CVFusion index 4230f28f1c..e69a8c04ad 160000 --- a/third_party/CVFusion +++ b/third_party/CVFusion @@ -1 +1 @@ -Subproject commit 4230f28f1cb05a2552a1d6b30fd2dc793c14d015 +Subproject commit e69a8c04ada66652951ad6346826fcb473adc368 diff --git a/tools/elena/extract_transform.py b/tools/elena/extract_transform.py index 143695d523..cf0b93e15e 100644 --- a/tools/elena/extract_transform.py +++ b/tools/elena/extract_transform.py @@ -46,7 +46,7 @@ int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left, int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1, float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h, - int pad_w, float pad_value, float* data_out, int data_out_num) { + int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) { const char* interpolation_ = "nearest"; if (strcmp(interpolation, "bilinear") == 0) { interpolation_ = "bilinear"; @@ -65,7 +65,7 @@ int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left, int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1, float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h, - int pad_w, float pad_value, float* data_out, int data_out_num) { + int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) { cudaStream_t stream_ = (cudaStream_t)stream; const char* interpolation_ = "nearest"; if (strcmp(interpolation, "bilinear") == 0) { @@ -73,8 +73,8 @@ } FuseKernelCU(stream_, resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0, - std1, std2, pad_top, pad_left, pad_bottom, pad_right, pad_h, pad_w, pad_value, data_in, - data_out, data_out_num, src_h, src_w, format, interpolation_); + std1, std2, pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in, + data_out, dst_h, dst_w, src_h, src_w, format, interpolation_); } REGISTER_FUSE_KERNEL(#TAG#_cuda, "#TAG#_cuda",