Skip to content

Commit

Permalink
support op dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Apr 15, 2024
1 parent cf4b44f commit d0c203a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
48 changes: 18 additions & 30 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
}

#if 0

Check warning on line 295 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [readability-avoid-unconditional-preprocessor-if]

Preprocessor condition is always 'false', consider removing both the condition and its contents
// 判断是否有对应的 diopi 实现:
// 如果有, 则直接 pybind 上去;
// 否则不注册, 等到 python 层处理.
Expand Down Expand Up @@ -363,6 +363,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"deeplink ext_scaled_masked_softmax_bwd");
}
}
#endif

at::Tensor& apply_penalty(at::Tensor& logits, const at::Tensor& presence_penalty,

Check warning on line 368 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [readability-identifier-naming]

Invalid case style for function 'apply_penalty' (fix available)
const at::Tensor& frequency_penalty,
Expand All @@ -381,40 +382,27 @@ at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
return out;
}

TORCH_LIBRARY(ops, m) {
//m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)");
m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)");
at::Tensor& example_for_all_backend(at::Tensor& inout) {

Check warning on line 385 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [readability-identifier-naming]

Invalid case style for function 'example_for_all_backend' (fix available)
std::cout << __FUNCTION__ << ": "<< inout.options() << std::endl;

Check warning on line 386 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [performance-avoid-endl]

Do not use 'std::endl' with streams; use '\n' instead (fix available)

Check notice on line 386 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "std::cout" is directly included (fixes available)

Check notice on line 386 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "std::endl" is directly included (fixes available)
return inout;
}

// impl for dipu
TORCH_LIBRARY_IMPL(ops, XPU, m) {
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
m.impl("apply_penalty", apply_penalty);
}
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
m.impl("dest_index_copy_kv", dest_index_copy_kv);
}
at::Tensor& example_only_for_xpu(at::Tensor& inout) {

Check warning on line 390 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [readability-identifier-naming]

Invalid case style for function 'example_only_for_xpu' (fix available)
std::cout << __FUNCTION__ << ": " << inout.options() << std::endl;

Check warning on line 391 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clang-tidy Warning [performance-avoid-endl]

Do not use 'std::endl' with streams; use '\n' instead (fix available)

Check notice on line 391 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "std::cout" is directly included (fixes available)

Check notice on line 391 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "std::endl" is directly included (fixes available)
return inout;
}

// impl for torch
TORCH_LIBRARY_IMPL(ops, CUDA, m) {
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
m.impl("apply_penalty", apply_penalty);
}
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
m.impl("dest_index_copy_kv", dest_index_copy_kv);
}
// By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1, AutogradPrivateUse1 etc) are registered. If you need to register separately for a certain backend, separate registration for a certain backend is also supported.
TORCH_LIBRARY(deeplink_ext_, m) {
m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)");
m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)");
m.def("example(Tensor(a!) inout)->Tensor(a!)", example_for_all_backend);
}

// impl for torch_npu
TORCH_LIBRARY_IMPL(ops, PrivateUse1, m) {
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
m.impl("apply_penalty", apply_penalty);
}
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
m.impl("dest_index_copy_kv", dest_index_copy_kv);
}
// only impl for dipu
TORCH_LIBRARY_IMPL(deeplink_ext_, XPU, m) {

Check notice on line 404 in csrc/extensions.cpp

View workflow job for this annotation

GitHub Actions / static checks on sco

clangd Information [missing-includes]

No header providing "c10::DispatchKey::XPU" is directly included (fixes available)
// m.impl("example", example_only_for_xpu);
}

} // namespace dipu::dipu_ext
} // namespace dipu::dipu_ext
24 changes: 24 additions & 0 deletions test_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch_dipu
import deeplink_ext
torch.ops.load_library(deeplink_ext.__path__[0] + "/cpp_extensions.cpython-39-x86_64-linux-gnu.so")
print(f"torch.ops.loaded_libraries:{torch.ops.loaded_libraries}")

#print(torch.ops.deeplink_ext_.dest_index_copy_kv)

def code_to_profile():
x = torch.randn(3,4)
y = torch.ops.deeplink_ext_.example(x)
y = torch.ops.deeplink_ext_.example(x.cuda())


with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
code_to_profile()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))

0 comments on commit d0c203a

Please sign in to comment.