Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MUSA: support ARM64 and enable dp4a .etc #11843

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open

Conversation

BodhiHu
Copy link

@BodhiHu BodhiHu commented Feb 13, 2025

This PR will do:

  1. enable dp4a on MUSA;
  2. fix compile errors on MUSA ARM64;
  3. support sparse MoE param expert_weights_scale for MoE sparsified LLaMA models;

Tested with following models:

ARM64:

MUSA SDK: 3.1.2
CPU compiler: clang-17

  • qwen2.5-1.5b-instruct-q8_0.gguf
  • qwen2.5-3b-instruct-q4_k_m.gguf
  • deepseek-r1-7B-Q4_K_M.gguf

x86:

MUSA SDK: 3.1.1
CPU compiler: clang-14

  • llama3_8b_q4_0.gguf
  • deepseek-r1_7b_q4_0.gguf
  • qwen2.5-3b-instruct-q4_k_m.gguf

@github-actions github-actions bot added documentation Improvements or additions to documentation build Compilation issues Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Feb 13, 2025
@BodhiHu BodhiHu changed the title [wip] MUSA: enable dp4a and fix compile errors on ARM64 MUSA: enable dp4a and fix compile errors on ARM64 Feb 13, 2025
@BodhiHu
Copy link
Author

BodhiHu commented Feb 13, 2025

Hi @JohannesGaessler , @ggerganov , @slaren , @yeahdongcn ,

Can you please help review this PR ?

Thanks a lot.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to the CUDA backend look fine to me other than the things I commented on. I don't know whether the changes for model support are correct.

CMakeLists.txt Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/common.cuh Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/ggml-cuda.cu Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/mmq.cu Outdated Show resolved Hide resolved
BodhiHu and others added 3 commits February 13, 2025 19:15
Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>
@yeahdongcn
Copy link
Contributor

yeahdongcn commented Feb 13, 2025

Please run the functionality tests and the tests under the tests directory on amd64 as well.
BTW, I'm updating the MUSA SDK version to rc3.1.1. You may want to hold off until #11822 is reviewed and merged.

@BodhiHu
Copy link
Author

BodhiHu commented Feb 14, 2025

The changes to the CUDA backend look fine to me other than the things I commented on. I don't know whether the changes for model support are correct.

Hi @JohannesGaessler , the changes to model support is to enable the expert_weights_scale for MoE sparsified LLaMA models,
I tested with following LLaMA MoE model, and it runs well:

https://huggingface.co/llama-moe/LLaMA-MoE-v2-3_8B-2_8-sft

image

@BodhiHu
Copy link
Author

BodhiHu commented Feb 14, 2025

Please run the functionality tests and the tests under the tests directory on amd64 as well. BTW, I'm updating the MUSA SDK version to rc3.1.1. You may want to hold off until #11822 is reviewed and merged.

Hi @yeahdongcn , I see #11822 had been merged.

When running ./build/bin/test-backend-ops, there's an exception, don't know if this also happens on your side or is an known issue ?

  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [MUSA0]
  CROSS_ENTROPY_LOSS(type=f32,ne=[10,5,4,3]): MUSA error: invalid argument
  current device: 0, in function ggml_cuda_cross_entropy_loss at /home/mm/bodhi/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu:129
  musaFuncSetAttribute(cross_entropy_loss_back_f32<true>, musaFuncAttributeMaxDynamicSharedMemorySize, smpbo)
