From 5a15342f874adb7dad908dfcf318c02f618778a8 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Tue, 19 Jul 2022 16:00:22 -0700 Subject: [PATCH 01/14] Fix CUDA error when grad_output is contiguous but address is not 16 bytes aligned (#1212) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1212 As title Reviewed By: divchenko Differential Revision: D37951520 fbshipit-source-id: a2c2ab57bb13ed750e986b0326566a8f0a8ea3ae --- fbgemm_gpu/codegen/embedding_backward_dense_host.cpp | 3 +++ .../codegen/embedding_backward_split_host_template.cpp | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 4cb6903157..1cb28e77d4 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -328,6 +328,9 @@ class SplitNoBagLookupFunction_Dense_Op grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { grad_output = grad_output.contiguous(); } + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } auto grad_dev_weights = split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda( diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index 9e212ffd24..44867e4612 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -275,10 +275,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || - grad_output.stride(0) % 4 != 0) { + grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { grad_output = grad_output.contiguous(); } + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } {% if not nobag %} if (!indice_weights.defined()) { From c9bbb77480c314602fa33e9ba8cd14304bf69811 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Wed, 20 Jul 2022 10:52:59 -0700 Subject: [PATCH 02/14] Add pooling mode to device bench (#1194) Summary: This would help us benchmark EmbeddingCollection in torchrec. Thank you for your time in reviewing this PR :) Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1194 Reviewed By: jianyuh Differential Revision: D37845758 Pulled By: colin2328 fbshipit-source-id: a65ef76b195ca2cfd56b69d39e2d59ae930edfae --- ...plit_table_batched_embeddings_benchmark.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 6e1d71ab2b..e72320407c 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -91,6 +91,7 @@ def cli() -> None: @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") @click.option("--weighted-num-requires-grad", type=int, default=None) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) @click.option("--flush-gpu-cache-size-mb", default=0) @@ -113,6 +114,7 @@ def device( # noqa C901 reuse: float, row_wise: bool, weighted: bool, + pooling: str, weighted_num_requires_grad: Optional[int], bounds_check_mode: int, flush_gpu_cache_size_mb: int, @@ -161,6 +163,17 @@ def device( # noqa C901 else: managed_option = EmbeddingLocation.MANAGED + if pooling is None or pooling == "sum": + pooling = "sum" + pooling_mode = PoolingMode.SUM + do_pooling = True + elif pooling == "mean": + pooling_mode = PoolingMode.MEAN + do_pooling = True + else: # "none" + pooling_mode = PoolingMode.NONE + do_pooling = False + if dense: emb = DenseTableBatchedEmbeddingBagsCodegen( [ @@ -170,6 +183,7 @@ def device( # noqa C901 ) for d in Ds ], + pooling_mode=pooling_mode, use_cpu=not torch.cuda.is_available(), ) else: @@ -191,6 +205,7 @@ def device( # noqa C901 weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, + pooling_mode=pooling_mode, bounds_check_mode=BoundsCheckMode(bounds_check_mode), ) emb = emb.to(get_device()) @@ -200,6 +215,17 @@ def device( # noqa C901 nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 + output_size_multiplier = output_dtype.bit_rate() / 8.0 + if do_pooling: + read_write_bytes = ( + output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L + ) + else: + read_write_bytes = ( + output_size_multiplier * B * sum(Ds) * L + + param_size_multiplier * B * sum(Ds) * L + ) + logging.info( f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" @@ -236,7 +262,7 @@ def device( # noqa C901 logging.info( f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " - f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) @@ -244,7 +270,10 @@ def device( # noqa C901 # backward bench not representative return - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + grad_output = torch.randn(B * T * L, D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -258,7 +287,7 @@ def device( # noqa C901 ) logging.info( f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " - f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " + f"BW: {3 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us" ) From 58a3df003df60304742f69f81d3c7f72a635e464 Mon Sep 17 00:00:00 2001 From: Chao Gu Date: Wed, 20 Jul 2022 11:46:29 -0700 Subject: [PATCH 03/14] Use proper type for pooling_mode (#1211) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1211 As title. Reviewed By: jianyuh Differential Revision: D37877585 fbshipit-source-id: 8063c05c5ed03b3c9c76c3a20bfb41db332a6d19 --- fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 8862ce4aa0..d1b0139a54 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -2064,7 +2064,7 @@ def forward( max_float32_D=self.max_float32_D, indices=indices, offsets=offsets, - pooling_mode=self.pooling_mode, + pooling_mode=int(self.pooling_mode), indice_weights=per_sample_weights, output_dtype=self.output_dtype, lxu_cache_weights=self.lxu_cache_weights, From 1870839df2051a05b8a84c69ae95c394846d1b92 Mon Sep 17 00:00:00 2001 From: Jiecao Yu Date: Wed, 20 Jul 2022 13:55:25 -0700 Subject: [PATCH 04/14] Optimize the implementation of torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output() (#1213) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1213 Reviewed By: mjanderson09 Differential Revision: D37446339 fbshipit-source-id: 88486749716d87fe618ba1f2366fb5dbe3b6691f --- fbgemm_gpu/src/jagged_tensor_ops.cu | 68 ++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 5623cff2e8..477f07fae0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -726,7 +726,7 @@ class JaggedDenseDenseAddJaggedOutputGPUOp AT_DISPATCH_FLOATING_TYPES_AND_HALF( x_values.scalar_type(), - "agged_dense_dense_elementwise_jagged_output_forward", + "jagged_dense_dense_elementwise_jagged_output_forward", [&] { jagged_dense_dense_elementwise_jagged_output_( x_values, @@ -775,6 +775,64 @@ std::tuple> dense_to_jagged( const c10::optional& total_L) { return {DenseToJaggedGPUOp::apply(dense, offsets, total_L)[0], offsets}; } + +class JaggedDenseAddJaggedOutputGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& offsets, + const Tensor& dense) { + ctx->save_for_backward(offsets); + ctx->saved_data["dense_shape"] = dense.sizes(); + + auto output = at::empty_like(x_values); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(dense.get_device()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), + "jagged_dense_elementwise_jagged_output_forward", + [&] { + jagged_dense_elementwise_jagged_output_( + x_values, + offsets, + dense, + output, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + auto dense_shape = ctx->saved_data["dense_shape"].toIntVector(); + TORCH_CHECK(grad_outputs.size() == 1); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_outputs[0].get_device()); + + Tensor dense_values_grad = jagged_to_padded_dense( + grad_outputs[0], + offsets, + std::vector(dense_shape.begin() + 1, dense_shape.end() - 1), + /*padding_value=*/0); + TORCH_CHECK(dense_values_grad.sizes() == dense_shape); + + return { + grad_outputs[0], + torch::autograd::Variable(), // offsets + dense_values_grad}; + } +}; + ///@ingroup jagged-tensor-ops-cuda /// output = x + y where x is jagged, y is dense, and output is jagged std::tuple> @@ -782,12 +840,8 @@ jagged_dense_elementwise_add_jagged_output( const Tensor& x_values, const std::vector& x_offsets, const Tensor& y) { - // Convert to jagged - auto jagged_values = - DenseToJaggedGPUOp::apply(y, x_offsets, c10::optional())[0]; - - // Add jagged_values + x_values -> sum_values - auto sum_values = x_values + jagged_values; + auto sum_values = + JaggedDenseAddJaggedOutputGPUOp::apply(x_values, x_offsets, y)[0]; return {sum_values, x_offsets}; } From 055158625eefc0f1f510ca6e9f0b9ccdfb836b95 Mon Sep 17 00:00:00 2001 From: CodemodService FBSourceBlackLinterBot <> Date: Fri, 22 Jul 2022 02:09:05 -0700 Subject: [PATCH 05/14] Daily `arc lint --take BLACK` Reviewed By: zsol Differential Revision: D38069933 fbshipit-source-id: 006a059e4728595df0a69df2890b38d4331b05e6 --- fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index e72320407c..7b550c21fc 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -218,7 +218,8 @@ def device( # noqa C901 output_size_multiplier = output_dtype.bit_rate() / 8.0 if do_pooling: read_write_bytes = ( - output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L + output_size_multiplier * B * sum(Ds) + + param_size_multiplier * B * sum(Ds) * L ) else: read_write_bytes = ( From 49061a284744423bad17162d45a9bdc56adb4739 Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Fri, 22 Jul 2022 15:35:04 -0700 Subject: [PATCH 06/14] add long type for jagged op. (#1214) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1214 as title. Reviewed By: jiaqizhai, brad-mengchi, mjanderson09 Differential Revision: D37978119 fbshipit-source-id: c2c004dfb2e1483f6fbf6a415c9dd58d95599cb0 --- fbgemm_gpu/src/jagged_tensor_ops.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 477f07fae0..7eb03b901d 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -667,8 +667,12 @@ class DenseToJaggedGPUOp at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dense.get_device()); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - values.scalar_type(), "jagged_dense_add_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::Long, + values.scalar_type(), + "jagged_dense_add_forward", + [&] { jagged_dense_elementwise_jagged_output_( values, offsets, From a6f5488c4c3380aa8dc0e567bf9401dde1922cd5 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Mon, 25 Jul 2022 11:44:18 -0700 Subject: [PATCH 07/14] Extract the quantization utils function in FBGEMM (#1204) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1204 This will be better shared between Trec and HPC. - We will refactor https://www.internalfb.com/code/fbsource/[history]/fbcode/caffe2/torch/fb/hpc/quantized_comms_lib.py to extract the common components in FBGEMM - It's open source so TorchRec can call it from FBGEMM. - Reuse the quantize utils functions and dedup the code. This part of code is landable. Reviewed By: YLGH Differential Revision: D37799807 fbshipit-source-id: ced3c98efd096985db02c287449fa48939fd3da3 --- fbgemm_gpu/fbgemm_gpu/quantize_utils.py | 124 ++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 fbgemm_gpu/fbgemm_gpu/quantize_utils.py diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py new file mode 100644 index 0000000000..308bea6b1b --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +logger: logging.Logger = logging.getLogger() + +try: + # pyre-ignore[21] + from fbgemm_gpu import open_source # noqa: F401 +except Exception: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + +TORCH_HALF_MIN: float = torch.finfo(torch.float16).min +TORCH_HALF_MAX: float = torch.finfo(torch.float16).max + +TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min +TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max + + +def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: + return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + + +def fp32_to_bf16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: + return torch.clamp(tensor, TORCH_BFLOAT16_MIN, TORCH_BFLOAT16_MAX).bfloat16() + + +def fp32_to_hfp8_with_clamp( + tensor: torch.Tensor, ebits: int = 4, mbits: int = 3, bias: int = 15 +) -> torch.Tensor: + max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits)) + return torch.ops.fbgemm.FloatToHFP8Quantized( + tensor.contiguous(), + ebits, + bias, + max_pos, + ) + + +def fp16_to_fp32(tensor: torch.Tensor) -> torch.Tensor: + return tensor.float() + + +def bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor: + return tensor.view(torch.bfloat16).float() + + +def hfp8_to_fp32(tensor: torch.Tensor, ebits: int = 4, bias: int = 15) -> torch.Tensor: + return torch.ops.fbgemm.HFP8QuantizedToFloat( + tensor.contiguous().view(torch.uint8), + ebits, + bias, + ) + + +def measure_fp16_quant_error(input_tensor: torch.Tensor) -> None: + # TODO: log to tensorboard + + num_nan_fp32_tensor = torch.numel(input_tensor[torch.isnan(input_tensor)]) + logger.info( + "num NaN in fp32 tensor: {}, ratio: {}.".format( + num_nan_fp32_tensor, num_nan_fp32_tensor / torch.numel(input_tensor) + ) + ) + + logger.info( + "fp32 tensor profile: min: {}, max: {}, min abs:{}, max abs:{}.".format( + torch.min(input_tensor), + torch.max(input_tensor), + torch.min(torch.abs(input_tensor)), + torch.max(torch.abs(input_tensor)), + ) + ) + + fp16_tensor = fp32_to_fp16_with_clamp(input_tensor) + num_nan_fp16_tensor = torch.numel(fp16_tensor[torch.isnan(fp16_tensor)]) + + logger.info( + "num NaN in fp16 tensor: {}, ratio: {}.".format( + num_nan_fp16_tensor, num_nan_fp16_tensor / torch.numel(input_tensor) + ) + ) + + diff = torch.abs(input_tensor - fp16_tensor.float()) + rel_diff = diff / torch.abs(input_tensor) + logger.info( + "fp32_to_fp16 abs error: min={}, max={}, avg={}.".format( + torch.min(diff), torch.max(diff), torch.mean(diff) + ) + ) + + rel_diff_not_nan = rel_diff[torch.logical_not(torch.isnan(rel_diff))] + logger.info( + "fp32_to_fp16 rel error: min={}, max={}, avg={}.".format( + torch.min(rel_diff_not_nan), + torch.max(rel_diff_not_nan), + torch.mean(rel_diff_not_nan), + ) + ) + + rel_diff_1_idx = torch.where(rel_diff == 1.0) + fp32_rel_err_1_vals = input_tensor[rel_diff_1_idx] + if torch.numel(fp32_rel_err_1_vals) > 0: + fp32_rel_err_1_vals = torch.abs(fp32_rel_err_1_vals) + logger.info( + "fp32_to_fp16 rel error == 1: fp32 min:{}, fp32 max:{}, fp32 avg:{}.".format( + torch.min(fp32_rel_err_1_vals), + torch.max(fp32_rel_err_1_vals), + torch.mean(fp32_rel_err_1_vals), + ) + ) + + subrange_ratio = torch.numel(fp16_tensor[rel_diff_1_idx]) / torch.numel( + fp16_tensor + ) + logger.info("sub fp16 range ratio: {}".format(subrange_ratio)) From 60a5c2b6eb83cf847949c2dbb7553c8162dc337a Mon Sep 17 00:00:00 2001 From: Ying Liu Date: Mon, 25 Jul 2022 19:07:25 -0700 Subject: [PATCH 08/14] Added quantization codecs to utils (#1219) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1219 This provides a generic codec interface that can be used for quantized comms. Reviewed By: colin2328, jianyuh Differential Revision: D38125284 fbshipit-source-id: 12a3a1e05e2893e1931e4782c47a75aab9840e52 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 110 +++++++++++++++++++++++++ fbgemm_gpu/test/quantize_comm_test.py | 67 +++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 fbgemm_gpu/fbgemm_gpu/quantize_comm.py create mode 100644 fbgemm_gpu/test/quantize_comm_test.py diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py new file mode 100644 index 0000000000..871fcb365c --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import Optional + +import torch + +from fbgemm_gpu.quantize_utils import ( + bf16_to_fp32, + fp16_to_fp32, + fp32_to_bf16_with_clamp, + fp32_to_fp16_with_clamp, + fp32_to_hfp8_with_clamp, + hfp8_to_fp32, +) +from fbgemm_gpu.split_embedding_configs import SparseType +from torch.autograd.profiler import record_function + +logger: logging.Logger = logging.getLogger() + + +def _quantize_tensor( + input_tensor: torch.Tensor, + comm_precision: SparseType, +) -> torch.Tensor: + if comm_precision == SparseType.FP32: + return input_tensor + elif comm_precision == SparseType.FP16: + return fp32_to_fp16_with_clamp(input_tensor) + elif comm_precision == SparseType.BF16: + return fp32_to_bf16_with_clamp(input_tensor) + elif comm_precision == SparseType.FP8: + return fp32_to_hfp8_with_clamp(input_tensor) + else: + raise ValueError(f"comm_precision={comm_precision} is not supported") + + +def _dequantize_tensor( + quantized_tensor: torch.Tensor, + comm_precision: SparseType, +) -> torch.Tensor: + if comm_precision == SparseType.FP32: + assert quantized_tensor.dtype == torch.float + return quantized_tensor + elif comm_precision == SparseType.FP16: + assert quantized_tensor.dtype == torch.half + return fp16_to_fp32(quantized_tensor) + elif comm_precision == SparseType.BF16: + assert quantized_tensor.dtype == torch.bfloat16 + return bf16_to_fp32(quantized_tensor) + elif comm_precision == SparseType.FP8: + assert quantized_tensor.dtype == torch.uint8 + return hfp8_to_fp32(quantized_tensor) + else: + raise ValueError(f"comm_precision={comm_precision} is not supported") + + +class QuantizedCommCodec: + def __init__( + self, + comm_precision: SparseType, + loss_scale: Optional[float] = None, + ) -> None: + + if loss_scale is not None: + if comm_precision not in [SparseType.FP16, SparseType.BF16]: + logger.warning( + f"Setting loss scale for comm_precision={comm_precision} is not supported. Overriding to None" + ) + loss_scale = None + + logger.info( + f"Creating QuantizedCommsCodec comm_precision:{comm_precision}, loss_scale:{loss_scale}" + ) + + self._comm_precision = comm_precision + self._loss_scale = loss_scale + + def encode(self, input_tensor: torch.Tensor) -> torch.Tensor: + if self._loss_scale is not None: + input_tensor = self._loss_scale * input_tensor + with record_function( + f"## encoder {self._comm_precision} {self._loss_scale} ##" + ): + return _quantize_tensor(input_tensor, self._comm_precision) + + def decode(self, input_grad: torch.Tensor) -> torch.Tensor: + if self._loss_scale is not None: + input_grad = input_grad / self._loss_scale + with record_function( + f"## decoder {self._comm_precision} {self._loss_scale} ##" + ): + dequantized_tensor = _dequantize_tensor(input_grad, self._comm_precision) + return dequantized_tensor + + @property + def quantized_dtype(self) -> torch.dtype: + if self._comm_precision == SparseType.FP16: + return torch.half + elif self._comm_precision == SparseType.BF16: + return torch.bfloat16 + elif self._comm_precision == SparseType.FP8: + return torch.uint8 + return torch.float diff --git a/fbgemm_gpu/test/quantize_comm_test.py b/fbgemm_gpu/test/quantize_comm_test.py new file mode 100644 index 0000000000..df024d645e --- /dev/null +++ b/fbgemm_gpu/test/quantize_comm_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Optional, Tuple + +import hypothesis.strategies as st +import torch +from fbgemm_gpu.quantize_comm import QuantizedCommCodec +from fbgemm_gpu.split_embedding_configs import SparseType +from hypothesis import assume, given, settings + + +class QuantizedCommCodecTest(unittest.TestCase): + @settings(deadline=2000) + # pyre-ignore + @given( + comm_precisions_loss_scale=st.sampled_from( + [ + (SparseType.FP32, None), + (SparseType.FP16, None), + (SparseType.FP16, 4.0), + (SparseType.BF16, None), + (SparseType.BF16, 2.0), + (SparseType.FP8, None), + (SparseType.FP8, 3.0), + ] + ), + row_size=st.integers(4, 256), + col_size=st.integers(4, 256), + rand_seed=st.integers(0, 65534), + ) + def test_quantized_comm_codec( + self, + comm_precisions_loss_scale: Tuple[SparseType, Optional[float]], + row_size: int, + col_size: int, + rand_seed: int, + ) -> None: + + (comm_precision, loss_scale) = comm_precisions_loss_scale + if comm_precision == SparseType.FP8: + assume(col_size % 4 == 0) + + torch.manual_seed(rand_seed) + shape = (row_size, col_size) + + quant_codec = QuantizedCommCodec(comm_precision, loss_scale) + + input_tensor = torch.rand(shape, requires_grad=True) + + quant_tensor = quant_codec.encode(input_tensor) + output_tensor = quant_codec.decode(quant_tensor) + + rtol = 0.005 + atol = 0.005 + if comm_precision == SparseType.FP8: + rtol = 0.05 + atol = 0.05 + + torch.testing.assert_close( + input_tensor.detach(), output_tensor.detach(), rtol=rtol, atol=atol + ) From 51ea00e576090c945b1481a1639be48db80a5636 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Mon, 25 Jul 2022 21:59:11 -0700 Subject: [PATCH 09/14] add --no-dkms to install ROCm on CI (#1220) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1220 This patch fixes a failure of `build_amd_gpu` on GitHub CI. ## Background We start to encounter the following error on `build_amd_gpu` recently (around July 11th). ``` Building initial module for 5.15.0-1014-azure Error! Bad return status for module build on kernel: 5.15.0-1014-azure (x86_64) Consult /var/lib/dkms/amdgpu/5.13.20.22.10-1401700/build/make.log for more information. dpkg: error processing package amdgpu-dkms (--configure): installed amdgpu-dkms package post-installation script subprocess returned error exit status 10 ``` ## Solution I did not perform any deep investigation, but basically, we don't need dkms for this test. This patch disables it. Reviewed By: sryap Differential Revision: D38133302 fbshipit-source-id: ade20dcec658ac186d8f8ad1b713d339e504ba6f --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 9595734852..250544bed5 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -267,7 +267,7 @@ jobs: wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb export DEBIAN_FRONTEND=noninteractive sudo apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb - amdgpu-install -y --usecase=hiplibsdk,rocm + amdgpu-install -y --usecase=hiplibsdk,rocm --no-dkms sudo rm amdgpu-install_22.10.1.50101-1_all.deb - name: Install dependencies From 308dd51d63702196259a637d0adede371c2f1cf5 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Tue, 26 Jul 2022 02:59:17 -0700 Subject: [PATCH 10/14] refactor grad output non-contiguous handler (#1215) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1215 This is a follow-up on D37951520 (https://github.com/pytorch/FBGEMM/commit/5a15342f874adb7dad908dfcf318c02f618778a8) - Minor clean-up and refactoring for non-contiguous grad output: avoid do .contiguous and clone twice when alignment is not multiple of 16. - Add more comments. - Add unit test coverage TODO: add the 16 alignment unit test coverage. Reviewed By: brad-mengchi Differential Revision: D37988742 fbshipit-source-id: ce8b7f9658c9e6229d67b420212bb3d55daea8c8 --- .../codegen/embedding_backward_dense_host.cpp | 16 ++++++++++------ .../embedding_backward_split_host_template.cpp | 8 ++++---- .../test/split_table_batched_embeddings_test.py | 12 ++++++------ 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 1cb28e77d4..7117db7509 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -174,8 +174,12 @@ class SplitLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { + + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { grad_output = grad_output.contiguous(); } @@ -324,12 +328,12 @@ class SplitNoBagLookupFunction_Dense_Op using torch::autograd::Variable; auto grad_output = grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { + grad_output = grad_output.contiguous(); } auto grad_dev_weights = diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index 44867e4612..70fcd1f547 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -274,12 +274,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { grad_output = at::empty_like(grad_output).copy_(grad_output); + } else if (!grad_output.is_contiguous()) { + grad_output = grad_output.contiguous(); } {% if not nobag %} diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index d153a9170d..f6bde3d7cf 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1324,9 +1324,9 @@ def test_backward_dense( rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5, ) if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous() + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) torch.testing.assert_close( cc.weights.grad, @@ -1584,9 +1584,9 @@ def test_backward_sgd( # noqa C901 else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)) ) if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous() + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) if use_cache: cc.flush() @@ -1817,7 +1817,7 @@ def execute_backward_adagrad_( # noqa C901 if do_pooling: goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) cc.flush() split_optimizer_states = [s for (s,) in cc.split_optimizer_states()] @@ -2637,7 +2637,7 @@ def execute_backward_optimizers_( # noqa C901 if do_pooling: goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - goc = torch.cat(gos, dim=0).contiguous() + goc = torch.cat(gos, dim=0) fc2.backward(goc) cc.flush() From 09099eac922921e8db01ba24455c953fe064c58a Mon Sep 17 00:00:00 2001 From: Yu Guo Date: Tue, 26 Jul 2022 14:08:56 -0700 Subject: [PATCH 11/14] add optional warmup in benchmark_requests (#1223) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1223 useful if some kernel warmup is desired Reviewed By: sryap Differential Revision: D38141801 fbshipit-source-id: b7efbad77040fa7a5216a14114a975932163ed14 --- fbgemm_gpu/bench/bench_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 38436f0a25..6edf200bbc 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -244,8 +244,15 @@ def benchmark_requests( func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], flush_gpu_cache_size_mb: int = 0, check_median: bool = False, + num_warmups: int = 0, ) -> float: times = [] + + if num_warmups > 0: + indices, offsets, weights = requests[0] + for _ in range(num_warmups): + func(indices, offsets, weights) + if torch.cuda.is_available(): torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) From 949420cb954ce4c7558d16ef9f695385ed813907 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Tue, 26 Jul 2022 14:41:06 -0700 Subject: [PATCH 12/14] Add Int types in jagged_index_select (#1218) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1218 Add support for Int types in jagged_index_select and update the test to validate Int types Reviewed By: jianyuh Differential Revision: D38105237 fbshipit-source-id: c9392da8378759bda384abd768a44f0aef5ad1b1 --- fbgemm_gpu/src/jagged_tensor_ops.cu | 14 ++++++--- fbgemm_gpu/test/jagged_tensor_ops_test.py | 35 +++++++++++++++++------ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 7eb03b901d..3ddd1d61d0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1623,8 +1623,11 @@ Tensor jagged_index_select_2d_cuda( at::empty({num_dense_output_rows, num_cols}, values.options()); if (num_blocks > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, + values.scalar_type(), + "jagged_index_select_2d_kernel_wrapper_1", + [&] { AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "jagged_index_select_2d_kernel_wrapper_2", @@ -1715,8 +1718,11 @@ Tensor jagged_index_add_2d_cuda( Tensor output = at::zeros({num_output_rows, num_cols}, grad.options()); if (num_blocks > 0) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "jagged_index_add_2d_kernel_wrapper_1", [&] { + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::Half, + grad.scalar_type(), + "jagged_index_add_2d_kernel_wrapper_1", + [&] { AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "jagged_index_add_2d_kernel_wrapper_2", diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index b567652f3a..7a67f0e12a 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -972,7 +972,9 @@ def jagged_index_select_2d_ref( num_cols=st.integers(1, 128), num_jagged_tensor_rows=st.integers(1, 128), index_dtype=st.sampled_from([torch.int, torch.long]), - jagged_tensor_dtype=st.sampled_from([torch.float, torch.half]), + jagged_tensor_dtype=st.sampled_from( + [torch.float, torch.half, torch.int, torch.long] + ), ) @settings(max_examples=20, deadline=None) def test_jagged_index_select_2d( @@ -984,6 +986,7 @@ def test_jagged_index_select_2d( index_dtype: torch.dtype, jagged_tensor_dtype: torch.dtype, ) -> None: + is_float = jagged_tensor_dtype in [torch.float, torch.half] lengths = torch.randint( low=0, high=max_seq_length, @@ -1000,21 +1003,35 @@ def test_jagged_index_select_2d( device="cuda", ) ) - values = torch.rand( - int(lengths.sum().item()), - num_cols, - dtype=jagged_tensor_dtype, - device="cuda", - ) + if is_float: + values = torch.rand( + int(lengths.sum().item()), + num_cols, + dtype=jagged_tensor_dtype, + device="cuda", + ) + else: + values = torch.randint( + 2**16, + (int(lengths.sum().item()), num_cols), + dtype=jagged_tensor_dtype, + device="cuda", + ) values_ref = values.detach().clone() - values.requires_grad = True - values_ref.requires_grad = True + + # Only float tensors can require grad + if is_float: + values.requires_grad = True + values_ref.requires_grad = True output, _ = torch.ops.fbgemm.jagged_index_select(values, lengths, indices) output_ref = self.jagged_index_select_2d_ref(values_ref, lengths, indices) assert torch.equal(output, output_ref) + if not is_float: + return + grad = torch.rand_like(output) grad_ref = grad.detach().clone() From cca200ca52240c7cdddbedc358fcdad0178d6f75 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Wed, 27 Jul 2022 22:04:03 -0700 Subject: [PATCH 13/14] Update asmjit submodule for OSS (#1202) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1202 Update OSS asmjit version to d3fbf7c9bc7c1d1365a94a45614b91c5a3706b81, sync with the third-party/asmjit. Add a -Wno-unused-variable compilation flag to enable Clang 14. Reviewed By: brad-mengchi Differential Revision: D37763232 fbshipit-source-id: 650ee2ba8601dc0922640507f28691a87d275db3 --- CMakeLists.txt | 4 ++++ third_party/asmjit | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ae2c4300df..d8b6b989c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -297,6 +297,10 @@ endif() if(FBGEMM_BUILD_BENCHMARKS) add_subdirectory(bench) + # add a flag to enable Clang 14 + set_source_files_properties( + bench/GEMMsBenchmark.cc + PROPERTIES COMPILE_FLAGS "-Wno-unused-variable") endif() if(FBGEMM_BUILD_DOCS) diff --git a/third_party/asmjit b/third_party/asmjit index 752eb38a4d..d3fbf7c9bc 160000 --- a/third_party/asmjit +++ b/third_party/asmjit @@ -1 +1 @@ -Subproject commit 752eb38a4dbe590995cbadaff06baadd8378eeeb +Subproject commit d3fbf7c9bc7c1d1365a94a45614b91c5a3706b81 From 48127588bd48347673a3e277015529e1a687fb12 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 28 Jul 2022 13:37:52 -0700 Subject: [PATCH 14/14] Revert D37670913 "[fbgemm] Use std::fma in tail of Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2" (#1222) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1222 Reviewed By: Liang-Dong Differential Revision: D38155592 fbshipit-source-id: 025fc4445029a104fe38e76ad317a2e0dfb6ede7 --- src/QuantUtilsAvx2.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 9616f21624..b6dfbba601 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -2164,8 +2164,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( } for (; col < output_columns; ++col) { - float output_value = std::fma( - input_row[col], input_row_scale_bias[0], input_row_scale_bias[1]); + float output_value = + input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; if (std::is_same()) { output_row[col] = output_value; } else {