Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an input debug function in TBE training #3022

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,11 @@ def __init__( # noqa C901

self.is_experimental: bool = is_experimental

# Get a debug function pointer
self._debug_print_input_stats: Callable[..., None] = (
self._debug_print_input_stats_factory()
)

@torch.jit.ignore
def log(self, msg: str) -> None:
"""Log with TBE id prefix to distinguish between multiple TBE instances per process."""
Expand Down Expand Up @@ -1361,6 +1366,9 @@ def forward( # noqa: C901
force_cast_input_types=True,
)

# Print input stats if enable (for debugging purpose only)
self._debug_print_input_stats(indices, offsets, per_sample_weights)

if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time

Expand Down Expand Up @@ -2656,6 +2664,132 @@ def prepare_inputs(

return indices, offsets, per_sample_weights, vbe_metadata

def _debug_print_input_stats_factory(self) -> Callable[..., None]:
"""
If the environment variable FBGEMM_DEBUG_PRINT_INPUT_STATS=1,
return a function pointer of a function that prints input
stats including weighted/unweighted, number of features,
batch size, average pooling factor, total number of indices,
number of unique indices, and number of indices that goes
through the different backward functions. Otherwise, return
a dummy function pointer.
"""

@torch.jit.ignore
def _debug_print_input_stats_factory_impl(
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> None:
"""
Print input stats (for debugging purpose only)

Args:
indices (Tensor): Input indices
offsets (Tensor): Input offsets
per_sample_weights (Optional[Tensor]): Input per
sample weights
"""
if self.debug_step % 100 == 0:
# Get number of features (T) and batch size (B)
T = len(self.feature_table_map)
B = (offsets.numel() - 1) // T

# Transfer hash_size_cumsum, indices and offsets to CPU
hash_size_cumsum_cpu = self.hash_size_cumsum.cpu()
indices_cpu = indices.cpu()
offsets_cpu = offsets.cpu()

# Compute linear indices
for t in range(T):
start = offsets_cpu[B * t].item()
end = offsets_cpu[B * (t + 1)].item()
indices_cpu[start:end] += hash_size_cumsum_cpu[t]

# Compute unique indices
uniq_indices_cpu, counts = indices_cpu.unique(return_counts=True)

# Compute num unique indices
num_uniq_indices = uniq_indices_cpu.numel()

# The warp_per_row kernel handles indices that their
# segment lengths <= 32
#
# The cta_per_row kernel handles indices that their
# segment lengths > 32. A single thread block is used
# if segment lengths <= 1024. Otherwise, multiple
# thread blocks are used.
#
# Counts of indices that segment lengths <= 32
counts_warp_per_row = counts[counts <= 32]
counts_cta_per_row = counts[counts > 32]
# Counts of indices that segment lengths > 32 and <= 1024
counts_cta_per_row_sth = counts_cta_per_row[counts_cta_per_row <= 1024]
# Counts of indices that segment lengths > 1024
counts_cta_per_row_mth = counts_cta_per_row[counts_cta_per_row > 1024]

def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]:
numel = counts.numel()
avg = (counts.sum().item() / numel) if numel != 0 else -1.0
return numel, avg

# warp_per_row stats
num_warp_per_row, avg_seglen_warp_per_row = compute_numel_and_avg(
counts_warp_per_row
)
# cta_per_row using a single thread block stats
num_cta_per_row_sth, avg_seglen_cta_per_row_sth = compute_numel_and_avg(
counts_cta_per_row_sth
)
# cta_per_row using multiple thread block stats
num_cta_per_row_mth, avg_seglen_cta_per_row_mth = compute_numel_and_avg(
counts_cta_per_row_mth
)

assert num_uniq_indices == (
num_warp_per_row + num_cta_per_row_sth + num_cta_per_row_mth
)

self.log(
"TBE_DEBUG: "
"weighted {} "
"num features {} "
"batch size {} "
"avg pooling factor {:.2f} "
"total num indices {} "
"num unique indices {} "
"num warp_per_row {} (avg segment length {:.2f}) "
"num cta_per_row single thread block (avg segment length) {} ({:.2f}) "
"num cta_per_row multiple thread blocks (avg segment length) {} ({:.2f})".format(
per_sample_weights is not None,
T,
B,
indices.numel() / (B * T),
indices.numel(),
num_uniq_indices,
num_warp_per_row,
avg_seglen_warp_per_row,
num_cta_per_row_sth,
avg_seglen_cta_per_row_sth,
num_cta_per_row_mth,
avg_seglen_cta_per_row_mth,
)
)
self.debug_step += 1

@torch.jit.ignore
def _debug_print_input_stats_factory_null(
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> None:
pass

if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1:
self.debug_step = 0
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down
Loading