Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gptManagerBenchmark seems to go into a dead loop with GPU usage 0% #1562

Closed
sleepwalker2017 opened this issue May 8, 2024 · 4 comments
Closed
Assignees
Labels
stale triaged Issue has been triaged by maintainers

Comments

@sleepwalker2017
Copy link

sleepwalker2017 commented May 8, 2024

I run this on GPU: 2 * A30 with CUDA driver 535.104.12.
The docker image is built using make -C docker release_build CUDA_ARCHS="80-real"
I use the latest code in branch main.

commit 89ba1b1a67d570e41b03da87e5518eaff0d31fbf (HEAD -> main, origin/main, origin/HEAD)
Author: Kaiyu Xie <[email protected]>
Date:   Tue May 7 23:34:28 2024 +0800

The GPU usage is 0% and CPU keeps 100% for a long time.

image

image

Stack trace for both processes:

Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
(gdb) bt
#0  0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007f1d26c9c624 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
#2  0x00007f1d270122c7 in std::thread::join() () from /lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x00007f1d2e082539 in tensorrt_llm::batch_manager::GptManager::waitUntilTerminate() () from /data/TensorRT-LLM/cpp/build/tensorrt_llm/libtensorrt_llm.so
#4  0x000056485357f62e in main ()
(gdb) info threads
  Id   Target Id                                           Frame
* 1    Thread 0x7f1d26b11000 (LWP 44455) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  2    Thread 0x7f1ce47ba000 (LWP 44456) "cuda00001400006" 0x00007f1d26d1ebcf in poll () from /lib/x86_64-linux-gnu/libc.so.6
  3    Thread 0x7f1cbffff000 (LWP 44461) "gptManagerBench" 0x00007f1d26d2be2e in epoll_wait () from /lib/x86_64-linux-gnu/libc.so.6
  4    Thread 0x7f1cbbffe000 (LWP 44462) "gptManagerBench" 0x00007f1d26d2be2e in epoll_wait () from /lib/x86_64-linux-gnu/libc.so.6
  5    Thread 0x7f1ca7fff000 (LWP 44466) "fuse"            0x00007f1d26d1a81c in read () from /lib/x86_64-linux-gnu/libc.so.6
  6    Thread 0x7f1ca3ffe000 (LWP 44468) "async"           0x00007f1d26d2be2e in epoll_wait () from /lib/x86_64-linux-gnu/libc.so.6
  7    Thread 0x7f1c93fa9000 (LWP 44470) "cuda-EvtHandlr"  0x00007f1d26d1ebcf in poll () from /lib/x86_64-linux-gnu/libc.so.6
  8    Thread 0x7f1965f26000 (LWP 44492) "cuda-EvtHandlr"  0x00007f1d26d1ebcf in poll () from /lib/x86_64-linux-gnu/libc.so.6
  9    Thread 0x7f1637fff000 (LWP 44509) "gptManagerBench" 0x00007f1d26d1ebcf in poll () from /lib/x86_64-linux-gnu/libc.so.6
  10   Thread 0x7f1633ffe000 (LWP 44510) "gptManagerBench" 0x00007f1d26d1ebcf in poll () from /lib/x86_64-linux-gnu/libc.so.6
  11   Thread 0x7f161ffff000 (LWP 44516) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  12   Thread 0x7f11ddcf4000 (LWP 44524) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  13   Thread 0x7f11d9cf3000 (LWP 44525) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  14   Thread 0x7f11cffff000 (LWP 44526) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  15   Thread 0x7f11c3fff000 (LWP 44527) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  16   Thread 0x7f11b7fff000 (LWP 44528) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  17   Thread 0x7f11abfff000 (LWP 44529) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  18   Thread 0x7f119ffff000 (LWP 44530) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  19   Thread 0x7f1193fff000 (LWP 44531) "gptManagerBench" 0x00007f1d26c97117 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
  20   Thread 0x7f1c77fff000 (LWP 44544) "gptManagerBench" 0x00007f1d26c4ac10 in getenv () from /lib/x86_64-linux-gnu/libc.so.6

It should be noted that it has finished the inference for non-lora requests, and get stuck when doing lora benchmark.

I post the scripts for reproducing this at the end of this issue.

Some additional questions about this script:

  1. what is the correct parameter for the --lora_target_modules, I see the manual only gives attn_qkv, why is that? and what is the meaning for attn_qkv?
  2. The manual sets --lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )), what is the meaning for 32? Is it lora_num ?
  3. The NUM_LORA_MODS is set as 7, what is that for? Does it mean this : attn_q attn_k attn_v attn_dense mlp_h_to_4h mlp_gate mlp_4h_to_h, So what is attn_qkv for?
