Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement groq completions #2

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/c00-readme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const MODEL_OPENAI: &str = "gpt-3.5-turbo";
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_GROQ: &str = "llama3-8b-8192";
const MODEL_OLLAMA: &str = "mixtral";

// NOTE: Those are the default env keys for each AI Provider type.
Expand All @@ -15,6 +16,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
(MODEL_COHERE, "COHERE_API_KEY"),
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_GROQ, "GROQ_API_KEY"),
(MODEL_OLLAMA, ""),
];

Expand All @@ -23,6 +25,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// - starts_with "claude" -> Anthropic
// - starts_with "command" -> Cohere
// - starts_with "gemini" -> Gemini
// - model in Groq models -> Groq
// - For anything else -> Ollama
//
// Refined mapping rules will be added later and extended as provider support grows.
Expand All @@ -44,6 +47,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
// Skip if does not have the environment name set
if !env_name.is_empty() && std::env::var(env_name).is_err() {
println!("Skipping model: {model} (env var not set: {env_name})");
continue;
}

Expand Down
5 changes: 5 additions & 0 deletions src/adapter/adapter_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ use derive_more::Display;
use reqwest::RequestBuilder;
use serde_json::Value;

use super::groq::MODELS as GROQ_MODELS;

#[derive(Debug, Clone, Copy, Display, Eq, PartialEq, Hash)]
pub enum AdapterKind {
OpenAI,
Ollama,
Anthropic,
Cohere,
Gemini,
Groq,
// Note: Variants will probalby be suffixed
// AnthropicBerock,
}
Expand All @@ -28,6 +31,8 @@ impl AdapterKind {
Ok(AdapterKind::Cohere)
} else if model.starts_with("gemini") {
Ok(AdapterKind::Gemini)
} else if GROQ_MODELS.contains(&model) {
return Ok(AdapterKind::Groq);
}
// for now, fallback on Ollama
else {
Expand Down
59 changes: 59 additions & 0 deletions src/adapter/adapters/groq/adapter_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::sync::OnceLock;

use reqwest::RequestBuilder;

use crate::adapter::openai::OpenAIAdapter;
use crate::adapter::support::get_api_key_resolver;
use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse};
use crate::webc::WebResponse;
use crate::{ConfigSet, Result};

pub struct GroqAdapter;

const BASE_URL: &str = "https://api.groq.com/openai/v1/";
pub(crate) const MODELS: &[&str] = &[
"llama3-8b-8192",
"llama3-70b-8192",
"mixtral-8x7b-32768",
"gemma-7b-it",
"whisper-large-v3",
];

// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
impl Adapter for GroqAdapter {
async fn list_models(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}

fn default_adapter_config(_kind: AdapterKind) -> &'static AdapterConfig {
static INSTANCE: OnceLock<AdapterConfig> = OnceLock::new();
INSTANCE.get_or_init(|| AdapterConfig::default().with_auth_env_name("GROQ_API_KEY"))
}

fn get_service_url(kind: AdapterKind, service_type: ServiceType) -> String {
OpenAIAdapter::util_get_service_url(kind, service_type, BASE_URL)
}

fn to_web_request_data(
kind: AdapterKind,
config_set: &ConfigSet<'_>,
service_type: ServiceType,
model: &str,
chat_req: ChatRequest,
_chat_req_options: Option<&ChatRequestOptions>,
) -> Result<WebRequestData> {
let api_key = get_api_key_resolver(kind, config_set)?;
let url = Self::get_service_url(kind, service_type);

OpenAIAdapter::util_to_web_request_data(kind, url, model, chat_req, service_type, &api_key, false)
}

fn to_chat_response(kind: AdapterKind, web_response: WebResponse) -> Result<ChatResponse> {
OpenAIAdapter::to_chat_response(kind, web_response)
}

fn to_chat_stream(kind: AdapterKind, reqwest_builder: RequestBuilder) -> Result<ChatStreamResponse> {
OpenAIAdapter::to_chat_stream(kind, reqwest_builder)
}
}
9 changes: 9 additions & 0 deletions src/adapter/adapters/groq/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//! API DOC: https://console.groq.com/docs/api-reference#chat

// region: --- Modules

mod adapter_impl;

pub use adapter_impl::*;

// endregion: --- Modules
1 change: 1 addition & 0 deletions src/adapter/adapters/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub(crate) mod anthropic;
pub(crate) mod cohere;
pub(crate) mod gemini;
pub(crate) mod groq;
pub(crate) mod ollama;
pub(crate) mod openai;
10 changes: 10 additions & 0 deletions src/adapter/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::webc::WebResponse;
use crate::{ConfigSet, Result};
use reqwest::RequestBuilder;

use super::groq::GroqAdapter;

pub struct AdapterDispatcher;

impl Adapter for AdapterDispatcher {
Expand All @@ -19,6 +21,7 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Cohere => CohereAdapter::list_models(kind).await,
AdapterKind::Ollama => OllamaAdapter::list_models(kind).await,
AdapterKind::Gemini => GeminiAdapter::list_models(kind).await,
AdapterKind::Groq => GroqAdapter::list_models(kind).await,
}
}

Expand All @@ -29,6 +32,7 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Cohere => CohereAdapter::default_adapter_config(kind),
AdapterKind::Ollama => OllamaAdapter::default_adapter_config(kind),
AdapterKind::Gemini => GeminiAdapter::default_adapter_config(kind),
AdapterKind::Groq => GroqAdapter::default_adapter_config(kind),
}
}

Expand All @@ -39,6 +43,7 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Cohere => CohereAdapter::get_service_url(kind, service_type),
AdapterKind::Ollama => OllamaAdapter::get_service_url(kind, service_type),
AdapterKind::Gemini => GeminiAdapter::get_service_url(kind, service_type),
AdapterKind::Groq => GroqAdapter::get_service_url(kind, service_type),
}
}

Expand Down Expand Up @@ -66,6 +71,9 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Gemini => {
GeminiAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options)
}
AdapterKind::Groq => {
GroqAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options)
}
}
}

Expand All @@ -76,6 +84,7 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Cohere => CohereAdapter::to_chat_response(kind, web_response),
AdapterKind::Ollama => OllamaAdapter::to_chat_response(kind, web_response),
AdapterKind::Gemini => GeminiAdapter::to_chat_response(kind, web_response),
AdapterKind::Groq => GroqAdapter::to_chat_response(kind, web_response),
}
}

Expand All @@ -86,6 +95,7 @@ impl Adapter for AdapterDispatcher {
AdapterKind::Cohere => CohereAdapter::to_chat_stream(kind, reqwest_builder),
AdapterKind::Ollama => OpenAIAdapter::to_chat_stream(kind, reqwest_builder),
AdapterKind::Gemini => GeminiAdapter::to_chat_stream(kind, reqwest_builder),
AdapterKind::Groq => GroqAdapter::to_chat_stream(kind, reqwest_builder),
}
}
}