Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Torch Selective macros in all C++ Ops for better support on mobile #3218

Merged
merged 2 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ at::Tensor deform_conv2d_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autocast));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ at::Tensor nms_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("nms", nms_autocast);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_align", ps_roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_pool", ps_roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ at::Tensor roi_align_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_pool", roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autocast));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,12 @@ deform_conv2d_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ at::Tensor ps_roi_align_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,12 @@ at::Tensor ps_roi_pool_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ at::Tensor roi_align_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,12 @@ at::Tensor roi_pool_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,12 @@ deform_conv2d_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ at::Tensor nms_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1189,8 +1189,12 @@ deform_conv2d_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ at::Tensor nms_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ _deform_conv2d_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
}

} // namespace ops
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ at::Tensor nms(
}

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/ps_roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ at::Tensor _ps_roi_align_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/ps_roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ at::Tensor _ps_roi_pool_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ at::Tensor _roi_align_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ at::Tensor _roi_pool_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down