Skip to content

Commit

Permalink
Abstract operator for bounds_check_indices (#2495)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2495

Create an abstract impl for fbgemm operator bounds_check_indices

Reviewed By: IvanKobzarev

Differential Revision: D56020789

fbshipit-source-id: e5291bb553614a57d405f4d6c4b901c1a03818b4
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 11, 2024
1 parent 818a26f commit 69a3309
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
// or DCE'd, etc.
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
m.def(
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? B_offsets=None, SymInt max_B=-1) -> ()",
{PT2_COMPLIANT_TAG});
Expand Down
16 changes: 16 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,19 @@ def keyed_jagged_index_select_dim1_abstract(
ret.append(weights.new_empty([selected_lengths_sum]))

return ret


@impl_abstract("fbgemm::bounds_check_indices")
def bounds_check_indices_abstract(
rows_per_table: torch.Tensor,
indices: torch.Tensor,
offsets: torch.Tensor,
bounds_check_mode_int: int,
bounds_check_warning: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
) -> None:
"""
This meta function is used to fake the bounds checking
from the original function `fbgemm::bounds_check_indices`
"""
return

0 comments on commit 69a3309

Please sign in to comment.