From ca678f8c3a802a506950152e7196f01fb7712798 Mon Sep 17 00:00:00 2001 From: Olmo Maldonado Date: Wed, 11 Dec 2024 21:48:10 -0600 Subject: [PATCH] feature: configurable (OpenAI) client All public interfaces that work with OpenAI transitively now has an optional AutoEvalCient client option. If not provided, we'll prepare and handle the OpenAI client as we do already. At the moment you need to pass the client in each call, but in a follow-up commit you'll be able to set it globally. --- py/autoevals/llm.py | 13 ++- py/autoevals/moderation.py | 23 ++++- py/autoevals/oai.py | 181 ++++++++++++++++++++++++------------- py/autoevals/ragas.py | 172 +++++++++++++++++++++++------------ py/autoevals/string.py | 23 ++++- 5 files changed, 280 insertions(+), 132 deletions(-) diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index a3f3a27..bca51c4 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -1,4 +1,3 @@ -import abc import json import os import re @@ -11,7 +10,7 @@ from autoevals.partial import ScorerWithPartial -from .oai import arun_cached_request, run_cached_request +from .oai import AutoEvalClient, arun_cached_request, run_cached_request SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -79,6 +78,7 @@ def __init__( self, api_key=None, base_url=None, + client: Optional[AutoEvalClient] = None, ): self.extra_args = {} if api_key: @@ -86,6 +86,8 @@ def __init__( if base_url: self.extra_args["base_url"] = base_url + self.client = client + class OpenAILLMScorer(OpenAIScorer): def __init__( @@ -93,10 +95,12 @@ def __init__( temperature=None, api_key=None, base_url=None, + client: Optional[AutoEvalClient] = None, ): super().__init__( api_key=api_key, base_url=base_url, + client=client, ) self.extra_args["temperature"] = temperature or 0 @@ -115,8 +119,10 @@ def __init__( engine=None, api_key=None, base_url=None, + client: Optional[AutoEvalClient] = None, ): super().__init__( + client=client, api_key=api_key, base_url=base_url, ) @@ -162,6 +168,7 @@ def _render_messages(self, **kwargs): def _request_args(self, output, expected, **kwargs): ret = { + "client": self.client, **self.extra_args, **self._build_args(output, expected, **kwargs), } @@ -233,6 +240,7 @@ def __init__( engine=None, api_key=None, base_url=None, + client: Optional[AutoEvalClient] = None, **extra_render_args, ): choice_strings = list(choice_scores.keys()) @@ -257,6 +265,7 @@ def __init__( api_key=api_key, base_url=base_url, render_args={"__choices": choice_strings, **extra_render_args}, + client=client, ) @classmethod diff --git a/py/autoevals/moderation.py b/py/autoevals/moderation.py index 4164e13..a2efa30 100644 --- a/py/autoevals/moderation.py +++ b/py/autoevals/moderation.py @@ -1,8 +1,10 @@ +from typing import Optional + from braintrust_core.score import Score from autoevals.llm import OpenAIScorer -from .oai import arun_cached_request, run_cached_request +from .oai import AutoEvalClient, arun_cached_request, run_cached_request REQUEST_TYPE = "moderation" @@ -15,7 +17,13 @@ class Moderation(OpenAIScorer): threshold = None extra_args = {} - def __init__(self, threshold=None, api_key=None, base_url=None): + def __init__( + self, + threshold=None, + api_key=None, + base_url=None, + client: Optional[AutoEvalClient] = None, + ): """ Create a new Moderation scorer. @@ -24,11 +32,14 @@ def __init__(self, threshold=None, api_key=None, base_url=None): :param api_key: OpenAI key :param base_url: Base URL to be used to reach OpenAI moderation endpoint. """ - super().__init__(api_key=api_key, base_url=base_url) + super().__init__(api_key=api_key, base_url=base_url, client=client) self.threshold = threshold + # need to check who calls _run_eval_a?sync def _run_eval_sync(self, output, __expected=None): - moderation_response = run_cached_request(REQUEST_TYPE, input=output, **self.extra_args)["results"][0] + moderation_response = run_cached_request( + client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args + )["results"][0] return self.__postprocess_response(moderation_response) def __postprocess_response(self, moderation_response) -> Score: @@ -42,7 +53,9 @@ def __postprocess_response(self, moderation_response) -> Score: ) async def _run_eval_async(self, output, expected=None, **kwargs) -> Score: - moderation_response = (await arun_cached_request(REQUEST_TYPE, input=output, **self.extra_args))["results"][0] + moderation_response = ( + await arun_cached_request(client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args) + )["results"][0] return self.__postprocess_response(moderation_response) @staticmethod diff --git a/py/autoevals/oai.py b/py/autoevals/oai.py index 9141945..ec90a5c 100644 --- a/py/autoevals/oai.py +++ b/py/autoevals/oai.py @@ -5,93 +5,146 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Optional PROXY_URL = "https://api.braintrust.dev/v1/proxy" @dataclass -class OpenAIWrapper: +class AutoEvalClient: + # TODO: add docs + # TODO: how to type if we don't depend on openai + openai: Any complete: Any embed: Any moderation: Any RateLimitError: Exception -def prepare_openai(is_async=False, api_key=None, base_url=None): - if api_key is None: - api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("BRAINTRUST_API_KEY") - if base_url is None: - base_url = os.environ.get("OPENAI_BASE_URL", PROXY_URL) +def prepare_openai(client: Optional[AutoEvalClient] = None, is_async=False, api_key=None, base_url=None): + """Prepares and configures an OpenAI client for use with AutoEval, if client is not provided. - try: - import openai - except Exception as e: - print( - textwrap.dedent( - f"""\ - Unable to import openai: {e} - - Please install it, e.g. with - - pip install 'openai' - """ - ), - file=sys.stderr, - ) - raise + This function handles both v0 and v1 of the OpenAI SDK, configuring the client + with the appropriate authentication and base URL settings. + + We will also attempt to enable Braintrust tracing export, if you've configured tracing. + + Args: + client (Optional[AutoEvalClient], optional): Existing AutoEvalClient instance. + If provided, this client will be used instead of creating a new one. + + is_async (bool, optional): Whether to create a client with async operations. Defaults to False. + Deprecated: Use the `client` argument and set the `openai` with the async/sync that you'd like to use. + + api_key (str, optional): OpenAI API key. If not provided, will look for + OPENAI_API_KEY or BRAINTRUST_API_KEY in environment variables. + + Deprecated: Use the `client` argument and set the `openai`. + + base_url (str, optional): Base URL for API requests. If not provided, will + use OPENAI_BASE_URL from environment or fall back to PROXY_URL. + + Deprecated: Use the `client` argument and set the `openai`. + + Returns: + Tuple[AutoEvalClient, bool]: A tuple containing: + - The configured AutoEvalClient instance, or the client you've provided + - A boolean indicating whether the client was wrapped with Braintrust tracing + + Raises: + ImportError: If the OpenAI package is not installed + """ + openai = getattr(client, "openai", None) + if not openai: + try: + import openai + except Exception as e: + print( + textwrap.dedent( + f"""\ + Unable to import openai: {e} + + Please install it, e.g. with + + pip install 'openai' + """ + ), + file=sys.stderr, + ) + raise openai_obj = openai + is_v1 = False if hasattr(openai, "OpenAI"): # This is the new v1 API is_v1 = True - if is_async: - openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) + + if client is None: + # prepare the default openai sdk, if not provided + if api_key is None: + api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("BRAINTRUST_API_KEY") + if base_url is None: + base_url = os.environ.get("OPENAI_BASE_URL", PROXY_URL) + + if is_v1: + if is_async: + openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) + else: + openai_obj = openai.OpenAI(api_key=api_key, base_url=base_url) else: - openai_obj = openai.OpenAI(api_key=api_key, base_url=base_url) - else: - if api_key: - openai.api_key = api_key - openai.api_base = base_url + if api_key: + openai.api_key = api_key + openai.api_base = base_url + # optimistically wrap openai instance for tracing wrapped = False try: - from braintrust.oai import wrap_openai + from braintrust.oai import NamedWrapper, wrap_openai + + if not isinstance(openai_obj, NamedWrapper): + openai_obj = wrap_openai(openai_obj) - openai_obj = wrap_openai(openai_obj) wrapped = True except ImportError: pass - complete_fn = None - rate_limit_error = None - if is_v1: - wrapper = OpenAIWrapper( - complete=openai_obj.chat.completions.create, - embed=openai_obj.embeddings.create, - moderation=openai_obj.moderations.create, - RateLimitError=openai.RateLimitError, - ) - else: - rate_limit_error = openai.error.RateLimitError - if is_async: - complete_fn = openai_obj.ChatCompletion.acreate - embedding_fn = openai_obj.Embedding.acreate - moderation_fn = openai_obj.Moderations.acreate + if client is None: + # prepare the default client if not provided + complete_fn = None + rate_limit_error = None + + # TODO: allow overriding globally + Client = AutoEvalClient + + if is_v1: + client = Client( + openai=openai, + complete=openai_obj.chat.completions.create, + embed=openai_obj.embeddings.create, + moderation=openai_obj.moderations.create, + RateLimitError=openai.RateLimitError, + ) else: - complete_fn = openai_obj.ChatCompletion.create - embedding_fn = openai_obj.Embedding.create - moderation_fn = openai_obj.Moderations.create - wrapper = OpenAIWrapper( - complete=complete_fn, - embed=embedding_fn, - moderation=moderation_fn, - RateLimitError=rate_limit_error, - ) - - return wrapper, wrapped + rate_limit_error = openai.error.RateLimitError + if is_async: + complete_fn = openai_obj.ChatCompletion.acreate + embedding_fn = openai_obj.Embedding.acreate + moderation_fn = openai_obj.Moderations.acreate + else: + complete_fn = openai_obj.ChatCompletion.create + embedding_fn = openai_obj.Embedding.create + moderation_fn = openai_obj.Moderations.create + client = Client( + openai=openai, + complete=complete_fn, + embed=embedding_fn, + moderation=moderation_fn, + RateLimitError=rate_limit_error, + ) + + return client, wrapped def post_process_response(resp): @@ -108,8 +161,10 @@ def set_span_purpose(kwargs): kwargs.setdefault("span_info", {}).setdefault("span_attributes", {})["purpose"] = "scorer" -def run_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs): - wrapper, wrapped = prepare_openai(is_async=False, api_key=api_key, base_url=base_url) +def run_cached_request( + *, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs +): + wrapper, wrapped = prepare_openai(client=client, is_async=False, api_key=api_key, base_url=base_url) if wrapped: set_span_purpose(kwargs) @@ -127,8 +182,10 @@ def run_cached_request(request_type="complete", api_key=None, base_url=None, **k return resp -async def arun_cached_request(request_type="complete", api_key=None, base_url=None, **kwargs): - wrapper, wrapped = prepare_openai(is_async=True, api_key=api_key, base_url=base_url) +async def arun_cached_request( + *, client: Optional[AutoEvalClient] = None, request_type="complete", api_key=None, base_url=None, **kwargs +): + wrapper, wrapped = prepare_openai(client=client, is_async=True, api_key=api_key, base_url=base_url) if wrapped: set_span_purpose(kwargs) diff --git a/py/autoevals/ragas.py b/py/autoevals/ragas.py index 6d5a820..20ab0af 100644 --- a/py/autoevals/ragas.py +++ b/py/autoevals/ragas.py @@ -2,13 +2,14 @@ import asyncio import json +from typing import Optional import chevron from . import Score from .list import ListContains from .llm import OpenAILLMScorer -from .oai import arun_cached_request, run_cached_request +from .oai import AutoEvalClient, arun_cached_request, run_cached_request from .string import EmbeddingSimilarity @@ -76,13 +77,13 @@ def extract_entities_request(text, **extra_args): ) -async def aextract_entities(text, **extra_args): - response = await arun_cached_request(**extract_entities_request(text=text, **extra_args)) +async def aextract_entities(*, text, client: Optional[AutoEvalClient] = None, **extra_args): + response = await arun_cached_request(client=client, **extract_entities_request(text=text, **extra_args)) return json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) -def extract_entities(text, **extra_args): - response = run_cached_request(**extract_entities_request(text=text, **extra_args)) +def extract_entities(*, text, client: Optional[AutoEvalClient] = None, **extra_args): + response = run_cached_request(client=client, **extract_entities_request(text=text, **extra_args)) return json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) @@ -92,12 +93,14 @@ class ContextEntityRecall(OpenAILLMScorer): retrieved context. """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__( + self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Optional[AutoEvalClient] = None, **kwargs + ): + super().__init__(client=client, **kwargs) self.extraction_model = model self.contains_scorer = ListContains( - pairwise_scorer=pairwise_scorer or EmbeddingSimilarity(), allow_extra_entities=True + pairwise_scorer=pairwise_scorer or EmbeddingSimilarity(client=client), allow_extra_entities=True ) async def _run_eval_async(self, output, expected=None, context=None, **kwargs): @@ -106,8 +109,8 @@ async def _run_eval_async(self, output, expected=None, context=None, **kwargs): context = "\n".join(context) if isinstance(context, list) else context expected_entities_future, context_entities_future = ( - aextract_entities(text=expected, model=self.extraction_model, **self.extra_args), - aextract_entities(text=context, model=self.extraction_model, **self.extra_args), + aextract_entities(client=self.client, text=expected, model=self.extraction_model, **self.extra_args), + aextract_entities(client=self.client, text=context, model=self.extraction_model, **self.extra_args), ) expected_entities = [e for e in (await expected_entities_future)["entities"]] @@ -127,10 +130,16 @@ def _run_eval_sync(self, output, expected=None, context=None, **kwargs): context = "\n".join(context) if isinstance(context, list) else context expected_entities = [ - e for e in (extract_entities(text=expected, model=self.extraction_model, **self.extra_args))["entities"] + e + for e in ( + extract_entities(client=self.client, text=expected, model=self.extraction_model, **self.extra_args) + )["entities"] ] context_entities = [ - e for e in (extract_entities(text=context, model=self.extraction_model, **self.extra_args))["entities"] + e + for e in ( + extract_entities(client=self.client, text=context, model=self.extraction_model, **self.extra_args) + )["entities"] ] score = self.contains_scorer.eval(output=context_entities, expected=expected_entities) @@ -208,8 +217,10 @@ class ContextRelevancy(OpenAILLMScorer): self-consistency checks. The number of relevant sentences and is used as the score. """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__( + self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Optional[AutoEvalClient] = None, **kwargs + ): + super().__init__(client=client, **kwargs) self.model = model @@ -234,7 +245,8 @@ async def _run_eval_async(self, output, expected=None, input=None, context=None, return self._postprocess( context, await arun_cached_request( - **extract_sentences_request(question=input, context=context, model=self.model, **self.extra_args) + client=self.client, + **extract_sentences_request(question=input, context=context, model=self.model, **self.extra_args), ), ) @@ -247,7 +259,8 @@ def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwar return self._postprocess( context, run_cached_request( - **extract_sentences_request(question=input, context=context, model=self.model, **self.extra_args) + client=self.client, + **extract_sentences_request(question=input, context=context, model=self.model, **self.extra_args), ), ) @@ -342,8 +355,10 @@ class ContextRecall(OpenAILLMScorer): retrieved context. """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__( + self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Optional[AutoEvalClient] = None, **kwargs + ): + super().__init__(client=client, **kwargs) self.model = model @@ -369,9 +384,10 @@ async def _run_eval_async(self, output, expected=None, input=None, context=None, return self._postprocess( await arun_cached_request( + client=self.client, **extract_context_recall_request( question=input, answer=expected, context=context, model=self.model, **self.extra_args - ) + ), ) ) @@ -383,9 +399,10 @@ def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwar return self._postprocess( run_cached_request( + client=self.client, **extract_context_recall_request( question=input, answer=expected, context=context, model=self.model, **self.extra_args - ) + ), ) ) @@ -475,8 +492,10 @@ class ContextPrecision(OpenAILLMScorer): relevant items selected by the model are ranked higher or not. """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__( + self, pairwise_scorer=None, model=DEFAULT_RAGAS_MODEL, client: Optional[AutoEvalClient] = None, **kwargs + ): + super().__init__(client=client, **kwargs) self.model = model @@ -499,9 +518,10 @@ async def _run_eval_async(self, output, expected=None, input=None, context=None, return self._postprocess( await arun_cached_request( + client=self.client, **extract_context_precision_request( question=input, answer=expected, context=context, model=self.model, **self.extra_args - ) + ), ) ) @@ -513,9 +533,10 @@ def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwar return self._postprocess( run_cached_request( + client=self.client, **extract_context_precision_request( question=input, answer=expected, context=context, model=self.model, **self.extra_args - ) + ), ) ) @@ -679,25 +700,31 @@ def extract_faithfulness_request(context, statements, **extra_args): ) -async def aextract_statements(question, answer, **extra_args): - response = await arun_cached_request(**extract_statements_request(question=question, answer=answer, **extra_args)) +async def aextract_statements(question, answer, client: Optional[AutoEvalClient] = None, **extra_args): + response = await arun_cached_request( + client=client, **extract_statements_request(question=question, answer=answer, **extra_args) + ) return load_function_call(response) -def extract_statements(question, answer, **extra_args): - response = run_cached_request(**extract_statements_request(question=question, answer=answer, **extra_args)) +def extract_statements(question, answer, client: Optional[AutoEvalClient] = None, **extra_args): + response = run_cached_request( + client=client, **extract_statements_request(question=question, answer=answer, **extra_args) + ) return load_function_call(response) -async def aextract_faithfulness(context, statements, **extra_args): +async def aextract_faithfulness(context, statements, client: Optional[AutoEvalClient] = None, **extra_args): response = await arun_cached_request( - **extract_faithfulness_request(context=context, statements=statements, **extra_args) + client=client, **extract_faithfulness_request(context=context, statements=statements, **extra_args) ) return load_function_call(response) -def extract_faithfulness(context, statements, **extra_args): - response = run_cached_request(**extract_faithfulness_request(context=context, statements=statements, **extra_args)) +def extract_faithfulness(context, statements, client: Optional[AutoEvalClient] = None, **extra_args): + response = run_cached_request( + client=client, **extract_faithfulness_request(context=context, statements=statements, **extra_args) + ) return load_function_call(response) @@ -706,20 +733,24 @@ class Faithfulness(OpenAILLMScorer): Measures factual consistency of a generated answer against the given context. """ - def __init__(self, model=DEFAULT_RAGAS_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__(self, model=DEFAULT_RAGAS_MODEL, client: Optional[AutoEvalClient] = None, **kwargs): + super().__init__(client=client, **kwargs) self.model = model async def _run_eval_async(self, output, expected=None, input=None, context=None, **kwargs): check_required("Faithfulness", input=input, output=output, context=context) - statements = (await aextract_statements(question=input, answer=expected, model=self.model, **self.extra_args))[ - "statements" - ] + statements = ( + await aextract_statements( + client=self.client, question=input, answer=expected, model=self.model, **self.extra_args + ) + )["statements"] faithfulness = ( - await aextract_faithfulness(context=context, statements=statements, model=self.model, **self.extra_args) + await aextract_faithfulness( + client=self.client, context=context, statements=statements, model=self.model, **self.extra_args + ) )["faithfulness"] return Score( @@ -734,12 +765,16 @@ async def _run_eval_async(self, output, expected=None, input=None, context=None, def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwargs): check_required("Faithfulness", input=input, context=context) - statements = (extract_statements(question=input, answer=expected, model=self.model, **self.extra_args))[ - "statements" - ] + statements = ( + extract_statements( + client=self.client, question=input, answer=expected, model=self.model, **self.extra_args + ) + )["statements"] faithfulness = ( - extract_faithfulness(context=context, statements=statements, model=self.model, **self.extra_args) + extract_faithfulness( + client=self.client, context=context, statements=statements, model=self.model, **self.extra_args + ) )["faithfulness"] return Score( @@ -837,9 +872,10 @@ def __init__( strictness=3, temperature=0.5, embedding_model=DEFAULT_RAGAS_EMBEDDING_MODEL, + client: Optional[AutoEvalClient] = None, **kwargs, ): - super().__init__(temperature=temperature, **kwargs) + super().__init__(temperature=temperature, client=client, **kwargs) self.model = model self.strictness = strictness @@ -868,14 +904,19 @@ async def _run_eval_async(self, output, expected=None, input=None, context=None, questions = await asyncio.gather( *[ aload_function_call_request( - **extract_question_gen_request(answer=output, context=context, model=self.model, **self.extra_args) + client=self.client, + **extract_question_gen_request( + answer=output, context=context, model=self.model, **self.extra_args + ), ) for _ in range(self.strictness) ] ) similarity = await asyncio.gather( *[ - EmbeddingSimilarity().eval_async(output=q["question"], expected=input, model=self.embedding_model) + EmbeddingSimilarity(client=self.client).eval_async( + output=q["question"], expected=input, model=self.embedding_model + ) for q in questions ] ) @@ -887,12 +928,14 @@ def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwar questions = [ load_function_call_request( - **extract_question_gen_request(answer=output, context=context, model=self.model, **self.extra_args) + client=self.client, + **extract_question_gen_request(answer=output, context=context, model=self.model, **self.extra_args), ) for _ in range(self.strictness) ] similarity = [ - EmbeddingSimilarity().eval(output=q["question"], expected=input, model=self.model) for q in questions + EmbeddingSimilarity(client=self.client).eval(output=q["question"], expected=input, model=self.model) + for q in questions ] return self._postprocess(questions, similarity) @@ -903,22 +946,30 @@ class AnswerSimilarity(OpenAILLMScorer): Measures the similarity between the generated answer and the expected answer. """ - def __init__(self, pairwise_scorer=None, model=DEFAULT_RAGAS_EMBEDDING_MODEL, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + pairwise_scorer=None, + model=DEFAULT_RAGAS_EMBEDDING_MODEL, + client: Optional[AutoEvalClient] = None, + **kwargs, + ): + super().__init__(client=client, **kwargs) self.model = model async def _run_eval_async(self, output, expected=None, input=None, **kwargs): check_required("AnswerSimilarity", expected=expected, output=output) - return await EmbeddingSimilarity().eval_async( + return await EmbeddingSimilarity(client=self.client).eval_async( output=output, expected=expected, model=self.model, **self.extra_args ) def _run_eval_sync(self, output, expected=None, input=None, **kwargs): check_required("AnswerSimilarity", expected=expected, output=output) - return EmbeddingSimilarity().eval(output=output, expected=expected, model=self.model, **self.extra_args) + return EmbeddingSimilarity(client=self.client).eval( + output=output, expected=expected, model=self.model, **self.extra_args + ) CORRECTNESS_PROMPT = """Given a ground truth and an answer, analyze each statement in the answer and classify them in one of the following categories: @@ -1015,12 +1066,13 @@ def __init__( factuality_weight=0.75, answer_similarity_weight=0.25, answer_similarity=None, + client: Optional[AutoEvalClient] = None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(client=client, **kwargs) self.model = model - self.answer_similarity = answer_similarity or AnswerSimilarity() + self.answer_similarity = answer_similarity or AnswerSimilarity(client=client) if factuality_weight == 0 and answer_similarity_weight == 0: raise ValueError("At least one weight must be nonzero") @@ -1065,9 +1117,10 @@ async def _run_eval_async(self, output, expected=None, input=None, **kwargs): factuality_future, similarity_future = ( aload_function_call_request( + client=self.client, **extract_correctness_request( question=input, answer=output, ground_truth=expected, model=self.model, **self.extra_args - ) + ), ), self._run_answer_similarity_async(output, expected), ) @@ -1079,9 +1132,10 @@ def _run_eval_sync(self, output, expected=None, input=None, **kwargs): factuality, similarity = ( load_function_call_request( + client=self.client, **extract_correctness_request( question=input, answer=output, ground_truth=expected, model=self.model, **self.extra_args - ) + ), ), self._run_answer_similarity_sync(output, expected), ) @@ -1093,9 +1147,9 @@ def load_function_call(response): return json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) -async def aload_function_call_request(**kwargs): - return load_function_call(await arun_cached_request(**kwargs)) +async def aload_function_call_request(client: Optional[AutoEvalClient] = None, **kwargs): + return load_function_call(await arun_cached_request(client=client, **kwargs)) -def load_function_call_request(**kwargs): - return load_function_call(run_cached_request(**kwargs)) +def load_function_call_request(client: Optional[AutoEvalClient] = None, **kwargs): + return load_function_call(run_cached_request(client=client, **kwargs)) diff --git a/py/autoevals/string.py b/py/autoevals/string.py index e078969..dced7a0 100644 --- a/py/autoevals/string.py +++ b/py/autoevals/string.py @@ -1,4 +1,5 @@ import threading +from typing import Optional from braintrust_core.score import Score from Levenshtein import distance @@ -6,7 +7,7 @@ from autoevals.partial import ScorerWithPartial from autoevals.value import normalize_value -from .oai import arun_cached_request, run_cached_request +from .oai import AutoEvalClient, arun_cached_request, run_cached_request class Levenshtein(ScorerWithPartial): @@ -41,7 +42,15 @@ class EmbeddingSimilarity(ScorerWithPartial): _CACHE = {} _CACHE_LOCK = threading.Lock() - def __init__(self, prefix="", model=MODEL, expected_min=0.7, api_key=None, base_url=None): + def __init__( + self, + prefix="", + model=MODEL, + expected_min=0.7, + api_key=None, + base_url=None, + client: Optional[AutoEvalClient] = None, + ): """ Create a new EmbeddingSimilarity scorer. @@ -59,13 +68,17 @@ def __init__(self, prefix="", model=MODEL, expected_min=0.7, api_key=None, base_ if base_url: self.extra_args["base_url"] = base_url + self.client = client + async def _a_embed(self, value): value = normalize_value(value, maybe_object=False) with self._CACHE_LOCK: if value in self._CACHE: return self._CACHE[value] - result = await arun_cached_request("embed", input=f"{self.prefix}{value}", **self.extra_args) + result = await arun_cached_request( + client=self.client, request_type="embed", input=f"{self.prefix}{value}", **self.extra_args + ) with self._CACHE_LOCK: self._CACHE[value] = result @@ -78,7 +91,9 @@ def _embed(self, value): if value in self._CACHE: return self._CACHE[value] - result = run_cached_request("embed", input=f"{self.prefix}{value}", **self.extra_args) + result = run_cached_request( + client=self.client, request_type="embed", input=f"{self.prefix}{value}", **self.extra_args + ) with self._CACHE_LOCK: self._CACHE[value] = result