From dd5a3509e2d0161df244f21757d5fcdace5c8b85 Mon Sep 17 00:00:00 2001 From: Siyan Lin Date: Thu, 2 Nov 2023 16:17:58 -0700 Subject: [PATCH] Add reverse qparam option for MTIA Summary: As title Differential Revision: D50945909 Privacy Context Container: L1188860 --- ..._table_batched_embeddings_ops_inference.py | 77 ++++++++++++++----- 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 261ce7c74e..301dad837a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -169,6 +169,7 @@ def __init__( # noqa C901 scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, cacheline_alignment: bool = True, uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged. + reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. ) -> None: # noqa C901 # tuple of (rows, dims,) super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -324,6 +325,7 @@ def max_ty_D(ty: SparseType) -> int: self.initialize_physical_weights_placements_and_offsets(cacheline_alignment) self.enforce_hbm: bool = enforce_hbm + self.reverse_qparam = reverse_qparam # Assign weights after weights and weights_offsets are initialized. if weight_lists: self._apply_split( @@ -1197,28 +1199,63 @@ def split_embedding_weights_with_scale_bias( or weight_ty.value == SparseType.INT2.value ): if split_scale_bias_mode == 1: - splits.append( - ( - weights_shifts[:, self.scale_bias_size_in_bytes :], - weights_shifts[:, : self.scale_bias_size_in_bytes], - None, + if self.reverse_qparam: + splits.append( + ( + weights_shifts[ + :, 0 : (0 - self.scale_bias_size_in_bytes) + ], + weights_shifts[ + :, (0 - self.scale_bias_size_in_bytes) : + ], + None, + ) ) - ) - else: # 2 - # weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights - splits.append( - ( - weights_shifts[:, self.scale_bias_size_in_bytes :], - weights_shifts[ - :, : self.scale_bias_size_in_bytes // 2 - ].view(torch.float16), - weights_shifts[ - :, - self.scale_bias_size_in_bytes - // 2 : self.scale_bias_size_in_bytes, - ].view(torch.float16), + else: + splits.append( + ( + weights_shifts[:, self.scale_bias_size_in_bytes :], + weights_shifts[:, : self.scale_bias_size_in_bytes], + None, + ) ) - ) + elif split_scale_bias_mode == 2: + if self.reverse_qparam: + # weights_shifts: [0:-4] is real weights; [-4:-2] is scale; [-2:] is bias + splits.append( + ( + weights_shifts[ + :, 0 : (0 - self.scale_bias_size_in_bytes) + ], + weights_shifts[ + :, + (0 - self.scale_bias_size_in_bytes) : ( + 0 - self.scale_bias_size_in_bytes // 2 + ), + ].view(torch.float16), + weights_shifts[ + :, (0 - self.scale_bias_size_in_bytes // 2) : + ].view(torch.float16), + ) + ) + else: + # weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights + splits.append( + ( + weights_shifts[:, self.scale_bias_size_in_bytes :], + weights_shifts[ + :, : self.scale_bias_size_in_bytes // 2 + ].view(torch.float16), + weights_shifts[ + :, + self.scale_bias_size_in_bytes + // 2 : self.scale_bias_size_in_bytes, + ].view(torch.float16), + ) + ) + else: + raise ValueError("split_scale_bias_mode is not supported") + elif ( weight_ty.value == SparseType.FP8.value or weight_ty.value == SparseType.FP16.value