From 63b9fecdb535c027f32ef11b1743a446e2a8d940 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Wed, 21 Aug 2024 18:26:07 -0700 Subject: [PATCH] Add an input debug function in TBE training (#3022) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/117 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3022 This diff adds a function for printing input stats including weighted/unweighted, number of features, batch size, average pooling factor, total number of indices, number of unique indices, and number of indices that goes through the different backward functions. This function is intended for debugging only. It can enabled during runtime by setting an enviroment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1. Reviewed By: dshi7 Differential Revision: D61634512 --- ...t_table_batched_embeddings_ops_training.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 013dbb48da..c8cbadd46d 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1107,6 +1107,11 @@ def __init__( # noqa C901 self.is_experimental: bool = is_experimental + # Get a debug function pointer + self._debug_print_input_stats: Callable[..., None] = ( + self._debug_print_input_stats_factory() + ) + @torch.jit.ignore def log(self, msg: str) -> None: """Log with TBE id prefix to distinguish between multiple TBE instances per process.""" @@ -1361,6 +1366,9 @@ def forward( # noqa: C901 force_cast_input_types=True, ) + # Print input stats if enable (for debugging purpose only) + self._debug_print_input_stats(indices, offsets, per_sample_weights) + if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -2656,6 +2664,132 @@ def prepare_inputs( return indices, offsets, per_sample_weights, vbe_metadata + def _debug_print_input_stats_factory(self) -> Callable[..., None]: + """ + If the environment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1, + return a function pointer of a function that prints input + stats including weighted/unweighted, number of features, + batch size, average pooling factor, total number of indices, + number of unique indices, and number of indices that goes + through the different backward functions. Otherwise, return + a dummy function pointer. + """ + + @torch.jit.ignore + def _debug_print_input_stats_factory_impl( + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> None: + """ + Print input stats (for debugging purpose only) + + Args: + indices (Tensor): Input indices + offsets (Tensor): Input offsets + per_sample_weights (Optional[Tensor]): Input per + sample weights + """ + if self.debug_step % 100 == 0: + # Get number of features (T) and batch size (B) + T = len(self.feature_table_map) + B = (offsets.numel() - 1) // T + + # Transfer hash_size_cumsum, indices and offsets to CPU + hash_size_cumsum_cpu = self.hash_size_cumsum.cpu() + indices_cpu = indices.cpu() + offsets_cpu = offsets.cpu() + + # Compute linear indices + for t in range(T): + start = offsets_cpu[B * t].item() + end = offsets_cpu[B * (t + 1)].item() + indices_cpu[start:end] += hash_size_cumsum_cpu[t] + + # Compute unique indices + uniq_indices_cpu, counts = indices_cpu.unique(return_counts=True) + + # Compute num unique indices + num_uniq_indices = uniq_indices_cpu.numel() + + # The warp_per_row kernel handles indices that their + # segment lengths <= 32 + # + # The cta_per_row kernel handles indices that their + # segment lengths > 32. A single thread block is used + # if segment lengths <= 1024. Otherwise, multiple + # thread blocks are used. + # + # Counts of indices that segment lengths <= 32 + counts_warp_per_row = counts[counts <= 32] + counts_cta_per_row = counts[counts > 32] + # Counts of indices that segment lengths > 32 and <= 1024 + counts_cta_per_row_sth = counts_cta_per_row[counts_cta_per_row <= 1024] + # Counts of indices that segment lengths > 1024 + counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024] + + def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]: + numel = counts.numel() + avg = (counts.sum().item() / numel) if numel != 0 else -1.0 + return numel, avg + + # warp_per_row stats + num_warp_per_row, avg_seglen_warp_per_row = compute_numel_and_avg( + counts_warp_per_row + ) + # cta_per_row using a single thread block stats + num_cta_per_row_sth, avg_seglen_cta_per_row_sth = compute_numel_and_avg( + counts_cta_per_row_sth + ) + # cta_per_row using multiple thread block stats + num_cta_per_row_mth, avg_seglen_cta_per_row_mth = compute_numel_and_avg( + counts_cta_per_row_mth + ) + + assert num_uniq_indices == ( + num_warp_per_row + num_cta_per_row_sth + num_cta_per_row_mth + ) + + self.log( + "TBE_DEBUG: " + "weighted {} " + "num features {} " + "batch size {} " + "avg pooling factor {:.2f} " + "total num indices {} " + "num unique indices {} " + "num warp_per_row {} (avg segment length {:.2f}) " + "num cta_per_row single thread block (avg segment length) {} ({:.2f}) " + "num cta_per_row multiple thread blocks (avg segment length) {} ({:.2f})".format( + per_sample_weights is not None, + T, + B, + indices.numel() / (B * T), + indices.numel(), + num_uniq_indices, + num_warp_per_row, + avg_seglen_warp_per_row, + num_cta_per_row_sth, + avg_seglen_cta_per_row_sth, + num_cta_per_row_mth, + avg_seglen_cta_per_row_mth, + ) + ) + self.debug_step += 1 + + @torch.jit.ignore + def _debug_print_input_stats_factory_null( + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> None: + pass + + if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1: + self.debug_step = 0 + return _debug_print_input_stats_factory_impl + return _debug_print_input_stats_factory_null + class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """