Skip to content

Commit

Permalink
feature: configurable (OpenAI) client
Browse files Browse the repository at this point in the history
All public interfaces that work with OpenAI transitively
now has an optional AutoEvalCient client option. If not provided,
we'll prepare and handle the OpenAI client as we do already.

At the moment you need to pass the client in each call, but
in a follow-up commit you'll be able to set it globally.
  • Loading branch information
ibolmo committed Dec 12, 2024
1 parent 40a051f commit ca678f8
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 132 deletions.
13 changes: 11 additions & 2 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import json
import os
import re
Expand All @@ -11,7 +10,7 @@

from autoevals.partial import ScorerWithPartial

from .oai import arun_cached_request, run_cached_request
from .oai import AutoEvalClient, arun_cached_request, run_cached_request

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -79,24 +78,29 @@ def __init__(
self,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
):
self.extra_args = {}
if api_key:
self.extra_args["api_key"] = api_key
if base_url:
self.extra_args["base_url"] = base_url

self.client = client


class OpenAILLMScorer(OpenAIScorer):
def __init__(
self,
temperature=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
):
super().__init__(
api_key=api_key,
base_url=base_url,
client=client,
)
self.extra_args["temperature"] = temperature or 0

Expand All @@ -115,8 +119,10 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
):
super().__init__(
client=client,
api_key=api_key,
base_url=base_url,
)
Expand Down Expand Up @@ -162,6 +168,7 @@ def _render_messages(self, **kwargs):

def _request_args(self, output, expected, **kwargs):
ret = {
"client": self.client,
**self.extra_args,
**self._build_args(output, expected, **kwargs),
}
Expand Down Expand Up @@ -233,6 +240,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
**extra_render_args,
):
choice_strings = list(choice_scores.keys())
Expand All @@ -257,6 +265,7 @@ def __init__(
api_key=api_key,
base_url=base_url,
render_args={"__choices": choice_strings, **extra_render_args},
client=client,
)

@classmethod
Expand Down
23 changes: 18 additions & 5 deletions py/autoevals/moderation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional

from braintrust_core.score import Score

from autoevals.llm import OpenAIScorer

from .oai import arun_cached_request, run_cached_request
from .oai import AutoEvalClient, arun_cached_request, run_cached_request

REQUEST_TYPE = "moderation"

Expand All @@ -15,7 +17,13 @@ class Moderation(OpenAIScorer):
threshold = None
extra_args = {}

def __init__(self, threshold=None, api_key=None, base_url=None):
def __init__(
self,
threshold=None,
api_key=None,
base_url=None,
client: Optional[AutoEvalClient] = None,
):
"""
Create a new Moderation scorer.
Expand All @@ -24,11 +32,14 @@ def __init__(self, threshold=None, api_key=None, base_url=None):
:param api_key: OpenAI key
:param base_url: Base URL to be used to reach OpenAI moderation endpoint.
"""
super().__init__(api_key=api_key, base_url=base_url)
super().__init__(api_key=api_key, base_url=base_url, client=client)
self.threshold = threshold

# need to check who calls _run_eval_a?sync
def _run_eval_sync(self, output, __expected=None):
moderation_response = run_cached_request(REQUEST_TYPE, input=output, **self.extra_args)["results"][0]
moderation_response = run_cached_request(
client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args
)["results"][0]
return self.__postprocess_response(moderation_response)

def __postprocess_response(self, moderation_response) -> Score:
Expand All @@ -42,7 +53,9 @@ def __postprocess_response(self, moderation_response) -> Score:
)

async def _run_eval_async(self, output, expected=None, **kwargs) -> Score:
moderation_response = (await arun_cached_request(REQUEST_TYPE, input=output, **self.extra_args))["results"][0]
moderation_response = (
await arun_cached_request(client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args)
)["results"][0]
return self.__postprocess_response(moderation_response)

@staticmethod
Expand Down
181 changes: 119 additions & 62 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,146 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, Optional

PROXY_URL = "https://api.braintrust.dev/v1/proxy"


@dataclass
class OpenAIWrapper:
class AutoEvalClient:
# TODO: add docs
# TODO: how to type if we don't depend on openai
openai: Any
complete: Any
embed: Any
moderation: Any
RateLimitError: Exception


