diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index a1a0852a5c..321141dbac 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -2305,6 +2305,7 @@ def pruned_array( # noqa C901 E, requests_data_file=requests_data_file, tables=tables, + use_cpu=True if device == "cpu" else False, ) requests = [(a.int().to(device), b.int().to(device), c) for (a, b, c) in requests] diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py index 825f788e70..dad47615e5 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py @@ -111,6 +111,7 @@ def generate_requests( # noqa C901 # and mu_L sigma_L: Optional[int] = None, emulate_pruning: bool = False, + use_cpu: bool = False, ) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]]: if requests_data_file is not None: indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file) @@ -309,7 +310,9 @@ def generate_requests( # noqa C901 ) # per sample weights will always be FP32 ) rs.append( - get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L)) + get_table_batched_offsets_from_dense( + all_indices[it].view(T, B, L), use_cpu=use_cpu + ) + (weights_tensor,) ) return rs