diff --git a/torchvision/csrc/cpu/DeformConv_cpu.cpp b/torchvision/csrc/cpu/DeformConv_cpu.cpp index a50eec3b76f..c1580f7228a 100644 --- a/torchvision/csrc/cpu/DeformConv_cpu.cpp +++ b/torchvision/csrc/cpu/DeformConv_cpu.cpp @@ -74,8 +74,6 @@ #include #include -using namespace at; - const int kMaxParallelImgs = 32; template @@ -597,7 +595,7 @@ static void deformable_col2im_coord_kernel( out_w; const int offset_c = c - offset_grp * 2 * weight_h * weight_w; - const int is_y_direction = offset_c % 2 == 0; + const bool is_y_direction = offset_c % 2 == 0; const int c_bound = c_per_offset_grp * weight_h * weight_w; for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { @@ -812,9 +810,9 @@ static std::tuple deform_conv2d_backward_input_cpu( static at::Tensor deform_conv2d_backward_parameters_cpu( at::Tensor input, - at::Tensor weight, + const at::Tensor& weight, at::Tensor offset, - at::Tensor grad_out, + const at::Tensor& grad_out, std::pair stride, std::pair pad, std::pair dilation, diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index 7a45c6b5c45..89516ae8454 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -78,8 +78,6 @@ #include #include -using namespace at; - const unsigned int CUDA_NUM_THREADS = 1024; const int kMaxParallelImgs = 32; @@ -618,7 +616,7 @@ __global__ void deformable_col2im_coord_gpu_kernel( out_h * out_w; const int offset_c = c - offset_grp * 2 * weight_h * weight_w; - const int is_y_direction = offset_c % 2 == 0; + const bool is_y_direction = offset_c % 2 == 0; const int c_bound = c_per_offset_grp * weight_h * weight_w; for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { @@ -840,9 +838,9 @@ static std::tuple deform_conv_backward_input_cuda( static at::Tensor deform_conv_backward_parameters_cuda( at::Tensor input, - at::Tensor weight, + const at::Tensor& weight, at::Tensor offset, - at::Tensor grad_out, + const at::Tensor& grad_out, std::pair stride, std::pair pad, std::pair dilation,