Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pooling mode to device bench #1194

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def cli() -> None:
@click.option("--reuse", default=0.0)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--pooling", type=str, default="sum")
@click.option("--weighted-num-requires-grad", type=int, default=None)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
@click.option("--flush-gpu-cache-size-mb", default=0)
Expand All @@ -113,6 +114,7 @@ def device( # noqa C901
reuse: float,
row_wise: bool,
weighted: bool,
pooling: str,
weighted_num_requires_grad: Optional[int],
bounds_check_mode: int,
flush_gpu_cache_size_mb: int,
Expand Down Expand Up @@ -161,6 +163,17 @@ def device( # noqa C901
else:
managed_option = EmbeddingLocation.MANAGED

if pooling is None or pooling == "sum":
pooling = "sum"
pooling_mode = PoolingMode.SUM
do_pooling = True
elif pooling == "mean":
pooling_mode = PoolingMode.MEAN
do_pooling = True
else: # "none"
pooling_mode = PoolingMode.NONE
do_pooling = False

if dense:
emb = DenseTableBatchedEmbeddingBagsCodegen(
[
Expand All @@ -170,6 +183,7 @@ def device( # noqa C901
)
for d in Ds
],
pooling_mode=pooling_mode,
use_cpu=not torch.cuda.is_available(),
)
else:
Expand All @@ -191,6 +205,7 @@ def device( # noqa C901
weights_precision=weights_precision,
stochastic_rounding=stoc,
output_dtype=output_dtype,
pooling_mode=pooling_mode,
Copy link
Member

@jianyuh jianyuh Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the BW calculation formula in Line 254 and Line 279 ? Basically referring to

if do_pooling:
read_write_bytes = (
output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D
)
else:
read_write_bytes = (
output_size_multiplier * B * T * L * D
+ param_size_multiplier * B * T * L * D
)
. The number of write bytes for unpooled embedding are increased by a factor of L.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

bounds_check_mode=BoundsCheckMode(bounds_check_mode),
)
emb = emb.to(get_device())
Expand All @@ -200,6 +215,17 @@ def device( # noqa C901

nparams = sum(w.numel() for w in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
output_size_multiplier = output_dtype.bit_rate() / 8.0
if do_pooling:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L
)
else:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) * L
+ param_size_multiplier * B * sum(Ds) * L
)

logging.info(
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
Expand Down Expand Up @@ -236,15 +262,18 @@ def device( # noqa C901
logging.info(
f"Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

if output_dtype == SparseType.INT8:
# backward bench not representative
return

grad_output = torch.randn(B, sum(Ds)).to(get_device())
if do_pooling:
grad_output = torch.randn(B, sum(Ds)).to(get_device())
else:
grad_output = torch.randn(B * T * L, D).to(get_device())
# backward
time_per_iter = benchmark_requests(
requests,
Expand All @@ -258,7 +287,7 @@ def device( # noqa C901
)
logging.info(
f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "
f"BW: {3 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Expand Down