Skip to content

Commit

Permalink
Move all fbgemm_gpu provided Python ops to fbgemm namespace from fb n…
Browse files Browse the repository at this point in the history
…amespace. (#823)

Summary: Pull Request resolved: #823

Reviewed By: jianyuh

Differential Revision: D33147038

fbshipit-source-id: fdcb667dfb920b4f04b7d0b08082afabe7213cc1
  • Loading branch information
Rick Weyrauch authored and facebook-github-bot committed Dec 21, 2021
1 parent d53ba96 commit fa44d9a
Show file tree
Hide file tree
Showing 14 changed files with 232 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def bounds_check_indices( # noqa C901
# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, _: torch.ops.fb.bounds_check_indices(
lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices,
offsets,
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,13 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
c10::DispatchKey::CUDA,
TORCH_FN(split_embedding_codegen_lookup_dense_function)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad) -> Tensor");
m.impl(
"dense_embedding_codegen_lookup_function",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(split_embedding_codegen_lookup_dense_function)));
}
8 changes: 8 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,12 @@ TORCH_LIBRARY_IMPL(fb, CPU, m) {
TORCH_FN(split_embedding_codegen_lookup_dense_function)));
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl(
"dense_embedding_codegen_lookup_function",
torch::dispatch(
c10::DispatchKey::CPU,
TORCH_FN(split_embedding_codegen_lookup_dense_function)));
}

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,10 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}) -> Tensor");
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu)));
}

} // namespace
// clang-format on
6 changes: 6 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,10 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function)));
}

// clang-format on
11 changes: 11 additions & 0 deletions fbgemm_gpu/codegen/embedding_bounds_check_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(bounds_check_indices_cuda)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
// or DCE'd, etc.
m.def(
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(a!) offsets, int bounds_check_mode, Tensor(a!) warning) -> ()");
m.impl(
"bounds_check_indices",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(bounds_check_indices_cuda)));
}
7 changes: 7 additions & 0 deletions fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,10 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
torch::dispatch(
c10::DispatchKey::CPU, TORCH_FN(bounds_check_indices_cpu)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl(
"bounds_check_indices",
torch::dispatch(
c10::DispatchKey::CPU, TORCH_FN(bounds_check_indices_cpu)));
}
25 changes: 25 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,28 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(pruned_array_lookup_cuda)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None) -> Tensor");
m.impl(
"int_nbit_split_embedding_codegen_lookup_function",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(int_nbit_split_embedding_codegen_lookup_function)));

m.def(
"pruned_hashmap_lookup(Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) -> Tensor");
m.impl(
"pruned_hashmap_lookup",
torch::dispatch(
c10::DispatchKey::CUDA,
TORCH_FN(pruned_hashmap_lookup_unweighted_cuda)));

m.def(
"pruned_array_lookup(Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets) -> Tensor");
m.impl(
"pruned_array_lookup",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(pruned_array_lookup_cuda)));
}
32 changes: 32 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
c10::DispatchKey::CPU, TORCH_FN(pruned_array_lookup_cpu)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl(
"int_nbit_split_embedding_codegen_lookup_function",
torch::dispatch(
c10::DispatchKey::CPU,
TORCH_FN(int_nbit_split_embedding_codegen_lookup_function_cpu)));

// GPU version of pruned_hashmap needs to use CPU version of
// pruned_hashmap_insert
m.def(
"pruned_hashmap_insert(Tensor indices, Tensor dense_indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) -> ()");
m.impl(
"pruned_hashmap_insert",
torch::dispatch(
c10::DispatchKey::CPU,
TORCH_FN(pruned_hashmap_insert_unweighted_cpu)));

// CPU version of hashmap Lookup isn't used. For CPUs, we should use
// PrunedMapCPU below.
m.impl(
"pruned_hashmap_lookup",
torch::dispatch(
c10::DispatchKey::CPU,
TORCH_FN(pruned_hashmap_lookup_unweighted_cpu)));

// CPU version of array lookup.
m.impl(
"pruned_array_lookup",
torch::dispatch(
c10::DispatchKey::CPU, TORCH_FN(pruned_array_lookup_cpu)));
}

