From 824ef10481be8e871f66e4fdf34f7b27e735b968 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 2 Nov 2023 11:09:27 -0700 Subject: [PATCH] Add meta functions for {}_embedding{}_codegen_forward_{}{}_cuda (#2094) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2094 For example, this hits split_embedding_codegen_lookup_rowwise_adagrad_function among others. There are a few other ways to write this and I'm open to other options. This version seemed the shortest to write, at the cost of some code duplication. The general strategy is I took embedding_forward_split_template.cu, copy pasted it into embedding_forward_split_meta_template.cpp, deleted all the CUDA specific stuff, and made it register a Meta implementation instead of a CUDA implementation. Most parts of the original were excised except the kernel itself, where all of the error checking and output tensor calculation is kept. The biggest annoyance is the code duplication for the error checking, which are copy pasted with essentially no modifications from the original CUDA kernel. I could potentially template the original logic so I can transclude it in two places, but that would also have made this diff harder to understand. I'm willing to do this if someone tells me to. Another approach would have been to implement the meta function entirely in Python (this is what I did in my original prototype), but to get it to handle all of the optimizer permutations seemed like a big pain. Reviewed By: zou3519 Differential Revision: D50653995 fbshipit-source-id: 9ea5482d6308c098146c877d29a62b56006c9d17 --- .../embedding_backward_code_generator.py | 8 + ...embedding_backward_split_host_template.cpp | 10 + .../embedding_forward_split_meta_template.cpp | 200 ++++++++++++++++++ fbgemm_gpu/test/failures_dict_fast.json | 95 +-------- 4 files changed, 223 insertions(+), 90 deletions(-) create mode 100644 fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index a86c4642b2..f42d2cda5c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -238,6 +238,14 @@ def forward_split() -> None: vbe_options=[True, False], ) + generate_forward_embedding_cuda( + "embedding_forward_split_meta_template.cpp", + "gen_embedding_forward_{}_codegen_meta.cpp", + dense_options=[True, False], + nobag_options=[False], # nobag is not used + vbe_options=[True, False], + ) + # Generate the kernels for the forward splits generate_forward_embedding_cuda( "embedding_forward_split_kernel_template.cu", diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index eb01e0b191..63fdf6e684 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -767,11 +767,21 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " int max_B_feature_rank=-1, " " int vbe_output_size=-1, " " bool is_experimental=False) -> Tensor"); + // We're playing a funny trick here: we're using the autograd + // implementation of the operator at all the dispatch keys. This is OK + // because autograd.Function works even in a context where there is + // no autograd enabled, and all of the internal implementations redispatch + // appropriately m.impl( "split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch( c10::DispatchKey::Autograd, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + m.impl( + "split_embedding_codegen_lookup_{{ optimizer }}_function", + torch::dispatch( + c10::DispatchKey::Meta, + TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); DISPATCH_TO_CUDA( "split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp new file mode 100644 index 0000000000..f46d56c096 --- /dev/null +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +{# +// @lint-ignore LINTIGNORE +// @lint-ignore-every CLANGFORMAT +// clang-format off +// Note: clang-format off doesn't work with this templaterized code, +// so we need to keep lint-ignore-every. +// See https://fburl.com/dw9ljh4h +#} + +// Companion template is embedding_forward_split_template.cu + +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set vdesc = "_vbe" if vbe else "" %} + +//////////////////////////////////////////////////////////////////////////////// +// Required for op registrations +#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/embedding_common.h" +//////////////////////////////////////////////////////////////////////////////// + +using namespace fbgemm_gpu; +using Tensor = at::Tensor; + +static constexpr float kINT8QparamsBytes = 8; + +//////////////////////////////////////////////////////////////////////////////// +// Kernel Definitions +//////////////////////////////////////////////////////////////////////////////// + +{%- for nobag in [True, False] %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- if (not nobag or (not weighted and not vbe)) %} +{%- set has_experimental = (not dense and not nobag and not vbe) %} + +Tensor +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta( + const Tensor& dev_weights, + {%- if not dense %} + const Tensor& uvm_weights, + const Tensor& lxu_cache_weights, + const Tensor& weights_placements, + {%- endif %} + const Tensor& weights_offsets, + {%- if not nobag %} + const Tensor& D_offsets, + {%- else %} + const int64_t D, + {%- endif %} + {%- if not nobag %} + const int64_t total_D, + {%- endif %} + {%- if not nobag %} + const int64_t max_D, + {% endif %} + const Tensor& indices, + const Tensor& offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + {%- if weighted %} + const Tensor& indice_weights, + {%- endif %} + {%- if not dense %} + const Tensor& lxu_cache_locations, + {%- endif %} + const int64_t output_dtype, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t vbe_output_size, + const int64_t info_B_num_bits, // int32_t + const int64_t info_B_mask_int64, // uint32_t + {%- endif %} + const bool is_experimental +) { + // NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL + + // TODO: SymIntify + + {%- if not nobag %} + int32_t T = D_offsets.numel() - 1; + {%- else %} + int32_t total_L = indices.numel(); + int32_t T = weights_offsets.numel(); + {%- endif %} + TORCH_CHECK_GT(T, 0); + // offsets = [B x T + 1] + {%- if is_index_select %} + const auto total_B = num_warps_per_feature * T; + const int32_t B = num_warps_per_feature; + {%- else %} + const auto total_B = offsets.size(0) - 1; + const int32_t B = total_B / T; + {%- endif %} + TORCH_CHECK_GE(B, 0); + {%- if not nobag or is_index_select %} + {%- if not nobag %} + TORCH_CHECK_GT(total_D, 0); + TORCH_CHECK_EQ(total_D % 4, 0); + {%- endif %} + TORCH_CHECK_LE(max_D, {{ max_embedding_dim }}); + {%- elif not is_index_select %} + TORCH_CHECK_GT(D, 0); + TORCH_CHECK_EQ(D % 4, 0); + {%- endif %} + {%- if vbe %} + TORCH_CHECK_EQ(vbe_row_output_offsets.numel(), total_B); + TENSORS_HAVE_SAME_NUMEL(vbe_row_output_offsets, vbe_b_t_map); + TORCH_CHECK_GE(vbe_output_size, 0); + + // Cast info_B_mask from int64_t to uint32_t + const uint32_t info_B_mask = info_B_mask_int64; + {%- endif %} + + Tensor output; + {%- if nobag %} + SparseType o_dtype = static_cast(output_dtype); + {%- if is_index_select %} + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16); + + TORCH_CHECK_GT(fixed_L_per_warp, 0); + TORCH_CHECK_GT(num_warps_per_feature, 0); + if (!permute_output_dim_0_1) { + TORCH_CHECK_GE(output_size, 0); + TORCH_CHECK_GT(output_offsets.numel(), 0); + } + + // If permute_output_dim_0_1 is true, output shape is (batch_size * total_D) + // Else, output shape is (output_size) + output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- else %} + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); + + int64_t adjusted_D = D; + if (o_dtype == SparseType::INT8) { + adjusted_D += T * kINT8QparamsBytes; + } + + output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- endif %} + {%- else %} + SparseType o_dtype = static_cast(output_dtype); + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); + int64_t total_adjusted_D = total_D; + if (o_dtype == SparseType::INT8) { + total_adjusted_D += T * kINT8QparamsBytes; + } + + {%- if vbe %} + // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + output = at::empty( + {1, vbe_output_size}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + {%- else %} + output = at::empty( + {B, total_adjusted_D}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + {%- endif %} + {%- endif %} // if nobag + + if (B == 0) { + {%- if vbe %} + output = output.reshape({-1}); + {%- endif %} + return output; + } + + return output; +} + +//////////////////////////////////////////////////////////////////////////////// +// Op registrations +//////////////////////////////////////////////////////////////////////////////// +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // NB: yes cuda here + {%- set embedding_codegen_forward_op = + "{}_embedding{}_codegen_forward_{}{}_cuda".format( + ddesc, ndesc, wdesc, vdesc + ) + %} + m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta))); +} +{%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} +{%- endfor %} {#-/* for nobag */#} + // clang-format on diff --git a/fbgemm_gpu/test/failures_dict_fast.json b/fbgemm_gpu/test/failures_dict_fast.json index 0e7d666814..a691ff1ad3 100644 --- a/fbgemm_gpu/test/failures_dict_fast.json +++ b/fbgemm_gpu/test/failures_dict_fast.json @@ -655,36 +655,7 @@ "status": "xfail" } }, - "fbgemm::split_embedding_codegen_lookup_adagrad_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_adagrad_function": {}, "fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { "comment": "", @@ -697,64 +668,20 @@ "status": "xfail" } }, - "fbgemm::split_embedding_codegen_lookup_lamb_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": { - "comment": "", - "status": "xfail" - } - }, - "fbgemm::split_embedding_codegen_lookup_lars_sgd_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lars": { - "comment": "", - "status": "xfail" - } - }, - "fbgemm::split_embedding_codegen_lookup_none_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_lamb_function": {}, + "fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {}, + "fbgemm::split_embedding_codegen_lookup_none_function": {}, "fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {}, - "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { "comment": "", "status": "xfail" @@ -771,10 +698,6 @@ "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { "comment": "", "status": "xfail" @@ -790,14 +713,6 @@ "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { "comment": "", "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_forward_fused_pooled_emb_quant": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { - "comment": "", - "status": "xfail" } }, "fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {