Skip to content

Commit

Permalink
NBit forward TBE: remap indices before prefetching; and some more uni…
Browse files Browse the repository at this point in the history
…t tests.

Summary:
It's reported that UVM_CACHING doesn't work with pruning.
D40788589 (bcc69ed) fixed issues in linearize kernel.

It worked with int_nbit_split_embedding_uvm_caching_codegen_lookup_function, because
in that case, index remapping for pruning occurs before calling  int_nbit_split_embedding_uvm_caching_codegen_lookup_function.

IntNBitTableBatchedEmbeddingBagsCodegen::forward() runs index remapping after
prefetch, so it didn't work even with the fix in linearize kernel.

This diff changes the order we call index remapping: first, index remapping, and then prefetch.

Added related tests also.

Differential Revision: D40821768

fbshipit-source-id: 418a568da62ecc1e8758de4047f478c5b7391ca5
  • Loading branch information
doehyun authored and facebook-github-bot committed Oct 28, 2022
1 parent bcc69ed commit 817c0f2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
12 changes: 6 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,15 +2169,10 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

lxu_cache_locations = self.lxu_cache_locations_list.pop()

assert (
self.weight_initialized
), "weight needs to be initialized before forward function"
# Remap indices, before prefetch, bound check, and emb lookup.
if self.index_remapping_hash_table_cpu is not None:
indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets)
elif self.index_remapping_hash_table.numel() > 0:
Expand All @@ -2195,6 +2190,11 @@ def forward(
self.index_remappings_array,
self.index_remappings_array_offsets,
)
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

lxu_cache_locations = self.lxu_cache_locations_list.pop()

# We cast to int as a TorchScript workaround.
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
Expand Down
31 changes: 27 additions & 4 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def generate_requests(
alpha: float = 1.0,
weights_precision: SparseType = SparseType.FP32,
weighted: bool = False,
emulate_pruning: bool = False,
) -> List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
if alpha <= 1.0:
all_indices = torch.randint(
Expand All @@ -170,6 +171,24 @@ def generate_requests(
]
all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices]

# Some indices are set to -1 for emulating pruned rows.
if emulate_pruning:
for it in range(iters):
for t in range(T):
num_negative_indices = int(B / 2)
random_locations = torch.randint(
low=0,
high=(B * L),
size=(num_negative_indices,),
device=torch.cuda.current_device(),
dtype=torch.int32,
)
all_indices[it, t, random_locations] = torch.tensor(
[-1] * num_negative_indices,
dtype=torch.int,
device=torch.cuda.current_device(),
)

rs = []
for it in range(iters):
weight_tensor = (
Expand Down Expand Up @@ -3665,17 +3684,19 @@ def test_nbit_forward_cpu(
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
use_cache=st.booleans(),
)
@settings(
verbosity=Verbosity.verbose,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
)
def test_nbit_forward_gpu_no_cache(
def test_nbit_forward_gpu(
self,
nbit_weights_ty: Optional[SparseType],
use_array_for_index_remapping: bool,
do_pruning: bool,
use_cache: bool,
) -> None:
use_cpu = False
T = random.randint(1, 50)
Expand All @@ -3684,8 +3705,6 @@ def test_nbit_forward_gpu_no_cache(
D = random.randint(2, 1024)
log_E = random.randint(2, 4)

use_cache = False
# cache_algorithm is don't care as we don't use cache.
cache_algorithm = split_table_batched_embeddings_ops.CacheAlgorithm.LRU

pooling_mode = random.choice(
Expand Down Expand Up @@ -3742,11 +3761,13 @@ def test_nbit_forward_gpu_no_cache(
SparseType.INT2,
]
),
emulate_pruning=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
self,
weights_ty: SparseType,
emulate_pruning: bool,
) -> None:
# TODO: support direct-mapped in int_nbit_split_embedding_uvm_caching_codegen_lookup_function
# This test is for int_nbit_split_embedding_uvm_caching_codegen_lookup_function.
Expand Down Expand Up @@ -3812,7 +3833,9 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
0, 0, device=current_device, dtype=torch.uint8
)

requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1)
requests = generate_requests(
iters, B, T, L, min(Es), reuse=0.1, emulate_pruning=emulate_pruning
)
for indices, offsets, _ in requests:
indices = indices.int()
offsets = offsets.int()
Expand Down

0 comments on commit 817c0f2

Please sign in to comment.