Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broken cpp integration caused by #567 #572

Merged
merged 1 commit into from
Oct 30, 2024

Conversation

tsu-bin
Copy link
Contributor

@tsu-bin tsu-bin commented Oct 30, 2024

Hi, cpp integration was broken again by #567, please be aware that there are cpp test, cpp benchmark and also tvm integration, they all relay on cmake build.

@zhyncs zhyncs requested review from yzh119 and abcdabcd987 and removed request for yzh119 October 30, 2024 12:15
@zhyncs
Copy link
Member

zhyncs commented Oct 30, 2024

src/tvm_wrapper.cu(690): error: identifier "BatchQKApplyRotaryInPlace" is undefined
    if (q->dtype.code == kDLFloat && q->dtype.bits == 16) { using dtype = half; {if (indptr->dtype.code == kDLInt && indptr->dtype.bits == 32) { using idtype = int32_t; { cudaError_t status = BatchQKApplyRotaryInPlace( static_cast<dtype*>(q->data), static_cast<dtype*>(k->data), static_cast<idtype*>(indptr->data), static_cast<idtype*>(offsets->data), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, false, rope_scale, rope_theta); if (status != cudaSuccess) { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 698).stream() << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } } } else { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 691).stream() << "Unsupported data type " << indptr->dtype.code; }} } else { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 690).stream() << "Unsupported data type " << q->dtype.code; }

@tsu-bin I verified your branch, and overall there are no major issues. This error wasn't introduced by you. Would it be convenient for you to fix it in the next PR? I'll merge this PR first. cc @yzh119

@zhyncs zhyncs merged commit f19e308 into flashinfer-ai:main Oct 30, 2024
@tsu-bin
Copy link
Contributor Author

tsu-bin commented Oct 30, 2024

hi @zhyncs I will look into this issue later, I think this is less urgent, since tvm / mlc only rely on specific commit point.
Would you help to merge this pr, then I can continue to rebase my current work.

@tsu-bin tsu-bin deleted the fix_broken_build branch October 30, 2024 13:05
@abcdabcd987
Copy link
Member

oops. Sorry for breaking cpp

@yzh119
Copy link
Collaborator

yzh119 commented Oct 30, 2024

Sorry for the late reply, yes this PR fixes #571, thank you!

@tsu-bin
Copy link
Contributor Author

tsu-bin commented Nov 5, 2024

src/tvm_wrapper.cu(690): error: identifier "BatchQKApplyRotaryInPlace" is undefined
    if (q->dtype.code == kDLFloat && q->dtype.bits == 16) { using dtype = half; {if (indptr->dtype.code == kDLInt && indptr->dtype.bits == 32) { using idtype = int32_t; { cudaError_t status = BatchQKApplyRotaryInPlace( static_cast<dtype*>(q->data), static_cast<dtype*>(k->data), static_cast<idtype*>(indptr->data), static_cast<idtype*>(offsets->data), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, false, rope_scale, rope_theta); if (status != cudaSuccess) { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 698).stream() << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } } } else { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 691).stream() << "Unsupported data type " << indptr->dtype.code; }} } else { ::tvm::runtime::detail::LogFatal("/workspace/persistent-storage/flashinfer_dev/src/tvm_wrapper.cu", 690).stream() << "Unsupported data type " << q->dtype.code; }

@tsu-bin I verified your branch, and overall there are no major issues. This error wasn't introduced by you. Would it be convenient for you to fix it in the next PR? I'll merge this PR first. cc @yzh119

Hi @zhyncs just fix it by #582

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants