Skip to content

Commit

Permalink
NBit forward TBE: remap indices before prefetching; and some more uni…
Browse files Browse the repository at this point in the history
…t tests. (pytorch#1433)

Summary:
Pull Request resolved: pytorch#1433

It's reported that UVM_CACHING doesn't work with pruning.
D40788589 (pytorch@bcc69ed) fixed issues in linearize kernel.

It worked with int_nbit_split_embedding_uvm_caching_codegen_lookup_function, because
in that case, index remapping for pruning occurs before calling  int_nbit_split_embedding_uvm_caching_codegen_lookup_function.

IntNBitTableBatchedEmbeddingBagsCodegen::forward() runs index remapping after
prefetch, so it didn't work even with the fix in linearize kernel.

This diff changes the order we call index remapping: first, index remapping, and then prefetch.

Added related tests also.

Differential Revision: D40821768

fbshipit-source-id: 017db632cd2d7891c177cd3959f1d8967538b2ea
  • Loading branch information
doehyun authored and facebook-github-bot committed Dec 8, 2022
1 parent 81ba6c5 commit 2d6f38c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
19 changes: 19 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def generate_requests(
requests_data_file: Optional[str] = None,
# Comma-separated list of table numbers
tables: Optional[str] = None,
emulate_pruning: bool = False,
) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]]:
if requests_data_file is not None:
indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file)
Expand Down Expand Up @@ -214,6 +215,24 @@ def generate_requests(
]
all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices]

# Some indices are set to -1 for emulating pruned rows.
if emulate_pruning:
for it in range(iters):
for t in range(T):
num_negative_indices = int(B / 2)
random_locations = torch.randint(
low=0,
high=(B * L),
size=(num_negative_indices,),
device=torch.cuda.current_device(),
dtype=torch.int32,
)
all_indices[it, t, random_locations] = torch.tensor(
[-1] * num_negative_indices,
dtype=torch.int,
device=torch.cuda.current_device(),
)

rs = []
for it in range(iters):
weights_tensor = (
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,15 +2209,10 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

lxu_cache_locations = self.lxu_cache_locations_list.pop()

assert (
self.weight_initialized
), "weight needs to be initialized before forward function"
# Remap indices, before prefetch, bound check, and emb lookup.
if self.index_remapping_hash_table_cpu is not None:
indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets)
elif self.index_remapping_hash_table.numel() > 0:
Expand All @@ -2235,6 +2230,11 @@ def forward(
self.index_remappings_array,
self.index_remappings_array_offsets,
)
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

lxu_cache_locations = self.lxu_cache_locations_list.pop()

# We cast to int as a TorchScript workaround.
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
CUDA_KERNEL_LOOP(n, N) {
int64_t idx = linear_cache_indices[n];
if (idx == max_indices) {
// Invalid index value; will skip this one.
cache_sets[n] = C;
continue;
}
int32_t cache_set = cache_slot(idx, C);
Expand Down Expand Up @@ -1223,14 +1225,15 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> cache_sets,
const int64_t row_alignment) {
const int32_t N = cache_sets.size(0);
const int32_t C = lxu_cache_state.size(0);
// one warp for each set (multiple times)
// (no divergence for each control branch)
for (int32_t pos = blockIdx.x * blockDim.y + threadIdx.y; pos < N;
pos += gridDim.x * blockDim.y) {
auto cache_set = cache_sets[pos];
if (cache_set == -1) {
if (cache_set == -1 || cache_set == C) {
// default value
continue;
}
Expand Down
19 changes: 18 additions & 1 deletion fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,11 +3510,13 @@ def test_nbit_forward_gpu_no_cache(
SparseType.INT2,
]
),
emulate_pruning=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
self,
weights_ty: SparseType,
emulate_pruning: bool,
) -> None:
# TODO: support direct-mapped in int_nbit_split_embedding_uvm_caching_codegen_lookup_function
# This test is for int_nbit_split_embedding_uvm_caching_codegen_lookup_function.
Expand Down Expand Up @@ -3580,7 +3582,9 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
0, 0, device=current_device, dtype=torch.uint8
)

requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1)
requests = generate_requests(
iters, B, T, L, min(Es), reuse=0.1, emulate_pruning=emulate_pruning
)
for indices, offsets, _ in requests:
indices = indices.int()
offsets = offsets.int()
Expand Down Expand Up @@ -3715,13 +3719,15 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
associativity=st.sampled_from(
[1, split_table_batched_embeddings_ops.DEFAULT_ASSOC]
),
do_pruning=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_forward_uvm_cache(
self,
weights_ty: SparseType,
cache_algorithm: split_table_batched_embeddings_ops.CacheAlgorithm,
associativity: int,
do_pruning: bool,
) -> None:
assume(
cache_algorithm == split_table_batched_embeddings_ops.CacheAlgorithm.LRU
Expand Down Expand Up @@ -3768,6 +3774,11 @@ def test_nbit_forward_uvm_cache(
if d < average_D
else managed[t]
)
index_remapping = None
use_array_for_index_remapping = False
pruning_hash_load_factor = 0.5
if do_pruning:
index_remapping = [torch.empty(0, dtype=torch.int32) for _ in Es]
cc_ref = (
split_table_batched_embeddings_ops.IntNBitTableBatchedEmbeddingBagsCodegen(
[
Expand All @@ -3780,13 +3791,19 @@ def test_nbit_forward_uvm_cache(
)
for (E, D) in zip(Es, Ds)
],
index_remapping=index_remapping,
use_array_for_index_remapping=use_array_for_index_remapping,
pruning_hash_load_factor=pruning_hash_load_factor,
)
)
cc_ref.fill_random_weights()
cc = split_table_batched_embeddings_ops.IntNBitTableBatchedEmbeddingBagsCodegen(
[("", E, D, weights_ty, M) for (E, D, M) in zip(Es, Ds, managed)],
cache_algorithm=cache_algorithm,
cache_assoc=associativity,
index_remapping=index_remapping,
use_array_for_index_remapping=use_array_for_index_remapping,
pruning_hash_load_factor=pruning_hash_load_factor,
)
cc.fill_random_weights()

Expand Down

0 comments on commit 2d6f38c

Please sign in to comment.