/home/mm/bodhi/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:73: MUSA error
[New LWP 178917]
[New LWP 178918]
[New LWP 178919]
[New LWP 178920]
[New LWP 178933]
[New LWP 178982]
[New LWP 179583]
[New LWP 179584]
[New LWP 179585]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/aarch64-linux-gnu/libthread_db.so.1".
0x0000ffff8d436800 in __GI___wait4 (pid=<optimized out>, stat_loc=0xffffc4a3e86c, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory.
#0  0x0000ffff8d436800 in __GI___wait4 (pid=<optimized out>, stat_loc=0xffffc4a3e86c, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x0000aaaac10a8d44 in ggml_print_backtrace ()
#2  0x0000aaaac10a8cd8 in ggml_abort ()
#3  0x0000aaaac0f6c8cc in ggml_cuda_error(char const*, char const*, char const*, int, char const*) ()
#4  0x0000aaaac1054b2c in ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context&, ggml_tensor*) ()
#5  0x0000aaaac0f715bc in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) ()
#6  0x0000aaaac10bf404 in ggml_backend_compare_graph_backend ()
#7  0x0000aaaac0ee4e78 in test_case::eval(ggml_backend*, ggml_backend*, char const*) ()
#8  0x0000aaaac0ed1f14 in main ()
[Inferior 1 (process 178916) detached]
Aborted (core dumped)

@BodhiHu
Copy link
Author

BodhiHu commented Feb 14, 2025

Please run the functionality tests and the tests under the tests directory on amd64 as well. BTW, I'm updating the MUSA SDK version to rc3.1.1. You may want to hold off until #11822 is reviewed and merged.

Hi @yeahdongcn , I see #11822 had been merged.

When running ./build/bin/test-backend-ops, there's an exception, don't know if this also happens on your side or is an known issue ?

  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=bf16,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q8_0,permute=[0,1,2,3]): not supported [MUSA0]
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=35,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=q4_0,permute=[0,1,2,3]): not supported [MUSA0]
  CROSS_ENTROPY_LOSS(type=f32,ne=[10,5,4,3]): MUSA error: invalid argument
  current device: 0, in function ggml_cuda_cross_entropy_loss at /home/mm/bodhi/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu:129
  musaFuncSetAttribute(cross_entropy_loss_back_f32<true>, musaFuncAttributeMaxDynamicSharedMemorySize, smpbo)
/home/mm/bodhi/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:73: MUSA error
[New LWP 178917]
[New LWP 178918]
[New LWP 178919]
[New LWP 178920]
[New LWP 178933]
[New LWP 178982]
[New LWP 179583]
[New LWP 179584]
[New LWP 179585]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/aarch64-linux-gnu/libthread_db.so.1".
0x0000ffff8d436800 in __GI___wait4 (pid=<optimized out>, stat_loc=0xffffc4a3e86c, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory.
#0  0x0000ffff8d436800 in __GI___wait4 (pid=<optimized out>, stat_loc=0xffffc4a3e86c, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x0000aaaac10a8d44 in ggml_print_backtrace ()
#2  0x0000aaaac10a8cd8 in ggml_abort ()
#3  0x0000aaaac0f6c8cc in ggml_cuda_error(char const*, char const*, char const*, int, char const*) ()
#4  0x0000aaaac1054b2c in ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context&, ggml_tensor*) ()
#5  0x0000aaaac0f715bc in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) ()
#6  0x0000aaaac10bf404 in ggml_backend_compare_graph_backend ()
#7  0x0000aaaac0ee4e78 in test_case::eval(ggml_backend*, ggml_backend*, char const*) ()
#8  0x0000aaaac0ed1f14 in main ()
[Inferior 1 (process 178916) detached]
Aborted (core dumped)

FYI,the above CROSS_ENTROPY_LOSS op test error had been fixed.

@BodhiHu BodhiHu closed this Feb 14, 2025
@BodhiHu BodhiHu reopened this Feb 14, 2025
ggml/src/ggml-cuda/common.cuh Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/common.cuh Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/cross-entropy-loss.cu Show resolved Hide resolved
ggml/src/ggml-cuda/mmq.cu Outdated Show resolved Hide resolved
@BodhiHu BodhiHu changed the title MUSA: enable dp4a and fix compile errors on ARM64 MUSA: support ARM64 and enable dp4a .etc Feb 17, 2025
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
@BodhiHu
Copy link
Author

BodhiHu commented Feb 18, 2025

Hi @yeahdongcn , the model running issue had been fixed on x86,
tested with following models and it runs well now:

  • llama3_8b_q4_0.gguf
  • deepseek-r1_7b_q4_0.gguf
  • qwen2.5-3b-instruct-q4_k_m.gguf

@BodhiHu
Copy link
Author

BodhiHu commented Feb 18, 2025

Hi @slaren , the LLaMA-MoE changes to convert_hf_to_gguf.py had been removed, can you please help review again ? Thanks.

@BodhiHu BodhiHu requested review from slaren and yeahdongcn February 19, 2025 02:20
docs/build.md Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/ggml-cuda.cu Outdated Show resolved Hide resolved
ggml/src/ggml-cuda/common.cuh Outdated Show resolved Hide resolved
src/llama-model.cpp Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues documentation Improvements or additions to documentation ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants