Skip to content

Commit

Permalink
update elena
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Aug 1, 2022
1 parent 5debb37 commit 6258a9c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion csrc/mmdeploy/preprocess/elena/collect_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ class CollectImpl : public ::mmdeploy::CollectImpl {
visitor.crop_hw[0], visitor.crop_hw[1], visitor.mean[0], visitor.mean[1],
visitor.mean[2], visitor.std[0], visitor.std[1], visitor.std[2], visitor.pad_tlbr[0],
visitor.pad_tlbr[1], visitor.pad_tlbr[2], visitor.pad_tlbr[3], visitor.pad_hw[0],
visitor.pad_hw[1], visitor.pad_val, dst_tensor.data<float>(), dst_tensor.size());
visitor.pad_hw[1], visitor.pad_val, dst_tensor.data<float>(), dst_tensor.shape(2),
dst_tensor.shape(3));
output[key] = std::move(dst_tensor);
}
return ::mmdeploy::CollectImpl::Process(output);
Expand Down
2 changes: 1 addition & 1 deletion csrc/mmdeploy/preprocess/elena/elena_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using FuseFunc = void (*)(void* stream, uint8_t* data_in, int src_h, int src_w,
int crop_left, int crop_h, int crop_w, float mean0, float mean1,
float mean2, float std0, float std1, float std2, int pad_top,
int pad_left, int pad_bottom, int pad_right, int pad_h, int pad_w,
float pad_value, float* data_out, int data_out_num);
float pad_value, float* data_out, int dst_h, int dst_w);

class MMDEPLOY_API FuseKernel {
public:
Expand Down
2 changes: 1 addition & 1 deletion third_party/CVFusion
8 changes: 4 additions & 4 deletions tools/elena/extract_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
int pad_w, float pad_value, float* data_out, int data_out_num) {
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
const char* interpolation_ = "nearest";
if (strcmp(interpolation, "bilinear") == 0) {
interpolation_ = "bilinear";
Expand All @@ -65,16 +65,16 @@
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
int pad_w, float pad_value, float* data_out, int data_out_num) {
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
cudaStream_t stream_ = (cudaStream_t)stream;
const char* interpolation_ = "nearest";
if (strcmp(interpolation, "bilinear") == 0) {
interpolation_ = "bilinear";
}
FuseKernelCU(stream_, resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0,
std1, std2, pad_top, pad_left, pad_bottom, pad_right, pad_h, pad_w, pad_value, data_in,
data_out, data_out_num, src_h, src_w, format, interpolation_);
std1, std2, pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in,
data_out, dst_h, dst_w, src_h, src_w, format, interpolation_);
}
REGISTER_FUSE_KERNEL(#TAG#_cuda, "#TAG#_cuda",
Expand Down

0 comments on commit 6258a9c

Please sign in to comment.