Skip to content

Commit

Permalink
Add LLaMA GQA ragged batching (microsoft#18337)
Browse files Browse the repository at this point in the history
This PR updates replacing MHA with GQA and updates the LLaMA scripts for
the modified GQA op. It is related to the changes in [this
PR](microsoft#18283).

### Motivation and Context
This PR allows us to run LLaMA with the GQA op end-to-end using ragged
batching (i.e. batched inputs of different lengths).
  • Loading branch information
kunal-vaishnavi authored and kleiti committed Mar 22, 2024
1 parent 7bac1fc commit a76ff4b
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 106 deletions.
119 changes: 88 additions & 31 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,39 +1272,96 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove


def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Add model input for past sequence length
new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1])
model.model.graph.input.append(new_input)
def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1):
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
#
# attention_mask
# / \
# ReduceSum Shape
# | |
# Sub Gather
# | |
# seqlens_k total_sequence_length
# | |
# Cast to int32 Cast to int32

model.add_initializer(
onnx.helper.make_tensor(
name="one",
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
)
)
reduce_sum_node = onnx.helper.make_node(
"ReduceSum",
inputs=[attn_mask, "one"],
outputs=[attn_mask + "_row_sums"],
name=model.create_node_name("ReduceSum"),
)
sub_node = onnx.helper.make_node(
"Sub",
inputs=[attn_mask + "_row_sums", "one"],
outputs=["seqlens_k_int64"],
name=model.create_node_name("Sub"),
)
seqlen_k_cast_node = onnx.helper.make_node(
"Cast",
inputs=["seqlens_k_int64"],
outputs=["seqlens_k"],
name=model.create_node_name("Cast"),
to=TensorProto.INT32,
)
shape_node = onnx.helper.make_node(
"Shape",
inputs=[attn_mask],
outputs=[attn_mask + "_shape"],
name=model.create_node_name("Shape"),
)
gather_node = onnx.helper.make_node(
"Gather",
inputs=[attn_mask + "_shape", "one"],
outputs=["total_seq_len_int64"],
name=model.create_node_name("Gather"),
axis=0,
)
total_seqlen_cast_node = onnx.helper.make_node(
"Cast",
inputs=["total_seq_len_int64"],
outputs=["total_seq_len"],
name=model.create_node_name("Cast"),
to=TensorProto.INT32,
)
model.model.graph.node.extend(
[reduce_sum_node, sub_node, seqlen_k_cast_node, shape_node, gather_node, total_seqlen_cast_node]
)

# Replace MultiHeadAttention with GroupQueryAttention
for node in model.model.graph.node:
if node.op_type == "MultiHeadAttention":
num_heads_mha = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
node.input[0], # query
node.input[1], # key
node.input[2], # value
node.input[6], # past_key
node.input[7], # past_value
past_seq_len, # past_sequence_length
],
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
is_past_bsnh=0,
)
model.model.graph.node.remove(node)
model.model.graph.node.extend([gqa_node])
mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
for node in mha_nodes:
num_heads_mha = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
node.input[0], # query
node.input[1], # key
node.input[2], # value
node.input[6], # past_key
node.input[7], # past_value
"seqlens_k", # seqlens_k (for attention_mask)
"total_seq_len", # total_seq_len (for attention_mask)
],
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
)
model.model.graph.node.remove(node)
model.model.graph.node.extend([gqa_node])
return model


Expand Down
97 changes: 86 additions & 11 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu
```

Export for FP16 CUDA
Export for FP16 CUDA (with MultiHeadAttention)
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda
Expand All @@ -126,6 +126,63 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda
```

Export for FP16 CUDA (with GroupQueryAttention)
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa
```

Note: GroupQueryAttention currently runs on Linux for FP16 CUDA and INT4 CUDA models, and it can provide faster inference than MultiHeadAttention, especially for large sequence lengths (e.g. 1024 or larger). For the best performance, you should pre-allocate the KV cache buffers to have size `(batch_size, num_heads, max_sequence_length, head_size)` so that the past KV and present KV caches share the same memory. You also need to bind them with ONNX Runtime's [IO binding](https://onnxruntime.ai/docs/api/python/api_summary.html#iobinding).

Here is an example of how you can bind directly to `torch.tensor` objects:
```
# Assumes all inputs and outputs to the model are pre-allocated with the correct shapes in GPU memory
# Bind inputs
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type="cuda",
device_id=0,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
)
# Bind outputs
for output in model.get_outputs():
name = output.name
if "present" in name:
# Bind KV cache outputs to KV cache inputs
v = inputs[name.replace("present", "past_key_values")]
io_binding.bind_output(
name=name,
device_type="cuda",
device_id=0,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
)
else:
# Bind other outputs as actual outputs
v = outputs[name]
io_binding.bind_output(
name=name,
device_type="cuda",
device_id=0,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
)
io_binding.synchronize_inputs()
sess.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
```

Export for INT8 CPU (SmoothQuant)
```
# From source:
Expand All @@ -149,12 +206,14 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama
Export for INT4 CUDA
```
# From source:
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda
$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa
# From wheel:
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda
$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa
```

Note: See the FP16 CUDA notes about GroupQueryAttention. The `--use_gqa` flag is optional.

Export for INT4 CPU
```
# From source:
Expand All @@ -168,13 +227,13 @@ Export LLaMA-2 70B sharded model into 4 partitions
```
# From source:
# 1. Install necessary packages from requirements-70b-model.txt
$ pip install -r requirements-70b-model.txt
# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command:
$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/
$ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/
# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command:
$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda
$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa
```

## Benchmark LLaMA-2
Expand Down Expand Up @@ -220,7 +279,20 @@ python3 -m models.llama.benchmark \
--auth
```

4. ONNX Runtime, FP32, Microsoft custom export
4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda \
--auth
```

5. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
Expand All @@ -232,7 +304,7 @@ python3 -m models.llama.benchmark \
--device cpu
```

5. ONNX Runtime, FP16, Microsoft custom export
6. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
Expand All @@ -244,7 +316,7 @@ python3 -m models.llama.benchmark \
--device cuda
```

6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU
7. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU
```
CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
Expand All @@ -256,7 +328,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
--device cpu
```

7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU
8. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU
```
CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
Expand All @@ -283,5 +355,8 @@ python3 -m models.llama.benchmark_all \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
--device cuda
--device cuda \
--warmup-runs 5 \
--num-runs 1000 \
--timeout 60 # number of minutes before moving to the next benchmark
```
Loading

0 comments on commit a76ff4b

Please sign in to comment.