Skip to content

Commit

Permalink
Update embeddings signature so inputs and outputs list align
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Feb 20, 2025
1 parent 984a803 commit 6e8d7e6
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 31 deletions.
6 changes: 3 additions & 3 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated

from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
Expand Down Expand Up @@ -481,12 +481,12 @@ async def chat_completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
4 changes: 2 additions & 2 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, AsyncGenerator, Dict, List, Optional

from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
Expand Down Expand Up @@ -214,7 +214,7 @@ async def completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/inline/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
Expand Down Expand Up @@ -230,5 +231,5 @@ async def _generate_and_convert_to_openai_compat():
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk

async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from botocore.client import BaseClient

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -162,7 +162,7 @@ async def _get_params_for_chat_completion(self, request: ChatCompletionRequest)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from cerebras.cloud.sdk import AsyncCerebras

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Expand Down Expand Up @@ -172,6 +172,6 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from openai import OpenAI

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_params(self, request: ChatCompletionRequest) -> dict:

async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
model_id: str,
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fireworks.client import Fireworks

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -232,7 +232,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
EmbeddingsResponse,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
Expand Down Expand Up @@ -140,7 +141,7 @@ async def chat_completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/remote/inference/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EmbeddingsResponse,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
Expand Down Expand Up @@ -117,7 +118,7 @@ async def completion(
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
3 changes: 2 additions & 1 deletion llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
Expand Down Expand Up @@ -332,7 +333,7 @@ async def _generate_and_convert_to_openai_compat():
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ async def chat_completion(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
Expand Down Expand Up @@ -119,6 +120,6 @@ def _get_params(self, request: ChatCompletionRequest) -> dict:
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
26 changes: 23 additions & 3 deletions llama_stack/providers/remote/inference/sambanova/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,36 @@
# the root directory of this source tree.

import json
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional

from openai import OpenAI

from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
Expand Down Expand Up @@ -119,7 +139,7 @@ async def _to_async_generator():
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from huggingface_hub import AsyncInferenceClient, HfApi

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -268,7 +268,7 @@ async def _get_params(self, request: ChatCompletionRequest) -> dict:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/remote/inference/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from together import Together

from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -219,7 +219,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
Expand Down
10 changes: 8 additions & 2 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI

from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -369,7 +375,7 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/utils/inference/embedding_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from llama_stack.apis.inference import (
EmbeddingsResponse,
InterleavedContent,
InterleavedContentItem,
ModelStore,
)

Expand All @@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
Expand Down

0 comments on commit 6e8d7e6

Please sign in to comment.