From b3e129c7ab6486fd971dfbc62c443706c1d5ada3 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Fri, 8 Jul 2022 20:42:16 +0800 Subject: [PATCH] Add pooling mode to device bench --- ...plit_table_batched_embeddings_benchmark.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 1c67f9a18d..ec32251fbb 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -91,6 +91,7 @@ def cli() -> None: @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") @click.option("--weighted-num-requires-grad", type=int, default=None) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) @click.option("--flush-gpu-cache-size-mb", default=0) @@ -113,6 +114,7 @@ def device( # noqa C901 reuse: float, row_wise: bool, weighted: bool, + pooling: str, weighted_num_requires_grad: Optional[int], bounds_check_mode: int, flush_gpu_cache_size_mb: int, @@ -161,6 +163,17 @@ def device( # noqa C901 else: managed_option = EmbeddingLocation.MANAGED + if pooling is None or pooling == "sum": + pooling = "sum" + pooling_mode = PoolingMode.SUM + do_pooling = True + elif pooling == "mean": + pooling_mode = PoolingMode.MEAN + do_pooling = True + else: # "none" + pooling_mode = PoolingMode.NONE + do_pooling = False + if dense: emb = DenseTableBatchedEmbeddingBagsCodegen( [ @@ -170,6 +183,7 @@ def device( # noqa C901 ) for d in Ds ], + pooling_mode=pooling_mode, use_cpu=not torch.cuda.is_available(), ) else: @@ -191,6 +205,7 @@ def device( # noqa C901 weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, + pooling_mode=pooling_mode, bounds_check_mode=BoundsCheckMode(bounds_check_mode), ) emb = emb.to(get_device()) @@ -244,7 +259,10 @@ def device( # noqa C901 # backward bench not representative return - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + grad_output = torch.randn(B * T * L, D).to(get_device()) # backward time_per_iter = benchmark_requests( requests,