diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index 415eeecc21..51388037ba 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -420,6 +420,7 @@ build_fbgemm_gpu_install () { # fbgemm_gpu/ subdirectory present cd - || return 1 (test_python_import_package "${env_name}" fbgemm_gpu) || return 1 + cd - || return 1 echo "[BUILD] FBGEMM-GPU build + install completed" } diff --git a/fbgemm_gpu/README.md b/fbgemm_gpu/README.md index a80186ea2a..8d35169653 100644 --- a/fbgemm_gpu/README.md +++ b/fbgemm_gpu/README.md @@ -13,7 +13,7 @@ packages (2.1+) that are built against those CUDA versions. Only Intel/AMD CPUs with AVX2 extensions are currently supported. -See our [Documentation](docs/README.md) for more information. +See our [Documentation](https://pytorch.org/FBGEMM) for more information. ## Installation diff --git a/fbgemm_gpu/docs/requirements.txt b/fbgemm_gpu/docs/requirements.txt index 010a69c58f..f62b11dae9 100644 --- a/fbgemm_gpu/docs/requirements.txt +++ b/fbgemm_gpu/docs/requirements.txt @@ -10,6 +10,7 @@ breathe bs4 docutils lxml +myst-parser sphinx-lint sphinx-serve six diff --git a/fbgemm_gpu/docs/src/conf.py b/fbgemm_gpu/docs/src/conf.py index ca4e7b21bc..79a19aa500 100644 --- a/fbgemm_gpu/docs/src/conf.py +++ b/fbgemm_gpu/docs/src/conf.py @@ -49,9 +49,11 @@ # ones. extensions = [ "breathe", + "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.autosectionlabel", "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", "sphinx.ext.napoleon", ] diff --git a/fbgemm_gpu/docs/src/general/ContactInfo.rst b/fbgemm_gpu/docs/src/general/ContactInfo.rst new file mode 100644 index 0000000000..b36bb4b6e0 --- /dev/null +++ b/fbgemm_gpu/docs/src/general/ContactInfo.rst @@ -0,0 +1,67 @@ +Testing FBGEMM_GPU +------------------ + +The tests (in the ``fbgemm_gpu/test/`` directory) and benchmarks (in the +``fbgemm_gpu/bench/`` directory) provide good examples on how to use FBGEMM_GPU. + +FBGEMM_GPU Tests +~~~~~~~~~~~~~~~~ + +To run the tests after building / installing the FBGEMM_GPU package: + +.. code:: sh + + # From the /fbgemm_gpu/ directory + cd test + + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning quantize_ops_test.py + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning sparse_ops_test.py + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_embedding_inference_converter_test.py + +Testing with the CUDA Variant +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For the FBGEMM_GPU CUDA package, GPUs will be automatically detected and +used for testing. To run the tests and benchmarks on a GPU-capable +device in CPU-only mode, ``CUDA_VISIBLE_DEVICES=-1`` must be set in the +environment: + +.. code:: sh + + # Enable for running in CPU-only mode (when on a GPU-capable machine) + export CUDA_VISIBLE_DEVICES=-1 + + # Enable for debugging failed kernel executions + export CUDA_LAUNCH_BLOCKING=1 + + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py + +Testing with the ROCm Variant +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For ROCm machines, testing against a ROCm GPU needs to be enabled with +``FBGEMM_TEST_WITH_ROCM=1`` set in the environment: + +.. code:: sh + + # From the /fbgemm_gpu/ directory + cd test + + export FBGEMM_TEST_WITH_ROCM=1 + # Enable for debugging failed kernel executions + export HIP_LAUNCH_BLOCKING=1 + + python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py + +FBGEMM_GPU Benchmarks +~~~~~~~~~~~~~~~~~~~~~ + +To run the benchmarks: + +.. code:: sh + + # From the /fbgemm_gpu/ directory + cd bench + + python split_table_batched_embeddings_benchmark.py uvm diff --git a/fbgemm_gpu/docs/src/general/DocsInstructions.rst b/fbgemm_gpu/docs/src/general/DocsInstructions.rst index 37cbd2868b..9ee4a51989 100644 --- a/fbgemm_gpu/docs/src/general/DocsInstructions.rst +++ b/fbgemm_gpu/docs/src/general/DocsInstructions.rst @@ -199,6 +199,7 @@ description: /// @param param2 Description of param #2 /// /// @return Description of the method's return value. + /// /// @throw fbgemm_gpu::my_error if an error occurs /// /// @note This is an example note. @@ -233,3 +234,39 @@ description: #. Verify the changes by building the docs locally or submitting a PR for a Netlify preview. + + +Sphinx Documentation Pointers +----------------------------- + +Adding References to Other Sections +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To reference other sections in the documentation, an anchor must first be +created above the target section: + +.. code:: rst + + .. _fbgemm-gpu.docs.example.reference: + + Example Section Header + ---------------------- + + NOTES: + + #. The reference anchor must start with an underscore, i.e. ``_``. + + #. There must be an empty line between the anchor and its target. + +The anchor can then be referenced elsewhere in the docs: + +.. code:: rst + + Referencing the section :ref:`fbgemm-gpu.docs.example.reference` from + another page in the docs. + + Referencing the section with + :ref:`custom text ` from another page + in the docs. + + Note that the prefix underscore is not needed when referencing the anchor. diff --git a/fbgemm_gpu/docs/src/general/TestInstructions.rst b/fbgemm_gpu/docs/src/general/TestInstructions.rst index eee3db4d02..b36bb4b6e0 100644 --- a/fbgemm_gpu/docs/src/general/TestInstructions.rst +++ b/fbgemm_gpu/docs/src/general/TestInstructions.rst @@ -1,8 +1,8 @@ Testing FBGEMM_GPU ------------------ -The tests (in the ``test/`` directoy) and benchmarks (in the ``bench/`` -directory) provide good examples on how to use FBGEMM_GPU. +The tests (in the ``fbgemm_gpu/test/`` directory) and benchmarks (in the +``fbgemm_gpu/bench/`` directory) provide good examples on how to use FBGEMM_GPU. FBGEMM_GPU Tests ~~~~~~~~~~~~~~~~ @@ -29,7 +29,9 @@ environment: .. code:: sh + # Enable for running in CPU-only mode (when on a GPU-capable machine) export CUDA_VISIBLE_DEVICES=-1 + # Enable for debugging failed kernel executions export CUDA_LAUNCH_BLOCKING=1 diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index 2b1139d460..057a1d670c 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -20,7 +20,15 @@ library. general/TestInstructions.rst general/DocsInstructions.rst -.. _fbgemm-gpu.docs.toc.python: +.. _fbgemm-gpu.docs.toc.overview: + +.. toctree:: + :maxdepth: 2 + :caption: FBGEMM_GPU Overview + + overview/jagged-tensor-ops/JaggedTensorOps.rst + +.. _fbgemm-gpu.docs.toc.api.python: .. toctree:: :maxdepth: 2 @@ -29,7 +37,7 @@ library. python-api/table_batched_embedding_ops.rst python-api/jagged_tensor_ops.rst -.. _fbgemm-gpu.docs.toc.cpp: +.. _fbgemm-gpu.docs.toc.api.cpp: .. toctree:: :maxdepth: 2 diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion1.png b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion1.png new file mode 100644 index 0000000000..0e00600148 Binary files /dev/null and b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion1.png differ diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion2.png b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion2.png new file mode 100644 index 0000000000..76356ff305 Binary files /dev/null and b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion2.png differ diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion3.png b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion3.png new file mode 100644 index 0000000000..f20378eabe Binary files /dev/null and b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorConversion3.png differ diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorExample.png b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorExample.png new file mode 100644 index 0000000000..a4c64f785e Binary files /dev/null and b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorExample.png differ diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst new file mode 100644 index 0000000000..4ab34f509d --- /dev/null +++ b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst @@ -0,0 +1,308 @@ +Jagged Tensor Operators +======================= + +High Level Overview +------------------- + +The purpose of jagged tensor operators is to handle the case where some +dimension of the input data is "jagged," i.e. each consecutive row in a given +dimension may be a different length. This is similar to the ``NestedTensor`` +`implementation `__ +in PyTorch and the ``RaggedTensor`` +`implementation `__ in +Tensorflow. + +Two notable examples of this type of input are: + +* Sparse feature inputs in recommendation systems + +* Batches of tokenized sentences which may be input to natural language + processing systems. + + +Jagged Tensor Format +------------------- + +Jagged tensors are effectively represented in FBGEMm_GPU as a three-tensor +object. The three tensors are: **Values**, **Max Lengths**, and **Offsets**. + +Values +~~~~~~ + +``Values`` is defined as a 2D tensor that contains all the element values +in the jagged tensor, i.e. ``Values.numel()`` is the number of elements in the +jagged tensor. The size of each row in ``Values`` is derived from the greatest +common divisor for the smallest (most-inner) dimension sub-tensor +(excluding tensors of size 0) in the jagged tensor. + +Offsets +~~~~~~~ + +``Offsets`` is a list of tensors, where each tensor ``Offsets[i]`` represents +the partitioning indices of the values of the next tensor in the list, +``Offsets[i + 1]``. + +For example, ``Offset[i] = [ 0, 3, 4 ]`` implies that the current +dimension ``i`` is divided into two groups, denoted by index bounds +``[0 , 3)`` and ``[3, 4)``. For each ``Offsets[i]``, where +``0 <= i < len(Offests) - 1``, ``Offsets[i][0] = 0``, and +``Offsets[i][-1] = Offsets[i+1].length``. + +``Offsets[-1]`` refers to the outer dimension index of ``Values`` (row index), +i.e. ``offsets[-1]`` would be the partition index of ``Values`` itself. As +such, ``Offsets[-1]``, the tensor begins with ``0`` and ends with +``Values.size(0)`` (i.e. the number of rows for ``Values``). + +Max Lengths +~~~~~~~~~~~ + +``MaxLengths`` is a list of integers, where each value ``MaxLengths[i]`` +represents the maximum value between corresponding offset values in +``Offsets[i]``: + +.. code:: cpp + + MaxLengths[i] = max( Offsets[i][j] - Offsets[i][j-1] | 0 < j < len(Offsets[i]) ) + +The information in ``MaxLengths`` is used for performing the conversion from +jagged tensor to normal (dense) densor where it will be used to determine the +shape of the tensor's dense form. + +.. _fbgemm-gpu.docs.overview.ops.jagged.example: + +Jagged Tensor Example +~~~~~~~~~~~~~~~~~~~~~ + +The figure below shows an example jagged tensor that contains three 2D +sub-tensors, with each sub-tensor having a different dimension: + +.. image:: JaggedTensorExample.png + +In this example, the sizes of the rows in the inner-most dimension of the jagged +tensor are ``8``, ``4``, and ``0``, and so number of elements per row in +``Values`` is set to ``4`` (greatest common divisor). This means ``Values`` +must be of size ``9 x 4`` in order to accomodate all values in the jagged +tensor. + +Because the example jagged tensor contains 2D sub-tensors, the ``Offsets`` list +will need to have a length of 2 to create the partitioning indices. +``Offsets[0]`` represents the partition for dimension ``0`` and ``Offsets[1]`` +represents the partition for dimension ``1``. + +The ``MaxLengths`` values in the example jagged tensor are ``[4 , 2]``. +``MaxLengths[0]`` is derived from ``Offsets[0]`` range ``[4, 0)`` and +``MaxLengths[1]`` is derived from ``Offsets[1]`` range ``[0, 2)`` (or +``[7, 9]``, ``[3,5]``). + +Below is a table of the partition indices applied to the ``Values`` tensor to +construct the logical representation of the example jagged tensor: + +.. _fbgemm-gpu.docs.overview.ops.jagged.example.table: + +.. list-table:: + :header-rows: 1 + + * - ``Offsets[0]`` + - ``Offsets[0]`` Range + - ``Offsets[0]`` Group + - Corresponding ``Offsets[1]`` + - ``Offsets[1]`` Range + - ``Values`` Group + - Corresponding ``Values`` + * - ``[ 0, 4, 6, 8 ]`` + - ``[0, 4)`` + - Group 1 + - ``[ 0, 2, 3, 3, 5 ]`` + - ``[ 0, 2 )`` + - Group 1 + - ``[ [ 1, 2, 3, 4 ], [ 5, 6, 7, 8 ] ]`` + * - + - + - + - + - ``[ 2, 3 )`` + - Group 2 + - ``[ [ 1, 2, 3, 4 ] ]`` + * - + - + - + - + - ``[ 3, 3 )`` + - Group 3 + - ``[ ]`` + * - + - + - + - + - ``[ 3, 5 )`` + - Group 4 + - ``[ [ 1, 2, 3, 4 ], [ 5, 6, 7, 8 ] ]`` + * - + - ``[4, 6)`` + - Group 2 + - ``[ 5, 6, 7 ]`` + - ``[ 5, 6 )`` + - Group 5 + - ``[ [ 1, 2, 3, 4 ] ]`` + * - + - + - + - + - ``[ 6, 7 )`` + - Group 6 + - ``[ [ 1, 2, 7, 9 ] ]`` + * - + - ``[6, 8)`` + - Group 3 + - ``[ 7, 9 ]`` + - ``[ 7, 9 )`` + - Group 7 + - ``[ [ 1, 2, 3, 4 ], [ 8, 8, 9, 6 ] ]`` + + +Jagged Tensor Operations +------------------------ + +At the current stage, FBGEMM_GPU only supports element-wise addition, +multiplication, and conversion operations for jagged tensors. + +Arithmetic Operations +~~~~~~~~~~~~~~~~~~~~~ + +Jagged Tensor addition and multiplication works similar to the +`Hadamard Product `__ +and involves only the ``Values`` of the jagged tensor. For example: + +.. math:: + + \begin{bmatrix} + \begin{bmatrix} + 1. & 2. \\ + 3. & 4. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 5. & 6. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 7. & 8. \\ + 9. & 10. \\ + 11. & 12. \\ + \end{bmatrix} \\ + \end{bmatrix} + \times + \begin{bmatrix} + \begin{bmatrix} + 1. & 2. \\ + 3. & 4. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 5. & 6. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 7. & 8. \\ + 9. & 5. \\ + 2. & 3. \\ + \end{bmatrix} \\ + \end{bmatrix} + \rightarrow + \begin{bmatrix} + \begin{bmatrix} + 1. & 4. \\ + 9. & 16. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 25. & 36. \\ + \end{bmatrix} \\ + \begin{bmatrix} + 49. & 64. \\ + 81. & 50. \\ + 22. & 36. \\ + \end{bmatrix} \\ + \end{bmatrix} + +As such, arithmetic operations on jagged tensors require the two operand to have +same shape. In other words, if we have jagged tensors, :math:`A`, :math:`X`, +:math:`B`, and :math:`C`, where :math:`C = AX + B`, then the following +properties hold: + +.. code:: cpp + + // MaxLengths are the same + C.maxlengths == A.maxlengths == X.maxlengths == B.maxlengths + + // Offsets are the same + C.offsets == A.offsets == X.offsets == B.offsets + + // Values are elementwise equal to the operations applied + C.values[i][j] == A.values[i][j] * X.values[i][j] + B.values[i][j] + +Conversion Operations +~~~~~~~~~~~~~~~~~~~~~ + +Jagged to Dense +^^^^^^^^^^^^^^^ + +.. image:: JaggedTensorConversion1.png + +Conversions of a jagged tensor :math:`J` to the equivalent dense tensor :math:`D` +starts with an empty dense tensor. The shape of :math:`D` is based on the +``MaxLengths``, the inner dimension of ``Values``, and the length of +``Offsets[0]``. The number of dimensions in :math:`D` is: + +.. code:: cpp + + rank(D) = len(MaxLengths) + 2 + +For each dimension in :math:`D`, the dimension size is: + +.. code:: cpp + + dim(i) = MaxLengths[i-1] // (0 < i < D.rank-1) + +Using the example jagged tensor from +:ref:`fbgemm-gpu.docs.overview.ops.jagged.example`, ``len(MaxLengths) = 2``, so +the equivalent dense tensor's rank (number of dimension) will be ``4``. The +example jagged tensor two offset tensors, ``Offsets[0]`` and ``Offsets[1]``. +During the conversion process, elements from ``Values`` will be loaded onto the +dense tensor based on the ranges denoted in the partition indices of +``Offsets[0]`` and ``Offsets[1]`` (see the +:ref:`table ` for the mapping +of the groups to corresponding rows in the dense table): + +.. image:: JaggedTensorConversion2.png + +Some parts of :math:`D` will not have values from :math:`J` loaded into it since +not every partition range denoted in ``Offsets[i]`` has a size equal to +``MaxLengths[i]``. In that case, those parts will be padded with a pad value. +In the above example, the pad value is ``0``. + +Dense to Jagged +^^^^^^^^^^^^^^^ + +For conversons from dense to jagged tensors, values in the dense tensor are +loaded into the jagged tensor's ``Values``. However, it's possible that the +given dense tensor is not same shape referring to the ``Offsets``. It could +lead to the case where jagged tensor can not read in corresponding dense location +if dense's related dimension is smaller than expected. When this happens we +give the padded value to corresponding ``Values`` (see below): + +.. image:: JaggedTensorConversion3.png + +Combined Arithmetic + Conversion Operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some situations, we would like to perform the following operation: + +.. code:: cpp + + dense_tensor + jagged_tensor → dense_tensor (or jagged_tensor) + +We can break such an operation into two steps: + +#. **Conversion Operation** - convert from jagged → dense or dense → jagged + depending on the desired format for the target tensor. After conversion, + the operand tensors, be it dense or jagged, should have the exact same + shapes. + +#. **Arithmetic operation** - perform the arithmetic operations as usual for dense + or jagged tensors. diff --git a/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu b/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu index 304271ea9d..241545411c 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_bfloat16.cu @@ -15,6 +15,11 @@ namespace fbgemm_gpu { /// @ingroup quantize-ops-cuda /// Converts a tensor of `float` values into a tensor of Brain Floating Point /// (`bfloat16`) values. +/// +/// @param input A tensor of `float` values +/// +/// @return A new tensor with values from the input tensor converted to +/// `bfloat16`. DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); @@ -41,6 +46,10 @@ DLL_PUBLIC at::Tensor _float_to_bfloat16_gpu(const at::Tensor& input) { /// @ingroup quantize-ops-cuda /// Converts a tensor of Brain Floating Point (`bfloat16`) values into a tensor /// of `float` values. +/// +/// @param input A tensor of `bfloat16` values +/// +/// @return A new tensor with values from the input tensor converted to `float`. DLL_PUBLIC at::Tensor _bfloat16_to_float_gpu(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu index b4602c56cb..229ac4a91e 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu @@ -325,6 +325,16 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { } /// @ingroup quantize-ops-cuda +/// Converts a tensor of `float` values into a tensor of `fp8` values. +/// +/// @param input A tensor of `float` values. The dtype can be either +/// `SparseType::FP32`, `SparseType::FP16`, or `SparseType::BF16` +/// @param forward +/// +/// @return A new tensor with values from the input tensor converted to `fp8`. +/// +/// @throw c10::Error if `input.dtype` is not one of (`SparseType::FP32`, +/// `SparseType::FP16`, or `SparseType::BF16`) DLL_PUBLIC Tensor _float_to_FP8rowwise_gpu(const Tensor& input, const bool forward) { auto input_type = input.dtype(); @@ -406,6 +416,20 @@ Tensor _FP8rowwise_to_float_gpu_t( return output; } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of `fp8` values into a tensor of `float` values. +/// +/// @param input A tensor of `fp8` values +/// @param forward +/// @param output_dtype The target floating point type, specified as integer +/// representation of `SparseType` enum +/// +/// @return A new tensor with values from the input tensor converted to +/// `float` (with `dtype` of either `SparseType::FP32`, `SparseType::FP16`, or +/// `SparseType::BF16`). +/// +/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32`, +/// `SparseType::FP16`, or `SparseType::BF16`) DLL_PUBLIC at::Tensor _FP8rowwise_to_float_gpu( const at::Tensor& input, bool forward, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu index c6585118b7..9f960c3ecc 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu @@ -330,15 +330,37 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) { } /// @ingroup quantize-ops-cuda +/// Converts a tensor of `float` values into a tensor of fused 8-bit rowwise +/// values. +/// +/// @param input A tensor of `float` values +/// +/// @return A new tensor with values from the input tensor converted to +/// fused 8-bit rowwise values. DLL_PUBLIC Tensor _float_to_fused8bitrowwise_gpu(const Tensor& input) { return _float_to_fused8bitrowwise_gpu_t(input); } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of `at::Half` values into a tensor of fused 8-bit rowwise +/// values. +/// +/// @param input A tensor of `at::Half` values +/// +/// @return A new tensor with values from the input tensor converted to +/// fused 8-bit rowwise values. DLL_PUBLIC Tensor _half_to_fused8bitrowwise_gpu(const Tensor& input) { return _float_to_fused8bitrowwise_gpu_t(input); } /// @ingroup quantize-ops-cuda +/// Converts a tensor of `at::Single` or `at::Half` values into a tensor of +/// fused 8-bit rowwise values. +/// +/// @param input A tensor of `at::Single` or `at::Half` values +/// +/// @return A new tensor with values from the input tensor converted to +/// fused 8-bit rowwise values. DLL_PUBLIC Tensor _single_or_half_precision_to_fused8bitrowwise_gpu(const Tensor& input) { Tensor output; @@ -416,15 +438,42 @@ Tensor _fused8bitrowwise_to_float_gpu_t(const Tensor& input) { return output; } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of fused 8-bit rowwise values into a tensor of `float` +/// values. +/// +/// @param input A tensor of fused 8-bit rowwise values +/// +/// @return A new tensor with values from the input tensor converted to `float`. DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_gpu(const at::Tensor& input) { return _fused8bitrowwise_to_float_gpu_t(input); } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of fused 8-bit rowwise values into a tensor of `at::Half` +/// values. +/// +/// @param input A tensor of fused 8-bit rowwise values +/// +/// @return A new tensor with values from the input tensor converted to +/// `at::Half`. DLL_PUBLIC at::Tensor _fused8bitrowwise_to_half_gpu(const at::Tensor& input) { return _fused8bitrowwise_to_float_gpu_t(input); } /// @ingroup quantize-ops-cuda +/// Converts a tensor of fused 8-bit rowwise values into a tensor of `float`, +/// `at::Half`, or `at::BFloat16` values. +/// +/// @param input A tensor of fused 8-bit rowwise values +/// @param output_dtype The target floating point type, specified as integer +/// representation of `SparseType` enum +/// +/// @return A new tensor with values from the input tensor converted to `float`, +/// `at::Half`, or `at::BFloat16`. +/// +/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32`, +/// `SparseType::FP16`, or `SparseType::BF16`) DLL_PUBLIC at::Tensor _fused8bitrowwise_to_single_or_half_precision_gpu( const at::Tensor& input, const int64_t output_dtype) { @@ -449,6 +498,19 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_single_or_half_precision_gpu( } /// @ingroup quantize-ops-cuda +/// Converts a tensor of fused 8-bit rowwise values into a tensor of +/// `at::kFloat` or `at::kHalf` values. +/// +/// @param input A tensor of fused 8-bit rowwise values +/// @param D_offsets +/// @param output_dtype The target floating point type, specified as integer +/// representation of `SparseType` enum +/// +/// @return A new tensor with values from the input tensor converted to +/// `at::kFloat` or `at::kHalf`. +/// +/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32`, +/// `SparseType::FP16`) DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu( const at::Tensor& input, const at::Tensor& D_offsets, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu index 8b07f80057..bfe9103916 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu @@ -107,7 +107,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel( } // namespace -/// @ingroup quantize-ops-cuda template Tensor _float_to_fusednbitrowwise_gpu_t( const Tensor& input, @@ -191,7 +190,6 @@ DLL_PUBLIC Tensor _float_or_half_to_fusednbitrowwise_gpu( return output; } -/// @ingroup quantize-ops-cuda template Tensor _fusednbitrowwise_to_float_gpu_t( const Tensor& input, @@ -251,6 +249,7 @@ Tensor _fusednbitrowwise_to_float_gpu_t( return output; } +/// @ingroup quantize-ops-cuda DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu( const at::Tensor& input, const int64_t bit_rate) { diff --git a/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu b/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu index 1d3194695a..a05ac038f9 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_hfp8.cu @@ -15,6 +15,15 @@ namespace fbgemm_gpu { /// @ingroup quantize-ops-cuda /// Converts a tensor of `float` values into a tensor of Hybrid 8-bit Floating /// Point (`hfp8`) values. +/// +/// @param input A tensor of `float` values +/// @param ebits +/// @param exponent_bias +/// @param max_pos +/// +/// @return A new tensor with values from the input tensor converted to `hfp8`. +/// +/// @throw c10::Error if `ebits > 0` or `exponent_bias > 0` DLL_PUBLIC at::Tensor _float_to_hfp8_gpu( const at::Tensor& input, const int64_t ebits, @@ -45,6 +54,14 @@ DLL_PUBLIC at::Tensor _float_to_hfp8_gpu( /// @ingroup quantize-ops-cuda /// Converts a tensor of Hybrid 8-bit Floating Point (`hfp8`) values into a /// tensor of `float` values. +/// +/// @param input A tensor of `hfp8` values +/// @param ebits +/// @param exponent_bias +/// +/// @return A new tensor with values from the input tensor converted to `float`. +/// +/// @throw c10::Error if `ebits > 0` or `exponent_bias > 0` DLL_PUBLIC at::Tensor _hfp8_to_float_gpu( const at::Tensor& input, const int64_t ebits, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu b/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu index 3e14db8f62..1416a8110d 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_msfp.cu @@ -111,6 +111,16 @@ __global__ inline void _compute_msfp_shared_exponent_cuda_kernel( /// @ingroup quantize-ops-cuda /// Converts a tensor of `float` values into a tensor of Microsoft Floating /// Point (`msfp`) values. +/// +/// @param input A tensor of `float` values +/// @param bounding_box_size +/// @param ebits +/// @param mbits +/// @param bias +/// @param min_pos +/// @param max_pos +/// +/// @return A new tensor with values from the input tensor converted to `msfp`. DLL_PUBLIC at::Tensor _float_to_msfp_gpu( const at::Tensor& input, const int64_t bounding_box_size, @@ -180,6 +190,13 @@ DLL_PUBLIC at::Tensor _float_to_msfp_gpu( /// @ingroup quantize-ops-cuda /// Converts a tensor of Microsoft Floating Point (`msfp`) values into a tensor /// of `float` values. +/// +/// @param input A tensor of `msfp` values +/// @param ebits +/// @param mbits +/// @param bias +/// +/// @return A new tensor with values from the input tensor converted to `float`. DLL_PUBLIC at::Tensor _msfp_to_float_gpu( const at::Tensor& input, const int64_t ebits, diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp index 93106561a0..8f62be6a0e 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp @@ -49,6 +49,7 @@ Tensor FP8rowwise_to_float_meta( } } +/// @ingroup quantize-data-meta Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) { TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu index ca9601255a..7e288f2d04 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu @@ -395,6 +395,16 @@ Tensor _paddedFP8rowwise_to_float_gpu_t( } /// @ingroup quantize-ops-cuda +/// Converts a tensor of `float` values into a tensor of padded `fp8` rowwise +/// values. +/// +/// @param input A tensor of `float` values. The dtype can be either +/// `SparseType::FP32`, `SparseType::FP16`, or `SparseType::BF16` +/// @param forward +/// @param row_dim +/// +/// @return A new tensor with values from the input tensor converted to padded +/// `fp8` rowwise values. DLL_PUBLIC Tensor _float_to_paddedFP8rowwise_gpu( const Tensor& input, const bool forward, @@ -402,6 +412,23 @@ DLL_PUBLIC Tensor _float_to_paddedFP8rowwise_gpu( return _float_to_paddedFP8rowwise_gpu_t(input, forward, row_dim); } +/// @ingroup quantize-ops-cuda +/// Converts a tensor of padded `fp8` rowwise values into a tensor of `float +/// values`. +/// +/// @param input A tensor of `float` values. The dtype can be either +/// `SparseType::FP32`, `SparseType::FP16`, or `SparseType::BF16` +/// @param forward +/// @param row_dim +/// @param output_last_dim +/// @param output_dtype The target floating point type, specified as integer +/// representation of `SparseType` enum +/// +/// @return A new tensor with values from the input tensor converted to `float` +/// values. +/// +/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32`, +/// `SparseType::FP16`, `SparseType::BF16`) DLL_PUBLIC at::Tensor _paddedFP8rowwise_to_float_gpu( const at::Tensor& input, const bool forward,