diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index a86c4642b2..f42d2cda5c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -238,6 +238,14 @@ def forward_split() -> None: vbe_options=[True, False], ) + generate_forward_embedding_cuda( + "embedding_forward_split_meta_template.cpp", + "gen_embedding_forward_{}_codegen_meta.cpp", + dense_options=[True, False], + nobag_options=[False], # nobag is not used + vbe_options=[True, False], + ) + # Generate the kernels for the forward splits generate_forward_embedding_cuda( "embedding_forward_split_kernel_template.cu", diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index eb01e0b191..63fdf6e684 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -767,11 +767,21 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " int max_B_feature_rank=-1, " " int vbe_output_size=-1, " " bool is_experimental=False) -> Tensor"); + // We're playing a funny trick here: we're using the autograd + // implementation of the operator at all the dispatch keys. This is OK + // because autograd.Function works even in a context where there is + // no autograd enabled, and all of the internal implementations redispatch + // appropriately m.impl( "split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch( c10::DispatchKey::Autograd, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + m.impl( + "split_embedding_codegen_lookup_{{ optimizer }}_function", + torch::dispatch( + c10::DispatchKey::Meta, + TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); DISPATCH_TO_CUDA( "split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp new file mode 100644 index 0000000000..f46d56c096 --- /dev/null +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +{# +// @lint-ignore LINTIGNORE +// @lint-ignore-every CLANGFORMAT +// clang-format off +// Note: clang-format off doesn't work with this templaterized code, +// so we need to keep lint-ignore-every. +// See https://fburl.com/dw9ljh4h +#} + +// Companion template is embedding_forward_split_template.cu + +{%- set ddesc = "dense" if dense else "split" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} +{%- set vdesc = "_vbe" if vbe else "" %} + +//////////////////////////////////////////////////////////////////////////////// +// Required for op registrations +#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/embedding_common.h" +//////////////////////////////////////////////////////////////////////////////// + +using namespace fbgemm_gpu; +using Tensor = at::Tensor; + +static constexpr float kINT8QparamsBytes = 8; + +//////////////////////////////////////////////////////////////////////////////// +// Kernel Definitions +//////////////////////////////////////////////////////////////////////////////// + +{%- for nobag in [True, False] %} +{%- set ndesc = "_nobag" if nobag else "" %} +{%- if (not nobag or (not weighted and not vbe)) %} +{%- set has_experimental = (not dense and not nobag and not vbe) %} + +Tensor +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta( + const Tensor& dev_weights, + {%- if not dense %} + const Tensor& uvm_weights, + const Tensor& lxu_cache_weights, + const Tensor& weights_placements, + {%- endif %} + const Tensor& weights_offsets, + {%- if not nobag %} + const Tensor& D_offsets, + {%- else %} + const int64_t D, + {%- endif %} + {%- if not nobag %} + const int64_t total_D, + {%- endif %} + {%- if not nobag %} + const int64_t max_D, + {% endif %} + const Tensor& indices, + const Tensor& offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + {%- if weighted %} + const Tensor& indice_weights, + {%- endif %} + {%- if not dense %} + const Tensor& lxu_cache_locations, + {%- endif %} + const int64_t output_dtype, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t vbe_output_size, + const int64_t info_B_num_bits, // int32_t + const int64_t info_B_mask_int64, // uint32_t + {%- endif %} + const bool is_experimental +) { + // NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL + + // TODO: SymIntify + + {%- if not nobag %} + int32_t T = D_offsets.numel() - 1; + {%- else %} + int32_t total_L = indices.numel(); + int32_t T = weights_offsets.numel(); + {%- endif %} + TORCH_CHECK_GT(T, 0); + // offsets = [B x T + 1] + {%- if is_index_select %} + const auto total_B = num_warps_per_feature * T; + const int32_t B = num_warps_per_feature; + {%- else %} + const auto total_B = offsets.size(0) - 1; + const int32_t B = total_B / T; + {%- endif %} + TORCH_CHECK_GE(B, 0); + {%- if not nobag or is_index_select %} + {%- if not nobag %} + TORCH_CHECK_GT(total_D, 0); + TORCH_CHECK_EQ(total_D % 4, 0); + {%- endif %} + TORCH_CHECK_LE(max_D, {{ max_embedding_dim }}); + {%- elif not is_index_select %} + TORCH_CHECK_GT(D, 0); + TORCH_CHECK_EQ(D % 4, 0); + {%- endif %} + {%- if vbe %} + TORCH_CHECK_EQ(vbe_row_output_offsets.numel(), total_B); + TENSORS_HAVE_SAME_NUMEL(vbe_row_output_offsets, vbe_b_t_map); + TORCH_CHECK_GE(vbe_output_size, 0); + + // Cast info_B_mask from int64_t to uint32_t + const uint32_t info_B_mask = info_B_mask_int64; + {%- endif %} + + Tensor output; + {%- if nobag %} + SparseType o_dtype = static_cast(output_dtype); + {%- if is_index_select %} + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16); + + TORCH_CHECK_GT(fixed_L_per_warp, 0); + TORCH_CHECK_GT(num_warps_per_feature, 0); + if (!permute_output_dim_0_1) { + TORCH_CHECK_GE(output_size, 0); + TORCH_CHECK_GT(output_offsets.numel(), 0); + } + + // If permute_output_dim_0_1 is true, output shape is (batch_size * total_D) + // Else, output shape is (output_size) + output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- else %} + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); + + int64_t adjusted_D = D; + if (o_dtype == SparseType::INT8) { + adjusted_D += T * kINT8QparamsBytes; + } + + output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- endif %} + {%- else %} + SparseType o_dtype = static_cast(output_dtype); + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); + int64_t total_adjusted_D = total_D; + if (o_dtype == SparseType::INT8) { + total_adjusted_D += T * kINT8QparamsBytes; + } + + {%- if vbe %} + // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + output = at::empty( + {1, vbe_output_size}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + {%- else %} + output = at::empty( + {B, total_adjusted_D}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + {%- endif %} + {%- endif %} // if nobag + + if (B == 0) { + {%- if vbe %} + output = output.reshape({-1}); + {%- endif %} + return output; + } + + return output; +} + +//////////////////////////////////////////////////////////////////////////////// +// Op registrations +//////////////////////////////////////////////////////////////////////////////// +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // NB: yes cuda here + {%- set embedding_codegen_forward_op = + "{}_embedding{}_codegen_forward_{}{}_cuda".format( + ddesc, ndesc, wdesc, vdesc + ) + %} + m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_meta))); +} +{%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} +{%- endfor %} {#-/* for nobag */#} + // clang-format on diff --git a/fbgemm_gpu/test/failures_dict_fast.json b/fbgemm_gpu/test/failures_dict_fast.json index 0e7d666814..a691ff1ad3 100644 --- a/fbgemm_gpu/test/failures_dict_fast.json +++ b/fbgemm_gpu/test/failures_dict_fast.json @@ -655,36 +655,7 @@ "status": "xfail" } }, - "fbgemm::split_embedding_codegen_lookup_adagrad_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_adagrad_function": {}, "fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { "comment": "", @@ -697,64 +668,20 @@ "status": "xfail" } }, - "fbgemm::split_embedding_codegen_lookup_lamb_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": { - "comment": "", - "status": "xfail" - } - }, - "fbgemm::split_embedding_codegen_lookup_lars_sgd_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lars": { - "comment": "", - "status": "xfail" - } - }, - "fbgemm::split_embedding_codegen_lookup_none_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_lamb_function": {}, + "fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {}, + "fbgemm::split_embedding_codegen_lookup_none_function": {}, "fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {}, - "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_lamb": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_none_with_rowwise_adagrad": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { "comment": "", "status": "xfail" @@ -771,10 +698,6 @@ "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { "comment": "", "status": "xfail" @@ -790,14 +713,6 @@ "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { "comment": "", "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_forward_fused_pooled_emb_quant": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { - "comment": "", - "status": "xfail" } }, "fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {