Skip to content

Commit

Permalink
Use PyTorch dispatcher in TBE training GPU autograds (pytorch#1948)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1948

Register TBE GPU training operators in PyTorch and invoke them in
autograd via PyTorch dispatcher

Reviewed By: jianyuh

Differential Revision: D48383830

fbshipit-source-id: 125e19e1886a6a318a443805451e55c28203d28e
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 25, 2023
1 parent df45374 commit 61c0272
Show file tree
Hide file tree
Showing 13 changed files with 767 additions and 450 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ set(embedding_codegen_dependencies
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh
${CMAKE_CODEGEN_DIR}/embedding_op_registration.h
${CMAKE_CODEGEN_DIR}/__init__.template
${CMAKE_CODEGEN_DIR}/lookup_args.py
${CMAKE_CODEGEN_DIR}/split_embedding_codegen_lookup_invoker.template
Expand Down
30 changes: 15 additions & 15 deletions fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Tensor batch_index_select_dim0_codegen_forward_cuda(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor indices,
int64_t output_dtype,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const Tensor& indices,
const int64_t output_dtype,
const Tensor& output_offsets,
const Tensor& total_L_offsets,
const int64_t output_size,
Expand All @@ -33,15 +33,15 @@ Tensor batch_index_select_dim0_codegen_forward_cuda(
const bool permute_output_dim_0_1);

Tensor batch_index_select_dim0_codegen_backward_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
int64_t max_segment_length_per_warp,
const Tensor& grad_output,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
const Tensor& indices,
const int64_t max_segment_length_per_warp,
const Tensor& grad_offsets,
const Tensor& total_L_offsets,
const int32_t fixed_L_per_warp,
Expand Down
148 changes: 74 additions & 74 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,70 @@ using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Tensor dense_embedding_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
int64_t output_dtype,
bool is_experimental);
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t total_D,
const int64_t max_D,
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const int64_t output_dtype,
const bool is_experimental);

Tensor dense_embedding_codegen_forward_weighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t output_dtype,
bool is_experimental);
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t total_D,
const int64_t max_D,
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const Tensor& indice_weights,
const int64_t output_dtype,
const bool is_experimental);

Tensor dense_embedding_codegen_grad_indice_weights_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor indices,
Tensor offsets,
Tensor feature_requires_grad);
const Tensor& grad_output,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const Tensor& indices,
const Tensor& offsets,
const Tensor& feature_requires_grad);

Tensor split_embedding_backward_codegen_dense_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
const Tensor& grad_output,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const int64_t BT_block_size,
const int64_t max_segment_length_per_warp,
const double unused);

Tensor split_embedding_backward_codegen_dense_weighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
const Tensor& grad_output,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const Tensor& D_offsets,
const int64_t max_D,
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const Tensor& indice_weights,
const int64_t BT_block_size,
const int64_t max_segment_length_per_warp,
const double unused);

class SplitLookupFunction_Dense_Op
: public torch::autograd::Function<SplitLookupFunction_Dense_Op> {
Expand Down Expand Up @@ -265,26 +265,26 @@ class SplitLookupFunction_Dense_Op

/******** nobag ops ********/
Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Tensor weights_offsets,
int64_t D,
Tensor indices,
Tensor offsets,
int64_t output_dtype,
bool is_experimental);
const Tensor& dev_weights,
const Tensor& weights_offsets,
const int64_t D,
const Tensor& indices,
const Tensor& offsets,
const int64_t output_dtype,
const bool is_experimental);

Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
Tensor grad_output,
Tensor dev_weights,
Tensor weights_offsets,
int64_t D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t BT_block_size,
int64_t max_segment_length_per_warp,
double unused);
const Tensor& grad_output,
const Tensor& dev_weights,
const Tensor& weights_offsets,
const int64_t D,
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
const Tensor& indices,
const Tensor& offsets,
const int64_t BT_block_size,
const int64_t max_segment_length_per_warp,
const double unused);

class SplitNoBagLookupFunction_Dense_Op
: public torch::autograd::Function<SplitNoBagLookupFunction_Dense_Op> {
Expand Down
Loading

0 comments on commit 61c0272

Please sign in to comment.