-
Notifications
You must be signed in to change notification settings - Fork 889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make LlamaStackLibraryClient work correctly #581
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
import asyncio | ||
import inspect | ||
import queue | ||
import threading | ||
from concurrent.futures import ThreadPoolExecutor | ||
from pathlib import Path | ||
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar | ||
|
||
import yaml | ||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN | ||
from pydantic import TypeAdapter | ||
from rich.console import Console | ||
|
||
from termcolor import cprint | ||
|
||
from llama_stack.distribution.build import print_pip_install_help | ||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config | ||
from llama_stack.distribution.resolver import ProviderRegistry | ||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints | ||
from llama_stack.distribution.stack import ( | ||
construct_stack, | ||
get_stack_run_config_from_template, | ||
replace_env_vars, | ||
) | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def stream_across_asyncio_run_boundary( | ||
async_gen_maker, | ||
pool_executor: ThreadPoolExecutor, | ||
) -> Generator[T, None, None]: | ||
result_queue = queue.Queue() | ||
stop_event = threading.Event() | ||
|
||
async def consumer(): | ||
# make sure we make the generator in the event loop context | ||
gen = await async_gen_maker() | ||
try: | ||
async for item in gen: | ||
result_queue.put(item) | ||
except Exception as e: | ||
print(f"Error in generator {e}") | ||
result_queue.put(e) | ||
except asyncio.CancelledError: | ||
return | ||
finally: | ||
result_queue.put(StopIteration) | ||
stop_event.set() | ||
|
||
def run_async(): | ||
# Run our own loop to avoid double async generator cleanup which is done | ||
# by asyncio.run() | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
try: | ||
task = loop.create_task(consumer()) | ||
loop.run_until_complete(task) | ||
finally: | ||
# Handle pending tasks like a generator's athrow() | ||
pending = asyncio.all_tasks(loop) | ||
if pending: | ||
loop.run_until_complete( | ||
asyncio.gather(*pending, return_exceptions=True) | ||
) | ||
loop.close() | ||
|
||
future = pool_executor.submit(run_async) | ||
|
||
try: | ||
# yield results as they come in | ||
while not stop_event.is_set() or not result_queue.empty(): | ||
try: | ||
item = result_queue.get(timeout=0.1) | ||
if item is StopIteration: | ||
break | ||
if isinstance(item, Exception): | ||
raise item | ||
yield item | ||
except queue.Empty: | ||
continue | ||
finally: | ||
future.result() | ||
|
||
|
||
class LlamaStackAsLibraryClient(LlamaStackClient): | ||
def __init__( | ||
self, | ||
config_path_or_template_name: str, | ||
custom_provider_registry: Optional[ProviderRegistry] = None, | ||
): | ||
super().__init__() | ||
self.async_client = AsyncLlamaStackAsLibraryClient( | ||
config_path_or_template_name, custom_provider_registry | ||
) | ||
self.pool_executor = ThreadPoolExecutor(max_workers=4) | ||
|
||
def initialize(self): | ||
asyncio.run(self.async_client.initialize()) | ||
|
||
def get(self, *args, **kwargs): | ||
if kwargs.get("stream"): | ||
return stream_across_asyncio_run_boundary( | ||
lambda: self.async_client.get(*args, **kwargs), | ||
self.pool_executor, | ||
) | ||
else: | ||
return asyncio.run(self.async_client.get(*args, **kwargs)) | ||
|
||
def post(self, *args, **kwargs): | ||
if kwargs.get("stream"): | ||
return stream_across_asyncio_run_boundary( | ||
lambda: self.async_client.post(*args, **kwargs), | ||
self.pool_executor, | ||
) | ||
else: | ||
return asyncio.run(self.async_client.post(*args, **kwargs)) | ||
|
||
|
||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): | ||
def __init__( | ||
self, | ||
config_path_or_template_name: str, | ||
custom_provider_registry: Optional[ProviderRegistry] = None, | ||
): | ||
super().__init__() | ||
|
||
if config_path_or_template_name.endswith(".yaml"): | ||
config_path = Path(config_path_or_template_name) | ||
if not config_path.exists(): | ||
raise ValueError(f"Config file {config_path} does not exist") | ||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text())) | ||
config = parse_and_maybe_upgrade_config(config_dict) | ||
else: | ||
# template | ||
config = get_stack_run_config_from_template(config_path_or_template_name) | ||
|
||
self.config_path_or_template_name = config_path_or_template_name | ||
self.config = config | ||
self.custom_provider_registry = custom_provider_registry | ||
|
||
async def initialize(self): | ||
try: | ||
self.impls = await construct_stack( | ||
self.config, self.custom_provider_registry | ||
) | ||
except ModuleNotFoundError as e: | ||
cprint( | ||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", | ||
"yellow", | ||
) | ||
print_pip_install_help(self.config.providers) | ||
raise e | ||
|
||
console = Console() | ||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") | ||
console.print(yaml.dump(self.config.model_dump(), indent=2)) | ||
|
||
endpoints = get_all_api_endpoints() | ||
endpoint_impls = {} | ||
for api, api_endpoints in endpoints.items(): | ||
for endpoint in api_endpoints: | ||
impl = self.impls[api] | ||
func = getattr(impl, endpoint.name) | ||
endpoint_impls[endpoint.route] = func | ||
|
||
self.endpoint_impls = endpoint_impls | ||
|
||
async def get( | ||
self, | ||
path: str, | ||
*, | ||
stream=False, | ||
**kwargs, | ||
): | ||
if not self.endpoint_impls: | ||
raise ValueError("Client not initialized") | ||
|
||
if stream: | ||
return self._call_streaming(path, "GET") | ||
else: | ||
return await self._call_non_streaming(path, "GET") | ||
|
||
async def post( | ||
self, | ||
path: str, | ||
*, | ||
body: dict = None, | ||
stream=False, | ||
**kwargs, | ||
): | ||
if not self.endpoint_impls: | ||
raise ValueError("Client not initialized") | ||
|
||
if stream: | ||
return self._call_streaming(path, "POST", body) | ||
else: | ||
return await self._call_non_streaming(path, "POST", body) | ||
|
||
async def _call_non_streaming(self, path: str, method: str, body: dict = None): | ||
func = self.endpoint_impls.get(path) | ||
if not func: | ||
raise ValueError(f"No endpoint found for {path}") | ||
|
||
body = self._convert_body(path, body) | ||
return await func(**body) | ||
|
||
async def _call_streaming(self, path: str, method: str, body: dict = None): | ||
func = self.endpoint_impls.get(path) | ||
if not func: | ||
raise ValueError(f"No endpoint found for {path}") | ||
|
||
body = self._convert_body(path, body) | ||
async for chunk in await func(**body): | ||
yield chunk | ||
|
||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: | ||
if not body: | ||
return {} | ||
|
||
func = self.endpoint_impls[path] | ||
sig = inspect.signature(func) | ||
|
||
# Strip NOT_GIVENs to use the defaults in signature | ||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN} | ||
|
||
# Convert parameters to Pydantic models where needed | ||
converted_body = {} | ||
for param_name, param in sig.parameters.items(): | ||
if param_name in body: | ||
value = body.get(param_name) | ||
converted_body[param_name] = self._convert_param( | ||
param.annotation, value | ||
) | ||
return converted_body | ||
|
||
def _convert_param(self, annotation: Any, value: Any) -> Any: | ||
if isinstance(annotation, type) and annotation in {str, int, float, bool}: | ||
return value | ||
|
||
origin = get_origin(annotation) | ||
if origin is list: | ||
item_type = get_args(annotation)[0] | ||
try: | ||
return [self._convert_param(item_type, item) for item in value] | ||
except Exception: | ||
print(f"Error converting list {value}") | ||
return value | ||
|
||
elif origin is dict: | ||
key_type, val_type = get_args(annotation) | ||
try: | ||
return {k: self._convert_param(val_type, v) for k, v in value.items()} | ||
except Exception: | ||
print(f"Error converting dict {value}") | ||
return value | ||
|
||
try: | ||
# Handle Pydantic models and discriminated unions | ||
return TypeAdapter(annotation).validate_python(value) | ||
except Exception as e: | ||
cprint( | ||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", | ||
"yellow", | ||
) | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
import argparse | ||
import os | ||
|
||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient | ||
from llama_stack_client.lib.agents.agent import Agent | ||
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger | ||
from llama_stack_client.lib.inference.event_logger import EventLogger | ||
from llama_stack_client.types import UserMessage | ||
from llama_stack_client.types.agent_create_params import AgentConfig | ||
|
||
|
||
def main(config_path: str): | ||
client = LlamaStackAsLibraryClient(config_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://llama-stack.readthedocs.io/en/latest/distributions/importing_as_library.html does this reference to |
||
client.initialize() | ||
|
||
models = client.models.list() | ||
print("\nModels:") | ||
for model in models: | ||
print(model) | ||
|
||
if not models: | ||
print("No models found, skipping chat completion test") | ||
return | ||
|
||
model_id = models[0].identifier | ||
response = client.inference.chat_completion( | ||
messages=[UserMessage(content="What is the capital of France?", role="user")], | ||
model_id=model_id, | ||
stream=False, | ||
) | ||
print("\nChat completion response (non-stream):") | ||
print(response) | ||
|
||
response = client.inference.chat_completion( | ||
messages=[UserMessage(content="What is the capital of France?", role="user")], | ||
model_id=model_id, | ||
stream=True, | ||
) | ||
|
||
print("\nChat completion response (stream):") | ||
for log in EventLogger().log(response): | ||
log.print() | ||
|
||
print("\nAgent test:") | ||
agent_config = AgentConfig( | ||
model=model_id, | ||
instructions="You are a helpful assistant", | ||
sampling_params={ | ||
"strategy": "greedy", | ||
"temperature": 1.0, | ||
"top_p": 0.9, | ||
}, | ||
tools=( | ||
[ | ||
{ | ||
"type": "brave_search", | ||
"engine": "brave", | ||
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"), | ||
} | ||
] | ||
if os.getenv("BRAVE_SEARCH_API_KEY") | ||
else [] | ||
), | ||
tool_choice="auto", | ||
tool_prompt_format="json", | ||
input_shields=[], | ||
output_shields=[], | ||
enable_session_persistence=False, | ||
) | ||
agent = Agent(client, agent_config) | ||
user_prompts = [ | ||
"Hello", | ||
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools", | ||
] | ||
|
||
session_id = agent.create_session("test-session") | ||
|
||
for prompt in user_prompts: | ||
response = agent.create_turn( | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": prompt, | ||
} | ||
], | ||
session_id=session_id, | ||
) | ||
|
||
for log in AgentEventLogger().log(response): | ||
log.print() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("config_path", help="Path to the config YAML file") | ||
args = parser.parse_args() | ||
main(args.config_path) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a most crucial part of this PR. without this you cannot make the "non-async" generators (which is a necessary mode for our client-sdk) work properly. you must be able to do
for chunk in inference.chat_completion()
since that's what synchronous generators are about. to bridge a sync generator to the async generators we have in our server-side code we need to intermediate via a thread pool.