Skip to content

Commit

Permalink
Add Torch Selective macros in all C++ Ops for better support on mobile (
Browse files Browse the repository at this point in the history
#3218)

* Adding TORCH_SELECTIVE_* macros on op registration.

* Adding torchvision namespace.
  • Loading branch information
datumbox authored Jan 4, 2021
1 parent 4d2d8bb commit 3711754
Show file tree
Hide file tree
Showing 29 changed files with 130 additions and 59 deletions.
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ at::Tensor deform_conv2d_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("deform_conv2d", deform_conv2d_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autocast));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ at::Tensor nms_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("nms", nms_autocast);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_align", ps_roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("ps_roi_pool", ps_roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ at::Tensor roi_align_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", roi_align_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
}

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/ops/autocast/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_pool", roi_pool_autocast);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autocast));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,12 @@ deform_conv2d_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ at::Tensor ps_roi_align_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,12 @@ at::Tensor ps_roi_pool_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ at::Tensor roi_align_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/autograd/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,12 @@ at::Tensor roi_pool_backward_autograd(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_autograd));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_autograd));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,12 @@ deform_conv2d_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ at::Tensor nms_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cpu/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1189,8 +1189,12 @@ deform_conv2d_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_deform_conv2d_backward"),
TORCH_FN(deform_conv2d_backward_kernel));
}

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ at::Tensor nms_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("nms", nms_kernel);
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,12 @@ at::Tensor ps_roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ at::Tensor ps_roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"),
TORCH_FN(ps_roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,12 @@ at::Tensor roi_align_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", roi_align_forward_kernel);
m.impl("_roi_align_backward", roi_align_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
TORCH_FN(roi_align_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 6 additions & 2 deletions torchvision/csrc/ops/cuda/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,12 @@ at::Tensor roi_pool_backward_kernel(
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_pool", roi_pool_forward_kernel);
m.impl("_roi_pool_backward", roi_pool_backward_kernel);
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ _deform_conv2d_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
}

} // namespace ops
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ at::Tensor nms(
}

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/ps_roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ at::Tensor _ps_roi_align_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/ps_roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ at::Tensor _ps_roi_pool_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ at::Tensor _roi_align_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"));
}

} // namespace ops
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/roi_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ at::Tensor _roi_pool_backward(
} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"));
}

} // namespace ops
Expand Down

0 comments on commit 3711754

Please sign in to comment.