Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes tests & move braintrust api_keys to request headers #535

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama_stack/distribution/request_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_request_provider_data(self) -> Any:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error("Error parsing provider data", e)
log.error(f"Error parsing provider data: {e}")


def set_request_provider_data(headers: Dict[str, str]):
Expand Down
5 changes: 5 additions & 0 deletions llama_stack/providers/inline/scoring/braintrust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
from typing import Dict

from llama_stack.distribution.datatypes import Api, ProviderSpec
from pydantic import BaseModel

from .config import BraintrustScoringConfig


class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str


async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
Expand Down
23 changes: 21 additions & 2 deletions llama_stack/providers/inline/scoring/braintrust/braintrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403

# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
import os

from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate

from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
Expand All @@ -24,7 +26,9 @@
from .scoring_fn.fn_defs.factuality import factuality_fn_def


class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
):
def __init__(
self,
config: BraintrustScoringConfig,
Expand Down Expand Up @@ -79,12 +83,25 @@ async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

async def set_api_key(self) -> None:
# api key is in the request headers
if self.config.openai_api_key is None:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key

os.environ["OPENAI_API_KEY"] = self.config.openai_api_key

async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
Expand All @@ -105,6 +122,7 @@ async def score_batch(
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
Expand All @@ -118,6 +136,7 @@ async def score_row(
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:
Expand Down
6 changes: 5 additions & 1 deletion llama_stack/providers/inline/scoring/braintrust/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
from llama_stack.apis.scoring import * # noqa: F401, F403


class BraintrustScoringConfig(BaseModel): ...
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)
1 change: 1 addition & 0 deletions llama_stack/providers/registry/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]
16 changes: 16 additions & 0 deletions llama_stack/providers/tests/eval/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

import pytest

from ..agents.fixtures import AGENTS_FIXTURES

from ..conftest import get_provider_fixture_overrides

from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES

Expand All @@ -20,6 +24,9 @@
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
Expand All @@ -30,6 +37,9 @@
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
Expand All @@ -40,6 +50,9 @@
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
Expand Down Expand Up @@ -75,6 +88,9 @@ def pytest_generate_tests(metafunc):
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
"agents": AGENTS_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
Expand Down
20 changes: 18 additions & 2 deletions llama_stack/providers/tests/eval/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,30 @@ async def eval_stack(request):

providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
for key in [
"datasetio",
"eval",
"scoring",
"inference",
"agents",
"safety",
"memory",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)

test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
[
Api.eval,
Api.datasetio,
Api.inference,
Api.scoring,
Api.agents,
Api.safety,
Api.memory,
],
providers,
provider_data,
)
Expand Down
7 changes: 5 additions & 2 deletions llama_stack/providers/tests/scoring/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from llama_stack.apis.models import ModelInput

from llama_stack.distribution.datatypes import Api, Provider

from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,7 +41,9 @@ def scoring_braintrust() -> ProviderFixture:
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
config={},
config=BraintrustScoringConfig(
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
).model_dump(),
)
],
)
Expand Down