Skip to content

Commit

Permalink
Add cache update function for delta in-place update (#1436)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1436

With inplace update, we need to update dev_weights, uvm_weights, as well as the cache_weights (note that during inference, cache is read-only; but in the inplace delta update, cache needs to be updated. Since we don't have "write back"/flushing cache (when inference finishes, or cache evicts) during inference, we need to update cache weight **and** the back storage (uvm weight) simultaneously during the inplace delta update (thanks yinghai for pointing this out!).

Reviewed By: jspark1105

Differential Revision: D39158010

fbshipit-source-id: ec1d973f71213180b9b0ed501e3e157b8e2f18e3
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 2, 2022
1 parent 43ca0c7 commit 0fc01bb
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
9 changes: 9 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ at::Tensor linearize_cache_indices_cuda(
at::Tensor indices,
at::Tensor offsets);

///@ingroup table-batched-embed-cuda
/// Linearize the indices of all tables to make it be unique.
/// Note the update_table_indices and update_row_indices are
/// from the row indices format for inplace update.
at::Tensor linearize_cache_indices_from_row_idx_cuda(
at::Tensor cache_hash_size_cumsum,
at::Tensor update_table_indices,
at::Tensor update_row_indices);

///@ingroup table-batched-embed-cuda
/// LRU cache: fetch the rows corresponding to `linear_cache_indices` from
///`weights`, and insert them into the cache at timestep `time_stamp`.
Expand Down
75 changes: 75 additions & 0 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,81 @@ Tensor linearize_cache_indices_cuda(
return linear_cache_indices;
}
namespace {
template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void linearize_cache_indices_from_row_idx_kernel(
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
cache_hash_size_cumsum,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
update_table_indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
update_row_indices,
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
linear_cache_indices) {
const index_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= update_row_indices.size(0)) {
return;
}
const int table_index = update_table_indices[index];
const auto max_offset =
::__ldg(&cache_hash_size_cumsum[cache_hash_size_cumsum.size(0) - 1]);
const auto curr_offset = ::__ldg(&cache_hash_size_cumsum[table_index]);
if (curr_offset >= 0 && update_row_indices[index] >= 0) {
linear_cache_indices[index] = update_row_indices[index] + curr_offset;
} else {
// Either table index is wrong, or index value is negative (due to pruning):
// set it to invalid value.
linear_cache_indices[index] = max_offset;
}
}
} // namespace
Tensor linearize_cache_indices_from_row_idx_cuda(
Tensor cache_hash_size_cumsum,
Tensor update_table_indices,
Tensor update_row_indices) {
TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum);
TENSOR_ON_CUDA_GPU(update_table_indices);
TENSOR_ON_CUDA_GPU(update_row_indices);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(cache_hash_size_cumsum.get_device());
const auto T = cache_hash_size_cumsum.size(0) - 1;
TORCH_CHECK(T > 0);
auto linear_cache_indices = at::empty_like(update_row_indices);
const auto num_indices = update_row_indices.numel();
if (num_indices == 0) {
return linear_cache_indices;
}
AT_DISPATCH_INDEX_TYPES(
update_row_indices.scalar_type(),
"linearize_cache_indices_from_row_idx_kernel",
[&] {
linearize_cache_indices_from_row_idx_kernel<<<
div_round_up(num_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cache_hash_size_cumsum
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
update_table_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
update_row_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
linear_cache_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return linear_cache_indices;
}
std::tuple<Tensor, Tensor, c10::optional<Tensor>> get_unique_indices_cuda(
Tensor linear_indices,
int64_t max_indices,
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) -> Tensor");
DISPATCH_TO_CUDA("linearize_cache_indices", linearize_cache_indices_cuda);
m.def(
"linearize_cache_indices_from_row_idx(Tensor cache_hash_size_cumsum, Tensor update_table_indices, Tensor update_row_indices) -> Tensor");
DISPATCH_TO_CUDA(
"linearize_cache_indices_from_row_idx",
linearize_cache_indices_from_row_idx_cuda);
m.def(
"lru_cache_populate(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, bool stochastic_rounding) -> ()");
DISPATCH_TO_CUDA("lru_cache_populate", lru_cache_populate_cuda);
Expand Down
69 changes: 69 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4717,6 +4717,75 @@ def test_linearize_cache_indices(self) -> None:
)
)

@unittest.skipIf(*gpu_unavailable)
def test_linearize_cache_indices_from_row_idx(self) -> None:
update_row_indices = torch.tensor(
[10, 2, 3, 7, 1, 4, 5, 9, 2, 7, 6, 8, 5, 1, 0, 4],
dtype=torch.int,
device="cuda",
)
update_table_indices = torch.tensor(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
dtype=torch.int,
device="cuda",
)
varying_update_table_indices = torch.tensor(
[0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3],
dtype=torch.int,
device="cuda",
)

# Testing equal sized tables.
cache_hash_size_cumsum_0 = torch.tensor([0, 12, 24, 36, 48]).cuda()
linear_cache_indices_0 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
cache_hash_size_cumsum_0,
update_table_indices,
update_row_indices,
)
self.assertTrue(
torch.equal(
linear_cache_indices_0.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 26, 31, 30, 32, 41, 37, 36, 40],
dtype=torch.int,
),
)
)

# Testing partially cached tables.
cache_hash_size_cumsum_1 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_1 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
cache_hash_size_cumsum_1,
update_table_indices,
update_row_indices,
)
self.assertTrue(
torch.equal(
linear_cache_indices_1.cpu(),
torch.tensor(
[10, 2, 3, 7, 13, 16, 17, 21, 36, 36, 36, 36, 29, 25, 24, 28],
dtype=torch.int,
),
)
)

# Testing batched with varying pooling factor.
cache_hash_size_cumsum_2 = torch.tensor([0, 12, -1, 24, 36]).cuda()
linear_cache_indices_2 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
cache_hash_size_cumsum_2,
varying_update_table_indices,
update_row_indices,
)
self.assertTrue(
torch.equal(
linear_cache_indices_2.cpu(),
torch.tensor(
[10, 2, 3, 19, 13, 16, 17, 21, 36, 36, 36, 36, 36, 36, 24, 28],
dtype=torch.int,
),
)
)

@unittest.skipIf(*gpu_unavailable)
@given(
associativity=st.sampled_from(
Expand Down

0 comments on commit 0fc01bb

Please sign in to comment.