Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openai embedding API #997

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def post_init(self):
if is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
self.sampling_params = {"max_new_tokens": 0}
if self.sampling_params is None:
self.sampling_params = {"max_new_tokens": 1}
else:
# support select operation
self.batch_size = (
Expand All @@ -205,9 +206,10 @@ def post_init(self):
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
self.sampling_params = [
{"max_new_tokens": 0} for _ in range(self.batch_size)
]
if self.sampling_params is None:
self.sampling_params = [
{"max_new_tokens": 1} for _ in range(self.batch_size)
]


@dataclass
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ async def _handle_single_request(
):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()

if req.finished():
Expand Down
32 changes: 21 additions & 11 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
Expand All @@ -52,6 +52,7 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
Expand Down Expand Up @@ -1016,10 +1017,10 @@ async def generate_stream_resp():
def v1_embedding_request(all_requests, tokenizer_manager):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].prompt)
first_prompt_type = type(all_requests[0].input)

for request in all_requests:
prompt = request.prompt
prompt = request.input
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
Expand All @@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return adapted_request, all_requests


def v1_embedding_response(request, ret, to_file=False):
response = []
def v1_embedding_response(ret, model_path, to_file=False):
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
response.append(
EmbeddingResponse(
embedding_objects.append(
EmbeddingObject(
embedding=ret[idx]["embedding"],
index=idx,
embedding=ret[idx],
object="embedding",
)
)
return response
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]

return EmbeddingResponse(
data=embedding_objects,
model=model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)


async def v1_embeddings(tokenizer_manager, raw_request: Request):
Expand All @@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]

response = v1_embedding_response(request, ret)
response = v1_embedding_response(ret, tokenizer_manager.model_path)

return response

Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user: Optional[str] = None


class EmbeddingResponse(BaseModel):
index: str
embedding: List[float] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"


class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None
9 changes: 8 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
Expand Down Expand Up @@ -174,6 +175,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)


@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(tokenizer_manager, raw_request)
return response


@app.get("/v1/models")
def available_models():
"""Show available models."""
Expand Down Expand Up @@ -406,7 +413,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
max_new_tokens = 8 if model_info["is_generation"] else 1
try:
for _ in range(server_args.dp_size):
res = requests.post(
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
suites = {
"minimal": [
"test_eval_accuracy.py",
"test_embedding_openai_server.py",
"test_openai_server.py",
"test_vision_openai_server.py",
"test_chunked_prefill.py",
Expand Down
87 changes: 87 additions & 0 deletions test/srt/test_embedding_openai_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import time
import unittest

import openai

from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import EmbeddingObject
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import popen_launch_server


class TestOpenAIServer(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.model = "intfloat/e5-mistral-7b-instruct"
cls.base_url = "http://127.0.0.1:8157"
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))

if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_prompts = len(prompt_arg)
else:
prompt_arg = prompt_input
num_prompts = 1

response = client.embeddings.create(
input=prompt_arg,
model=self.model,
)

assert len(response.data) == num_prompts
assert isinstance(response.data, list)
assert response.data[0].embedding
assert response.data[0].index is not None
assert response.data[0].object == "embedding"
assert response.model == self.model
assert response.object == "list"
assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert (
response.usage.total_tokens == num_prompt_tokens
), f"{response.usage.total_tokens} vs {num_prompt_tokens}"

def run_batch(self):
# FIXME not implemented
pass

def test_embedding(self):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for use_list_input in [False]:
for token_input in [False, True]:
self.run_embedding(use_list_input, token_input)

def test_batch(self):
self.run_batch()


if __name__ == "__main__":
unittest.main(warnings="ignore")

# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()
Loading