-
-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(AI): Create abstractions for generic LLM calls within sentry (#6…
…8771) ### Background For the User Feedback #61372 Spam Detection feature, we intend to use an LLM. Along with other use cases such as suggested fix, and code integration, there is a need across the Sentry codebase to be able to call LLMs. Because Sentry is self hosted, some features use different LLMs and we want to provide modularity, we need to be able to configure different LLM providers, models, and usecases. ### Solution: - We define an options based config for Providers and use cases, where you can specify a LLM provider's options and settings. - For each use case, you then define what LLM provider it uses, and what model. Within `sentry/llm` we define a `providers` module, which consists of a base and implementations. To start, we have OpenAI, Google Vertex, and a Preview implementation used for testing. These will use the provider options to initialize a client and connect to the LLM provider. The providers inherit from `LazyServiceWrapper`. Also within `sentry/llm`, we define a `usecases` module, which simply consists of a function `complete_prompt`, along with an enum of use cases. These options are passed to the LLM provider per use case, and can be configured via the above option. ### Testing I've added unit tests which mock the LLM calls, and I've tested in my local environment that calls to the actual services work. ### In practice: So to use an LLM, you do the following steps: 1. define your usecase in the [usecase enum](https://github.com/getsentry/sentry/blob/a4e7a0e4af8c09a1d4007a3d7c53b71a2d4db5ff/src/sentry/llm/usecases/__init__.py#L14) 2. Call the `complete_prompt` function with your `usecase`, prompt, content, temperature, and max_tokens) ### Limitations: Because each LLM right now has a different interface, some things that are specific, say to OpenAI like "function calling", where an output is guaranteed to be in a specific JSON format, this solution does not currently support. Advanced usecases beyond simple "prompt" + "text" and a singe output, are not currently supported. It is likely possible to add support for these on a case by case basis. LLM providers are not quite to the point where they have standardized on a consistent API, which makes supporting these somewhat difficult. Third parties have come up with various solutions [LangChain](https://github.com/langchain-ai/langchain), [LiteLLM](https://github.com/BerriAI/litellm), [LocalAI](https://github.com/mudler/LocalAI), [OpenRouter](https://openrouter.ai/). It will probably make sense eventually to adopt one of these tools, or our own advanced tooling, once our use cases outgrow this solution. There is also a possible future where we want different use cases to use different API keys, but for now, one provider only has one set of credentials. ### TODO - [ ] create develop docs for how to add a usecase, or new LLM provider - [x] Follow up PR to replace suggested fix openai calls with new abstraction - [ ] PR in getsentry to set provider / usecase values for SaaS - [ ] PR followup to add telemetry information - [ ] We'll likely want to support streaming responses. --------- Co-authored-by: Michelle Zhang <[email protected]> Co-authored-by: getsantry[bot] <66042841+getsantry[bot]@users.noreply.github.com>
- Loading branch information
1 parent
729b942
commit 32f7e6f
Showing
21 changed files
with
540 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
class InvalidUsecaseError(ValueError): | ||
pass | ||
|
||
|
||
class InvalidProviderError(ValueError): | ||
pass | ||
|
||
|
||
class InvalidModelError(ValueError): | ||
pass | ||
|
||
|
||
class InvalidTemperature(ValueError): | ||
pass |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from sentry.llm.exceptions import InvalidModelError, InvalidProviderError | ||
from sentry.llm.types import ProviderConfig, UseCaseConfig | ||
from sentry.utils.services import Service | ||
|
||
|
||
class LlmModelBase(Service): | ||
def __init__(self, provider_config: ProviderConfig) -> None: | ||
self.provider_config = provider_config | ||
|
||
def complete_prompt( | ||
self, | ||
*, | ||
usecase_config: UseCaseConfig, | ||
prompt: str, | ||
message: str, | ||
temperature: float, | ||
max_output_tokens: int, | ||
) -> str | None: | ||
self.validate_model(usecase_config["options"]["model"]) | ||
|
||
return self._complete_prompt( | ||
usecase_config=usecase_config, | ||
prompt=prompt, | ||
message=message, | ||
temperature=temperature, | ||
max_output_tokens=max_output_tokens, | ||
) | ||
|
||
def _complete_prompt( | ||
self, | ||
*, | ||
usecase_config: UseCaseConfig, | ||
prompt: str, | ||
message: str, | ||
temperature: float, | ||
max_output_tokens: int, | ||
) -> str | None: | ||
raise NotImplementedError | ||
|
||
def validate_model(self, model_name: str) -> None: | ||
if "models" not in self.provider_config: | ||
raise InvalidProviderError(f"No models defined for provider {self.__class__.__name__}") | ||
|
||
if model_name not in self.provider_config["models"]: | ||
raise InvalidModelError(f"Invalid model: {model_name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from openai import OpenAI | ||
|
||
from sentry.llm.providers.base import LlmModelBase | ||
from sentry.llm.types import UseCaseConfig | ||
|
||
|
||
class OpenAIProvider(LlmModelBase): | ||
|
||
provider_name = "openai" | ||
|
||
def _complete_prompt( | ||
self, | ||
*, | ||
usecase_config: UseCaseConfig, | ||
prompt: str, | ||
message: str, | ||
temperature: float, | ||
max_output_tokens: int, | ||
) -> str | None: | ||
model = usecase_config["options"]["model"] | ||
client = get_openai_client(self.provider_config["options"]["api_key"]) | ||
|
||
response = client.chat.completions.create( | ||
model=model, | ||
temperature=temperature | ||
* 2, # open AI temp range is [0.0 - 2.0], so we have to multiply by two | ||
messages=[ | ||
{"role": "system", "content": prompt}, | ||
{ | ||
"role": "user", | ||
"content": message, | ||
}, | ||
], | ||
stream=False, | ||
max_tokens=max_output_tokens, | ||
) | ||
|
||
return response.choices[0].message.content | ||
|
||
|
||
openai_client: OpenAI | None = None | ||
|
||
|
||
class OpenAIClientSingleton: | ||
_instance = None | ||
client: OpenAI | ||
|
||
def __init__(self) -> None: | ||
raise RuntimeError("Call instance() instead") | ||
|
||
@classmethod | ||
def instance(cls, api_key: str) -> "OpenAIClientSingleton": | ||
if cls._instance is None: | ||
cls._instance = cls.__new__(cls) | ||
cls._instance.client = OpenAI(api_key=api_key) | ||
return cls._instance | ||
|
||
|
||
def get_openai_client(api_key: str) -> OpenAI: | ||
return OpenAIClientSingleton.instance(api_key).client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from sentry.llm.providers.base import LlmModelBase | ||
from sentry.llm.types import UseCaseConfig | ||
|
||
|
||
class PreviewLLM(LlmModelBase): | ||
""" | ||
A dummy LLM provider that does not actually send any requests to any LLM API. | ||
""" | ||
|
||
provider_name = "preview" | ||
|
||
def _complete_prompt( | ||
self, | ||
*, | ||
usecase_config: UseCaseConfig, | ||
prompt: str, | ||
message: str, | ||
temperature: float = 0.7, | ||
max_output_tokens: int = 1000, | ||
) -> str | None: | ||
return "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import logging | ||
|
||
import google.auth | ||
import google.auth.transport.requests | ||
import requests | ||
|
||
from sentry.llm.providers.base import LlmModelBase | ||
from sentry.llm.types import UseCaseConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VertexProvider(LlmModelBase): | ||
""" | ||
A provider for Google Vertex AI. Uses default service account credentials. | ||
""" | ||
|
||
provider_name = "vertex" | ||
candidate_count = 1 # we only want one candidate returned at the moment | ||
top_p = 1 # TODO: make this configurable? | ||
|
||
def _complete_prompt( | ||
self, | ||
*, | ||
usecase_config: UseCaseConfig, | ||
prompt: str, | ||
message: str, | ||
temperature: float, | ||
max_output_tokens: int, | ||
) -> str | None: | ||
|
||
payload = { | ||
"instances": [{"content": f"{prompt} {message}"}], | ||
"parameters": { | ||
"candidateCount": self.candidate_count, | ||
"maxOutputTokens": max_output_tokens, | ||
"temperature": temperature, | ||
"topP": self.top_p, | ||
}, | ||
} | ||
|
||
headers = { | ||
"Authorization": f"Bearer {self._get_access_token()}", | ||
"Content-Type": "application/json", | ||
} | ||
vertex_url = self.provider_config["options"]["url"] | ||
vertex_url += usecase_config["options"]["model"] + ":predict" | ||
|
||
response = requests.post(vertex_url, headers=headers, json=payload) | ||
|
||
if response.status_code == 200: | ||
logger.info("Request successful.") | ||
else: | ||
logger.info( | ||
"Request failed with status code and response text.", | ||
extra={"status_code": response.status_code, "response_text": response.text}, | ||
) | ||
|
||
return response.json()["predictions"][0]["content"] | ||
|
||
def _get_access_token(self) -> str: | ||
# https://stackoverflow.com/questions/53472429/how-to-get-a-gcp-bearer-token-programmatically-with-python | ||
|
||
creds, _ = google.auth.default() | ||
creds.refresh(google.auth.transport.requests.Request()) | ||
return creds.token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from typing import TypedDict | ||
|
||
|
||
class ProviderConfig(TypedDict): | ||
options: dict[str, str] | ||
models: list[str] | ||
|
||
|
||
class UseCaseConfig(TypedDict): | ||
provider: str | ||
options: dict[str, str] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from enum import Enum | ||
|
||
from sentry import options | ||
from sentry.llm.exceptions import InvalidProviderError, InvalidTemperature, InvalidUsecaseError | ||
from sentry.llm.providers.base import LlmModelBase | ||
from sentry.llm.providers.openai import OpenAIProvider | ||
from sentry.llm.providers.preview import PreviewLLM | ||
from sentry.llm.providers.vertex import VertexProvider | ||
from sentry.llm.types import ProviderConfig, UseCaseConfig | ||
|
||
SENTRY_LLM_SERVICE_ALIASES = { | ||
"vertex": VertexProvider, | ||
"openai": OpenAIProvider, | ||
"preview": PreviewLLM, | ||
} | ||
|
||
|
||
class LLMUseCase(Enum): | ||
EXAMPLE = "example" # used in tests / examples | ||
SUGGESTED_FIX = "suggestedfix" # OG version of suggested fix | ||
|
||
|
||
llm_provider_backends: dict[str, LlmModelBase] = {} | ||
|
||
|
||
def get_llm_provider_backend(usecase: LLMUseCase) -> LlmModelBase: | ||
usecase_config = get_usecase_config(usecase.value) | ||
global llm_provider_backends | ||
|
||
if usecase_config["provider"] in llm_provider_backends: | ||
return llm_provider_backends[usecase_config["provider"]] | ||
|
||
if usecase_config["provider"] not in SENTRY_LLM_SERVICE_ALIASES: | ||
raise InvalidProviderError(f"LLM provider {usecase_config['provider']} not found") | ||
|
||
provider = SENTRY_LLM_SERVICE_ALIASES[usecase_config["provider"]] | ||
|
||
provider_config = get_provider_config(usecase_config["provider"]) | ||
|
||
llm_provider_backends[usecase_config["provider"]] = provider( | ||
provider_config, | ||
) | ||
|
||
return llm_provider_backends[usecase_config["provider"]] | ||
|
||
|
||
def complete_prompt( | ||
*, | ||
usecase: LLMUseCase, | ||
prompt: str, | ||
message: str, | ||
temperature: float = 0.5, | ||
max_output_tokens: int = 1000, | ||
) -> str | None: | ||
""" | ||
Complete a prompt with a message using the specified usecase. | ||
Default temperature and max_output_tokens set to a hopefully | ||
reasonable value, but please consider what makes sense for | ||
your specific use case. | ||
Note that temperature should be between 0 and 1, and we will | ||
normalize to any providers who have a different range | ||
""" | ||
_validate_temperature(temperature) | ||
|
||
usecase_config = get_usecase_config(usecase.value) | ||
|
||
backend = get_llm_provider_backend(usecase) | ||
return backend.complete_prompt( | ||
usecase_config=usecase_config, | ||
prompt=prompt, | ||
message=message, | ||
temperature=temperature, | ||
max_output_tokens=max_output_tokens, | ||
) | ||
|
||
|
||
def get_usecase_config(usecase: str) -> UseCaseConfig: | ||
usecase_options_all = options.get("llm.usecases.options") | ||
if not usecase_options_all: | ||
raise InvalidUsecaseError( | ||
"LLM usecase options not found. please check llm.usecases.options" | ||
) | ||
|
||
if usecase not in usecase_options_all: | ||
raise InvalidUsecaseError( | ||
f"LLM usecase options not found for {usecase}. please check llm.usecases.options" | ||
) | ||
|
||
return usecase_options_all[usecase] | ||
|
||
|
||
def get_provider_config(provider: str) -> ProviderConfig: | ||
llm_provider_options_all = options.get("llm.provider.options") | ||
if not llm_provider_options_all: | ||
raise InvalidProviderError("LLM provider option value not found") | ||
if provider not in llm_provider_options_all: | ||
raise InvalidProviderError(f"LLM provider {provider} not found") | ||
return llm_provider_options_all[provider] | ||
|
||
|
||
def _validate_temperature(temperature: float) -> None: | ||
if not (0 <= temperature <= 1): | ||
raise InvalidTemperature("Temperature must be between 0 and 1") |
Oops, something went wrong.