Skip to content

Commit

Permalink
refact embedding/ranking/llm request/response by referring to openai …
Browse files Browse the repository at this point in the history
…format (#405)

Co-authored-by: sys-lpot-val <[email protected]>
Co-authored-by: lvliang-intel <[email protected]>
  • Loading branch information
3 people authored Aug 12, 2024
1 parent 761f7e0 commit 7287caa
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 139 deletions.
1 change: 1 addition & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ async def handle_request(self, request: Request):
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
Expand Down
268 changes: 226 additions & 42 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,243 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0


class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]


class StreamOptions(BaseModel):
# refer https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L105
include_usage: Optional[bool]


class FunctionDefinition(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(BaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition


class ChatCompletionNamedFunction(BaseModel):
name: str


class ChatCompletionNamedToolChoiceParam(BaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"


class TokenCheckRequestItem(BaseModel):
model: str
prompt: str
max_tokens: int


class TokenCheckRequest(BaseModel):
prompts: List[TokenCheckRequestItem]


class TokenCheckResponseItem(BaseModel):
fits: bool
tokenCount: int
contextLength: int


class TokenCheckResponse(BaseModel):
prompts: List[TokenCheckResponseItem]


class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: Optional[str] = None
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field("float", pattern="^(float|base64)$")
dimensions: Optional[int] = None
user: Optional[str] = None

# define
request_type: Literal["embedding"] = "embedding"


class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]


class EmbeddingResponse(BaseModel):
object: str = "list"
model: Optional[str] = None
data: List[EmbeddingResponseData]
usage: Optional[UsageInfo] = None


class RetrievalRequest(BaseModel):
embedding: Union[EmbeddingResponse, List[float]] = None
input: Optional[str] = None # search_type maybe need, like "mmr"
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2

# define
request_type: Literal["retrieval"] = "retrieval"


class RetrievalResponseData(BaseModel):
text: str
metadata: Optional[Dict[str, Any]] = None


class RetrievalResponse(BaseModel):
retrieved_docs: List[RetrievalResponseData]


class RerankingRequest(BaseModel):
input: str
retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]], List[str]]
top_n: int = 1

# define
request_type: Literal["reranking"] = "reranking"


class RerankingResponseData(BaseModel):
text: str
score: Optional[float] = 0.0


class RerankingResponse(BaseModel):
reranked_docs: List[RerankingResponseData]


class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: Union[
str,
List[Dict[str, str]],
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
]
model: Optional[str] = "Intel/neural-chat-7b-v3-3"
temperature: Optional[float] = 0.01
top_p: Optional[float] = 0.95
top_k: Optional[int] = 10
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = 16 # use https://platform.openai.com/docs/api-reference/completions/create
n: Optional[int] = 1
max_tokens: Optional[int] = 1024
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
service_tier: Optional[str] = None
stop: Union[str, List[str], None] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.03
frequency_penalty: Optional[float] = 0.0
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0 # vllm default 0.7
top_p: Optional[float] = None # openai default 1.0, but tgi needs `top_p` must be > 0.0 and < 1.0, set None
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none"
parallel_tool_calls: Optional[bool] = True
user: Optional[str] = None

# Ordered by official OpenAI API documentation
# default values are same with
# https://platform.openai.com/docs/api-reference/completions/create
best_of: Optional[int] = 1
suffix: Optional[str] = None

# vllm reference: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L130
repetition_penalty: Optional[float] = 1.0

# tgi reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate
# some tgi parameters in use
# default values are same with
# https://github.com/huggingface/text-generation-inference/blob/main/router/src/lib.rs#L190
# max_new_tokens: Optional[int] = 100 # Priority use openai
top_k: Optional[int] = None
# top_p: Optional[float] = None # Priority use openai
typical_p: Optional[float] = None
# repetition_penalty: Optional[float] = None

# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message " "if they belong to the same role."
),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: Optional[bool] = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to False (as is the "
"default)."
),
)
documents: Optional[Union[List[Dict[str, str]], List[str]]] = Field(
default=None,
description=(
"A list of dicts representing documents that will be accessible to "
"the model if it is performing RAG (retrieval-augmented generation)."
" If the template does not support RAG, this argument will have no "
"effect. We recommend that each document should be a dict containing "
'"title" and "text" keys.'
),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead. We recommend that the template contains {context} and {question} for rag,"
"or only contains {question} for chat completion without rag."
),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."),
)
# doc: end-chat-completion-extra-params

# embedding
input: Union[List[int], List[List[int]], str, List[str]] = None # user query/question from messages[-]
encoding_format: Optional[str] = Field("float", pattern="^(float|base64)$")
dimensions: Optional[int] = None
embedding: Union[EmbeddingResponse, List[float]] = Field(default_factory=list)

# retrieval
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2
retrieved_docs: Union[List[RetrievalResponseData], List[Dict[str, Any]]] = Field(default_factory=list)

# reranking
top_n: int = 1
reranked_docs: Union[List[RerankingResponseData], List[Dict[str, Any]]] = Field(default_factory=list)

# define
request_type: Literal["chat"] = "chat"


class AudioChatCompletionRequest(BaseModel):
audio: str
Expand Down Expand Up @@ -110,41 +329,6 @@ class ChatCompletionStreamResponse(BaseModel):
choices: List[ChatCompletionResponseStreamChoice]


class TokenCheckRequestItem(BaseModel):
model: str
prompt: str
max_tokens: int


class TokenCheckRequest(BaseModel):
prompts: List[TokenCheckRequestItem]


class TokenCheckResponseItem(BaseModel):
fits: bool
tokenCount: int
contextLength: int


class TokenCheckResponse(BaseModel):
prompts: List[TokenCheckResponseItem]


class EmbeddingsRequest(BaseModel):
model: Optional[str] = None
engine: Optional[str] = None
input: Union[str, List[Any]]
user: Optional[str] = None
encoding_format: Optional[str] = None


class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[Dict[str, Any]]
model: str
usage: UsageInfo


class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[Any]]
Expand Down
38 changes: 36 additions & 2 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
from typing import Dict, List, Optional, Union

import numpy as np
from docarray import BaseDoc, DocList
from docarray.documents import AudioDoc
from docarray.typing import AudioUrl
from pydantic import Field, conint, conlist
from pydantic import Field, conint, conlist, field_validator


class TopologyInfo:
Expand Down Expand Up @@ -88,6 +88,30 @@ class LLMParamsDoc(BaseDoc):
repetition_penalty: float = 1.03
streaming: bool = True

chat_template: Optional[str] = Field(
default=None,
description=(
"A template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead. We recommend that the template contains {context} and {question} for rag,"
"or only contains {question} for chat completion without rag."
),
)
documents: Optional[Union[List[Dict[str, str]], List[str]]] = Field(
default=[],
description=(
"A list of dicts representing documents that will be accessible to "
"the model if it is performing RAG (retrieval-augmented generation)."
" If the template does not support RAG, this argument will have no "
"effect. We recommend that each document should be a dict containing "
'"title" and "text" keys.'
),
)

@field_validator("chat_template")
def chat_template_must_contain_variables(cls, v):
return v


class LLMParams(BaseDoc):
max_new_tokens: int = 1024
Expand All @@ -98,6 +122,16 @@ class LLMParams(BaseDoc):
repetition_penalty: float = 1.03
streaming: bool = True

chat_template: Optional[str] = Field(
default=None,
description=(
"A template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead. We recommend that the template contains {context} and {question} for rag,"
"or only contains {question} for chat completion without rag."
),
)


class RAGASParams(BaseDoc):
questions: DocList[TextDoc]
Expand Down
Loading

0 comments on commit 7287caa

Please sign in to comment.