diff --git a/Makefile b/Makefile index ccf99e0a..ca5d752e 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ install-punica-kernel: pip install wheel setuptools --upgrade + git submodule sync + git submodule update --init cd server/punica_kernels && pip install -v --no-build-isolation . install-server: diff --git a/server/examples/test_local_api.py b/server/examples/test_local_api.py index 87606efc..87fad7c7 100644 --- a/server/examples/test_local_api.py +++ b/server/examples/test_local_api.py @@ -2,10 +2,6 @@ import torch from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma -from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2 -from text_generation_server.models_flashinfer.flashinfer_chatglm import ( - FlashinferChatGLM, -) import sys try: @@ -31,13 +27,11 @@ # test = "gemma" # test = "llama-3" # test = 'llama-3-70' - test = "gemma" + test = "llama-2" # test = 'mistral' - # test = 'qwen1.5-7' - # test = 'qwen1.5-1.8' - # test = 'qwen1.5-70' - # test = 'qwen2-7' - # test = "chatglm4" + # test = 'qwen2' + # test = 'qwen2-1.8' + # test = 'qwen2-70' print("Testing " + test) # Load demo inputs @@ -167,7 +161,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): ), ] service = FlashinferMistral(model_id="mistralai/Mistral-7B-v0.3") -elif test == "qwen1.5-7": +elif test == "qwen2": requests = [ make_input( "REILX/Qwen1.5-7B-Chat-750Mb-lora", @@ -186,7 +180,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): service = FlashinferQwen2( model_id="Qwen/Qwen1.5-7B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"] ) -elif test == "qwen1.5-1.8": +elif test == "qwen2-1.8": # Todo: Add qwen1.5 1.8b chat lora adapter / Output Repetition Problem requests = [ make_input( @@ -200,7 +194,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): service = FlashinferQwen2( model_id="Qwen/Qwen1.5-1.8B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"] ) -elif test == "qwen1.5-70": +elif test == "qwen2-70": # Todo: Add qwen1.5 72b chat lora adapter requests = [ make_input( @@ -266,45 +260,23 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): service = FlashinferLlama( model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True ) -elif test == "qwen2-7": - # Todo: qwen2-7b instruct lora adapter - requests = [ - make_input( - "abcdabcd987/gsm8k-llama2-7b-lora-16", - "base", - id=0, - promptOverride="给我讲个故事", - ), - ] - service = FlashinferQwen2(model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True) - -elif test == "chatglm4": - # Todo: chatglm4-9b lora adapter - requests = [ - make_input( - "abcdabcd987/gsm8k-llama2-7b-lora-16", - "base", - id=0, - promptOverride="给我讲个故事", - ), - ] - service = FlashinferChatGLM(model_id="THUDM/glm-4-9b-chat", trust_remote_code=True) print(service.get_lora_adapters()) tokenizer = service.tokenizer batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) +pb_batch = FlashinferBatch.from_pb( + batch, tokenizer, torch.float16, torch.device("cuda") +) + +# Add input batch to model service +ids = service.add_request(pb_batch) display_results = {} # Iterative generation: each step generates a token for each input in the batch isPrefill = True while True: - if isPrefill: - generations, next_batch, _ = service.prefill_batch(batch) - isPrefill = False - else: - generations, next_batch, _, _ = service.decode_batch([next_batch.to_pb()]) - + generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id)) for gen in generations: if gen.prefill_tokens: display_results[gen.request_id] = [ diff --git a/server/examples/test_local_grpc.py b/server/examples/test_local_grpc.py index 8b92865c..10ac1e59 100644 --- a/server/examples/test_local_grpc.py +++ b/server/examples/test_local_grpc.py @@ -46,8 +46,18 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): requests = [ - make_input("tjluyao/gemma-2b-it-math", "base", id=0), - make_input("tjluyao/gemma-2b-it-math", "base", id=1), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="Give me a breif introduction to Byznatine Fault Tolerance and why it is important?", + ), + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "lora", + id=1, + promptOverride="Which network interface card is more suitable for distributed systems, Meallanox or Broadcom?", + ), ] # Assemble input batch @@ -68,26 +78,11 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): ) stub.Warmup(wr) # Prefill - pr = generate_pb2.PrefillRequest(batch=pb_batch_with_inputs) + pr = generate_pb2.PrefillRequest(batch=pb_batch_empty) resp = stub.Prefill(pr) - generations, cbatch = resp.generations, resp.batch - for gen in generations: - print(gen.tokens.texts) - - print("finished prefill tokens") - - while True: - dr = generate_pb2.DecodeRequest(batches=[cbatch]) - resp = stub.Decode(dr) - generations, cbatch = resp.generations, resp.batch - toExit = False - for gen in generations: - if gen.generated_text.text: - print("finished") - res = gen.generated_text.text - toExit = True - - if toExit: - break - - print(res) + gen, cbatch = resp.generations, resp.batch + # Decode + dr = generate_pb2.DecodeRequest(batches=[cbatch]) + resp = stub.Decode(dr) + gen, cbatch = resp.generations, resp.batch + print("done") diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 117f8499..4504733e 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -11,9 +11,6 @@ class Cache: def __init__(self): self.cache: Dict[int, B] = {} - def get_all_values(self): - return self.cache.values() - def pop(self, batch_id: int) -> Optional[B]: return self.cache.pop(batch_id, None) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 30bf479f..8406946b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -93,20 +93,30 @@ def serve( if use_flashinfer: from text_generation_server import server_flashinfer - serv = server_flashinfer + server_flashinfer.serve( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + lora_ids, + ) else: - serv = server - serv.serve( - model_id, - revision, - sharded, - quantize, - speculate, - dtype, - trust_remote_code, - uds_path, - lora_ids, - ) + server.serve( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + lora_ids, + ) + @app.command() diff --git a/server/text_generation_server/layers/flashinfer_attention.py b/server/text_generation_server/layers/flashinfer_attention.py index 73245680..a13ddfdd 100644 --- a/server/text_generation_server/layers/flashinfer_attention.py +++ b/server/text_generation_server/layers/flashinfer_attention.py @@ -43,8 +43,6 @@ def __init__( ) self.page_size = 16 - self.group_size = self.num_attention_heads // self.num_key_value_heads - def computeAttention( self, q: torch.Tensor, @@ -184,17 +182,9 @@ def _batchDecode( decodeBatchPosition.kv_last_page_len, ) - if self.group_size in [7, 16]: - decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer=self._workspace_buffer, - kv_layout="NHD", - use_tensor_cores=True, - ) - else: - decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer=self._workspace_buffer, kv_layout="NHD" - ) - + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer=self._workspace_buffer, kv_layout="NHD" + ) decode_wrapper.begin_forward( decodeBatchPosition.kv_page_indptr, decodeBatchPosition.kv_page_indices, diff --git a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py index bbd3e6d5..270f4b46 100644 --- a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py +++ b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py @@ -123,8 +123,7 @@ def forward( q = q_proj.contiguous() k = k_proj.contiguous() v = v_proj.contiguous() - if loraWeight: - loraWeight.apply_lora_weight_kvq(q, k, v, hidden_states, self.layer_idx) + loraWeight.apply_lora_weight_kvq(q, k, v, hidden_states, self.layer_idx) self.rotary_emb( q.view( @@ -152,10 +151,9 @@ def forward( self.rotaryParams, ) attn_outputs = self.o_proj(attn_outputs_raw) - if loraWeight: - loraWeight.apply_lora_weight_attn( - attn_outputs, attn_outputs_raw, self.layer_idx - ) + loraWeight.apply_lora_weight_attn( + attn_outputs, attn_outputs_raw, self.layer_idx + ) return attn_outputs @@ -208,16 +206,13 @@ def forward( gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate = gate_up_states[:, 0].contiguous() - if loraWeight: - loraWeight.apply_lora_weight_gate(gate, hidden_states, self.layer_idx) + loraWeight.apply_lora_weight_gate(gate, hidden_states, self.layer_idx) gate = self.act(gate) up = gate_up_states[:, 1].contiguous() - if loraWeight: - loraWeight.apply_lora_weight_up(up, hidden_states, self.layer_idx) + loraWeight.apply_lora_weight_up(up, hidden_states, self.layer_idx) t = gate * up down = self.down_proj(t) - if loraWeight: - loraWeight.apply_lora_weight_down(down, t, self.layer_idx) + loraWeight.apply_lora_weight_down(down, t, self.layer_idx) return down diff --git a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py index 3057d32b..7b8c54c7 100644 --- a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py +++ b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py @@ -47,11 +47,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import FastRMSNorm -from text_generation_server.layers.flashinfer_attention import ( - FlashinferAttentionWrapper, - AttentionRotaryParams, -) - class FlashinferBatch: def __init__(self, seq_indptr, kv_page_indptr, kv_page_indices, kv_last_page_len): @@ -231,38 +226,177 @@ def forward( kvCachePool: KvCachePool, prefillBatchPosition: KvCacheBatchPosition, decodeBatchPosition: KvCacheBatchPosition, - loraWeight: BatchedModelLoraWeight | None, + lora: BatchedModelLoraWeight | None, ) -> torch.Tensor: - q_dim = ( - self.flashinferWrapper.num_attention_heads * self.flashinferWrapper.head_dim - ) - kv_dim = ( - self.flashinferWrapper.num_key_value_heads * self.flashinferWrapper.head_dim - ) - qkv = self.qkv_proj(hidden_states) + qkv = self.query_key_value(hidden_states) + + # qkv = qkv.to('cuda') + q_proj, k_proj, v_proj = qkv.split( - [q_dim, kv_dim, kv_dim], + [ + self.head_size * self.num_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], dim=1, ) - q = q_proj.contiguous() - k = k_proj.contiguous() - v = v_proj.contiguous() - loraWeight.apply_lora_weight_kvq(q, k, v, hidden_states, self.layer_idx) - attn_outputs_raw = self.flashinferWrapper.computeAttention( - q, - k, - v, - kvCachePool.cache_data[self.layer_idx], - kvCachePool.page_len, - prefillBatchPosition, - decodeBatchPosition, - self.rotaryParams, - ) - attn_outputs = self.o_proj(attn_outputs_raw) - loraWeight.apply_lora_weight_attn( - attn_outputs, attn_outputs_raw, self.layer_idx + + q_proj = q_proj.contiguous() + k_proj = k_proj.contiguous() + v_proj = v_proj.contiguous() + + # print(f"q proj {q_proj}") + # print(f"lora rank: {lora.rank}") + + if lora: + add_lora( + q_proj, + hidden_states, + lora.q.wa_ptr, + lora.q.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + add_lora( + k_proj, + hidden_states, + lora.k.wa_ptr, + lora.k.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + add_lora( + v_proj, + hidden_states, + lora.v.wa_ptr, + lora.v.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + + stack_attn_output = [] + workspace_buffer = torch.empty( + 32 * 1024 * 1024, dtype=torch.int8, device=kvCachePool.device ) - return attn_outputs + prefillTotalSeqLen = prefillBatchPosition.total_seq_len + if prefillTotalSeqLen > 0: + q = ( + q_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_heads, self.head_size) + .contiguous() + ) + k = ( + k_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_key_value_heads, self.head_size) + .contiguous() + ) + v = ( + v_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_key_value_heads, self.head_size) + .contiguous() + ) + + seq_indptr = prefillBatchPosition.seq_indptr.clone() + kv_page_indices = prefillBatchPosition.kv_page_indices.clone() + kv_page_indptr = prefillBatchPosition.kv_page_indptr.clone() + kv_last_page_len = prefillBatchPosition.kv_last_page_len.clone() + + flashinfer.append_paged_kv_cache( + k, + v, + seq_indptr, + kvCachePool.cache_data[self.layer_idx], + kv_page_indices, + kv_page_indptr, + kv_last_page_len, + ) + + prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + + prefill_wrapper.begin_forward( + seq_indptr, + kv_page_indptr, + kv_page_indices, + kv_last_page_len, + self.num_heads, + self.num_key_value_heads, + self.head_size, + ) + + attn_output_prefill = prefill_wrapper.forward( + q, + kvCachePool.cache_data[self.layer_idx], + causal=True, + pos_encoding_mode="ROPE_LLAMA", + ).view(prefillTotalSeqLen, self.hidden_size) + + prefill_wrapper.end_forward() + stack_attn_output.append(attn_output_prefill) + + decodeTotalSeqLen = decodeBatchPosition.total_seq_len + if decodeTotalSeqLen > 0: + q = ( + q_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_heads, self.head_size) + .contiguous() + ) + k = ( + k_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_key_value_heads, self.head_size) + .contiguous() + ) + v = ( + v_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_key_value_heads, self.head_size) + .contiguous() + ) + + flashinfer.append_paged_kv_cache( + k, + v, + decodeBatchPosition.seq_indptr, + kvCachePool.cache_data[self.layer_idx], + decodeBatchPosition.kv_page_indices, + decodeBatchPosition.kv_page_indptr, + decodeBatchPosition.kv_last_page_len, + ) + + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + + decode_wrapper.begin_forward( + decodeBatchPosition.kv_page_indptr, + decodeBatchPosition.kv_page_indices, + decodeBatchPosition.kv_last_page_len, + self.num_heads, + self.num_key_value_heads, + self.head_size, + kvCachePool.page_len, + pos_encoding_mode="ROPE_LLAMA", + ) + + attn_output_decode = decode_wrapper.forward( + q, + kvCachePool.cache_data[self.layer_idx], + pos_encoding_mode="ROPE_LLAMA", + ).view(decodeTotalSeqLen, self.hidden_size) + + decode_wrapper.end_forward() + stack_attn_output.append(attn_output_decode) + + if len(stack_attn_output) == 1: + attn_output = stack_attn_output[0] + else: + attn_output = torch.cat(stack_attn_output, dim=0) + + o = self.o_proj(attn_output) + return o class MistralMLP(nn.Module): @@ -347,7 +481,6 @@ def __init__(self, flashinferWrapper: FlashinferAttentionWrapper, layer_id, conf prefix = f"model.layers.{layer_id}" self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", - flashinferWrapper=flashinferWrapper, config=config, weights=weights, layer_idx=layer_id, diff --git a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py index 9c5a85b6..11e33dc3 100644 --- a/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py +++ b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py @@ -32,11 +32,9 @@ Qwen2Config, PreTrainedModel, ) -from text_generation_server.layers.flashinfer_attention import ( - FlashinferAttentionWrapper, - AttentionRotaryParams, -) + from punica_kernels import ( + add_lora_sgmv_custom_cutlass as add_lora, rms_norm, ) @@ -236,17 +234,44 @@ def __init__(self, prefix, config, weights, layer_idx: int): config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states, loraWeight: BatchedModelLoraWeight | None): + def forward(self, hidden_states, lora: BatchedModelLoraWeight | None): gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate = gate_up_states[:, 0].contiguous() - loraWeight.apply_lora_weight_gate(gate, hidden_states, self.layer_idx) + if lora: + add_lora( + gate, + hidden_states, + lora.gate.wa_ptr, + lora.gate.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) gate = self.act(gate) up = gate_up_states[:, 1].contiguous() - loraWeight.apply_lora_weight_up(up, hidden_states, self.layer_idx) + if lora: + add_lora( + up, + hidden_states, + lora.up.wa_ptr, + lora.up.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) t = gate * up down = self.down_proj(t) - loraWeight.apply_lora_weight_down(down, t, self.layer_idx) + if lora: + add_lora( + down, + hidden_states, + lora.down.wa_ptr, + lora.down.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) return down @@ -299,21 +324,19 @@ def _load_gqa(config, prefix: str, weights): class FlashQwen2Attention(nn.Module): - def __init__( - self, - prefix: str, - flashinferWrapper: FlashinferAttentionWrapper, - config: Qwen2Config, - weights, - layer_idx: int - ): + def __init__(self, prefix: str, config: Qwen2Config, weights, layer_idx: int): super().__init__() - - self.flashinferWrapper = flashinferWrapper - self.rotaryParams = AttentionRotaryParams( - rope_scale=None, rope_theta=config.rope_theta - ) - + self.num_heads = config.num_attention_heads + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_qo_heads = self.num_heads // weights.process_group.size() + self.num_kv_heads = config.num_key_value_heads // weights.process_group.size() + self.config = config + self.hidden_size = config.hidden_size + self.head_dim = self.hidden_size // self.num_heads self.layer_idx = layer_idx self.qkv_proj = load_attention(config, prefix, weights) self.o_proj = TensorParallelRowLinear.load( @@ -323,6 +346,8 @@ def __init__( bias=False, ) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads def forward( self, @@ -330,54 +355,181 @@ def forward( kvCachePool: KvCachePool, prefillBatchPosition: KvCacheBatchPosition, decodeBatchPosition: KvCacheBatchPosition, - loraWeight: BatchedModelLoraWeight | None, + lora: BatchedModelLoraWeight | None, ): - q_dim = ( - self.flashinferWrapper.num_attention_heads * self.flashinferWrapper.head_dim - ) - kv_dim = ( - self.flashinferWrapper.num_key_value_heads * self.flashinferWrapper.head_dim - ) qkv = self.qkv_proj(hidden_states) + q_proj, k_proj, v_proj = qkv.split( - [q_dim, kv_dim, kv_dim], + [ + self.head_dim * self.num_qo_heads, + self.head_dim * self.num_kv_heads, + self.head_dim * self.num_kv_heads, + ], dim=1, ) - q = q_proj.contiguous() - k = k_proj.contiguous() - v = v_proj.contiguous() - - loraWeight.apply_lora_weight_kvq(q, k, v, hidden_states, self.layer_idx) - - attn_outputs_raw = self.flashinferWrapper.computeAttention( - q, - k, - v, - kvCachePool.cache_data[self.layer_idx], - kvCachePool.page_len, - prefillBatchPosition, - decodeBatchPosition, - self.rotaryParams, - ) - attn_outputs = self.o_proj(attn_outputs_raw) - loraWeight.apply_lora_weight_attn( - attn_outputs, attn_outputs_raw, self.layer_idx + + q_proj = q_proj.contiguous() + k_proj = k_proj.contiguous() + v_proj = v_proj.contiguous() + + if lora: + add_lora( + q_proj, + hidden_states, + lora.q.wa_ptr, + lora.q.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + add_lora( + k_proj, + hidden_states, + lora.k.wa_ptr, + lora.k.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + add_lora( + v_proj, + hidden_states, + lora.v.wa_ptr, + lora.v.wb_ptr, + lora.segment, + self.layer_idx, + lora.rank, + ) + + stack_attn_output = [] + workspace_buffer = torch.empty( + 32 * 1024 * 1024, dtype=torch.int8, device=kvCachePool.device ) - return attn_outputs + prefillTotalSeqLen = prefillBatchPosition.total_seq_len + if prefillTotalSeqLen > 0: + # need to revisit if contiguous conversion is the best way + q = ( + q_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_qo_heads, self.head_dim) + .contiguous() + ) + k = ( + k_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_kv_heads, self.head_dim) + .contiguous() + ) + v = ( + v_proj[:prefillTotalSeqLen] + .view(prefillTotalSeqLen, self.num_kv_heads, self.head_dim) + .contiguous() + ) + + seq_indptr = prefillBatchPosition.seq_indptr.clone() + kv_page_indices = prefillBatchPosition.kv_page_indices.clone() + kv_page_indptr = prefillBatchPosition.kv_page_indptr.clone() + kv_last_page_len = prefillBatchPosition.kv_last_page_len.clone() + + flashinfer.append_paged_kv_cache( + k, + v, + seq_indptr, + kvCachePool.cache_data[self.layer_idx], + kv_page_indices, + kv_page_indptr, + kv_last_page_len, + ) + + prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + + prefill_wrapper.begin_forward( + seq_indptr, + kv_page_indptr, + kv_page_indices, + kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + + attn_output_prefill = prefill_wrapper.forward( + q, + kvCachePool.cache_data[self.layer_idx], + causal=True, + pos_encoding_mode="ROPE_LLAMA", # this may need change + ).view(prefillTotalSeqLen, self.hidden_size) + prefill_wrapper.end_forward() + stack_attn_output.append(attn_output_prefill) + + decodeTotalSeqLen = decodeBatchPosition.total_seq_len + if decodeTotalSeqLen > 0: + q = ( + q_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_qo_heads, self.head_dim) + .contiguous() + ) + k = ( + k_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_kv_heads, self.head_dim) + .contiguous() + ) + v = ( + v_proj[prefillTotalSeqLen:] + .view(decodeTotalSeqLen, self.num_kv_heads, self.head_dim) + .contiguous() + ) + + flashinfer.append_paged_kv_cache( + k, + v, + decodeBatchPosition.seq_indptr, + kvCachePool.cache_data[self.layer_idx], + decodeBatchPosition.kv_page_indices, + decodeBatchPosition.kv_page_indptr, + decodeBatchPosition.kv_last_page_len, + ) + + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + decode_wrapper.begin_forward( + decodeBatchPosition.kv_page_indptr, + decodeBatchPosition.kv_page_indices, + decodeBatchPosition.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + kvCachePool.page_len, + pos_encoding_mode="ROPE_LLAMA", + ) + + attn_output_decode = decode_wrapper.forward( + q, + kvCachePool.cache_data[self.layer_idx], + pos_encoding_mode="ROPE_LLAMA", + ).view(decodeTotalSeqLen, self.hidden_size) + + decode_wrapper.end_forward() + stack_attn_output.append(attn_output_decode) + + if len(stack_attn_output) == 1: + attn_outputs = stack_attn_output[0] + else: + attn_outputs = torch.cat(stack_attn_output, dim=0) + + # output projection + o = self.o_proj(attn_outputs) + return o class FlashQwen2Layer(nn.Module): - def __init__( - self, - flashinferWrapper: FlashinferAttentionWrapper, - layer_id, config, weights - ): + def __init__(self, layer_id, config, weights): super().__init__() self.layer_id = layer_id prefix = f"model.layers.{layer_id}" self.self_attn = FlashQwen2Attention( prefix=f"{prefix}.self_attn", - flashinferWrapper=flashinferWrapper, config=config, weights=weights, layer_idx=layer_id, @@ -402,7 +554,7 @@ def forward( kvCachePool: KvCachePool, prefillBatchPosition: KvCacheBatchPosition, decodeBatchPosition: KvCacheBatchPosition, - loraWeight: BatchedModelLoraWeight | None, + lora: BatchedModelLoraWeight | None, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -411,14 +563,14 @@ def forward( kvCachePool, prefillBatchPosition, decodeBatchPosition, - loraWeight, + lora, ) normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output, loraWeight) + mlp_output = self.mlp(normed_attn_res_output, lora) return mlp_output, attn_res @@ -435,19 +587,10 @@ def __init__(self, config, weights): prefix="model.embed_tokens", weights=weights ) # self.embed_tokens.weight *= embed_norm - assert config.num_attention_heads % weights.process_group.size() == 0 - assert config.num_key_value_heads % weights.process_group.size() == 0 - num_attention_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - - flashinferWrapper = FlashinferAttentionWrapper( - num_attention_heads, num_key_value_heads, config.hidden_size - ) self.layers = nn.ModuleList( [ FlashQwen2Layer( - flashinferWrapper, layer_id, config, weights, @@ -461,6 +604,9 @@ def __init__(self, config, weights): self.gradient_checkpointing = False + self.head_size = self.layers[0].self_attn.head_dim + self.num_heads = self.layers[0].self_attn.num_qo_heads + self.num_key_value_heads = self.layers[0].self_attn.num_kv_heads def forward( self, @@ -468,7 +614,7 @@ def forward( kvCachePool: KvCachePool, prefillBatchPosition: KvCacheBatchPosition, decodeBatchPosition: KvCacheBatchPosition, - loraWeight: BatchedModelLoraWeight, + lora: BatchedModelLoraWeight | None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -479,7 +625,7 @@ def forward( kvCachePool, prefillBatchPosition, decodeBatchPosition, - loraWeight, + lora, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -504,14 +650,10 @@ def forward( kvCachePool: KvCachePool, prefillBatchPosition: KvCacheBatchPosition, decodeBatchPosition: KvCacheBatchPosition, - loraWeight: BatchedModelLoraWeight, + lora: BatchedModelLoraWeight | None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( - input_ids, - kvCachePool, - prefillBatchPosition, - decodeBatchPosition, - loraWeight + input_ids, kvCachePool, prefillBatchPosition, decodeBatchPosition, lora ) logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py b/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py index 93eeeff7..83be3de2 100644 --- a/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py @@ -1,14 +1,14 @@ import torch import torch.distributed -from typing import Any, Optional +from typing import Any, TypedDict, Optional from text_generation_server.utils.lora_utils import ModelLoraManager, ModelConfigForLora from text_generation_server.utils.cache_manager_flashinfer import ( - getKvCacheBatchPosition, - KvCacheBatchPosition, + ModelKvCache, KvCachePool, - RequestKvCache, ) from text_generation_server.utils.tokens import ( + StopSequenceCriteria, + StoppingCriteria, FinishReason, ) from text_generation_server.layers.flashinfer_attention import find_padded_head_dim @@ -20,24 +20,119 @@ from opentelemetry import trace from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.types import ( + Batch, Tokens, Generation, GeneratedText, ) +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, +) from text_generation_server.utils.dist import MEMORY_FRACTION from dataclasses import dataclass -from collections.abc import Iterable -from text_generation_server.cache import Cache tracer = trace.get_tracer(__name__) +class TextGenerationChunk(TypedDict): + index: int + token_id: int + text: str + is_stop: bool + + +@dataclass +class FlashinferBatch(CausalLMBatch): + @classmethod + def Empty(cls, batch_id): + return cls( + batch_id=batch_id, + requests=None, + prefix_offsets=None, + read_offsets=None, + next_token_choosers=None, + stopping_criterias=None, + top_n_tokens=None, + top_n_tokens_tensor=None, + input_ids=None, + requests_idx_mapping=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + all_input_ids=None, + input_lengths=None, + max_input_length=None, + padding_right_offset=None, + max_tokens=None, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase = None, + dtype: torch.dtype = None, + device: torch.device = "cuda", + ) -> "CausalLMBatch": + input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + + # Parse batch + for i, r in enumerate(pb.requests): + prompt = r.inputs + + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + tokenized_inputs = tokenizer.encode(prompt) + input_len = len(tokenized_inputs) + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + input_ids.append(tokenized_inputs) + + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=None, + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + all_input_ids=None, + input_lengths=None, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=None, + padding_right_offset=None, + max_tokens=None, + ) + + class RequestContext: def __init__( self, - request_id: str, input_ids: list[int], + lora_id: str, tokenizer, *, temperature: float, @@ -46,12 +141,8 @@ def __init__( top_k: int, maxlen: int, stop_token_id: int, - is_stopped: bool, - request_kv_cache: RequestKvCache, prefill_logprobs: bool = True, - lora_id: str = "empty", ): - self.request_id = request_id self.temperature = temperature self.repetition_penalty = repetition_penalty self.top_p = top_p @@ -81,9 +172,6 @@ def __init__( self.tokenizer = tokenizer self.prefix_offset = 0 self.read_offset = 0 - self.is_stopped = is_stopped - self.prefill_tokens: Optional[Tokens] = None - self.request_kv_cache = request_kv_cache def get_next_token_id(self, logits: torch.Tensor) -> int: if self.logits_processor: @@ -106,34 +194,15 @@ def get_next_token_id(self, logits: torch.Tensor) -> int: def append_token(self, token_id: int): self.output_ids.append(token_id) - def get_stop_reason(self) -> FinishReason: + def is_stop(self) -> FinishReason: if len(self.output_ids) - self.prompt_len >= self.maxlen: return FinishReason.FINISH_REASON_LENGTH if self.output_ids[-1] == self.stop_token_id: return FinishReason.FINISH_REASON_EOS_TOKEN return None - -@dataclass(frozen=True) -class FlashinferBatch: - batch_id: int - is_prefill: bool - request_contexts: List[RequestContext] - - def to_pb(self) -> generate_pb2.CachedBatch: - - max_input_length = max([r.prompt_len for r in self.request_contexts]) - max_decode_tokens = max([r.maxlen for r in self.request_contexts]) - max_tokens = len(self.request_contexts) * (max_input_length + max_decode_tokens) - - return generate_pb2.CachedBatch( - id=self.batch_id, - request_ids=[ - request_context.request_id for request_context in self.request_contexts - ], - size=len(self.request_contexts), - max_tokens=max_tokens, - ) + def is_prefill(self) -> bool: + return len(self.output_ids) == self.prompt_len class FlashinferLM(Model): @@ -144,12 +213,11 @@ def __init__( config: PretrainedConfig, dtype: torch.dtype, device: torch.device, - lora_ids: List[str], + lora_ids: List[str] = None, ): self.device = device self.dtype = dtype self.model_config = config - self.batch_cache = Cache() if ( torch.cuda.is_available() @@ -199,7 +267,7 @@ def __init__( f" Number of Pages to Allocate: {num_pages_to_allocate}" ) - self.kvCachePool = KvCachePool( + kvCachePool = KvCachePool( max_pages=num_pages_to_allocate, num_layers=self.model_config.num_hidden_layers, num_heads=self.model_config.num_key_value_heads, @@ -209,6 +277,7 @@ def __init__( device=device, ) + self.modelKvCache = ModelKvCache(kvCachePool) self.model_config_for_lora = ModelConfigForLora( num_hidden_layers=config.num_hidden_layers, hidden_size=config.hidden_size, @@ -220,8 +289,9 @@ def __init__( self.loraManager = ModelLoraManager(self.model_config_for_lora, dtype) if lora_ids: self.loraManager.set_lora_weights( - lora_ids, self.model_config_for_lora, dtype + lora_ids, self.model_config_for_lora or {}, dtype ) + self.reqctx: dict[int, RequestContext] = {} super(FlashinferLM, self).__init__( model=model, @@ -231,6 +301,13 @@ def __init__( device=device, ) + def _find_padded_head_dim(self, head_dim): + flashInferDimensions = [64, 128, 256] + for dim in flashInferDimensions: + if head_dim <= dim: + return dim + raise ValueError("The head dimension is too large for FlashInfer") + def load_lora_adapters(self, lora_ids: List[str]): self.loraManager.set_lora_weights( lora_ids, @@ -244,152 +321,115 @@ def remove_lora_adapters(self, lora_ids: list[str] = None): def get_lora_adapters(self): return list(self.loraManager.lora_weights_cpu) - def decode_batch( - self, cachedBatchesPb: Iterable[generate_pb2.CachedBatch] - ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int], int]: - start_concat = time.time_ns() - batch = self._convertCachedBatch(cachedBatchesPb) - concat_ns = time.time_ns() - start_concat - generations, next_batch, timings = self.generate_token(batch) - if next_batch: - self.batch_cache.set(next_batch) - return generations, batch, timings, concat_ns - - def prefill_batch( - self, batchPb: generate_pb2.Batch - ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int]]: - batch = self._convertPbBatch(batchPb) - generations, next_batch, timings = self.generate_token(batch) - if next_batch: - self.batch_cache.set(next_batch) - return generations, batch, timings - - def clear_cache(self): - all_batches: List[FlashinferBatch] = self.batch_cache.get_all_values() - for batch in all_batches: - for request_context in batch.request_contexts: - request_context.request_kv_cache.release() - - self.batch_cache.clear() + def has_request(self): + return len(self.reqctx) > 0 - def _find_padded_head_dim(self, head_dim): - flashInferDimensions = [64, 128, 256] - for dim in flashInferDimensions: - if head_dim <= dim: - return dim - raise ValueError("The head dimension is too large for FlashInfer") - - def _convertPbBatch(self, batchPb: generate_pb2.Batch) -> FlashinferBatch: - request_contexts = [] - - for request in batchPb.requests: - prompt = request.inputs - input_ids = self.tokenizer.encode(prompt) - parameters = request.parameters - request_context = RequestContext( - request.id, - input_ids, - self.tokenizer, - temperature=parameters.temperature, - repetition_penalty=parameters.repetition_penalty, - top_p=parameters.top_p, - top_k=parameters.top_k, - maxlen=min(request.stopping_parameters.max_new_tokens, 4096), - stop_token_id=self.tokenizer.eos_token_id, - is_stopped=False, - request_kv_cache=RequestKvCache( - self.kvCachePool, - self.kvCachePool.page_len, - len(input_ids), - ), - prefill_logprobs=request.prefill_logprobs, - lora_id=request.lora_id, - ) - - request_contexts.append(request_context) + @property + def batch_type(self) -> Type[FlashinferBatch]: + return FlashinferBatch - return FlashinferBatch( - batch_id=batchPb.id, is_prefill=True, request_contexts=request_contexts + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) - def _convertCachedBatch( - self, cachedBatchesPb: Iterable[generate_pb2.CachedBatch] - ) -> FlashinferBatch: - batches: List[FlashinferBatch] = [] - for batch_pb in cachedBatchesPb: - batch = self.batch_cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) == 0: - raise ValueError("All batches are empty") - - request_contexts_combined: List[RequestContext] = [] - for batch in batches: - request_contexts_combined.extend(batch.request_contexts) - - return FlashinferBatch( - batch_id=batches[0].batch_id, - is_prefill=False, - request_contexts=request_contexts_combined, - ) + def add_request(self, batch: FlashinferBatch): + ids = [] + for r in range(len(batch.requests)): + id = batch.requests[r].id + # Router sends initial request in each iteration + if id not in self.reqctx: + lora_id = batch.requests[r].lora_id or "empty" + input = batch.input_ids[r] + parameters = batch.requests[r].parameters + stop = batch.requests[r].stopping_parameters + prefill_logprobs = batch.requests[r].prefill_logprobs + + if lora_id not in self.loraManager.lora_weights_cpu: + raise ValueError("Cannot find lora weights", lora_id) + + self.reqctx[id] = RequestContext( + input, + lora_id, + self.tokenizer, + temperature=parameters.temperature, + repetition_penalty=parameters.repetition_penalty, + top_p=parameters.top_p, + top_k=parameters.top_k, + maxlen=min(stop.max_new_tokens, 4096), + stop_token_id=self.tokenizer.eos_token_id, + prefill_logprobs=prefill_logprobs, + ) + ids.append(id) + return ids - def batch_type(self): - return FlashinferBatch + def warmup(self, batch: FlashinferBatch): + pass @tracer.start_as_current_span("generate_token") + @torch.no_grad() def generate_token( self, batch: FlashinferBatch ) -> Tuple[List[Generation], Optional[FlashinferBatch], Tuple[int, int]]: start = time.time_ns() - input_ids, lora_ids, lora_lens = [], [], [] - request_kv_caches = [] - for request_context in batch.request_contexts: - if not request_context.is_stopped: - if batch.is_prefill: - input_ids.extend(request_context.output_ids) - else: - input_ids.append(request_context.output_ids[-1]) - request_kv_caches.append(request_context.request_kv_cache) - if not batch.is_prefill: - request_context.request_kv_cache.increment() - - if lora_ids and lora_ids[-1] == request_context.lora_id: - lora_lens[-1] += 1 - elif request_context.lora_id: - lora_ids.append(request_context.lora_id) - lora_lens.append(1) - - input_ids_tensor = torch.tensor( + + if hasattr(batch, "requests") and batch.requests: + ids = self.add_request(batch) + + if not self.reqctx: + return None, batch, (0, 0) + + reqs = sorted( + self.reqctx.items(), + key=lambda req: (not req[1].is_prefill(), req[1].lora_id), + ) + + input_ids = [] + lora_ids, lora_lens = [], [] + batchKvCache = self.modelKvCache.getOrCreate(batch.batch_id) + prefill_reqIds = [] + decode_reqIds = [] + + for requestId, req in reqs: + req.prefill = req.is_prefill() + if req.prefill: + input_ids.extend(req.output_ids) + prefill_reqIds.append(requestId) + batchKvCache.create(requestId, req.prompt_len) + else: + input_ids.append(req.output_ids[-1]) + decode_reqIds.append(requestId) + batchKvCache.get(requestId).increment() + if lora_ids and lora_ids[-1] == req.lora_id: + lora_lens[-1] += 1 + else: + lora_ids.append(req.lora_id) + lora_lens.append(1) + + input_ids = torch.tensor( input_ids, dtype=torch.long, device=self.device, ) - request_kv_caches_prefill = request_kv_caches if batch.is_prefill else [] - request_kv_caches_decode = [] if batch.is_prefill else request_kv_caches - prefillBatchPosition: KvCacheBatchPosition = getKvCacheBatchPosition( - request_kv_caches_prefill, isPrefill=True, device=self.device + prefillBatchPosition = batchKvCache.getKvCacheBatchPosition( + prefill_reqIds, isPrefill=True ) - decodeBatchPosition: KvCacheBatchPosition = getKvCacheBatchPosition( - request_kv_caches_decode, isPrefill=False, device=self.device + decodeBatchPosition = batchKvCache.getKvCacheBatchPosition( + decode_reqIds, isPrefill=False ) - loraWeights = ( - self.loraManager.get_lora_batched_weights(lora_ids, lora_lens) - if lora_ids - else None - ) + # Forward pass raw_logits, _ = self.model( - input_ids_tensor, - self.kvCachePool, + input_ids, + self.modelKvCache.kvCachePool, prefillBatchPosition, decodeBatchPosition, - loraWeights, + self.loraManager.get_lora_batched_weights(lora_ids, lora_lens), ) start_decode = time.time_ns() + prefill_logits = ( raw_logits[prefillBatchPosition.seq_indptr[1:] - 1] if prefillBatchPosition.total_seq_len > 0 @@ -400,82 +440,58 @@ def generate_token( all_stop = True generations: List[Generation] = [] - num_stopped_requests = 0 - for i, request_context in enumerate(batch.request_contexts): - if request_context.is_stopped: - num_stopped_requests += 1 - continue - next_token_id = request_context.get_next_token_id( - logits[i - num_stopped_requests].unsqueeze(0) - ) - request_context.append_token(next_token_id) + for i, (reqid, reqctx) in enumerate(reqs): + next_token_id = reqctx.get_next_token_id(logits[i].unsqueeze(0)) + reqctx.append_token(next_token_id) # text = reqctx.decode_tokens() # todo: ?? - # special handling for ChatGLM - if "ChatGLM" in str(type(self.model)): - text = self.tokenizer.decode( - [next_token_id], - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - else: - text = self.tokenizer.decode( - next_token_id, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + text = self.tokenizer.decode( + next_token_id, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) - stop_reason = request_context.get_stop_reason() - if stop_reason != None: + is_stop = reqctx.is_stop() + if is_stop != None: output_text = self.tokenizer.decode( - request_context.output_ids[request_context.prompt_len :], + reqctx.output_ids[reqctx.prompt_len :], clean_up_tokenization_spaces=False, skip_special_tokens=False, ) generated_text = GeneratedText( output_text, - len(request_context.output_ids) - request_context.prompt_len + 1, - stop_reason, + len(reqctx.output_ids) - reqctx.prompt_len + 1, + is_stop, None, ) - request_context.is_stopped = True - request_context.request_kv_cache.release() + self.reqctx.pop(reqid) + batchKvCache.release(reqid) else: generated_text = None all_stop = False # Prefill - if batch.is_prefill: # and request_context.prefill_logprobs: + if reqctx.prefill: # and reqctx.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [] # todo - prefill_token_ids = request_context.output_ids[ - : request_context.prompt_len - ] - # special handling for ChatGLM - if "ChatGLM" in str(type(self.model)): - prefill_texts = self.tokenizer.batch_decode( - [prefill_token_ids], - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - else: - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - request_context.prefill_tokens = Tokens( + prefill_token_ids = reqctx.output_ids[: reqctx.prompt_len] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + reqctx.prefill_tokens = Tokens( prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[], ) - request_context.prefix_offset = request_context.prompt_len + reqctx.prefix_offset = reqctx.prompt_len else: - request_context.prefill_tokens = None + reqctx.prefill_tokens = None generation = Generation( - request_context.request_id, - request_context.prefill_tokens, + reqid, + reqctx.prefill_tokens, Tokens( [next_token_id], [0], # prob @@ -492,6 +508,5 @@ def generate_token( decode_ns = time.time_ns() - start_decode # The router stops generation only when batch=None if all_stop: - return generations, None, (forward_ns, decode_ns) - else: - return generations, batch, (forward_ns, decode_ns) + batch = None + return generations, batch, (forward_ns, decode_ns) diff --git a/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py b/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py index 42ab6be9..69c83271 100644 --- a/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py @@ -2,18 +2,20 @@ import torch.distributed from typing import Optional, List -from transformers import AutoTokenizer, AutoConfig from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM from text_generation_server.models_flashinfer.custom_modeling.flashinfer_qwen2_modeling import ( Qwen2Config, FlashQwen2ForCausalLM, ) + from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, ) +from transformers import AutoTokenizer, AutoConfig + class FlashinferQwen2(FlashinferLM): def __init__( diff --git a/server/text_generation_server/server_flashinfer.py b/server/text_generation_server/server_flashinfer.py index 3a636e02..b8127639 100644 --- a/server/text_generation_server/server_flashinfer.py +++ b/server/text_generation_server/server_flashinfer.py @@ -13,8 +13,7 @@ from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models_flashinfer import get_model -from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer import Model, get_model from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor @@ -35,7 +34,7 @@ def exit_gracefully(self, signum, frame): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, - model: FlashinferLM, + model: Model, cache: Cache, quantize: Optional[str], server_urls: List[str], @@ -61,26 +60,40 @@ async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): - self.model.clear_cache() + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() return generate_pb2.ClearCacheResponse() - # async def FilterBatch(self, request, context): - # batch = self.cache.pop(request.batch_id) - # if batch is None: - # raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - # filtered_batch = batch.filter(request.request_ids) - # self.cache.set(filtered_batch) + async def FilterBatch(self, request, context): + batch = self.cache.pop(request.batch_id) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.request_ids) + self.cache.set(filtered_batch) - # return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + max_supported_total_tokens = self.model.warmup(batch) + return generate_pb2.WarmupResponse( - max_supported_total_tokens=request.max_total_tokens + max_supported_total_tokens=max_supported_total_tokens ) async def Prefill(self, request, context): start = time.time_ns() - generations, next_batch, timings = self.model.prefill_batch(request.batch) + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + + generations, next_batch, timings = self.model.generate_token(batch) + self.cache.set(next_batch) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, @@ -91,9 +104,30 @@ async def Prefill(self, request, context): async def Decode(self, request, context): start = time.time_ns() - generations, next_batch, timings, concat_ns = self.model.decode_batch( - request.batches - ) + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + if len(batches) > 1: + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate(batches) + concat_ns = time.time_ns() - start_concat + else: + batch = batches[0] + concat_ns = None + + generations, next_batch, timings = self.model.generate_token(batch) + self.cache.set(next_batch) + return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, diff --git a/server/text_generation_server/utils/cache_manager_flashinfer.py b/server/text_generation_server/utils/cache_manager_flashinfer.py index 7d34edcf..f0e27d49 100644 --- a/server/text_generation_server/utils/cache_manager_flashinfer.py +++ b/server/text_generation_server/utils/cache_manager_flashinfer.py @@ -80,6 +80,51 @@ def release(self): self.is_released = True +class BatchKvCache: + def __init__(self, kvCachePool: KvCachePool, page_len, device): + self.kvCachePool = kvCachePool + self.page_len = page_len + self.device = device + self.kvCacheDict: dict[int, RequestKvCache] = {} + + def get(self, req_id): + return self.kvCacheDict.get(req_id) + + def create(self, req_id, seq_init_len): + self.kvCacheDict[req_id] = RequestKvCache( + self.kvCachePool, self.page_len, seq_init_len + ) + return self.kvCacheDict[req_id] + + def release(self, req_id): + self.kvCacheDict[req_id].release() + del self.kvCacheDict[req_id] + + def increment(self): + for kvCache in self.kvCacheDict.values(): + kvCache.increment() + + def setRequestOrder(self, requestIds: List[int]): + self.requestIds = requestIds + + def getKvCacheBatchPosition(self, requestIds: List[int], isPrefill: bool): + kv_page_indices_list = [] + kv_page_indptr_list = [] + seq_indptr_list = [] + kv_last_page_len_list = [] + seq_lens_list = [] + cum_pages = 0 + cum_seq_len = 0 + for requestId in requestIds: + kvCache = self.kvCacheDict[requestId] + kv_page_indices_list.extend(kvCache.kv_page_indices) + kv_page_indptr_list.append(cum_pages) + seq_indptr_list.append(cum_seq_len) + kv_last_page_len_list.append(kvCache.kv_last_page_len) + seq_lens_list.append(kvCache.kv_len) + cum_pages += len(kvCache.kv_page_indices) + cum_seq_len += kvCache.kv_len if isPrefill else 1 + def getKvCacheBatchPosition( request_kv_caches: List[RequestKvCache], isPrefill: bool, device: torch.device ) -> KvCacheBatchPosition: @@ -94,31 +139,43 @@ def getKvCacheBatchPosition( kv_page_indices_list.extend(request_kv_cache.kv_page_indices) kv_page_indptr_list.append(cum_pages) seq_indptr_list.append(cum_seq_len) - kv_last_page_len_list.append(request_kv_cache.kv_last_page_len) - seq_lens_list.append(request_kv_cache.kv_len) - cum_pages += len(request_kv_cache.kv_page_indices) - cum_seq_len += request_kv_cache.kv_len if isPrefill else 1 - - kv_page_indptr_list.append(cum_pages) - seq_indptr_list.append(cum_seq_len) - kv_page_indices = torch.tensor( - kv_page_indices_list, dtype=torch.int32, device=device - ) - kv_page_indptr = torch.tensor(kv_page_indptr_list, dtype=torch.int32, device=device) - kv_last_page_len = torch.tensor( - kv_last_page_len_list, dtype=torch.int32, device=device - ) - seq_indptr = torch.tensor(seq_indptr_list, dtype=torch.int32, device=device) - seq_lens = torch.tensor( - seq_lens_list, - dtype=torch.int32, - device=device, - ) - return KvCacheBatchPosition( - seq_indptr=seq_indptr, - kv_page_indptr=kv_page_indptr, - kv_page_indices=kv_page_indices, - kv_last_page_len=kv_last_page_len, - seq_lens=seq_lens, - total_seq_len=cum_seq_len, - ) + kv_page_indices = torch.tensor( + kv_page_indices_list, dtype=torch.int32, device=self.device + ) + kv_page_indptr = torch.tensor( + kv_page_indptr_list, dtype=torch.int32, device=self.device + ) + kv_last_page_len = torch.tensor( + kv_last_page_len_list, dtype=torch.int32, device=self.device + ) + seq_indptr = torch.tensor( + seq_indptr_list, dtype=torch.int32, device=self.device + ) + seq_lens = torch.tensor( + seq_lens_list, + dtype=torch.int32, + device=self.device, + ) + return KvCacheBatchPosition( + seq_indptr=seq_indptr, + kv_page_indptr=kv_page_indptr, + kv_page_indices=kv_page_indices, + kv_last_page_len=kv_last_page_len, + seq_lens=seq_lens, + total_seq_len=cum_seq_len, + ) + + +class ModelKvCache: + def __init__(self, kvCachePool: KvCachePool): + self.kvCachePool = kvCachePool + self.device = kvCachePool.device + self.page_len = kvCachePool.page_len + self.batchKvCacheDict: dict[int, BatchKvCache] = {} + + def getOrCreate(self, batch_id): + batchKvCache = self.batchKvCacheDict.get(batch_id) or BatchKvCache( + self.kvCachePool, self.page_len, self.device + ) + self.batchKvCacheDict[batch_id] = batchKvCache + return batchKvCache diff --git a/server/text_generation_server/utils/lora_utils.py b/server/text_generation_server/utils/lora_utils.py index bb1984be..bde45ca7 100644 --- a/server/text_generation_server/utils/lora_utils.py +++ b/server/text_generation_server/utils/lora_utils.py @@ -351,8 +351,8 @@ def get_lora_batched_weights( self, lora_ids: List[str], lora_lens: List[int] ) -> BatchedModelLoraWeight: assert len(lora_ids) <= self.lora_cap - # for lora_id in lora_ids: - # assert lora_id in self.lora_weights_cpu + for lora_id in lora_ids: + assert lora_id in self.lora_weights_cpu loraweights = [] for lora_id in lora_ids: if lora_id and lora_id not in self.lora_weights_gpu: