Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vertex embedding crash and add cohere AWS bedrock re-ranking model support #4146

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/model_server/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
DEFAULT_VERTEX_MODEL = "text-embedding-005"


class EmbeddingModelTextType:
Expand Down
92 changes: 79 additions & 13 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import cast
from typing import Optional

import aioboto3
import httpx
import openai
import vertexai # type: ignore
Expand All @@ -28,11 +29,13 @@
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
Expand Down Expand Up @@ -178,17 +181,24 @@ async def _embed_vertex(
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)

embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]

# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]

# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]

# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)

return [embedding.values for batch in results for embedding in batch]

async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
Expand Down Expand Up @@ -440,7 +450,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)


async def cohere_rerank(
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
Expand All @@ -450,6 +460,45 @@ async def cohere_rerank(
return [result.relevance_score for result in sorted_results]


async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)

# Read the response asynchronously
response_body = json.loads(await response["body"].read())

# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])

return [result["relevance_score"] for result in sorted_results]


async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
Expand Down Expand Up @@ -564,15 +613,32 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank(
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)

elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")

except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(
Expand Down
29 changes: 29 additions & 0 deletions backend/model_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,32 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS

return GPUStatus.NONE


def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.

Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'

Returns:
Tuple of (access_key, secret_key, region)

Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")

parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)

try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
3 changes: 2 additions & 1 deletion backend/requirements/model_server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0
sentry-sdk[fastapi,celery,starlette]==2.14.0
aioboto3==13.4.0
10 changes: 8 additions & 2 deletions backend/shared_configs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
# allow us to specify a custom timeout
API_BASED_EMBEDDING_TIMEOUT = int(os.environ.get("API_BASED_EMBEDDING_TIMEOUT", "600"))

# Local batch size for VertexAI embedding models currently calibrated for item size of 512 tokens
# NOTE: increasing this value may lead to API errors due to token limit exhaustion per call.
VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE = int(
os.environ.get("VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE", "25")
)

# Only used for OpenAI
OPENAI_EMBEDDING_TIMEOUT = int(
os.environ.get("OPENAI_EMBEDDING_TIMEOUT", API_BASED_EMBEDDING_TIMEOUT)
Expand Down Expand Up @@ -200,12 +206,12 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
index_name="danswer_chunk_text_embedding_3_small",
),
SupportedEmbeddingModel(
name="google/text-embedding-004",
name="google/text-embedding-005",
dim=768,
index_name="danswer_chunk_google_text_embedding_004",
),
SupportedEmbeddingModel(
name="google/text-embedding-004",
name="google/text-embedding-005",
dim=768,
index_name="danswer_chunk_text_embedding_004",
),
Expand Down
1 change: 1 addition & 0 deletions backend/shared_configs/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class EmbeddingProvider(str, Enum):
class RerankerProvider(str, Enum):
COHERE = "cohere"
LITELLM = "litellm"
BEDROCK = "bedrock"


class EmbedTextType(str, Enum):
Expand Down
21 changes: 19 additions & 2 deletions web/src/app/admin/embeddings/RerankingFormPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
} from "./interfaces";
import { FiExternalLink } from "react-icons/fi";
import {
AmazonIcon,
CohereIcon,
LiteLLMIcon,
MixedBreadIcon,
Expand Down Expand Up @@ -185,6 +186,11 @@ const RerankingDetailsForm = forwardRef<
card.rerank_provider_type == RerankerProvider.COHERE
) {
setIsApiKeyModalOpen(true);
} else if (
card.rerank_provider_type ==
RerankerProvider.BEDROCK
) {
setIsApiKeyModalOpen(true);
} else if (
card.rerank_provider_type ==
RerankerProvider.LITELLM
Expand Down Expand Up @@ -221,6 +227,9 @@ const RerankingDetailsForm = forwardRef<
) : card.rerank_provider_type ===
RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : card.rerank_provider_type ===
RerankerProvider.BEDROCK ? (
<AmazonIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />
)}
Expand Down Expand Up @@ -380,7 +389,10 @@ const RerankingDetailsForm = forwardRef<
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
: values.rerank_provider_type ===
RerankerProvider.BEDROCK
? "aws_ACCESSKEY_SECRETKEY_REGION"
: "Enter your API key"
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
Expand All @@ -391,7 +403,12 @@ const RerankingDetailsForm = forwardRef<
setFieldValue("api_key", value);
}}
type="password"
label="Cohere API Key"
label={
values.rerank_provider_type ===
RerankerProvider.BEDROCK
? "AWS Credentials in format: aws_ACCESSKEY_SECRETKEY_REGION"
: "Cohere API Key"
}
name="rerank_api_key"
/>
<div className="flex w-full justify-end mt-4">
Expand Down
10 changes: 10 additions & 0 deletions web/src/app/admin/embeddings/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export interface RerankingDetails {
export enum RerankerProvider {
COHERE = "cohere",
LITELLM = "litellm",
BEDROCK = "bedrock",
}

export interface AdvancedSearchConfiguration {
Expand Down Expand Up @@ -92,6 +93,15 @@ export const rerankingModels: RerankingModel[] = [
description: "Powerful multilingual reranking model.",
link: "https://docs.cohere.com/v2/reference/rerank",
},
{
cloud: true,
rerank_provider_type: RerankerProvider.BEDROCK,
modelName: "cohere.rerank-v3-5:0",
displayName: "Cohere Rerank 3.5",
description:
"Powerful multilingual reranking model invoked through AWS Bedrock.",
link: "https://aws.amazon.com/blogs/machine-learning/cohere-rerank-3-5-is-now-available-in-amazon-bedrock-through-rerank-api",
},
];

export const getCurrentModelCopy = (
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/embedding/interfaces.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
embedding_models: [
{
provider_type: EmbeddingProvider.GOOGLE,
model_name: "text-embedding-004",
model_name: "text-embedding-005",
description: "Google's most recent text embedding model.",
pricePerMillion: 0.025,
model_dim: 768,
Expand Down
8 changes: 7 additions & 1 deletion web/src/lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,10 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"gemini-1.5-flash-001": "Gemini 1.5 Flash",
"gemini-1.5-pro-002": "Gemini 1.5 Pro (v2)",
"gemini-1.5-flash-002": "Gemini 1.5 Flash (v2)",
"gemini-2.0-flash-exp": "Gemini 2.0 Flash (Experimental)",
"gemini-2.0-flash-001": "Gemini 2.0 Flash",
"gemini-2.0-flash": "Gemini 2.0 Flash",
"gemini-2.0-pro-exp-02-05": "Gemini 2.0 Pro",
"gemini-2.0-flash-thinking-exp-01-21": "Gemini 2.0 Flash Thinking",

// Mistral Models
"mistral-large-2411": "Mistral Large 24.11",
Expand All @@ -755,6 +758,8 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"anthropic.claude-v2:1": "Claude v2.1",
"anthropic.claude-v2": "Claude v2",
"anthropic.claude-v1": "Claude v1",
"anthropic.claude-3-7-sonnet-20250219-v1:0": "Claude 3.7 Sonnet",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": "Claude 3.7 Sonnet",
"anthropic.claude-3-opus-20240229-v1:0": "Claude 3 Opus",
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
"anthropic.claude-3-5-sonnet-20240620-v1:0": "Claude 3.5 Sonnet",
Expand Down Expand Up @@ -788,6 +793,7 @@ export const defaultModelsByProvider: { [name: string]: string[] } = {
"anthropic.claude-3-opus-20240229-v1:0",
"mistral.mistral-large-2402-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
],
anthropic: ["claude-3-opus-20240229", "claude-3-5-sonnet-20241022"],
};