Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integrations): Add ollama embeddings support #278

Merged
merged 2 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ aws-sdk-bedrockruntime = { version = "1.37", features = [
], optional = true }
secrecy = { version = "0.8.0", optional = true }
reqwest = { version = "0.12.5", optional = true, default-features = false }
ollama-rs = { version = "0.2.0", optional = true }
ollama-rs = { version = "0.2.1", optional = true }
deadpool = { version = "0.12", optional = true, features = [
"managed",
"rt_tokio_1",
Expand Down
33 changes: 33 additions & 0 deletions swiftide-integrations/src/ollama/embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use anyhow::{Context as _, Result};
use async_trait::async_trait;

use ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest;
use swiftide_core::{EmbeddingModel, Embeddings};

use super::Ollama;

#[async_trait]
impl EmbeddingModel for Ollama {
async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
let model = self
.default_options
.embed_model
.as_ref()
.context("Model not set")?;

let request = GenerateEmbeddingsRequest::new(model.to_string(), input.into());
tracing::debug!(
messages = serde_json::to_string_pretty(&request)?,
"[Embed] Request to ollama"
);
let response = self
.client
.generate_embeddings(request)
.await
.context("Request to Ollama Failed")?;

tracing::debug!("[Embed] Response ollama");

Ok(response.embeddings)
}
}
114 changes: 104 additions & 10 deletions swiftide-integrations/src/ollama/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
//! This module provides integration with `Ollama`'s API, enabling the use of language models within the Swiftide project.
//! It includes the `Ollama` struct for managing API clients and default options for prompt models.
//! This module provides integration with `Ollama`'s API, enabling the use of language models and embeddings within the Swiftide project.
//! It includes the `Ollama` struct for managing API clients and default options for embedding and prompt models.
//! The module is conditionally compiled based on the "ollama" feature flag.

use derive_builder::Builder;
use std::sync::Arc;

mod embed;
mod simple_prompt;

/// The `Ollama` struct encapsulates a `Ollama` client that implements [`swiftide::traits::SimplePrompt`]
/// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt models.
/// It uses the `Builder` pattern for flexible and customizable instantiation.
///
/// There is also a builder available.
///
/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that a model
/// always needs to be set, either with [`Ollama::with_default_prompt_model`] or via the builder.
/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt model or embedding model
/// always need to be set, either with [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the builder.
/// You can find available models in the Ollama documentation.
///
/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means
Expand All @@ -23,7 +23,7 @@ pub struct Ollama {
/// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting.
#[builder(default = "default_client()", setter(custom))]
client: Arc<ollama_rs::Ollama>,
/// Default options for prompt models.
/// Default options for the embedding and prompt models.
#[builder(default)]
default_options: Options,
}
Expand All @@ -38,10 +38,14 @@ impl Default for Ollama {
}

/// The `Options` struct holds configuration options for the `Ollama` client.
/// It includes optional fields for specifying the prompt model.
/// It includes optional fields for specifying the embedding and prompt models.
#[derive(Debug, Default, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct Options {
/// The default embedding model to use, if specified.
#[builder(default)]
pub embed_model: Option<String>,

/// The default prompt model to use, if specified.
#[builder(default)]
pub prompt_model: Option<String>,
Expand All @@ -64,6 +68,16 @@ impl Ollama {
pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
self.default_options = Options {
prompt_model: Some(model.into()),
embed_model: self.default_options.embed_model.clone(),
};
self
}

/// Sets a default embedding model to use when embedding
pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
self.default_options = Options {
prompt_model: self.default_options.prompt_model.clone(),
embed_model: Some(model.into()),
};
self
}
Expand All @@ -82,6 +96,25 @@ impl OllamaBuilder {
self
}

/// Sets the default embedding model for the `Ollama` instance.
///
/// # Parameters
/// - `model`: The embedding model to set.
///
/// # Returns
/// A mutable reference to the `OllamaBuilder`.
pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
if let Some(options) = self.default_options.as_mut() {
options.embed_model = Some(model.into());
} else {
self.default_options = Some(Options {
embed_model: Some(model.into()),
..Default::default()
});
}
self
}

/// Sets the default prompt model for the `Ollama` instance.
///
/// # Parameters
Expand All @@ -95,6 +128,7 @@ impl OllamaBuilder {
} else {
self.default_options = Some(Options {
prompt_model: Some(model.into()),
..Default::default()
});
}
self
Expand Down Expand Up @@ -122,7 +156,36 @@ mod test {
}

#[test]
fn test_building_via_default() {
fn test_default_embed_model() {
let ollama = Ollama::builder()
.default_embed_model("mxbai-embed-large")
.build()
.unwrap();
assert_eq!(
ollama.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}

#[test]
fn test_default_models() {
let ollama = Ollama::builder()
.default_embed_model("mxbai-embed-large")
.default_prompt_model("llama3.1")
.build()
.unwrap();
assert_eq!(
ollama.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
assert_eq!(
ollama.default_options.prompt_model,
Some("llama3.1".to_string())
);
}

#[test]
fn test_building_via_default_prompt_model() {
let mut client = Ollama::default();

assert!(client.default_options.prompt_model.is_none());
Expand All @@ -133,4 +196,35 @@ mod test {
Some("llama3.1".to_string())
);
}

#[test]
fn test_building_via_default_embed_model() {
let mut client = Ollama::default();

assert!(client.default_options.embed_model.is_none());

client.with_default_embed_model("mxbai-embed-large");
assert_eq!(
client.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}

#[test]
fn test_building_via_default_models() {
let mut client = Ollama::default();

assert!(client.default_options.embed_model.is_none());

client.with_default_prompt_model("llama3.1");
client.with_default_embed_model("mxbai-embed-large");
assert_eq!(
client.default_options.prompt_model,
Some("llama3.1".to_string())
);
assert_eq!(
client.default_options.embed_model,
Some("mxbai-embed-large".to_string())
);
}
}
Loading