Skip to content

Commit

Permalink
Support non-llama models for inference providers
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Feb 21, 2025
1 parent 0fe0717 commit 1b64573
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 20 deletions.
7 changes: 3 additions & 4 deletions llama_stack/providers/remote/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,14 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ
input_dict = {}
media_present = request_has_media(request)

llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,17 @@ async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequ

input_dict = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
Expand Down
7 changes: 3 additions & 4 deletions llama_stack/providers/remote/inference/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,12 @@ async def _to_async_generator():
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present:
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
Expand Down
18 changes: 9 additions & 9 deletions llama_stack/providers/utils/inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,28 @@ async def register_model(self, model: Model) -> Model:
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)

if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
if model.metadata.get("llama_model") is None:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
"Please specify a llama_model in metadata or use a supported model identifier"
)
llama_model = model.metadata.get("llama_model")
if llama_model is None:
return model

existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != model.metadata["llama_model"]:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)

return model
10 changes: 9 additions & 1 deletion tests/client-sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def provider_data():


@pytest.fixture(scope="session")
def llama_stack_client(provider_data):
def llama_stack_client(provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
Expand All @@ -95,6 +95,14 @@ def llama_stack_client(provider_data):
)
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")

inference_providers = [
p.provider_id
for p in client.providers.list()
if p.api == "inference" and p.provider_id != "sentence-transformers"
]
assert len(inference_providers) > 0, "No inference providers found"
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
return client


Expand Down

0 comments on commit 1b64573

Please sign in to comment.