From 61aa705ef9ff5d4450a5ec3d04fada57ba47717f Mon Sep 17 00:00:00 2001 From: Rex Magana Date: Wed, 19 Jun 2024 09:54:48 -0600 Subject: [PATCH 1/2] implement groq completions --- examples/c00-readme.rs | 14 +++-- src/adapter/adapter_types.rs | 31 ++++++---- .../adapters/anthropic/adapter_impl.rs | 2 +- src/adapter/adapters/cohere/adapter_impl.rs | 2 +- src/adapter/adapters/gemini/adapter_impl.rs | 2 +- src/adapter/adapters/groq/adapter_impl.rs | 59 +++++++++++++++++++ src/adapter/adapters/groq/mod.rs | 9 +++ src/adapter/adapters/mod.rs | 1 + src/adapter/adapters/openai/adapter_impl.rs | 2 +- src/adapter/dispatcher.rs | 10 ++++ 10 files changed, 111 insertions(+), 21 deletions(-) create mode 100644 src/adapter/adapters/groq/adapter_impl.rs create mode 100644 src/adapter/adapters/groq/mod.rs diff --git a/examples/c00-readme.rs b/examples/c00-readme.rs index 97fb421..ac678fc 100644 --- a/examples/c00-readme.rs +++ b/examples/c00-readme.rs @@ -6,6 +6,7 @@ const MODEL_OPENAI: &str = "gpt-3.5-turbo"; const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307"; const MODEL_COHERE: &str = "command-light"; const MODEL_GEMINI: &str = "gemini-1.5-flash-latest"; +const MODEL_GROQ: &str = "llama3-8b-8192"; const MODEL_OLLAMA: &str = "mixtral"; // NOTE: Those are the default env keys for each AI Provider type. @@ -15,15 +16,17 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ (MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"), (MODEL_COHERE, "COHERE_API_KEY"), (MODEL_GEMINI, "GEMINI_API_KEY"), + (MODEL_GROQ, "GROQ_API_KEY"), (MODEL_OLLAMA, ""), ]; // NOTE: Model to AdapterKind (AI Provider) type mapping rule -// - starts_with "gpt" -> OpenAI -// - starts_with "claude" -> Anthropic -// - starts_with "command" -> Cohere -// - starts_with "gemini" -> Gemini -// - For anything else -> Ollama +// - If the model is in the OpenAI models, then the AdapterKind is OpenAI +// - If the model is in the Anthropic models, then the AdapterKind is Anthropic +// - If the model is in the Cohere models, then the AdapterKind is Cohere +// - If the model is in the Gemini models, then the AdapterKind is Gemini +// - If the model is in the Groq models, then the AdapterKind is Groq +// - Otherwise, the AdapterKind is Ollama // // Refined mapping rules will be added later and extended as provider support grows. @@ -44,6 +47,7 @@ async fn main() -> Result<(), Box> { for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST { // Skip if does not have the environment name set if !env_name.is_empty() && std::env::var(env_name).is_err() { + println!("Skipping model: {model} (env var not set: {env_name})"); continue; } diff --git a/src/adapter/adapter_types.rs b/src/adapter/adapter_types.rs index 03c280a..7c6a363 100644 --- a/src/adapter/adapter_types.rs +++ b/src/adapter/adapter_types.rs @@ -6,6 +6,12 @@ use derive_more::Display; use reqwest::RequestBuilder; use serde_json::Value; +use super::anthropic::MODELS as ANTHROPIC_MODELS; +use super::cohere::MODELS as COHERE_MODELS; +use super::gemini::MODELS as GEMINI_MODELS; +use super::groq::MODELS as GROQ_MODELS; +use super::openai::MODELS as OPENAI_MODELS; + #[derive(Debug, Clone, Copy, Display, Eq, PartialEq, Hash)] pub enum AdapterKind { OpenAI, @@ -13,6 +19,7 @@ pub enum AdapterKind { Anthropic, Cohere, Gemini, + Groq, // Note: Variants will probalby be suffixed // AnthropicBerock, } @@ -20,18 +27,18 @@ pub enum AdapterKind { impl AdapterKind { /// Very simplistic mapper for now. pub fn from_model(model: &str) -> Result { - if model.starts_with("gpt") { - Ok(AdapterKind::OpenAI) - } else if model.starts_with("claude") { - Ok(AdapterKind::Anthropic) - } else if model.starts_with("command") { - Ok(AdapterKind::Cohere) - } else if model.starts_with("gemini") { - Ok(AdapterKind::Gemini) - } - // for now, fallback on Ollama - else { - Ok(Self::Ollama) + if OPENAI_MODELS.contains(&model) { + return Ok(AdapterKind::OpenAI); + } else if ANTHROPIC_MODELS.contains(&model) { + return Ok(AdapterKind::Anthropic); + } else if COHERE_MODELS.contains(&model) { + return Ok(AdapterKind::Cohere); + } else if GEMINI_MODELS.contains(&model) { + return Ok(AdapterKind::Gemini); + } else if GROQ_MODELS.contains(&model) { + return Ok(AdapterKind::Groq); + } else { + return Ok(AdapterKind::Ollama); } } } diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index f5b8353..a21e268 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -14,7 +14,7 @@ pub struct AnthropicAdapter; const MAX_TOKENS: u32 = 1024; const ANTRHOPIC_VERSION: &str = "2023-06-01"; -const MODELS: &[&str] = &["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; +pub(crate) const MODELS: &[&str] = &["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; const BASE_URL: &str = "https://api.anthropic.com/v1/"; diff --git a/src/adapter/adapters/cohere/adapter_impl.rs b/src/adapter/adapters/cohere/adapter_impl.rs index b1bfb79..a5f83be 100644 --- a/src/adapter/adapters/cohere/adapter_impl.rs +++ b/src/adapter/adapters/cohere/adapter_impl.rs @@ -13,7 +13,7 @@ pub struct CohereAdapter; const MAX_TOKENS: u32 = 1024; const BASE_URL: &str = "https://api.cohere.com/v1/"; -const MODELS: &[&str] = &[ +pub(crate) const MODELS: &[&str] = &[ "command-r-plus", "command-r", "command", diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index 81ce0ca..28eeb62 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -12,7 +12,7 @@ use std::sync::OnceLock; pub struct GeminiAdapter; const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/"; -const MODELS: &[&str] = &[ +pub(crate) const MODELS: &[&str] = &[ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro", diff --git a/src/adapter/adapters/groq/adapter_impl.rs b/src/adapter/adapters/groq/adapter_impl.rs new file mode 100644 index 0000000..7bfc980 --- /dev/null +++ b/src/adapter/adapters/groq/adapter_impl.rs @@ -0,0 +1,59 @@ +use std::sync::OnceLock; + +use reqwest::RequestBuilder; + +use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::support::get_api_key_resolver; +use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse}; +use crate::webc::WebResponse; +use crate::{ConfigSet, Result}; + +pub struct GroqAdapter; + +const BASE_URL: &str = "https://api.groq.com/openai/v1/"; +pub(crate) const MODELS: &[&str] = &[ + "llama3-8b-8192", + "llama3-70b-8192", + "mixtral-8x7b-32768", + "gemma-7b-it", + "whisper-large-v3", +]; + +// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API. +impl Adapter for GroqAdapter { + async fn list_models(_kind: AdapterKind) -> Result> { + Ok(MODELS.iter().map(|s| s.to_string()).collect()) + } + + fn default_adapter_config(_kind: AdapterKind) -> &'static AdapterConfig { + static INSTANCE: OnceLock = OnceLock::new(); + INSTANCE.get_or_init(|| AdapterConfig::default().with_auth_env_name("GROQ_API_KEY")) + } + + fn get_service_url(kind: AdapterKind, service_type: ServiceType) -> String { + OpenAIAdapter::util_get_service_url(kind, service_type, BASE_URL) + } + + fn to_web_request_data( + kind: AdapterKind, + config_set: &ConfigSet<'_>, + service_type: ServiceType, + model: &str, + chat_req: ChatRequest, + _chat_req_options: Option<&ChatRequestOptions>, + ) -> Result { + let api_key = get_api_key_resolver(kind, config_set)?; + let url = Self::get_service_url(kind, service_type); + + OpenAIAdapter::util_to_web_request_data(kind, url, model, chat_req, service_type, &api_key, false) + } + + fn to_chat_response(kind: AdapterKind, web_response: WebResponse) -> Result { + OpenAIAdapter::to_chat_response(kind, web_response) + } + + fn to_chat_stream(kind: AdapterKind, reqwest_builder: RequestBuilder) -> Result { + OpenAIAdapter::to_chat_stream(kind, reqwest_builder) + } +} diff --git a/src/adapter/adapters/groq/mod.rs b/src/adapter/adapters/groq/mod.rs new file mode 100644 index 0000000..ac3f32a --- /dev/null +++ b/src/adapter/adapters/groq/mod.rs @@ -0,0 +1,9 @@ +//! API DOC: https://console.groq.com/docs/api-reference#chat + +// region: --- Modules + +mod adapter_impl; + +pub use adapter_impl::*; + +// endregion: --- Modules diff --git a/src/adapter/adapters/mod.rs b/src/adapter/adapters/mod.rs index a0e32cb..033203c 100644 --- a/src/adapter/adapters/mod.rs +++ b/src/adapter/adapters/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod anthropic; pub(crate) mod cohere; pub(crate) mod gemini; +pub(crate) mod groq; pub(crate) mod ollama; pub(crate) mod openai; diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 1cefffe..78331fb 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -13,7 +13,7 @@ use std::sync::OnceLock; pub struct OpenAIAdapter; const BASE_URL: &str = "https://api.openai.com/v1/"; -const MODELS: &[&str] = &["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]; +pub(crate) const MODELS: &[&str] = &["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]; impl Adapter for OpenAIAdapter { /// Note: For now returns the common ones (see above) diff --git a/src/adapter/dispatcher.rs b/src/adapter/dispatcher.rs index afdf0aa..cc26d69 100644 --- a/src/adapter/dispatcher.rs +++ b/src/adapter/dispatcher.rs @@ -9,6 +9,8 @@ use crate::webc::WebResponse; use crate::{ConfigSet, Result}; use reqwest::RequestBuilder; +use super::groq::GroqAdapter; + pub struct AdapterDispatcher; impl Adapter for AdapterDispatcher { @@ -19,6 +21,7 @@ impl Adapter for AdapterDispatcher { AdapterKind::Cohere => CohereAdapter::list_models(kind).await, AdapterKind::Ollama => OllamaAdapter::list_models(kind).await, AdapterKind::Gemini => GeminiAdapter::list_models(kind).await, + AdapterKind::Groq => GroqAdapter::list_models(kind).await, } } @@ -29,6 +32,7 @@ impl Adapter for AdapterDispatcher { AdapterKind::Cohere => CohereAdapter::default_adapter_config(kind), AdapterKind::Ollama => OllamaAdapter::default_adapter_config(kind), AdapterKind::Gemini => GeminiAdapter::default_adapter_config(kind), + AdapterKind::Groq => GroqAdapter::default_adapter_config(kind), } } @@ -39,6 +43,7 @@ impl Adapter for AdapterDispatcher { AdapterKind::Cohere => CohereAdapter::get_service_url(kind, service_type), AdapterKind::Ollama => OllamaAdapter::get_service_url(kind, service_type), AdapterKind::Gemini => GeminiAdapter::get_service_url(kind, service_type), + AdapterKind::Groq => GroqAdapter::get_service_url(kind, service_type), } } @@ -66,6 +71,9 @@ impl Adapter for AdapterDispatcher { AdapterKind::Gemini => { GeminiAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) } + AdapterKind::Groq => { + GroqAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) + } } } @@ -76,6 +84,7 @@ impl Adapter for AdapterDispatcher { AdapterKind::Cohere => CohereAdapter::to_chat_response(kind, web_response), AdapterKind::Ollama => OllamaAdapter::to_chat_response(kind, web_response), AdapterKind::Gemini => GeminiAdapter::to_chat_response(kind, web_response), + AdapterKind::Groq => GroqAdapter::to_chat_response(kind, web_response), } } @@ -86,6 +95,7 @@ impl Adapter for AdapterDispatcher { AdapterKind::Cohere => CohereAdapter::to_chat_stream(kind, reqwest_builder), AdapterKind::Ollama => OpenAIAdapter::to_chat_stream(kind, reqwest_builder), AdapterKind::Gemini => GeminiAdapter::to_chat_stream(kind, reqwest_builder), + AdapterKind::Groq => GroqAdapter::to_chat_stream(kind, reqwest_builder), } } } From c519f52ba55e97e9090991b5f44a97d66adcca15 Mon Sep 17 00:00:00 2001 From: Rex Magana Date: Wed, 19 Jun 2024 21:01:25 -0700 Subject: [PATCH 2/2] review comments --- examples/c00-readme.rs | 12 ++++----- src/adapter/adapter_types.rs | 26 +++++++++---------- .../adapters/anthropic/adapter_impl.rs | 2 +- src/adapter/adapters/cohere/adapter_impl.rs | 2 +- src/adapter/adapters/gemini/adapter_impl.rs | 2 +- src/adapter/adapters/openai/adapter_impl.rs | 2 +- 6 files changed, 22 insertions(+), 24 deletions(-) diff --git a/examples/c00-readme.rs b/examples/c00-readme.rs index ac678fc..180588a 100644 --- a/examples/c00-readme.rs +++ b/examples/c00-readme.rs @@ -21,12 +21,12 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ ]; // NOTE: Model to AdapterKind (AI Provider) type mapping rule -// - If the model is in the OpenAI models, then the AdapterKind is OpenAI -// - If the model is in the Anthropic models, then the AdapterKind is Anthropic -// - If the model is in the Cohere models, then the AdapterKind is Cohere -// - If the model is in the Gemini models, then the AdapterKind is Gemini -// - If the model is in the Groq models, then the AdapterKind is Groq -// - Otherwise, the AdapterKind is Ollama +// - starts_with "gpt" -> OpenAI +// - starts_with "claude" -> Anthropic +// - starts_with "command" -> Cohere +// - starts_with "gemini" -> Gemini +// - model in Groq models -> Groq +// - For anything else -> Ollama // // Refined mapping rules will be added later and extended as provider support grows. diff --git a/src/adapter/adapter_types.rs b/src/adapter/adapter_types.rs index 7c6a363..69e1e46 100644 --- a/src/adapter/adapter_types.rs +++ b/src/adapter/adapter_types.rs @@ -6,11 +6,7 @@ use derive_more::Display; use reqwest::RequestBuilder; use serde_json::Value; -use super::anthropic::MODELS as ANTHROPIC_MODELS; -use super::cohere::MODELS as COHERE_MODELS; -use super::gemini::MODELS as GEMINI_MODELS; use super::groq::MODELS as GROQ_MODELS; -use super::openai::MODELS as OPENAI_MODELS; #[derive(Debug, Clone, Copy, Display, Eq, PartialEq, Hash)] pub enum AdapterKind { @@ -27,18 +23,20 @@ pub enum AdapterKind { impl AdapterKind { /// Very simplistic mapper for now. pub fn from_model(model: &str) -> Result { - if OPENAI_MODELS.contains(&model) { - return Ok(AdapterKind::OpenAI); - } else if ANTHROPIC_MODELS.contains(&model) { - return Ok(AdapterKind::Anthropic); - } else if COHERE_MODELS.contains(&model) { - return Ok(AdapterKind::Cohere); - } else if GEMINI_MODELS.contains(&model) { - return Ok(AdapterKind::Gemini); + if model.starts_with("gpt") { + Ok(AdapterKind::OpenAI) + } else if model.starts_with("claude") { + Ok(AdapterKind::Anthropic) + } else if model.starts_with("command") { + Ok(AdapterKind::Cohere) + } else if model.starts_with("gemini") { + Ok(AdapterKind::Gemini) } else if GROQ_MODELS.contains(&model) { return Ok(AdapterKind::Groq); - } else { - return Ok(AdapterKind::Ollama); + } + // for now, fallback on Ollama + else { + Ok(Self::Ollama) } } } diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index a21e268..f5b8353 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -14,7 +14,7 @@ pub struct AnthropicAdapter; const MAX_TOKENS: u32 = 1024; const ANTRHOPIC_VERSION: &str = "2023-06-01"; -pub(crate) const MODELS: &[&str] = &["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; +const MODELS: &[&str] = &["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; const BASE_URL: &str = "https://api.anthropic.com/v1/"; diff --git a/src/adapter/adapters/cohere/adapter_impl.rs b/src/adapter/adapters/cohere/adapter_impl.rs index a5f83be..b1bfb79 100644 --- a/src/adapter/adapters/cohere/adapter_impl.rs +++ b/src/adapter/adapters/cohere/adapter_impl.rs @@ -13,7 +13,7 @@ pub struct CohereAdapter; const MAX_TOKENS: u32 = 1024; const BASE_URL: &str = "https://api.cohere.com/v1/"; -pub(crate) const MODELS: &[&str] = &[ +const MODELS: &[&str] = &[ "command-r-plus", "command-r", "command", diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index 28eeb62..81ce0ca 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -12,7 +12,7 @@ use std::sync::OnceLock; pub struct GeminiAdapter; const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/"; -pub(crate) const MODELS: &[&str] = &[ +const MODELS: &[&str] = &[ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro", diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 78331fb..1cefffe 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -13,7 +13,7 @@ use std::sync::OnceLock; pub struct OpenAIAdapter; const BASE_URL: &str = "https://api.openai.com/v1/"; -pub(crate) const MODELS: &[&str] = &["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]; +const MODELS: &[&str] = &["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]; impl Adapter for OpenAIAdapter { /// Note: For now returns the common ones (see above)