Skip to content

Commit

Permalink
Add meta functions for {}_embedding{}_codegen_forward_{}{}_cuda (#2094)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2094

For example, this hits split_embedding_codegen_lookup_rowwise_adagrad_function among others.

There are a few other ways to write this and I'm open to other options.  This version seemed the shortest to write, at the cost of some code duplication.  The general strategy is I took embedding_forward_split_template.cu, copy pasted it into embedding_forward_split_meta_template.cpp, deleted all the CUDA specific stuff, and made it register a Meta implementation instead of a CUDA implementation.  Most parts of the original were excised except the kernel itself, where all of the error checking and output tensor calculation is kept.

The biggest annoyance is the code duplication for the error checking, which are copy pasted with essentially no modifications from the original CUDA kernel.  I could potentially template the original logic so I can transclude it in two places, but that would also have made this diff harder to understand.  I'm willing to do this if someone tells me to.  Another approach would have been to implement the meta function entirely in Python (this is what I did in my original prototype), but to get it to handle all of the optimizer permutations seemed like a big pain.

Reviewed By: zou3519

Differential Revision: D50653995

fbshipit-source-id: 9ea5482d6308c098146c877d29a62b56006c9d17
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 2, 2023
1 parent 80eaddf commit 824ef10
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 90 deletions.
8 changes: 8 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ def forward_split() -> None:
vbe_options=[True, False],
)

generate_forward_embedding_cuda(
"embedding_forward_split_meta_template.cpp",
"gen_embedding_forward_{}_codegen_meta.cpp",
dense_options=[True, False],
nobag_options=[False], # nobag is not used
vbe_options=[True, False],
)

# Generate the kernels for the forward splits
generate_forward_embedding_cuda(
"embedding_forward_split_kernel_template.cu",
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,21 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
" int max_B_feature_rank=-1, "
" int vbe_output_size=-1, "
" bool is_experimental=False) -> Tensor");
// We're playing a funny trick here: we're using the autograd
// implementation of the operator at all the dispatch keys. This is OK
// because autograd.Function works even in a context where there is
// no autograd enabled, and all of the internal implementations redispatch
// appropriately
m.impl(
"split_embedding_codegen_lookup_{{ optimizer }}_function",
torch::dispatch(
c10::DispatchKey::Autograd,
TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function)));
m.impl(
"split_embedding_codegen_lookup_{{ optimizer }}_function",
torch::dispatch(
c10::DispatchKey::Meta,
TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function)));
DISPATCH_TO_CUDA(
"split_embedding_codegen_lookup_{{ optimizer }}_function",
split_embedding_codegen_lookup_{{ optimizer }}_function);
Expand Down
200 changes: 200 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

