Skip to content

Commit

Permalink
support cpu device for pruned-array benchmark (#1874)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1874

as titled, we need to pass whether cpu is used to make pruned-array cpu benchmark working

Differential Revision: D47344641

fbshipit-source-id: 5405b671948ca6e531284f0d4bc1c266aa13ab8b
  • Loading branch information
Feixiong Zhang authored and facebook-github-bot committed Jul 18, 2023
1 parent 9c4fffc commit e359740
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,7 @@ def pruned_array( # noqa C901
E,
requests_data_file=requests_data_file,
tables=tables,
use_cpu=True if device == "cpu" else False,
)
requests = [(a.int().to(device), b.int().to(device), c) for (a, b, c) in requests]

Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def generate_requests( # noqa C901
# and mu_L
sigma_L: Optional[int] = None,
emulate_pruning: bool = False,
use_cpu: bool = False,
) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]]:
if requests_data_file is not None:
indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file)
Expand Down Expand Up @@ -309,7 +310,9 @@ def generate_requests( # noqa C901
) # per sample weights will always be FP32
)
rs.append(
get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
get_table_batched_offsets_from_dense(
all_indices[it].view(T, B, L), use_cpu=use_cpu
)
+ (weights_tensor,)
)
return rs
Expand Down

0 comments on commit e359740

Please sign in to comment.