-
Notifications
You must be signed in to change notification settings - Fork 509
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add meta functions for {}_embedding{}_codegen_forward_{}{}_cuda (#2094)
Summary: Pull Request resolved: #2094 For example, this hits split_embedding_codegen_lookup_rowwise_adagrad_function among others. There are a few other ways to write this and I'm open to other options. This version seemed the shortest to write, at the cost of some code duplication. The general strategy is I took embedding_forward_split_template.cu, copy pasted it into embedding_forward_split_meta_template.cpp, deleted all the CUDA specific stuff, and made it register a Meta implementation instead of a CUDA implementation. Most parts of the original were excised except the kernel itself, where all of the error checking and output tensor calculation is kept. The biggest annoyance is the code duplication for the error checking, which are copy pasted with essentially no modifications from the original CUDA kernel. I could potentially template the original logic so I can transclude it in two places, but that would also have made this diff harder to understand. I'm willing to do this if someone tells me to. Another approach would have been to implement the meta function entirely in Python (this is what I did in my original prototype), but to get it to handle all of the optimizer permutations seemed like a big pain. Reviewed By: zou3519 Differential Revision: D50653995 fbshipit-source-id: 9ea5482d6308c098146c877d29a62b56006c9d17
- Loading branch information
1 parent
80eaddf
commit 824ef10
Showing
4 changed files
with
223 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
200 changes: 200 additions & 0 deletions
200
fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
{# | ||
// @lint-ignore LINTIGNORE | ||
// @lint-ignore-every CLANGFORMAT | ||
// clang-format off | ||
// Note: clang-format off doesn't work with this templaterized code, | ||
// so we need to keep lint-ignore-every. | ||
// See https://fburl.com/dw9ljh4h | ||
#} | ||
|
||
// Companion template is embedding_forward_split_template.cu | ||
|
||
{%- set ddesc = "dense" if dense else "split" %} | ||
{%- set wdesc = "weighted" if weighted else "unweighted" %} | ||
{%- set vdesc = "_vbe" if vbe else "" %} | ||
|
||
//////////////////////////////////////////////////////////////////////////////// | ||
// Required for op registrations | ||
#include "codegen/embedding_op_registration.h" | ||
#include "fbgemm_gpu/sparse_ops_utils.h" | ||
#include "fbgemm_gpu/embedding_common.h" | ||
//////////////////////////////////////////////////////////////////////////////// | ||
|
||
using namespace fbgemm_gpu; | ||
using Tensor = at::Tensor; | ||
|
||
static constexpr float kINT8QparamsBytes = 8; | ||
|
||
//////////////////////////////////////////////////////////////////////////////// | ||
// Kernel Definitions | ||
//////////////////////////////////////////////////////////////////////////////// | ||
|
||
{%- for nobag in [True, False] %} | ||
{%- set ndesc = "_nobag" if nobag else "" %} | ||
{%- if (not nobag or (not weighted and not vbe)) %} | ||
{%- set has_experimental = (not dense and not nobag and not vbe) %} | ||
|
||
Tensor | ||
{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta( | ||
const Tensor& dev_weights, | ||
{%- if not dense %} | ||
const Tensor& uvm_weights, | ||
const Tensor& lxu_cache_weights, | ||
const Tensor& weights_placements, | ||
{%- endif %} | ||
const Tensor& weights_offsets, | ||
{%- if not nobag %} | ||
const Tensor& D_offsets, | ||
{%- else %} | ||
const int64_t D, | ||
{%- endif %} | ||
{%- if not nobag %} | ||
const int64_t total_D, | ||
{%- endif %} | ||
{%- if not nobag %} | ||
const int64_t max_D, | ||
{% endif %} | ||
const Tensor& indices, | ||
const Tensor& offsets, | ||
{%- if not nobag %} | ||
const int64_t pooling_mode, | ||
{%- endif %} | ||
{%- if weighted %} | ||
const Tensor& indice_weights, | ||
{%- endif %} | ||
{%- if not dense %} | ||
const Tensor& lxu_cache_locations, | ||
{%- endif %} | ||
const int64_t output_dtype, | ||
{%- if vbe %} | ||
const Tensor& vbe_row_output_offsets, | ||
const Tensor& vbe_b_t_map, | ||
const int64_t vbe_output_size, | ||
const int64_t info_B_num_bits, // int32_t | ||
const int64_t info_B_mask_int64, // uint32_t | ||
{%- endif %} | ||
const bool is_experimental | ||
) { | ||
// NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL | ||
|
||
// TODO: SymIntify | ||
|
||
{%- if not nobag %} | ||
int32_t T = D_offsets.numel() - 1; | ||
{%- else %} | ||
int32_t total_L = indices.numel(); | ||
int32_t T = weights_offsets.numel(); | ||
{%- endif %} | ||
TORCH_CHECK_GT(T, 0); | ||
// offsets = [B x T + 1] | ||
{%- if is_index_select %} | ||
const auto total_B = num_warps_per_feature * T; | ||
const int32_t B = num_warps_per_feature; | ||
{%- else %} | ||
const auto total_B = offsets.size(0) - 1; | ||
const int32_t B = total_B / T; | ||
{%- endif %} | ||
TORCH_CHECK_GE(B, 0); | ||
{%- if not nobag or is_index_select %} | ||
{%- if not nobag %} | ||
TORCH_CHECK_GT(total_D, 0); | ||
TORCH_CHECK_EQ(total_D % 4, 0); | ||
{%- endif %} | ||
TORCH_CHECK_LE(max_D, {{ max_embedding_dim }}); | ||
{%- elif not is_index_select %} | ||
TORCH_CHECK_GT(D, 0); | ||
TORCH_CHECK_EQ(D % 4, 0); | ||
{%- endif %} | ||
{%- if vbe %} | ||
TORCH_CHECK_EQ(vbe_row_output_offsets.numel(), total_B); | ||
TENSORS_HAVE_SAME_NUMEL(vbe_row_output_offsets, vbe_b_t_map); | ||
TORCH_CHECK_GE(vbe_output_size, 0); | ||
|
||
// Cast info_B_mask from int64_t to uint32_t | ||
const uint32_t info_B_mask = info_B_mask_int64; | ||
{%- endif %} | ||
|
||
Tensor output; | ||
{%- if nobag %} | ||
SparseType o_dtype = static_cast<SparseType>(output_dtype); | ||
{%- if is_index_select %} | ||
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || | ||
o_dtype == SparseType::BF16); | ||
|
||
TORCH_CHECK_GT(fixed_L_per_warp, 0); | ||
TORCH_CHECK_GT(num_warps_per_feature, 0); | ||
if (!permute_output_dim_0_1) { | ||
TORCH_CHECK_GE(output_size, 0); | ||
TORCH_CHECK_GT(output_offsets.numel(), 0); | ||
} | ||
|
||
// If permute_output_dim_0_1 is true, output shape is (batch_size * total_D) | ||
// Else, output shape is (output_size) | ||
output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); | ||
{%- else %} | ||
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || | ||
o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); | ||
|
||
int64_t adjusted_D = D; | ||
if (o_dtype == SparseType::INT8) { | ||
adjusted_D += T * kINT8QparamsBytes; | ||
} | ||
|
||
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); | ||
{%- endif %} | ||
{%- else %} | ||
SparseType o_dtype = static_cast<SparseType>(output_dtype); | ||
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || | ||
o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); | ||
int64_t total_adjusted_D = total_D; | ||
if (o_dtype == SparseType::INT8) { | ||
total_adjusted_D += T * kINT8QparamsBytes; | ||
} | ||
|
||
{%- if vbe %} | ||
// Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output | ||
output = at::empty( | ||
{1, vbe_output_size}, | ||
dev_weights.options().dtype(getScalarType(o_dtype)) | ||
); | ||
{%- else %} | ||
output = at::empty( | ||
{B, total_adjusted_D}, | ||
dev_weights.options().dtype(getScalarType(o_dtype)) | ||
); | ||
{%- endif %} | ||
{%- endif %} // if nobag | ||
|
||
if (B == 0) { | ||
{%- if vbe %} | ||
output = output.reshape({-1}); | ||
{%- endif %} | ||
return output; | ||
} | ||
|
||
return output; | ||
} | ||
|
||
//////////////////////////////////////////////////////////////////////////////// | ||
// Op registrations | ||
//////////////////////////////////////////////////////////////////////////////// | ||
TORCH_LIBRARY_FRAGMENT(fbgemm, m) { | ||
// NB: yes cuda here | ||
{%- set embedding_codegen_forward_op = | ||
"{}_embedding{}_codegen_forward_{}{}_cuda".format( | ||
ddesc, ndesc, wdesc, vdesc | ||
) | ||
%} | ||
m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta))); | ||
} | ||
{%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} | ||
{%- endfor %} {#-/* for nobag */#} | ||
// clang-format on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters