Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LlamaStackLibraryClient work correctly #581

Merged
merged 4 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llama_stack/distribution/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ApiInput(BaseModel):


def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
config_providers: Dict[str, List[Provider]],
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
Expand Down Expand Up @@ -92,11 +92,11 @@ def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)

cprint(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}",
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
"yellow",
)
for special_dep in special_deps:
cprint(f"\tpip install {special_dep}", "yellow")
cprint(f"pip install {special_dep}", "yellow")
print()


Expand Down
272 changes: 272 additions & 0 deletions llama_stack/distribution/library_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import inspect
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar

import yaml
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
from pydantic import TypeAdapter
from rich.console import Console

from termcolor import cprint

from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
replace_env_vars,
)

T = TypeVar("T")


def stream_across_asyncio_run_boundary(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a most crucial part of this PR. without this you cannot make the "non-async" generators (which is a necessary mode for our client-sdk) work properly. you must be able to do for chunk in inference.chat_completion() since that's what synchronous generators are about. to bridge a sync generator to the async generators we have in our server-side code we need to intermediate via a thread pool.

async_gen_maker,
pool_executor: ThreadPoolExecutor,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()

async def consumer():
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
try:
async for item in gen:
result_queue.put(item)
except Exception as e:
print(f"Error in generator {e}")
result_queue.put(e)
except asyncio.CancelledError:
return
finally:
result_queue.put(StopIteration)
stop_event.set()

def run_async():
# Run our own loop to avoid double async generator cleanup which is done
# by asyncio.run()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
task = loop.create_task(consumer())
loop.run_until_complete(task)
finally:
# Handle pending tasks like a generator's athrow()
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.close()

future = pool_executor.submit(run_async)

try:
# yield results as they come in
while not stop_event.is_set() or not result_queue.empty():
try:
item = result_queue.get(timeout=0.1)
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
except queue.Empty:
continue
finally:
future.result()


class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)

def initialize(self):
asyncio.run(self.async_client.initialize())

def get(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.get(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.get(*args, **kwargs))

def post(self, *args, **kwargs):
if kwargs.get("stream"):
return stream_across_asyncio_run_boundary(
lambda: self.async_client.post(*args, **kwargs),
self.pool_executor,
)
else:
return asyncio.run(self.async_client.post(*args, **kwargs))


class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
):
super().__init__()

if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
config = parse_and_maybe_upgrade_config(config_dict)
else:
# template
config = get_stack_run_config_from_template(config_path_or_template_name)

self.config_path_or_template_name = config_path_or_template_name
self.config = config
self.custom_provider_registry = custom_provider_registry

async def initialize(self):
try:
self.impls = await construct_stack(
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as e:
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",
)
print_pip_install_help(self.config.providers)
raise e

console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))

endpoints = get_all_api_endpoints()
endpoint_impls = {}
for api, api_endpoints in endpoints.items():
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func

self.endpoint_impls = endpoint_impls

async def get(
self,
path: str,
*,
stream=False,
**kwargs,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")

if stream:
return self._call_streaming(path, "GET")
else:
return await self._call_non_streaming(path, "GET")

async def post(
self,
path: str,
*,
body: dict = None,
stream=False,
**kwargs,
):
if not self.endpoint_impls:
raise ValueError("Client not initialized")

if stream:
return self._call_streaming(path, "POST", body)
else:
return await self._call_non_streaming(path, "POST", body)

async def _call_non_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")

body = self._convert_body(path, body)
return await func(**body)

async def _call_streaming(self, path: str, method: str, body: dict = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")

body = self._convert_body(path, body)
async for chunk in await func(**body):
yield chunk

def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body:
return {}

func = self.endpoint_impls[path]
sig = inspect.signature(func)

# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(
param.annotation, value
)
return converted_body

def _convert_param(self, annotation: Any, value: Any) -> Any:
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
return value

origin = get_origin(annotation)
if origin is list:
item_type = get_args(annotation)[0]
try:
return [self._convert_param(item_type, item) for item in value]
except Exception:
print(f"Error converting list {value}")
return value

elif origin is dict:
key_type, val_type = get_args(annotation)
try:
return {k: self._convert_param(val_type, v) for k, v in value.items()}
except Exception:
print(f"Error converting dict {value}")
return value

try:
# Handle Pydantic models and discriminated unions
return TypeAdapter(annotation).validate_python(value)
except Exception as e:
cprint(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
"yellow",
)
return value
103 changes: 103 additions & 0 deletions llama_stack/distribution/tests/library_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import os

from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
from llama_stack_client.lib.inference.event_logger import EventLogger
from llama_stack_client.types import UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig


def main(config_path: str):
client = LlamaStackAsLibraryClient(config_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://llama-stack.readthedocs.io/en/latest/distributions/importing_as_library.html

does this reference to LlamaStackDirectClient also needs to be updated?

client.initialize()

models = client.models.list()
print("\nModels:")
for model in models:
print(model)

if not models:
print("No models found, skipping chat completion test")
return

model_id = models[0].identifier
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=False,
)
print("\nChat completion response (non-stream):")
print(response)

response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,
stream=True,
)

print("\nChat completion response (stream):")
for log in EventLogger().log(response):
log.print()

print("\nAgent test:")
agent_config = AgentConfig(
model=model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": "greedy",
"temperature": 1.0,
"top_p": 0.9,
},
tools=(
[
{
"type": "brave_search",
"engine": "brave",
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
}
]
if os.getenv("BRAVE_SEARCH_API_KEY")
else []
),
tool_choice="auto",
tool_prompt_format="json",
input_shields=[],
output_shields=[],
enable_session_persistence=False,
)
agent = Agent(client, agent_config)
user_prompts = [
"Hello",
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
]

session_id = agent.create_session("test-session")

for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)

for log in AgentEventLogger().log(response):
log.print()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
main(args.config_path)
Loading
Loading