Skip to content

Commit

Permalink
Add pooling mode to device bench (#1194)
Browse files Browse the repository at this point in the history
Summary:
This would help us benchmark EmbeddingCollection in torchrec.

Thank you for your time in reviewing this PR :)

Pull Request resolved: #1194

Reviewed By: jianyuh

Differential Revision: D37845758

Pulled By: colin2328

fbshipit-source-id: a65ef76b195ca2cfd56b69d39e2d59ae930edfae
  • Loading branch information
zhuzilin authored and facebook-github-bot committed Jul 20, 2022
1 parent 5a15342 commit c9bbb77
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def cli() -> None:
@click.option("--reuse", default=0.0)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--pooling", type=str, default="sum")
@click.option("--weighted-num-requires-grad", type=int, default=None)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
@click.option("--flush-gpu-cache-size-mb", default=0)
Expand All @@ -113,6 +114,7 @@ def device( # noqa C901
reuse: float,
row_wise: bool,
weighted: bool,
pooling: str,
weighted_num_requires_grad: Optional[int],
bounds_check_mode: int,
flush_gpu_cache_size_mb: int,
Expand Down Expand Up @@ -161,6 +163,17 @@ def device( # noqa C901
else:
managed_option = EmbeddingLocation.MANAGED

if pooling is None or pooling == "sum":
pooling = "sum"
pooling_mode = PoolingMode.SUM
do_pooling = True
elif pooling == "mean":
pooling_mode = PoolingMode.MEAN
do_pooling = True
else: # "none"
pooling_mode = PoolingMode.NONE
do_pooling = False

if dense:
emb = DenseTableBatchedEmbeddingBagsCodegen(
[
Expand All @@ -170,6 +183,7 @@ def device( # noqa C901
)
for d in Ds
],
pooling_mode=pooling_mode,
use_cpu=not torch.cuda.is_available(),
)
else:
Expand All @@ -191,6 +205,7 @@ def device( # noqa C901
weights_precision=weights_precision,
stochastic_rounding=stoc,
output_dtype=output_dtype,
pooling_mode=pooling_mode,
bounds_check_mode=BoundsCheckMode(bounds_check_mode),
)
emb = emb.to(get_device())
Expand All @@ -200,6 +215,17 @@ def device( # noqa C901

nparams = sum(w.numel() for w in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
output_size_multiplier = output_dtype.bit_rate() / 8.0
if do_pooling:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L
)
else:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) * L
+ param_size_multiplier * B * sum(Ds) * L
)

logging.info(
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
Expand Down Expand Up @@ -236,15 +262,18 @@ def device( # noqa C901
logging.info(
f"Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

if output_dtype == SparseType.INT8:
# backward bench not representative
return

grad_output = torch.randn(B, sum(Ds)).to(get_device())
if do_pooling:
grad_output = torch.randn(B, sum(Ds)).to(get_device())
else:
grad_output = torch.randn(B * T * L, D).to(get_device())
# backward
time_per_iter = benchmark_requests(
requests,
Expand All @@ -258,7 +287,7 @@ def device( # noqa C901
)
logging.info(
f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "
f"BW: {3 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Expand Down

0 comments on commit c9bbb77

Please sign in to comment.