MODEL_CHECKPOINT=/data/vicuna-13b/vicuna-13b-v1.5/
CONVERTED_CHECKPOINT=Llama-2-13b-hf-ckpt
TOKENIZER=/data/vicuna-13b/vicuna-13b-v1.5/
LORA_ENGINE=Llama-2-13b-hf-engine

DTYPE=float16
TP=2
PP=1
MAX_LEN=1024
MAX_BATCH=24
MAX_LORA_RANK=32
NUM_LORA_MODS=7

SOURCE_LORA=chinese-llama-2-lora-13b
CPP_LORA=chinese-llama-2-lora-13b-cpp

EG_DIR=/tmp/lora-eg

# Build lora enabled engine
python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
                              --output_dir ${CONVERTED_CHECKPOINT} \
                              --dtype ${DTYPE} \
                              --tp_size ${TP} \
                              --pp_size 1

trtllm-build \
    --checkpoint_dir ${CONVERTED_CHECKPOINT} \
    --output_dir ${LORA_ENGINE} \
    --max_batch_size ${MAX_BATCH} \
    --max_input_len $MAX_LEN \
    --max_output_len $MAX_LEN \
    --gemm_plugin float16 \
    --lora_plugin float16 \
    --use_paged_context_fmha enable \
    --lora_target_modules attn_qkv attn_q attn_k attn_v attn_dense mlp_h_to_4h mlp_gate mlp_4h_to_h \
    --use_custom_all_reduce disable \
    --max_lora_rank ${MAX_LORA_RANK}

NUM_LORAS=(1)
NUM_REQUESTS=200

# Convert LoRA to cpp format
python examples/hf_lora_convert.py \
    -i $SOURCE_LORA \
    --storage-type $DTYPE \
    -o $CPP_LORA

# Prepare datasets
mkdir -p $EG_DIR/data

# Prepare dataset without lora_task_id
python benchmarks/cpp/prepare_dataset.py \
    --output "${EG_DIR}/data/token-norm-dist.json" \
    --tokenizer $TOKENIZER \
    token-norm-dist \
    --num-requests $NUM_REQUESTS \
    --input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24

# Prepare dataset with lora_task_ids from 0 - $nloras
for nloras in ${NUM_LORAS[@]}; do
    python benchmarks/cpp/prepare_dataset.py \
        --output "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
        --rand-task-id 0 $(( $nloras - 1 )) \
        --tokenizer $TOKENIZER \
        token-norm-dist \
        --num-requests $NUM_REQUESTS \
        --input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
done

# Generate random lora weights for 256 adapters
python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 8

