Skip to content

Commit

Permalink
Allocate a big output tensor and split in group_index_select_dim0_bac…
Browse files Browse the repository at this point in the history
…kward

Summary:
Before this diff, `group_index_select_dim0` backward calls `at::zeros`
`group_size` number of times which launches `group_size` elementwise
kernels.  Since `group_size` can be a large value (up to 55), this can
be costly.

This diff fixes the problem by allocating one big tensor and splitting
it into smaller tensors.  This will launch only one elementwise kernel
per group.  However, this can cause higher overhead on the host side.

Differential Revision: D45823864

fbshipit-source-id: f127b82bea6e49d4373bedf6c7307635161db87a
  • Loading branch information
sryap authored and facebook-github-bot committed May 12, 2023
1 parent 36b0d18 commit e477fe9
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ class GroupIndexSelectDim0GPUOp
// Transfer args tensor to GPU
args_tensor = args_tensor.to(first_input.device(), /*non_blocking=*/true);

TORCH_CHECK(group_size * input_dim == (int)input_shape_group.size())
TORCH_CHECK(
static_cast<size_t>(group_size * input_dim) == input_shape_group.size())

struct GroupIndexSelectArgs* gpu_args =
static_cast<struct GroupIndexSelectArgs*>(args_tensor.data_ptr());
Expand Down Expand Up @@ -425,30 +426,51 @@ class GroupIndexSelectDim0GPUOp
Tensor first_indices = *saved_itr;
Tensor fwd_input = *(saved_itr + 1);

std::vector<Tensor> output_group;
output_group.reserve(group_size + 1);
output_group.push_back(torch::autograd::Variable());

Tensor args_tensor = at::empty(
{group_size * 2},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
int64_t* grad_output_ptrs = args_tensor.data_ptr<int64_t>();
int64_t* grad_input_ptrs = args_tensor.data_ptr<int64_t>() + group_size;
int64_t group_grad_input_numel = 0;

std::vector<int64_t> grad_input_numels;
grad_input_numels.reserve(group_size + 1);
grad_input_numels.push_back(0); // indices_group

for (int i = 0; i < group_size; i++) {
Tensor& grad = grad_output_group[i];
TENSOR_ON_CUDA_GPU(grad);
TENSORS_ON_SAME_DEVICE(grad, first_indices);

auto grad_input_shape = std::vector<int64_t>(
output_shape_group.begin() + i * output_dim,
output_shape_group.begin() + (i + 1) * output_dim);
Tensor grad_input = at::zeros(grad_input_shape, fwd_input.options());
output_group.push_back(grad_input);
// Compute the total number of elements for all grad_inputs
int64_t grad_input_numel = output_shape_group[i * output_dim];
for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) {
grad_input_numel *= output_shape_group[j];
}
grad_input_numels.push_back(grad_input_numel);
group_grad_input_numel += grad_input_numel;

// Put all grad output/input pointers in an array
grad_output_ptrs[i] = reinterpret_cast<int64_t>(grad.data_ptr());
grad_input_ptrs[i] = reinterpret_cast<int64_t>(grad_input.data_ptr());
}

// Allocate a big tensor to avoid calling many small elementwise kernels
const auto group_grad_input =
at::zeros({group_grad_input_numel}, fwd_input.options());
auto output_group = group_grad_input.split(grad_input_numels, 0);
TORCH_CHECK(output_group.size() == static_cast<size_t>(group_size) + 1);
output_group[0] = torch::autograd::Variable();

// Reshape grad inputs and obtain their pointers
for (int i = 0; i < group_size; i++) {
const auto grad_input_shape = std::vector<int64_t>(
output_shape_group.begin() + i * output_dim,
output_shape_group.begin() + (i + 1) * output_dim);
output_group[i + 1] = output_group[i + 1].reshape(grad_input_shape);
grad_input_ptrs[i] =
reinterpret_cast<int64_t>(output_group[i + 1].data_ptr());
}

// Transfer grad output pointers to GPU
args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true);

Expand Down

0 comments on commit e477fe9

Please sign in to comment.