Skip to content

Commit

Permalink
Static Analysis corrections on DeformConv (#2885)
Browse files Browse the repository at this point in the history
* Convert to const reference and eliminate unnecessary bool casting.

* Removing unnecessary namespace use.
  • Loading branch information
datumbox authored Oct 26, 2020
1 parent cffac64 commit 5cb77a2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
8 changes: 3 additions & 5 deletions torchvision/csrc/cpu/DeformConv_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@
#include <iostream>
#include <tuple>

using namespace at;

const int kMaxParallelImgs = 32;

template <typename scalar_t>
Expand Down Expand Up @@ -597,7 +595,7 @@ static void deformable_col2im_coord_kernel(
out_w;

const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const int is_y_direction = offset_c % 2 == 0;
const bool is_y_direction = offset_c % 2 == 0;

const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
Expand Down Expand Up @@ -812,9 +810,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(

static at::Tensor deform_conv2d_backward_parameters_cpu(
at::Tensor input,
at::Tensor weight,
const at::Tensor& weight,
at::Tensor offset,
at::Tensor grad_out,
const at::Tensor& grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
Expand Down
8 changes: 3 additions & 5 deletions torchvision/csrc/cuda/DeformConv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@
#include <iostream>
#include <tuple>

using namespace at;

const unsigned int CUDA_NUM_THREADS = 1024;
const int kMaxParallelImgs = 32;

Expand Down Expand Up @@ -618,7 +616,7 @@ __global__ void deformable_col2im_coord_gpu_kernel(
out_h * out_w;
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const int is_y_direction = offset_c % 2 == 0;
const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
Expand Down Expand Up @@ -840,9 +838,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
static at::Tensor deform_conv_backward_parameters_cuda(
at::Tensor input,
at::Tensor weight,
const at::Tensor& weight,
at::Tensor offset,
at::Tensor grad_out,
const at::Tensor& grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
Expand Down

0 comments on commit 5cb77a2

Please sign in to comment.