Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mixtral quantization using HQT #67

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4b9b955
support hqt on vllm
nirda7 Jun 18, 2024
f3ffc8c
Support HQT on VLLM - KVCache and Mark Step uses
nirda7 Jun 18, 2024
8ffc3d0
HQT on VLLM - prep model and finish measurements and multi cards run
nirda7 Jun 18, 2024
f5f0972
HQT on VLLM - separate kv caches
nirda7 Jun 18, 2024
c521c4d
HQT on VLLM - remove code duplications
nirda7 Jun 18, 2024
64c8c7f
HQT on VLLM - move matmul and softmax to hpu utils and revert logits …
nirda7 Jun 18, 2024
2e291c5
Move model to hpu when HQT is not used
nirda7 Jun 19, 2024
9d0fbb7
fix CR comments
nirda7 Jun 19, 2024
09e0078
add model weights device load
nirda7 Jun 21, 2024
24847a9
skip replay cached graphs during warmup
nirda7 Jun 26, 2024
90c2527
HQT on VLLM - Enable split value in G3
nirda7 Jun 26, 2024
f7c2157
pass optimizations flags only in Lazy mode
nirda7 Jun 27, 2024
83770dc
Merge remote-tracking branch 'origin/habana_next' into vllm-hqt-fork
madamczykhabana Jun 28, 2024
ae1d3f4
Filter-out warmup_mode before passing to model.forward
madamczykhabana Jun 28, 2024
33a2620
Merge pull request #75 from HabanaAI/vllm-hqt-fork
madamczykhabana Jun 28, 2024
566bdd2
Profile single forward (#68)
adobrzyniewicz-habana Jul 1, 2024
55ea726
Skip logprobs processing for greedy
Jun 28, 2024
0674aea
Fix lower bucket range calculation
madamczykhabana Jul 2, 2024
15c67ed
Disable warmup_mode for now
madamczykhabana Jul 2, 2024
77e1ab8
Introduce delayed sampling mechanism (#84)
lahead Jul 4, 2024
1dc6cb2
Disable tensor cache set to True (#88)
michalkuligowski Jul 4, 2024
4afe86d
Revert "Disable tensor cache set to True (#88)" (#89)
madamczykhabana Jul 4, 2024
ca1dbf6
Revert "Revert "Disable tensor cache set to True (#88)" (#89)" (#90)
michalkuligowski Jul 5, 2024
87d95ad
Support Mixtral quantization using HQT
dudilester Jun 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions hpu-utils/profile_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import argparse
import torch
import os
import glob
import shutil

os.environ['VLLM_SKIP_WARMUP']='true'
from vllm import LLM, SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, ExecuteModelRequest
from multiprocessing import Process

def setup_profiler(steps):
activities = [torch.profiler.ProfilerActivity.CPU]
activities.extend([torch.profiler.ProfilerActivity.HPU])
wait = 0
active = 1
warmup = steps - active

schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1)
profiler = torch.profiler.profile(
schedule=schedule,
activities=activities,
on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True),
record_shapes=False,
with_stack=True)
return profiler

def profiler_files_organise(output_file):
"""Changes new profiling file to specified path"""
profiler_files = glob.glob('./*.json.gz')
latest_file = max(profiler_files, key=os.path.getctime)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
shutil.move(latest_file, output_file)

def kill_process(pid):
"""Kills python3 main process manually"""
print("Killing process manually")
import psutil
for proc in psutil.process_iter():
if proc.pid == pid:
proc.kill()

def round_up(n, k):
return ((n + k - 1) // k) * k

def run_forward(llm, is_prompt, block_size, batch_size, seq_len):
"""Single forward run"""
sampling_params = SamplingParams(temperature=0)
seqs = []
if is_prompt:
input_len = seq_len
output_len = 0
else:
input_len = seq_len - 1
output_len = 1

for group_id in range(batch_size):
prompt_token_ids = [0] * input_len
output_token_ids = [1] * output_len
block_tables = {group_id: [0] * (round_up(seq_len, block_size) // block_size)}
seq_data = SequenceData(prompt_token_ids)
seq_data.output_token_ids = output_token_ids
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=(output_len == 0),
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
)
seqs.append(seq)

model_request = ExecuteModelRequest(seq_group_metadata_list=seqs)

llm.llm_engine.model_executor.execute_model(model_request)

print("Forward completed")

def run_vllm(model_dtype, is_prompt, args):
"""vLLM setup and run"""
llm = LLM(model=args.model_path, enforce_eager=True, dtype=model_dtype, block_size=args.block_size, tensor_parallel_size=args.num_cards)
profiler = setup_profiler(args.steps)
profiler.start()
print("Starting steps")
for _ in range(args.steps):
run_forward(llm, is_prompt, args.block_size, args.batch_size, args.seq_len)
profiler.step()
profiler.stop()
print("Finished running llm")

parser = argparse.ArgumentParser("vLLM arguments parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--model-path", help="Path to the model that will be used", type=str, required=True)
parser.add_argument("--num-cards", help="Number of cards that will be used by model", type=int, default=1)
parser.add_argument("--phase", help="Phase", type=str, choices=["prompt", "decode"], default="decode")
parser.add_argument("--data-type", help="Type of data that will be used", type=str, default="bf16", choices=["bf16"])
parser.add_argument("--output-path", help="Path where profiler file will be stored", type=str, default="./output.json.gz")
parser.add_argument("--block-size", help="Block size", type=int, default=128)
parser.add_argument("--batch-size", help="Batch size", type=int, default=32)
parser.add_argument("--seq-len", help="Sequence length", type=int, default=1024)
parser.add_argument("--steps", help="Number of steps", type=int, default=3)
args = parser.parse_args()

print(args)

if args.data_type == "bf16":
model_dtype = torch.bfloat16

is_prompt = args.phase == "prompt"

pid = os.getpid()

run_vllm(model_dtype, is_prompt, args)

profiler_files_organise(args.output_path)

print("Done")

kill_process(pid)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _is_neuron() -> bool:
torch_neuronx_installed = True
try:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
torch_neuronx_installed = False
return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON

Expand Down
30 changes: 22 additions & 8 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import os
import torch
import math
import vllm.hpu.xops as xops
from vllm.hpu import cache_ops, xops
from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache
from vllm.hpu.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)

Expand Down Expand Up @@ -111,7 +113,7 @@ def __post_init__(self):
self.attn_bias: Optional[List[AttentionBias]] = None


class HabanaAttentionImpl(AttentionImpl):
class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
Expand All @@ -137,8 +139,14 @@ def __init__(
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.num_heads = num_heads
self.head_size = head_size
self.qk_matmul = Matmul()
self.softmax = Softmax()
self.kv_matmul = Matmul()
self.key_cache = VLLMKVCache()
self.value_cache = VLLMKVCache()
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
Expand Down Expand Up @@ -188,11 +196,9 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
HabanaPagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
attn_metadata.prefill_metadata is not None)
block_indices, block_offset = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping)
key_cache = self.key_cache(key, key_cache, block_indices, block_offset)
value_cache = self.value_cache(value, value_cache, block_indices, block_offset)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
Expand All @@ -208,6 +214,9 @@ def forward(
attn_bias=prefill_meta.attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
softmax_op=self.softmax,
kv_matmul_op=self.kv_matmul,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand Down Expand Up @@ -237,7 +246,12 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale
kv_scale,
self.qk_matmul,
self.softmax,
self.kv_matmul,
self.key_cache.fetch_from_cache,
self.value_cache.fetch_from_cache,
)

# Reshape the output tensor.
Expand Down
10 changes: 10 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def forward_decode(
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
qk_op=torch.matmul,
softmax_op=torch.softmax,
kv_op=torch.matmul,
keys_fetch=ops.fetch_from_cache,
values_fetch=ops.fetch_from_cache,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
Expand All @@ -98,6 +103,11 @@ def forward_decode(
block_size,
alibi_slopes,
kv_cache_dtype,
qk_op,
softmax_op,
kv_op,
keys_fetch,
values_fetch,
)

@staticmethod
Expand Down
18 changes: 16 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,15 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8":
elif self.cache_dtype in ["fp8", "hf8"]:
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
"is instead supported for common inference criteria. "
"FP8_E4M3 is also supported on hpu.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down Expand Up @@ -474,10 +475,12 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
device: Device on which weights are loaded.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
device: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)

Expand Down Expand Up @@ -602,6 +605,7 @@ def __init__(
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
enable_delayed_sampling: bool = False,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -623,6 +627,7 @@ def __init__(
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.enable_delayed_sampling = enable_delayed_sampling

self._verify_args()

Expand All @@ -649,6 +654,15 @@ def _verify_args(self) -> None:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")

if self.enable_delayed_sampling and self.num_lookahead_slots != 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be 1 for delayed sampling.")

if self.enable_delayed_sampling and not self.use_v2_block_manager:
raise ValueError(
"use_v2_block_manager "
f"({self.use_v2_block_manager}) must be True for delayed sampling.")

class DeviceConfig:

Expand Down
22 changes: 20 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
weights_load_device: Optional[str] = None
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -77,6 +78,7 @@ class EngineArgs:
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
enable_delayed_sampling: bool = False

guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
Expand Down Expand Up @@ -168,6 +170,11 @@ def add_cli_args(
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of '
'the serialized weights.')
parser.add_argument("--weights-load-device",
type=str,
default=EngineArgs.weights_load_device,
choices=["cuda", "neuron", "hpu", "cpu"],
help='Device on which weights are loaded.')
parser.add_argument(
'--dtype',
type=str,
Expand All @@ -186,12 +193,13 @@ def add_cli_args(
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'hf8'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'supported for common inference criteria. FP8_E4M3 is also supported '
'on hpu.')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down Expand Up @@ -444,6 +452,13 @@ def add_cli_args(
action='store_true',
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument(
'--enable-delayed-sampling',
action='store_true',
help='If set, the sampling will be delayed by 1 step. First '
'model request execution (prefill) will return an invalid token '
'id that will be discarded. Actual sampling of valid token ids '
'starts from second model execution.')

parser.add_argument(
'--speculative-model',
Expand Down Expand Up @@ -564,6 +579,7 @@ def create_engine_config(self, ) -> EngineConfig:
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
enable_delayed_sampling=self.enable_delayed_sampling,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
Expand All @@ -574,9 +590,11 @@ def create_engine_config(self, ) -> EngineConfig:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

device = device_config.device_type if self.weights_load_device is None else self.weights_load_device
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device=device,
model_loader_extra_config=self.model_loader_extra_config,
)

Expand Down
Loading