From ab6f7ff02ced43883905e1b0b6e1fd46c7887383 Mon Sep 17 00:00:00 2001 From: Flavio Sales Truzzi <590773+flaviotruzzi@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:44:26 -0700 Subject: [PATCH] - Add abstract impl for FloatToHFP8Quantized (#2983) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2983 As title Differential Revision: D61216517 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 10 ++++++++++ fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp | 8 ++++++++ fbgemm_gpu/test/quantize/failures_dict_fast.json | 15 +-------------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 62c811bd88..24897c12a0 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -966,6 +966,12 @@ def histogram_binning_calibration_abstract( return torch.empty_like(logit), torch.empty([logit.numel()], dtype=torch.int64) +def float_to_hfp8_quantized( + input: Tensor, ebits: int, exponent_bias: int, max_pos: float +) -> Tensor: + return torch.empty_like(input, dtype=torch.uint8) + + def _setup() -> None: # pyre-ignore[16] _setup.done = getattr(_setup, "done", False) @@ -1092,6 +1098,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None "fbgemm::histogram_binning_calibration", histogram_binning_calibration_abstract, ) + impl_abstract( + "fbgemm::FloatToHFP8Quantized", + float_to_hfp8_quantized, + ) _setup.done = True diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index 89411a073d..758ac91a9d 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -428,6 +428,14 @@ at::Tensor _hfp8_to_float_cpu( } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); +#endif + + m.set_python_module("fbgemm_gpu.sparse_ops"); + m.def("FloatToFused8BitRowwiseQuantized(Tensor t) -> Tensor"); m.def( "FloatToFP8RowwiseQuantized(Tensor t, bool forward) -> Tensor", diff --git a/fbgemm_gpu/test/quantize/failures_dict_fast.json b/fbgemm_gpu/test/quantize/failures_dict_fast.json index 1571d8e5c7..fa917c72cc 100644 --- a/fbgemm_gpu/test/quantize/failures_dict_fast.json +++ b/fbgemm_gpu/test/quantize/failures_dict_fast.json @@ -30,20 +30,7 @@ "status": "xfail" } }, - "fbgemm::FloatToHFP8Quantized": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache_fp8_2048": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::FloatToHFP8Quantized": {}, "fbgemm::Fused8BitRowwiseQuantizedToFloat": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_forward_cpu_int8": { "comment": "",