Skip to content

Commit

Permalink
Add an input debug function in TBE training (pytorch#3022)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#117

Pull Request resolved: pytorch#3022

This diff adds a function for printing input stats including
weighted/unweighted, number of features, batch size, average pooling
factor, total number of indices, number of unique indices, and number
of indices that goes through the different backward functions. This
function is intended for debugging only. It can enabled during runtime
by setting an enviroment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1.

Reviewed By: dshi7

Differential Revision: D61634512
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 22, 2024
1 parent 6bb22e0 commit 63b9fec
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,11 @@ def __init__( # noqa C901

self.is_experimental: bool = is_experimental

# Get a debug function pointer
self._debug_print_input_stats: Callable[..., None] = (
self._debug_print_input_stats_factory()
)

@torch.jit.ignore
def log(self, msg: str) -> None:
"""Log with TBE id prefix to distinguish between multiple TBE instances per process."""
Expand Down Expand Up @@ -1361,6 +1366,9 @@ def forward( # noqa: C901
force_cast_input_types=True,
)

# Print input stats if enable (for debugging purpose only)
self._debug_print_input_stats(indices, offsets, per_sample_weights)

if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time

Expand Down Expand Up @@ -2656,6 +2664,132 @@ def prepare_inputs(

return indices, offsets, per_sample_weights, vbe_metadata

def _debug_print_input_stats_factory(self) -> Callable[..., None]:
"""
If the environment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1,
return a function pointer of a function that prints input
stats including weighted/unweighted, number of features,
batch size, average pooling factor, total number of indices,
number of unique indices, and number of indices that goes
through the different backward functions. Otherwise, return
a dummy function pointer.
"""

@torch.jit.ignore
def _debug_print_input_stats_factory_impl(
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> None:
"""
Print input stats (for debugging purpose only)
Args:
indices (Tensor): Input indices
offsets (Tensor): Input offsets
per_sample_weights (Optional[Tensor]): Input per
sample weights
"""
if self.debug_step % 100 == 0:
# Get number of features (T) and batch size (B)
T = len(self.feature_table_map)
B = (offsets.numel() - 1) // T

# Transfer hash_size_cumsum, indices and offsets to CPU
hash_size_cumsum_cpu = self.hash_size_cumsum.cpu()
indices_cpu = indices.cpu()
offsets_cpu = offsets.cpu()

# Compute linear indices
for t in range(T):
start = offsets_cpu[B * t].item()
end = offsets_cpu[B * (t + 1)].item()
indices_cpu[start:end] += hash_size_cumsum_cpu[t]

# Compute unique indices
uniq_indices_cpu, counts = indices_cpu.unique(return_counts=True)

# Compute num unique indices
num_uniq_indices = uniq_indices_cpu.numel()

# The warp_per_row kernel handles indices that their
# segment lengths <= 32
#
# The cta_per_row kernel handles indices that their
# segment lengths > 32. A single thread block is used
# if segment lengths <= 1024. Otherwise, multiple
# thread blocks are used.
#
# Counts of indices that segment lengths <= 32
counts_warp_per_row = counts[counts <= 32]
counts_cta_per_row = counts[counts > 32]
# Counts of indices that segment lengths > 32 and <= 1024
counts_cta_per_row_sth = counts_cta_per_row[counts_cta_per_row <= 1024]
# Counts of indices that segment lengths > 1024
counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]

def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]:
numel = counts.numel()
avg = (counts.sum().item() / numel) if numel != 0 else -1.0
return numel, avg

# warp_per_row stats
num_warp_per_row, avg_seglen_warp_per_row = compute_numel_and_avg(
counts_warp_per_row
)
# cta_per_row using a single thread block stats
num_cta_per_row_sth, avg_seglen_cta_per_row_sth = compute_numel_and_avg(
counts_cta_per_row_sth
)
# cta_per_row using multiple thread block stats
num_cta_per_row_mth, avg_seglen_cta_per_row_mth = compute_numel_and_avg(
counts_cta_per_row_mth
)

assert num_uniq_indices == (
num_warp_per_row + num_cta_per_row_sth + num_cta_per_row_mth
)

self.log(
"TBE_DEBUG: "
"weighted {} "
"num features {} "
"batch size {} "
"avg pooling factor {:.2f} "
"total num indices {} "
"num unique indices {} "
"num warp_per_row {} (avg segment length {:.2f}) "
"num cta_per_row single thread block (avg segment length) {} ({:.2f}) "
"num cta_per_row multiple thread blocks (avg segment length) {} ({:.2f})".format(
per_sample_weights is not None,
T,
B,
indices.numel() / (B * T),
indices.numel(),
num_uniq_indices,
num_warp_per_row,
avg_seglen_warp_per_row,
num_cta_per_row_sth,
avg_seglen_cta_per_row_sth,
num_cta_per_row_mth,
avg_seglen_cta_per_row_mth,
)
)
self.debug_step += 1

@torch.jit.ignore
def _debug_print_input_stats_factory_null(
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> None:
pass

if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1:
self.debug_step = 0
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down

0 comments on commit 63b9fec

Please sign in to comment.