Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allocate a big output tensor and split in group_index_select_dim0_backward #1764

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ class GroupIndexSelectDim0GPUOp
// Transfer args tensor to GPU
args_tensor = args_tensor.to(first_input.device(), /*non_blocking=*/true);

TORCH_CHECK(group_size * input_dim == (int)input_shape_group.size())
TORCH_CHECK(
static_cast<size_t>(group_size * input_dim) == input_shape_group.size())

struct GroupIndexSelectArgs* gpu_args =
static_cast<struct GroupIndexSelectArgs*>(args_tensor.data_ptr());
Expand Down Expand Up @@ -421,30 +422,51 @@ class GroupIndexSelectDim0GPUOp
Tensor first_indices = *saved_itr;
Tensor fwd_input = *(saved_itr + 1);

std::vector<Tensor> output_group;
output_group.reserve(group_size + 1);
output_group.push_back(torch::autograd::Variable());

Tensor args_tensor = at::empty(
{group_size * 2},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
int64_t* grad_output_ptrs = args_tensor.data_ptr<int64_t>();
int64_t* grad_input_ptrs = args_tensor.data_ptr<int64_t>() + group_size;
int64_t group_grad_input_numel = 0;

std::vector<int64_t> grad_input_numels;
grad_input_numels.reserve(group_size + 1);
grad_input_numels.push_back(0); // indices_group

for (const auto i : c10::irange(group_size)) {
Tensor& grad = grad_output_group[i];
TENSOR_ON_CUDA_GPU(grad);
TENSORS_ON_SAME_DEVICE(grad, first_indices);

auto grad_input_shape = std::vector<int64_t>(
output_shape_group.begin() + i * output_dim,
output_shape_group.begin() + (i + 1) * output_dim);
Tensor grad_input = at::zeros(grad_input_shape, fwd_input.options());
output_group.push_back(grad_input);
// Compute the total number of elements for all grad_inputs
int64_t grad_input_numel = output_shape_group[i * output_dim];
for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) {
grad_input_numel *= output_shape_group[j];
}
grad_input_numels.push_back(grad_input_numel);
group_grad_input_numel += grad_input_numel;

// Put all grad output/input pointers in an array
grad_output_ptrs[i] = reinterpret_cast<int64_t>(grad.data_ptr());
grad_input_ptrs[i] = reinterpret_cast<int64_t>(grad_input.data_ptr());
}

// Allocate a big tensor to avoid calling many small elementwise kernels
const auto group_grad_input =
at::zeros({group_grad_input_numel}, fwd_input.options());
auto output_group = group_grad_input.split(grad_input_numels, 0);
TORCH_CHECK(output_group.size() == static_cast<size_t>(group_size) + 1);
output_group[0] = torch::autograd::Variable();

// Reshape grad inputs and obtain their pointers
for (int i = 0; i < group_size; i++) {
const auto grad_input_shape = std::vector<int64_t>(
output_shape_group.begin() + i * output_dim,
output_shape_group.begin() + (i + 1) * output_dim);
output_group[i + 1] = output_group[i + 1].reshape(grad_input_shape);
grad_input_ptrs[i] =
reinterpret_cast<int64_t>(output_group[i + 1].data_ptr());
}

// Transfer grad output pointers to GPU
args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true);

Expand Down