From 241be3af2a3f1f5a7a71a33c37546ed4e091572e Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 8 Aug 2024 07:15:29 +0000 Subject: [PATCH] add io struct for embedding models --- python/sglang/srt/managers/io_struct.py | 60 ++++++++++++++++++++ python/sglang/srt/models/llama2.py | 4 +- python/sglang/srt/openai_api/adapter.py | 70 +++++++++++++++++++++++- python/sglang/srt/openai_api/protocol.py | 16 ++++++ 4 files changed, 146 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5aa767d5872..e4c3040c9a9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,6 +22,8 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union +import torch + from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams @@ -166,6 +168,56 @@ class TokenizedGenerateReqInput: stream: bool +@dataclass +class EmbeddingReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + text: Optional[Union[List[str], str]] = None + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + # Dummy sampling params for compatibility + sampling_params: Union[List[Dict], Dict] = None + + def post_init(self): + if (self.text is None and self.input_ids is None) or ( + self.text is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) + self.is_single = is_single + + if is_single: + if self.rid is None: + self.rid = uuid.uuid4().hex + self.sampling_params = {"max_new_tokens": 0} + else: + # support select operation + self.batch_size = ( + len(self.text) if self.text is not None else len(self.input_ids) + ) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + self.sampling_params = [ + {"max_new_tokens": 0} for _ in range(self.batch_size) + ] + + +@dataclass +class TokenizedEmbeddingReqInput: + rid: str + input_text: str + input_ids: List[int] + sampling_params: SamplingParams + + @dataclass class BatchTokenIDOut: rids: List[str] @@ -187,6 +239,14 @@ class BatchStrOut: finished_reason: List[BaseFinishReason] +@dataclass +class BatchEmbeddingOut: + rids: List[str] + embeddings: List[List[float]] + meta_info: List[Dict] + finished_reason: List[BaseFinishReason] + + @dataclass class FlushCacheReq: pass diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 7a6d570a457..20f8970f7d0 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -39,7 +39,7 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -310,7 +310,7 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: + ) -> LogitProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2b6fd961a74..c2cdfefe353 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -52,6 +52,8 @@ CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, FileDeleteResponse, FileRequest, @@ -357,7 +359,6 @@ def iter_file(): def v1_generate_request(all_requests): - prompts = [] sampling_params_list = [] return_logprobs = [] @@ -648,7 +649,6 @@ async def generate_stream_resp(): def v1_chat_generate_request(all_requests, tokenizer_manager): - input_ids = [] sampling_params_list = [] image_data_list = [] @@ -961,6 +961,72 @@ async def generate_stream_resp(): return response +def v1_embedding_request(all_requests, tokenizer_manager): + prompts = [] + sampling_params_list = [] + first_prompt_type = type(all_requests[0].prompt) + + for request in all_requests: + prompt = request.prompt + assert ( + type(prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + prompts.append(prompt) + + if len(all_requests) == 1: + prompt = prompts[0] + if isinstance(prompt, str) or isinstance(prompt[0], str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + else: + if isinstance(prompts[0], str) or isinstance(propmt[0][0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + + adapted_request = EmbeddingReqInput( + **prompt_kwargs, + ) + + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_embedding_response(request, ret, to_file=False): + response = [] + for idx, ret_item in enumerate(ret): + response.append( + EmbeddingResponse( + index=idx, + embedding=ret[idx], + object="embedding", + ) + ) + return response + + +async def v1_embeddings(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [EmbeddingRequest(**request_json)] + adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager) + + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = v1_embedding_response(request, ret) + + return response + + def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 0e9b902231d..3a91c12e86b 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -294,3 +294,19 @@ class ChatCompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings/create + input: Union[List[int], List[List[int]], str, List[str]] + model: str + encoding_format: str = "float" + dimensions: int = None + user: Optional[str] = None + + +class EmbeddingResponse(BaseModel): + index: str + embedding: List[float] = None + object: str = "embedding"