class PrunedMapCPU : public torch::jit::CustomClassHolder {
public:
PrunedMapCPU() {}
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/fbgemm_gpu/uvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")

# Import all uvm enums from c++ library
create_enums(globals(), torch.ops.fb.fbgemm_gpu_uvm_enum_query)
create_enums(globals(), torch.ops.fbgemm.fbgemm_gpu_uvm_enum_query)


def cudaMemAdvise(
t: torch.Tensor,
advice: Enum,
) -> None:
torch.ops.fb.cuda_mem_advise(t, advice.value)
torch.ops.fbgemm.cuda_mem_advise(t, advice.value)


def cudaMemPrefetchAsync(
t: torch.Tensor,
device_t: Optional[torch.Tensor] = None,
) -> None:
torch.ops.fb.cuda_mem_prefetch_async(t, device_t)
torch.ops.fbgemm.cuda_mem_prefetch_async(t, device_t)
32 changes: 32 additions & 0 deletions fbgemm_gpu/src/cumem_utils_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,36 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("is_uvm_tensor(Tensor t) -> bool", TORCH_FN(is_uvm_tensor));
m.def("uvm_storage(Tensor t) -> bool", TORCH_FN(uvm_storage));
m.def(
"uvm_to_device(Tensor self, Tensor prototype) -> Tensor",
TORCH_FN(uvm_to_device));
m.def("uvm_to_cpu(Tensor t) -> Tensor");
m.impl(
"uvm_to_cpu",
torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(uvm_to_cpu)));
m.def("new_managed_tensor(Tensor self, int[] sizes) -> Tensor");
m.impl(
"new_managed_tensor",
torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(new_managed_tensor)));
m.def("new_vanilla_managed_tensor(Tensor self, int[] sizes) -> Tensor");
m.impl(
"new_vanilla_managed_tensor",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(new_vanilla_managed_tensor)));
m.def(
"cuda_mem_advise(Tensor t, int advice) -> ()",
TORCH_FN(uvm_cuda_mem_advise));
m.def(
"cuda_mem_prefetch_async(Tensor t, Tensor? device_t) -> ()",
TORCH_FN(uvm_cuda_mem_prefetch_async));
m.def(
"uvm_mem_advice_dont_fork(Tensor t) -> ()",
TORCH_FN(uvm_mem_advice_dont_fork));

m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query));
}

} // namespace fbgemm_gpu
49 changes: 49 additions & 0 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,53 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
torch::dispatch(
c10::DispatchKey::CatchAll, TORCH_FN(host_lxu_cache_slot)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) -> Tensor");
m.impl(
"linearize_cache_indices",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(linearize_cache_indices_cuda)));
m.def(
"lru_cache_populate(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, bool stochastic_rounding) -> ()");
m.impl(
"lru_cache_populate",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(lru_cache_populate_cuda)));
m.def(
"lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state) -> ()");
m.impl(
"lru_cache_populate_byte",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(lru_cache_populate_byte_cuda)));
m.def(
"lfu_cache_populate(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state, bool stochastic_rounding) -> ()");
m.impl(
"lfu_cache_populate",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(lfu_cache_populate_cuda)));
m.def(
"lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state) -> ()");
m.impl(
"lfu_cache_populate_byte",
torch::dispatch(
c10::DispatchKey::CUDA, TORCH_FN(lfu_cache_populate_byte_cuda)));
m.def(
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state) -> Tensor");
m.impl(
"lxu_cache_lookup",
torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(lxu_cache_lookup_cuda)));
m.def(
"lxu_cache_flush(Tensor(a!) uvm_weights, Tensor cache_hash_size_cumsum, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor(b!) lxu_cache_state, Tensor(c!) lxu_cache_weights, bool stochastic_rounding) -> ()");
m.impl(
"lxu_cache_flush",
torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(lxu_cache_flush_cuda)));
m.def("lxu_cache_slot(int h_in, int C) -> int");
m.impl(
"lxu_cache_slot",
torch::dispatch(
c10::DispatchKey::CatchAll, TORCH_FN(host_lxu_cache_slot)));
}

} // namespace
10 changes: 5 additions & 5 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3438,14 +3438,14 @@ def test_bounds_check(
warning.cuda(),
)
indices_copy = indices.clone()
torch.ops.fb.bounds_check_indices(
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning
)
# we don't modify when we are in-bounds.
torch.testing.assert_allclose(indices_copy, indices)
indices[:] = torch.iinfo(dtype).max
if bounds_check_mode != BoundsCheckMode.FATAL:
torch.ops.fb.bounds_check_indices(
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning
)
torch.testing.assert_allclose(indices, torch.zeros_like(indices))
Expand All @@ -3454,7 +3454,7 @@ def test_bounds_check(
else:
if use_cpu and indices.numel():
with self.assertRaises(RuntimeError):
torch.ops.fb.bounds_check_indices(
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning
)
# It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL,
Expand All @@ -3468,7 +3468,7 @@ def test_bounds_check(
if offsets.numel() > 1:
offsets[-1] += 100
if bounds_check_mode != BoundsCheckMode.FATAL:
torch.ops.fb.bounds_check_indices(
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning
)
if offsets.numel() > 0:
Expand All @@ -3482,7 +3482,7 @@ def test_bounds_check(
else:
if use_cpu and indices.numel():
with self.assertRaises(RuntimeError):
torch.ops.fb.bounds_check_indices(
torch.ops.fbgemm.bounds_check_indices(
rows_per_table, indices, offsets, bounds_check_mode, warning
)

Expand Down
Loading

0 comments on commit fa44d9a

Please sign in to comment.