# perform benchmarking
NUM_LAYERS=40
NUM_LORA_MODS=7
MAX_LORA_RANK=32
EOS_ID=-1
# First run inference without LoRAs
mkdir -p ${EG_DIR}/log-base-lora
mpirun -n ${TP} --allow-run-as-root --output-filename ${EG_DIR}/log-base-lora \
    cpp/build/benchmarks/gptManagerBenchmark \
    --engine_dir $LORA_ENGINE \
    --type IFB \
    --dataset "${EG_DIR}/data/token-norm-dist.json" \
    --lora_host_cache_bytes 8589934592 \
    --lora_num_device_mod_layers $(( 1 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
    --kv_cache_free_gpu_mem_fraction 0.80 \
    --log_level error \
    --eos_id ${EOS_ID}

# Now run inference with various numbers or loras
# The host cache is set large enough to hold all the LoRAs in lora_dir
# GPU cache is set to hold 32 LoRAs
# This benchmark will preload all the LoRAs into the host cache
# We run inference on a range of active LoRAs exercising different cache miss rates.
for nloras in ${NUM_LORAS[@]}; do
    mkdir -p ${EG_DIR}/log-lora-${nloras}
    mpirun -n ${TP} --allow-run-as-root --output-filename "${EG_DIR}/log-lora-${nloras}" \
        cpp/build/benchmarks/gptManagerBenchmark \
        --engine_dir $LORA_ENGINE \
        --type IFB \
        --dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
        --lora_host_cache_bytes 8589934592 \
        --lora_num_device_mod_layers $(( 1 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
        --kv_cache_free_gpu_mem_fraction 0.80 \
        --log_level error \
        --eos_id ${EOS_ID} \
        --lora_dir ${EG_DIR}/loras
done
@sleepwalker2017
Copy link
Author

sleepwalker2017 commented May 20, 2024

@byshiue @juney-nvidia Anyone give some comments?

@VincentJing
Copy link

VincentJing commented May 29, 2024

  • question 1: you can refer to this link for the definition of the Lora modules. attn_qkv is a combined qkv adapter.
  • question 2: since GPU cache is set to hold 32 LoRAs, so lora_num_device_mod_layers needs to be multiplied by 32. lora_num_device_mod_layers is the number of max sized 1-layer 1-module sets of weights that can be stored in host cache
  • question 3: you're right.

You can set correctly parameters and try again.
Here is a script that you can refer to.

# git-lfs clone https://huggingface.co/meta-llama/Llama-2-13b-hf
# git-lfs clone https://huggingface.co/hfl/chinese-llama-2-lora-13b

MODEL_CHECKPOINT=/llm_data/llm-models/llama-models-v2/llama-v2-13b-hf
CONVERTED_CHECKPOINT=Llama-2-13b-hf-ckpt
TOKENIZER=/llm_data/llm-models/llama-models-v2/llama-v2-13b-hf
LORA_ENGINE=Llama-2-13b-hf-engine

DTYPE=float16
TP=2
PP=1
MAX_LEN=1024
MAX_BATCH=32
NUM_LAYERS=40
MAX_LORA_RANK=64
NUM_LORA_MODS=7

SOURCE_LORA=/llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b
CPP_LORA=chinese-llama-2-lora-13b-cpp

EG_DIR=tmp/lora-eg

# Build lora enabled engine
python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
                              --output_dir ${CONVERTED_CHECKPOINT} \
                              --dtype ${DTYPE} \
                              --tp_size ${TP} \
                              --pp_size 1

trtllm-build \
    --checkpoint_dir ${CONVERTED_CHECKPOINT} \
    --output_dir ${LORA_ENGINE} \
    --max_batch_size ${MAX_BATCH} \
    --max_input_len $MAX_LEN \
    --max_output_len $MAX_LEN \
    --gemm_plugin float16 \
    --lora_plugin float16 \
    --use_paged_context_fmha enable \
    --lora_target_modules attn_q attn_k attn_v attn_dense mlp_h_to_4h mlp_4h_to_h mlp_gate \
    --max_lora_rank ${MAX_LORA_RANK}

NUM_LORAS=(8 16)
NUM_REQUESTS=1024


# Convert LoRA to cpp format
python examples/hf_lora_convert.py \
    -i $SOURCE_LORA \
    --storage-type $DTYPE \
    -o $CPP_LORA

# Prepare datasets
mkdir -p $EG_DIR/data

# Prepare dataset without lora_task_id
python benchmarks/cpp/prepare_dataset.py \
    --output "${EG_DIR}/data/token-norm-dist.json" \
    --tokenizer $TOKENIZER \
    token-norm-dist \
    --num-requests $NUM_REQUESTS \
    --input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24

# Prepare dataset with lora_task_ids from 0 - $nloras
for nloras in ${NUM_LORAS[@]}; do
    python benchmarks/cpp/prepare_dataset.py \
        --output "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
        --rand-task-id 0 $(( $nloras - 1 )) \
        --tokenizer $TOKENIZER \
        token-norm-dist \
        --num-requests $NUM_REQUESTS \
        --input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
done

# Generate random lora weights for 16 adapters
python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 16

# perform benchmarking

EOS_ID=2

# First run inference without LoRAs
mkdir -p ${EG_DIR}/log-base-lora
mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \
    cpp/build/benchmarks/gptManagerBenchmark \
    --engine_dir $LORA_ENGINE \
    --type IFB \
    --dataset "${EG_DIR}/data/token-norm-dist.json" \
    --lora_host_cache_bytes 8589934592 \
    --lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
    --kv_cache_free_gpu_mem_fraction 0.80 \
    --log_level info \
    --eos_id ${EOS_ID}

# Now run inference with various numbers or loras
# The host cache is set large enough to hold all the LoRAs in lora_dir
# GPU cache is set to hold 16 LoRAs
# This benchmark will preload all the LoRAs into the host cache
# We run inference on a range of active LoRAs exercising different cache miss rates.
for nloras in ${NUM_LORAS[@]}; do
    mkdir -p ${EG_DIR}/log-lora-${nloras}
    mpirun -n ${TP} --output-filename "${EG_DIR}/log-lora-${nloras}" \
        cpp/build/benchmarks/gptManagerBenchmark \
        --engine_dir $LORA_ENGINE \
        --type IFB \
        --dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
        --lora_host_cache_bytes 8589934592 \
        --lora_num_device_mod_layers $(( 16 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
        --kv_cache_free_gpu_mem_fraction 0.80 \
        --log_level info \
        --eos_id ${EOS_ID} \
        --lora_dir ${EG_DIR}/loras
done

@kaiyux
Copy link
Member

kaiyux commented Aug 31, 2024

Hi @sleepwalker2017 , can you please help check if the issue has been fixed on the latest main branch? Thanks.

@nv-guomingz
Copy link
Collaborator

Hi @sleepwalker2017 do u still have further issue or question now? If not, we'll close it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants