Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support patching request url, headers and body #756

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,15 @@ clients:
# type: reranker
# max_input_tokens: 2048
# patch: # Patch api request
# chat_completions_body:
# chat_completions: # Api types, one of chat_completions, embeddings, and rerank
# <regex>: # The regex to match model names, e.g. '.*' 'gpt-4o' 'gpt-4o|gpt-4-.*'
# <json> # The JSON to be merged with the chat completions request body.
# url: '' # Patch request url
# body: # Patch request body
# <json>
# headers: # Patch request headers
# <key>: <value>
# extra:
# proxy: socks5://127.0.0.1:1080 # Set https/socks5 proxy. ENV: HTTPS_PROXY/https_proxy/ALL_PROXY/all_proxy
# proxy: socks5://127.0.0.1:1080 # Set https/socks5 proxy. ENV: HTTPS_PROXY/ALL_PROXY
# connect_timeout: 10 # Set timeout in seconds for connect to api

# See https://platform.openai.com/docs/quickstart
Expand All @@ -123,17 +127,18 @@ clients:
- type: gemini
api_key: xxx # ENV: {client}_API_KEY
patch:
chat_completions_body:
'.*':
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
chat_completions:
'.*':
body:
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE

# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
Expand Down Expand Up @@ -189,17 +194,18 @@ clients:
# see https://cloud.google.com/docs/authentication/external/set-up-adc
adc_file: <path-to/gcloud/application_default_credentials.json>
patch:
chat_completions_body:
chat_completions:
'gemini-.*':
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_ONLY_HIGH
body:
safetySettings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_ONLY_HIGH
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_ONLY_HIGH

# See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude
- type: vertexai-claude
Expand Down
36 changes: 13 additions & 23 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::openai::*;
use super::*;

use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -12,7 +11,7 @@ pub struct AzureOpenAIConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<ModelPatch>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}

Expand All @@ -32,51 +31,42 @@ impl AzureOpenAIClient {
),
];

fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result<RequestData> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;

let mut body = openai_build_chat_completions_body(data, &self.model);
self.patch_chat_completions_body(&mut body);

let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
&api_base,
self.model.name()
);

debug!("AzureOpenAI Chat Completions Request: {url} {body}");
let body = openai_build_chat_completions_body(data, &self.model);

let mut request_data = RequestData::new(url, body);

let builder = client.post(url).header("api-key", api_key).json(&body);
request_data.header("api-key", api_key);

Ok(builder)
Ok(request_data)
}

fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<RequestBuilder> {
fn prepare_embeddings(&self, data: EmbeddingsData) -> Result<RequestData> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;

let body = openai_build_embeddings_body(data, &self.model);

let url = format!(
"{}/openai/deployments/{}/embeddings?api-version=2024-02-01",
&api_base,
self.model.name()
);

debug!("AzureOpenAI Embeddings Request: {url} {body}");
let body = openai_build_embeddings_body(data, &self.model);

let mut request_data = RequestData::new(url, body);

let builder = client.post(url).header("api-key", api_key).json(&body);
request_data.header("api-key", api_key);

Ok(builder)
Ok(request_data)
}
}

Expand Down
84 changes: 58 additions & 26 deletions src/client/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@ use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};

use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
use chrono::{DateTime, Utc};
use futures_util::StreamExt;
use indexmap::IndexMap;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client as ReqwestClient, Method, RequestBuilder,
};
use reqwest::{Client as ReqwestClient, Method, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::str::FromStr;

#[derive(Debug, Clone, Deserialize)]
pub struct BedrockConfig {
Expand All @@ -25,7 +22,7 @@ pub struct BedrockConfig {
pub region: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<ModelPatch>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}

Expand Down Expand Up @@ -61,16 +58,22 @@ impl BedrockClient {
let host = format!("bedrock-runtime.{region}.amazonaws.com");

let model_name = &self.model.name();

let uri = if data.stream {
format!("/model/{model_name}/converse-stream")
} else {
format!("/model/{model_name}/converse")
};

let headers = IndexMap::new();
let body = build_chat_completions_body(data, &self.model)?;

let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_chat_completions_body(&mut body);
let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data, ApiType::ChatCompletions);
let RequestData {
url: _,
headers,
body,
} = request_data;

let builder = aws_fetch(
client,
Expand Down Expand Up @@ -105,8 +108,6 @@ impl BedrockClient {

let uri = format!("/model/{}/invoke", self.model.name());

let headers = IndexMap::new();

let input_type = match data.query {
true => "search_query",
false => "search_document",
Expand All @@ -117,6 +118,14 @@ impl BedrockClient {
"input_type": input_type,
});

let mut request_data = RequestData::new("", body);
self.patch_request_data(&mut request_data, ApiType::Embeddings);
let RequestData {
url: _,
headers,
body,
} = request_data;

let builder = aws_fetch(
client,
&AwsCredentials {
Expand All @@ -139,12 +148,38 @@ impl BedrockClient {
}
}

impl_client_trait!(
BedrockClient,
chat_completions,
chat_completions_streaming,
embeddings
);
#[async_trait]
impl Client for BedrockClient {
client_common_fns!();

async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder).await
}

async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}

async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let builder = self.embeddings_builder(client, data)?;
embeddings(builder).await
}
}

async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
Expand Down Expand Up @@ -550,17 +585,14 @@ fn aws_fetch(

headers.insert("authorization".into(), authorization_header);

let mut req_headers = HeaderMap::new();
for (k, v) in &headers {
req_headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
}
debug!("Request {endpoint} {body}");

debug!("Bedrock Request: {endpoint} {body}");
let mut request_builder = client.request(method, endpoint).body(body);

for (key, value) in &headers {
request_builder = request_builder.header(key, value);
}

let request_builder = client
.request(method, endpoint)
.headers(req_headers)
.body(body);
Ok(request_builder)
}

Expand Down
24 changes: 8 additions & 16 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;

use anyhow::{bail, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

Expand All @@ -13,7 +13,7 @@ pub struct ClaudeConfig {
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<ModelPatch>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}

Expand All @@ -23,27 +23,19 @@ impl ClaudeClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result<RequestData> {
let api_key = self.get_api_key().ok();

let mut body = claude_build_chat_completions_body(data, &self.model)?;
self.patch_chat_completions_body(&mut body);
let body = claude_build_chat_completions_body(data, &self.model)?;

let url = API_BASE;
let mut request_data = RequestData::new(API_BASE, body);

debug!("Claude Request: {url} {body}");

let mut builder = client.post(url).json(&body);
builder = builder.header("anthropic-version", "2023-06-01");
request_data.header("anthropic-version", "2023-06-01");
if let Some(api_key) = api_key {
builder = builder.header("x-api-key", api_key)
request_data.header("x-api-key", api_key)
}

Ok(builder)
Ok(request_data)
}
}

Expand Down
Loading