def prepare_openai(is_async=False, api_key=None, base_url=None):
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("BRAINTRUST_API_KEY")
if base_url is None:
base_url = os.environ.get("OPENAI_BASE_URL", PROXY_URL)
def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_key=None, base_url=None):
"""Prepares and configures an OpenAI client for use with AutoEval, if client is not provided.
try:
import openai
except Exception as e:
print(
textwrap.dedent(
f"""\
Unable to import openai: {e}
Please install it, e.g. with
pip install 'openai'
"""
),
file=sys.stderr,
)
raise
This function handles both v0 and v1 of the OpenAI SDK, configuring the client
with the appropriate authentication and base URL settings.
We will also attempt to enable Braintrust tracing export, if you've configured tracing.
Args:
client (Optional[AutoEvalClient], optional): Existing AutoEvalClient instance.
If provided, this client will be used instead of creating a new one.
is_async (bool, optional): Whether to create a client with async operations. Defaults to False.
Deprecated: Use the `client` argument and set the `openai` with the async/sync that you'd like to use.
api_key (str, optional): OpenAI API key. If not provided, will look for
OPENAI_API_KEY or BRAINTRUST_API_KEY in environment variables.
Deprecated: Use the `client` argument and set the `openai`.
base_url (str, optional): Base URL for API requests. If not provided, will
use OPENAI_BASE_URL from environment or fall back to PROXY_URL.
Deprecated: Use the `client` argument and set the `openai`.
Returns:
Tuple[AutoEvalClient, bool]: A tuple containing:
- The configured AutoEvalClient instance, or the client you've provided
- A boolean indicating whether the client was wrapped with Braintrust tracing
Raises:
ImportError: If the OpenAI package is not installed
"""
openai = getattr(client, "openai", None)
if not openai:
try:
import openai
except Exception as e:
print(
textwrap.dedent(
f"""\
Unable to import openai: {e}
Please install it, e.g. with
pip install 'openai'
"""
),
file=sys.stderr,
)
raise

openai_obj = openai

is_v1 = False

if hasattr(openai, "OpenAI"):
# This is the new v1 API
is_v1 = True
if is_async:
openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=base_url)

if client is None:
# prepare the default openai sdk, if not provided
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("BRAINTRUST_API_KEY")
if base_url is None:
base_url = os.environ.get("OPENAI_BASE_URL", PROXY_URL)

if is_v1:
if is_async:
openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
openai_obj = openai.OpenAI(api_key=api_key, base_url=base_url)
else:
openai_obj = openai.OpenAI(api_key=api_key, base_url=base_url)
else:
if api_key:
openai.api_key = api_key
openai.api_base = base_url
if api_key:
openai.api_key = api_key
openai.api_base = base_url

# optimistically wrap openai instance for tracing
wrapped = False
try:
from braintrust.oai import wrap_openai
from braintrust.oai import NamedWrapper, wrap_openai

if not isinstance(openai_obj, NamedWrapper):
openai_obj = wrap_openai(openai_obj)

openai_obj = wrap_openai(openai_obj)
wrapped = True
except ImportError:
pass

complete_fn = None
rate_limit_error = None
if is_v1:
wrapper = OpenAIWrapper(
complete=openai_obj.chat.completions.create,
embed=openai_obj.embeddings.create,
moderation=openai_obj.moderations.create,
RateLimitError=openai.RateLimitError,
)
else:
rate_limit_error = openai.error.RateLimitError
if is_async:
complete_fn = openai_obj.ChatCompletion.acreate
embedding_fn = openai_obj.Embedding.acreate
moderation_fn = openai_obj.Moderations.acreate
if client is None:
# prepare the default client if not provided
complete_fn = None
rate_limit_error = None

# TODO: allow overriding globally
Client = AutoEvalClient

if is_v1:
client = Client(
openai=openai,
complete=openai_obj.chat.completions.create,
embed=openai_obj.embeddings.create,
moderation=openai_obj.moderations.create,
RateLimitError=openai.RateLimitError,
)
else:
complete_fn = openai_obj.ChatCompletion.create
embedding_fn = openai_obj.Embedding.create
moderation_fn = openai_obj.Moderations.create
wrapper = OpenAIWrapper(
complete=complete_fn,
embed=embedding_fn,
moderation=moderation_fn,
RateLimitError=rate_limit_error,
)

return wrapper, wrapped
rate_limit_error = openai.error.RateLimitError
if is_async:
complete_fn = openai_obj.ChatCompletion.acreate
embedding_fn = openai_obj.Embedding.acreate
moderation_fn = openai_obj.Moderations.acreate
else:
complete_fn = openai_obj.ChatCompletion.create
embedding_fn = openai_obj.Embedding.create
moderation_fn = openai_obj.Moderations.create
client = Client(
openai=openai,
complete=complete_fn,
embed=embedding_fn,
moderation=moderation_fn,
RateLimitError=rate_limit_error,
)

return client, wrapped


def post_process_response(resp):
Expand All @@ -108,8 +161,10 @@ def set_span_purpose(kwargs):
kwargs.setdefault("span_info", {}).setdefault("span_attributes", {})["purpose"] = "scorer"


def run_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper, wrapped = prepare_openai(is_async=False, api_key=api_key, base_url=base_url)
def run_cached_request(
*, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
):
wrapper, wrapped = prepare_openai(client=client, is_async=False, api_key=api_key, base_url=base_url)
if wrapped:
set_span_purpose(kwargs)

Expand All @@ -127,8 +182,10 @@ def run_cached_request(request_type="complete", api_key=None, base_url=None, **k
return resp


async def arun_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs):
wrapper, wrapped = prepare_openai(is_async=True, api_key=api_key, base_url=base_url)
async def arun_cached_request(
*, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs
):
wrapper, wrapped = prepare_openai(client=client, is_async=True, api_key=api_key, base_url=base_url)
if wrapped:
set_span_purpose(kwargs)

Expand Down
Loading

0 comments on commit ca678f8

Please sign in to comment.