{#
// @lint-ignore LINTIGNORE
// @lint-ignore-every CLANGFORMAT
// clang-format off
// Note: clang-format off doesn't work with this templaterized code,
// so we need to keep lint-ignore-every.
// See https://fburl.com/dw9ljh4h
#}

// Companion template is embedding_forward_split_template.cu

{%- set ddesc = "dense" if dense else "split" %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
{%- set vdesc = "_vbe" if vbe else "" %}

////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
#include "codegen/embedding_op_registration.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/embedding_common.h"
////////////////////////////////////////////////////////////////////////////////

using namespace fbgemm_gpu;
using Tensor = at::Tensor;

static constexpr float kINT8QparamsBytes = 8;

////////////////////////////////////////////////////////////////////////////////
// Kernel Definitions
////////////////////////////////////////////////////////////////////////////////

{%- for nobag in [True, False] %}
{%- set ndesc = "_nobag" if nobag else "" %}
{%- if (not nobag or (not weighted and not vbe)) %}
{%- set has_experimental = (not dense and not nobag and not vbe) %}

Tensor
{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta(
const Tensor& dev_weights,
{%- if not dense %}
const Tensor& uvm_weights,
const Tensor& lxu_cache_weights,
const Tensor& weights_placements,
{%- endif %}
const Tensor& weights_offsets,
{%- if not nobag %}
const Tensor& D_offsets,
{%- else %}
const int64_t D,
{%- endif %}
{%- if not nobag %}
const int64_t total_D,
{%- endif %}
{%- if not nobag %}
const int64_t max_D,
{% endif %}
const Tensor& indices,
const Tensor& offsets,
{%- if not nobag %}
const int64_t pooling_mode,
{%- endif %}
{%- if weighted %}
const Tensor& indice_weights,
{%- endif %}
{%- if not dense %}
const Tensor& lxu_cache_locations,
{%- endif %}
const int64_t output_dtype,
{%- if vbe %}
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t vbe_output_size,
const int64_t info_B_num_bits, // int32_t
const int64_t info_B_mask_int64, // uint32_t
{%- endif %}
const bool is_experimental
) {
// NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL

// TODO: SymIntify

{%- if not nobag %}
int32_t T = D_offsets.numel() - 1;
{%- else %}
int32_t total_L = indices.numel();
int32_t T = weights_offsets.numel();
{%- endif %}
TORCH_CHECK_GT(T, 0);
// offsets = [B x T + 1]
{%- if is_index_select %}
const auto total_B = num_warps_per_feature * T;
const int32_t B = num_warps_per_feature;
{%- else %}
const auto total_B = offsets.size(0) - 1;
const int32_t B = total_B / T;
{%- endif %}
TORCH_CHECK_GE(B, 0);
{%- if not nobag or is_index_select %}
{%- if not nobag %}
TORCH_CHECK_GT(total_D, 0);
TORCH_CHECK_EQ(total_D % 4, 0);
{%- endif %}
TORCH_CHECK_LE(max_D, {{ max_embedding_dim }});
{%- elif not is_index_select %}
TORCH_CHECK_GT(D, 0);
TORCH_CHECK_EQ(D % 4, 0);
{%- endif %}
{%- if vbe %}
TORCH_CHECK_EQ(vbe_row_output_offsets.numel(), total_B);
TENSORS_HAVE_SAME_NUMEL(vbe_row_output_offsets, vbe_b_t_map);
TORCH_CHECK_GE(vbe_output_size, 0);

// Cast info_B_mask from int64_t to uint32_t
const uint32_t info_B_mask = info_B_mask_int64;
{%- endif %}

Tensor output;
{%- if nobag %}
SparseType o_dtype = static_cast<SparseType>(output_dtype);
{%- if is_index_select %}
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 ||
o_dtype == SparseType::BF16);

TORCH_CHECK_GT(fixed_L_per_warp, 0);
TORCH_CHECK_GT(num_warps_per_feature, 0);
if (!permute_output_dim_0_1) {
TORCH_CHECK_GE(output_size, 0);
TORCH_CHECK_GT(output_offsets.numel(), 0);
}

// If permute_output_dim_0_1 is true, output shape is (batch_size * total_D)
// Else, output shape is (output_size)
output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype)));
{%- else %}
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 ||
o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);

int64_t adjusted_D = D;
if (o_dtype == SparseType::INT8) {
adjusted_D += T * kINT8QparamsBytes;
}

output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
{%- endif %}
{%- else %}
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 ||
o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
int64_t total_adjusted_D = total_D;
if (o_dtype == SparseType::INT8) {
total_adjusted_D += T * kINT8QparamsBytes;
}

{%- if vbe %}
// Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output
output = at::empty(
{1, vbe_output_size},
dev_weights.options().dtype(getScalarType(o_dtype))
);
{%- else %}
output = at::empty(
{B, total_adjusted_D},
dev_weights.options().dtype(getScalarType(o_dtype))
);
{%- endif %}
{%- endif %} // if nobag

if (B == 0) {
{%- if vbe %}
output = output.reshape({-1});
{%- endif %}
return output;
}

return output;
}

////////////////////////////////////////////////////////////////////////////////
// Op registrations
////////////////////////////////////////////////////////////////////////////////
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// NB: yes cuda here
{%- set embedding_codegen_forward_op =
"{}_embedding{}_codegen_forward_{}{}_cuda".format(
ddesc, ndesc, wdesc, vdesc
)
%}
m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta)));
}
{%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#}
{%- endfor %} {#-/* for nobag */#}
// clang-format on
95 changes: 5 additions & 90 deletions fbgemm_gpu/test/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -655,36 +655,7 @@
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_adagrad_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
Expand All @@ -697,64 +668,20 @@
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_lamb_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lars": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_none_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_lamb_function": {},
"fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {},
"fbgemm::split_embedding_codegen_lookup_none_function": {},
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {},
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
"status": "xfail"
Expand All @@ -771,10 +698,6 @@
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": {
"comment": "",
"status": "xfail"
Expand All @@ -790,14 +713,6 @@
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_forward_fused_pooled_emb_quant": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {
Expand Down

0 comments on commit 824ef10

Please sign in to comment.