From 8372bc776ec77da5d51544bae543628c27902763 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 20 Jul 2023 18:46:05 -0700 Subject: [PATCH] Use TensorList in group_index_select_dim0 (#1884) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1884 Replace variadic template in `group_index_select_dim0` by packing all tensors into `TensorList` and passing the `TensorList` to autograd function. Return output needs to be tensors of the same size as input specified in the forward functions and can be returned as type of variable_list. This improves performance and lifts limitations of using variadic template that needs to be instantiated for each group size and limits max group size of 55. Reviewed By: sryap Differential Revision: D47488358 fbshipit-source-id: aa2c23e1b038743d54e5878ff9e8eb6ccd27e852 --- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 306 +++++++++++-------- 1 file changed, 178 insertions(+), 128 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index cc613ba4d3..66daca781b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -20,55 +20,37 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { namespace { -// From https://stackoverflow.com/a/28411055 -template -auto vec_to_tup_helper( - const std::vector& v, - std::index_sequence) { - return std::make_tuple(v[Indices]...); -} -template -auto vec_to_tup(const std::vector& v) { - assert(v.size() >= N); - return vec_to_tup_helper(v, std::make_index_sequence()); -} +constexpr int32_t NUM_ARGS = 5; +enum args_pos { + P_input_ptrs = 0, + P_output_ptrs = 1, + P_indices_ptrs = 2, + P_warp_offsets_group_ptrs = 3, + P_num_cols_group_ptrs = 4 +}; -template -void apply_(F fn, const std::vector& v) { - auto size = v.size(); -#define APPLY_AUTOGRAD_FN_N(N) \ - { \ - case N: \ - std::apply(fn, vec_to_tup(v)); \ - break; \ - } +template +int64_t compute_num_int64s(const int64_t num_elements) { + const int64_t ratio = sizeof(int64_t) / sizeof(T); + return (num_elements + ratio - 1) / ratio; +} -#define APPLY_AUTOGRAD_FN_2N(N) \ - APPLY_AUTOGRAD_FN_N(N) \ - APPLY_AUTOGRAD_FN_N(N + 1) - -#define APPLY_AUTOGRAD_FN_6N(N) \ - APPLY_AUTOGRAD_FN_2N(N) \ - APPLY_AUTOGRAD_FN_2N(N + 2) \ - APPLY_AUTOGRAD_FN_2N(N + 4) - -#define APPLY_AUTOGRAD_FN_18N(N) \ - APPLY_AUTOGRAD_FN_6N(N) \ - APPLY_AUTOGRAD_FN_6N(N + 6) \ - APPLY_AUTOGRAD_FN_6N(N + 12) - -#define APPLY_AUTOGRAD_FN_54N(N) \ - APPLY_AUTOGRAD_FN_18N(N) \ - APPLY_AUTOGRAD_FN_18N(N + 18) \ - APPLY_AUTOGRAD_FN_18N(N + 36) - - switch (size) { - APPLY_AUTOGRAD_FN_54N(1) - default: - TORCH_CHECK(false, "size is not supported ", size) - } -#undef APPLY_AUTOGRAD_FN +// Compute offsets to set raw pointers +void offset_args( + int64_t** input_ptrs, + int64_t** output_ptrs, + int64_t** indices_ptrs, + int64_t** warp_offsets_group, + int32_t** num_cols_group, + int64_t* base_addr, + const int64_t* const ptr_offsets) { + *input_ptrs = base_addr + ptr_offsets[P_input_ptrs]; + *output_ptrs = base_addr + ptr_offsets[P_output_ptrs]; + *indices_ptrs = base_addr + ptr_offsets[P_indices_ptrs]; + *warp_offsets_group = base_addr + ptr_offsets[P_warp_offsets_group_ptrs]; + *num_cols_group = reinterpret_cast( + base_addr + ptr_offsets[P_num_cols_group_ptrs]); } } // namespace @@ -250,34 +232,74 @@ class IndexSelectDim0GPUOp class GroupIndexSelectDim0GPUOp : public torch::autograd::Function { public: - template static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - const std::vector& indices_group, - Tensors&&... input_tensors) { - std::vector input_group = {input_tensors...}; - constexpr int group_size = sizeof...(Tensors); + const at::TensorList& all_indices_input, + const int group_size) { + // Unpack from TensorList + std::vector indices_group; + std::vector input_group; - if (group_size == 0) { - return {torch::autograd::Variable()}; + indices_group.reserve(group_size); + input_group.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); } TORCH_CHECK(group_size == static_cast(indices_group.size())); - struct GroupIndexSelectArgs { - int64_t input_ptrs[group_size]; - int64_t output_ptrs[group_size]; - int64_t indices_ptrs[group_size]; - int64_t warp_offsets_group[group_size + 1]; - int32_t num_cols_group[group_size]; - }; + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; + } + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; // Allocate memory for GroupIndexSelectArgs - Tensor args_tensor = at::empty( - {sizeof(GroupIndexSelectArgs)}, + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - struct GroupIndexSelectArgs* args = - reinterpret_cast(args_tensor.data_ptr()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); auto& first_input = input_group[0]; auto& first_indices = indices_group[0]; @@ -291,10 +313,14 @@ class GroupIndexSelectDim0GPUOp int64_t warp_offset = 0; bool use_var_cols = false; - std::vector outputs; - outputs.reserve(group_size); + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size); + std::vector input_shape_group; input_shape_group.reserve(group_size * input_dim); + + // For each group, copy input to output for (const auto i : c10::irange(group_size)) { auto& input = input_group[i]; auto& indices = indices_group[i]; @@ -334,27 +360,38 @@ class GroupIndexSelectDim0GPUOp // Create output pointers input_shape[0] = num_output_rows_; Tensor output = at::empty(input_shape, input.options()); - outputs.push_back(output); + output_group.push_back(output); // Store args - args->input_ptrs[i] = reinterpret_cast(input.data_ptr()); - args->output_ptrs[i] = reinterpret_cast(output.data_ptr()); - args->indices_ptrs[i] = reinterpret_cast(indices.data_ptr()); - args->warp_offsets_group[i] = warp_offset; - args->num_cols_group[i] = num_cols_; + input_ptrs[i] = reinterpret_cast(input.data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(indices.data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; warp_offset += warps_per_row * num_output_rows; } + // Store the last offset - args->warp_offsets_group[group_size] = warp_offset; + warp_offsets_group[group_size] = warp_offset; + // Transfer args tensor to GPU - args_tensor = args_tensor.to(first_input.device(), /*non_blocking=*/true); + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); TORCH_CHECK( static_cast(group_size * input_dim) == input_shape_group.size()) - struct GroupIndexSelectArgs* gpu_args = - static_cast(args_tensor.data_ptr()); + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); // Need to store args_tensor for backward to keep indices_ptrs alive ctx->save_for_backward({indices_group[0], input_group[0], args_tensor}); @@ -362,20 +399,19 @@ class GroupIndexSelectDim0GPUOp ctx->saved_data["input_shape_group"] = input_shape_group; ctx->saved_data["group_size"] = group_size; ctx->saved_data["use_var_cols"] = use_var_cols; - ctx->saved_data["indices_ptrs"] = - reinterpret_cast(gpu_args->indices_ptrs); + ctx->saved_data["indices_ptrs"] = reinterpret_cast(indices_ptrs); ctx->saved_data["warp_offsets_group"] = - reinterpret_cast(gpu_args->warp_offsets_group); + reinterpret_cast(warp_offsets_group); ctx->saved_data["num_cols_group"] = - reinterpret_cast(gpu_args->num_cols_group); + reinterpret_cast(num_cols_group); ctx->saved_data["total_num_warps"] = warp_offset; group_index_select_or_add_cuda( - gpu_args->input_ptrs, - gpu_args->output_ptrs, - gpu_args->indices_ptrs, - gpu_args->warp_offsets_group, - gpu_args->num_cols_group, + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, first_input.scalar_type(), first_indices.scalar_type(), first_input.device().index(), @@ -386,18 +422,19 @@ class GroupIndexSelectDim0GPUOp /*use_index_select=*/true, use_var_cols); - return outputs; + return output_group; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output_group) { - const int group_size = ctx->saved_data["group_size"].toInt(); + const auto group_size = ctx->saved_data["group_size"].toInt(); if (group_size == 0) { return torch::autograd::variable_list(); } + // Retrieve saved data const int output_dim = ctx->saved_data["input_dim"].toInt(); std::vector output_shape_group = ctx->saved_data["input_shape_group"].toIntVector(); @@ -410,6 +447,7 @@ class GroupIndexSelectDim0GPUOp reinterpret_cast(ctx->saved_data["num_cols_group"].toInt()); auto total_num_warps = ctx->saved_data["total_num_warps"].toInt(); + // Check that the size is the same TORCH_CHECK(static_cast(grad_output_group.size()) == group_size); // We checked in forward that all output rows are the same for all member @@ -419,19 +457,38 @@ class GroupIndexSelectDim0GPUOp const auto saved = ctx->get_saved_variables(); const auto saved_itr = std::begin(saved); + + // Retrieve first index group Tensor first_indices = *saved_itr; + // Retrieve first input group Tensor fwd_input = *(saved_itr + 1); + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable + for (auto i = 0; i < group_size; i++) { + outputs.push_back(torch::autograd::Variable()); + } + + // Allocate Tensor for ptrs of grad output and input Tensor args_tensor = at::empty( {group_size * 2}, at::TensorOptions().dtype(at::kLong).pinned_memory(true)); int64_t* grad_output_ptrs = args_tensor.data_ptr(); int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t group_grad_input_numel = 0; + int64_t group_grad_input_numel = 0; std::vector grad_input_numels; - grad_input_numels.reserve(group_size + 1); - grad_input_numels.push_back(0); // indices_group + + // Reserve memory for grad input group + grad_input_numels.reserve(group_size); for (const auto i : c10::irange(group_size)) { Tensor& grad = grad_output_group[i]; @@ -453,20 +510,28 @@ class GroupIndexSelectDim0GPUOp // Allocate a big tensor to avoid calling many small elementwise kernels const auto group_grad_input = at::zeros({group_grad_input_numel}, fwd_input.options()); + + // Split to output_group auto output_group = group_grad_input.split(grad_input_numels, 0); - TORCH_CHECK(output_group.size() == static_cast(group_size) + 1); - output_group[0] = torch::autograd::Variable(); + + TORCH_CHECK(output_group.size() == static_cast(group_size)); // Reshape grad inputs and obtain their pointers for (int i = 0; i < group_size; i++) { const auto grad_input_shape = std::vector( output_shape_group.begin() + i * output_dim, output_shape_group.begin() + (i + 1) * output_dim); - output_group[i + 1] = output_group[i + 1].reshape(grad_input_shape); + output_group[i] = output_group[i].reshape(grad_input_shape); grad_input_ptrs[i] = - reinterpret_cast(output_group[i + 1].data_ptr()); + reinterpret_cast(output_group[i].data_ptr()); + + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); } + // 3) Add 1 Variable() for group_size + outputs.push_back(torch::autograd::Variable()); + // Transfer grad output pointers to GPU args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); @@ -486,7 +551,7 @@ class GroupIndexSelectDim0GPUOp /*use_index_select=*/false, use_var_cols); - return output_group; + return outputs; } }; @@ -518,44 +583,29 @@ Tensor index_select_dim0_gpu( std::vector group_index_select_dim0_gpu( const std::vector& input_group, const std::vector& indices_group) { - const auto group_size = input_group.size(); + const auto group_size = indices_group.size(); std::vector output_group; - // We use the APPLY_AUTOGRAD_FN macros to instantiate - // GroupIndexSelectDim0GPUOp for different group sizes. We only instantiate - // up to group size of 54. - constexpr size_t max_group_size = 54; - // Specialize this path to avoid copy - if (group_size <= max_group_size) { - apply_( - [&](auto&&... args) { - output_group = - GroupIndexSelectDim0GPUOp::apply(indices_group, args...); - }, - input_group); - return output_group; + + if (group_size == 0) { + return std::vector(); } - const auto input_itr = input_group.begin(); - const auto indices_itr = indices_group.begin(); - - for (size_t start = 0; start < group_size; start += max_group_size) { - const auto end = std::min(start + max_group_size, group_size); - std::vector input_subgroup(input_itr + start, input_itr + end); - std::vector indices_subgroup( - indices_itr + start, indices_itr + end); - std::vector output_subgroup; - apply_( - [&](auto&&... args) { - output_subgroup = - GroupIndexSelectDim0GPUOp::apply(indices_subgroup, args...); - }, - input_subgroup); - output_group.insert( - output_group.end(), output_subgroup.begin(), output_subgroup.end()); + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); } - return output_group; -} + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + return output_group = fbgemm_gpu::GroupIndexSelectDim0GPUOp::apply( + all_indices_input_tensor, group_size); +} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {