diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 9c6d1b5b19..2cf6301150 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -1137,78 +1137,6 @@ def reset_cache_states(self) -> None: self.lxu_state.fill_(0) self.timestep_counter.reset() - @torch.jit.export - def split_embedding_weights( - self, - split_scale_shifts: bool = True - # When true, return list of two tensors, the first with weights and - # the second with scale_bias. - # This should've been named as split_scale_bias. - # Keep as is for backward compatibility. - ) -> List[Tuple[Tensor, Optional[Tensor]]]: - """ - Returns a list of weights, split by table - """ - assert self.weight_initialized - splits: List[Tuple[Tensor, Optional[Tensor]]] = [] - for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs): - placement = self.weights_physical_placements[t] - if placement == EmbeddingLocation.DEVICE.value: - weights = self.weights_dev - elif placement == EmbeddingLocation.HOST.value: - weights = self.weights_host - else: - weights = self.weights_uvm - offset = self.weights_physical_offsets[t] - weights_shifts = weights.detach()[ - offset : offset - + rows - * rounded_row_size_in_bytes( - dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes - ) - ].view( - rows, - rounded_row_size_in_bytes( - dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes - ), - ) - - if split_scale_shifts: - # remove the padding at the end of each row. - weights_shifts = weights_shifts[ - :, - : unpadded_row_size_in_bytes( - dim, weight_ty, self.scale_bias_size_in_bytes - ), - ] - if ( - weight_ty == SparseType.INT8 - or weight_ty == SparseType.INT4 - or weight_ty == SparseType.INT2 - ): - splits.append( - ( - weights_shifts[:, self.scale_bias_size_in_bytes :], - weights_shifts[:, : self.scale_bias_size_in_bytes], - ) - ) - else: - assert ( - weight_ty == SparseType.FP8 - or weight_ty == SparseType.FP16 - or weight_ty == SparseType.FP32 - ) - splits.append( - ( - weights_shifts, - None, - ) - ) - else: - splits.append((weights_shifts, None)) - - return splits - @torch.jit.export def split_embedding_weights_with_scale_bias( self, split_scale_bias_mode: int = 1 @@ -1300,6 +1228,28 @@ def split_embedding_weights_with_scale_bias( return splits + @torch.jit.export + def split_embedding_weights( + self, + split_scale_shifts: bool = True + # When true, return list of two tensors, the first with weights and + # the second with scale_bias. + # This should've been named as split_scale_bias. + # Keep as is for backward compatibility. + ) -> List[Tuple[Tensor, Optional[Tensor]]]: + """ + Returns a list of weights, split by table + """ + splits: List[ + Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + ] = self.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=(1 if split_scale_shifts else 0) + ) + return [ + (split_weight_scale_bias[0], split_weight_scale_bias[1]) + for split_weight_scale_bias in splits + ] + @torch.jit.export def initialize_weights(self) -> None: if not self.weight_initialized: