From b2d31e555cb8da525513490e7603df1f6b2bfa5b Mon Sep 17 00:00:00 2001 From: Tinco Andringa Date: Fri, 2 Aug 2024 19:33:54 +0200 Subject: [PATCH] feat(integrations): add ollama support (#214) --- Cargo.lock | 15 ++ examples/Cargo.toml | 13 +- examples/index_ollama.rs | 55 +++++++ swiftide-integrations/Cargo.toml | 39 ++--- swiftide-integrations/src/lib.rs | 2 + swiftide-integrations/src/ollama/mod.rs | 136 ++++++++++++++++++ .../src/ollama/simple_prompt.rs | 60 ++++++++ swiftide/Cargo.toml | 25 ++-- 8 files changed, 312 insertions(+), 33 deletions(-) create mode 100644 examples/index_ollama.rs create mode 100644 swiftide-integrations/src/ollama/mod.rs create mode 100644 swiftide-integrations/src/ollama/simple_prompt.rs diff --git a/Cargo.lock b/Cargo.lock index 04bfe759..e71e826c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2766,6 +2766,20 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "255252ec57e13d2d6ae074c7b7cd8c004d17dafb1e03f954ba2fd5cc226f8f49" +dependencies = [ + "async-trait", + "log", + "reqwest", + "serde", + "serde_json", + "url", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -4298,6 +4312,7 @@ dependencies = [ "indoc", "itertools 0.13.0", "mockall", + "ollama-rs", "qdrant-client", "redis", "reqwest", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 26020fe5..2a5ae7f1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -7,10 +7,11 @@ edition = "2021" [dependencies] tokio = { version = "1.0", features = ["full"] } swiftide = { path = "../swiftide/", features = [ - "all", - "scraping", - "aws-bedrock", - "groq", + "all", + "scraping", + "aws-bedrock", + "groq", + "ollama", ] } tracing-subscriber = "0.3" serde_json = "1.0" @@ -55,6 +56,10 @@ path = "store_multiple_vectors.rs" name = "index-groq" path = "index_groq.rs" +[[example]] +name = "index-ollama" +path = "index_ollama.rs" + [[example]] name = "query-pipeline" path = "query_pipeline.rs" diff --git a/examples/index_ollama.rs b/examples/index_ollama.rs new file mode 100644 index 00000000..62bf3d8d --- /dev/null +++ b/examples/index_ollama.rs @@ -0,0 +1,55 @@ +//! # [Swiftide] Indexing with Ollama +//! +//! This example demonstrates how to index the Swiftide codebase itself. +//! Note that for it to work correctly you need to have ollama running on the default local port. +//! +//! The pipeline will: +//! - Loads the readme from the project +//! - Chunk the code into pieces of 10 to 2048 bytes +//! - Run metadata QA on each chunk with Ollama; generating questions and answers and adding metadata +//! - Embed the chunks in batches of 10, Metadata is embedded by default +//! - Store the nodes in Memory Storage +//! +//! [Swiftide]: https://github.com/bosun-ai/swiftide +//! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples + +use swiftide::{ + indexing, + indexing::loaders::FileLoader, + indexing::persist::MemoryStorage, + indexing::transformers::{ChunkMarkdown, Embed, MetadataQAText}, + integrations, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + + let ollama_client = integrations::ollama::Ollama::default() + .with_default_prompt_model("llama3.1") + .to_owned(); + + let fastembed = integrations::fastembed::FastEmbed::try_default()?; + let memory_store = MemoryStorage::default(); + + indexing::Pipeline::from_loader(FileLoader::new("README.md")) + .then_chunk(ChunkMarkdown::from_chunk_range(10..2048)) + .then(MetadataQAText::new(ollama_client.clone())) + .then_in_batch(10, Embed::new(fastembed)) + .then_store_with(memory_store.clone()) + .run() + .await?; + + println!("Example results:"); + println!( + "{}", + memory_store + .get_all_values() + .await + .into_iter() + .flat_map(|n| n.metadata.into_values().map(|v| v.to_string())) + .collect::>() + .join("\n") + ); + Ok(()) +} diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index 42cedbce..a380f15d 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -29,13 +29,13 @@ strum_macros = { workspace = true } # Integrations async-openai = { workspace = true, optional = true } qdrant-client = { workspace = true, optional = true, default-features = false, features = [ - "serde", + "serde", ] } redis = { version = "0.26", features = [ - "aio", - "tokio-comp", - "connection-manager", - "tokio-rustls-comp", + "aio", + "tokio-comp", + "connection-manager", + "tokio-rustls-comp", ], optional = true } tree-sitter = { version = "0.22", optional = true } tree-sitter-rust = { version = "0.21", optional = true } @@ -47,16 +47,17 @@ fastembed = { version = "3.6", optional = true } spider = { version = "1.98", optional = true } htmd = { version = "0.1", optional = true } aws-config = { version = "1.5", features = [ - "behavior-version-latest", + "behavior-version-latest", ], optional = true } aws-credential-types = { version = "1.2", features = [ - "hardcoded-credentials", + "hardcoded-credentials", ], optional = true } aws-sdk-bedrockruntime = { version = "1.37", features = [ - "behavior-version-latest", + "behavior-version-latest", ], optional = true } secrecy = { version = "0.8.0", optional = true } reqwest = { version = "0.12.5", optional = true, default-features = false } +ollama-rs = { version = "0.2.0", optional = true } [dev-dependencies] swiftide-core = { path = "../swiftide-core", features = ["test-utils"] } @@ -70,7 +71,7 @@ indoc = { workspace = true } [features] default = ["rustls"] -# Ensures rustls is used +# Ensures rustls is used rustls = ["reqwest/rustls-tls-native-roots"] # Qdrant for storage qdrant = ["dep:qdrant-client"] @@ -78,26 +79,28 @@ qdrant = ["dep:qdrant-client"] redis = ["dep:redis"] # Tree-sitter for code operations and chunking tree-sitter = [ - "dep:tree-sitter", - "dep:tree-sitter-rust", - "dep:tree-sitter-python", - "dep:tree-sitter-ruby", - "dep:tree-sitter-typescript", - "dep:tree-sitter-javascript", + "dep:tree-sitter", + "dep:tree-sitter-rust", + "dep:tree-sitter-python", + "dep:tree-sitter-ruby", + "dep:tree-sitter-typescript", + "dep:tree-sitter-javascript", ] # OpenAI for embedding and prompting openai = ["dep:async-openai"] # Groq prompting groq = ["dep:async-openai", "dep:secrecy", "dep:reqwest"] +# Ollama prompting +ollama = ["dep:ollama-rs"] # FastEmbed (by qdrant) for fast, local embeddings fastembed = ["dep:fastembed"] # Scraping via spider as loader and a html to markdown transformer scraping = ["dep:spider", "dep:htmd"] # AWS Bedrock for prompting aws-bedrock = [ - "dep:aws-config", - "dep:aws-credential-types", - "dep:aws-sdk-bedrockruntime", + "dep:aws-config", + "dep:aws-credential-types", + "dep:aws-sdk-bedrockruntime", ] [lints] diff --git a/swiftide-integrations/src/lib.rs b/swiftide-integrations/src/lib.rs index 374dad80..dbb1f4a3 100644 --- a/swiftide-integrations/src/lib.rs +++ b/swiftide-integrations/src/lib.rs @@ -6,6 +6,8 @@ pub mod aws_bedrock; pub mod fastembed; #[cfg(feature = "groq")] pub mod groq; +#[cfg(feature = "ollama")] +pub mod ollama; #[cfg(feature = "openai")] pub mod openai; #[cfg(feature = "qdrant")] diff --git a/swiftide-integrations/src/ollama/mod.rs b/swiftide-integrations/src/ollama/mod.rs new file mode 100644 index 00000000..98847e1a --- /dev/null +++ b/swiftide-integrations/src/ollama/mod.rs @@ -0,0 +1,136 @@ +//! This module provides integration with `Ollama`'s API, enabling the use of language models within the Swiftide project. +//! It includes the `Ollama` struct for managing API clients and default options for prompt models. +//! The module is conditionally compiled based on the "ollama" feature flag. + +use derive_builder::Builder; +use std::sync::Arc; + +mod simple_prompt; + +/// The `Ollama` struct encapsulates a `Ollama` client that implements [`swiftide::traits::SimplePrompt`] +/// +/// There is also a builder available. +/// +/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that a model +/// always needs to be set, either with [`Ollama::with_default_prompt_model`] or via the builder. +/// You can find available models in the Ollama documentation. +/// +/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means +/// some features might not work as expected. See the Ollama documentation for details. +#[derive(Debug, Builder, Clone)] +#[builder(setter(into, strip_option))] +pub struct Ollama { + /// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting. + #[builder(default = "default_client()", setter(custom))] + client: Arc, + /// Default options for prompt models. + #[builder(default)] + default_options: Options, +} + +impl Default for Ollama { + fn default() -> Self { + Self { + client: default_client(), + default_options: Options::default(), + } + } +} + +/// The `Options` struct holds configuration options for the `Ollama` client. +/// It includes optional fields for specifying the prompt model. +#[derive(Debug, Default, Clone, Builder)] +#[builder(setter(into, strip_option))] +pub struct Options { + /// The default prompt model to use, if specified. + #[builder(default)] + pub prompt_model: Option, +} + +impl Options { + /// Creates a new `OptionsBuilder` for constructing `Options` instances. + pub fn builder() -> OptionsBuilder { + OptionsBuilder::default() + } +} + +impl Ollama { + /// Creates a new `OllamaBuilder` for constructing `Ollama` instances. + pub fn builder() -> OllamaBuilder { + OllamaBuilder::default() + } + + /// Sets a default prompt model to use when prompting + pub fn with_default_prompt_model(&mut self, model: impl Into) -> &mut Self { + self.default_options = Options { + prompt_model: Some(model.into()), + }; + self + } +} + +impl OllamaBuilder { + /// Sets the `Ollama` client for the `Ollama` instance. + /// + /// # Parameters + /// - `client`: The `Ollama` client to set. + /// + /// # Returns + /// A mutable reference to the `OllamaBuilder`. + pub fn client(&mut self, client: ollama_rs::Ollama) -> &mut Self { + self.client = Some(Arc::new(client)); + self + } + + /// Sets the default prompt model for the `Ollama` instance. + /// + /// # Parameters + /// - `model`: The prompt model to set. + /// + /// # Returns + /// A mutable reference to the `OllamaBuilder`. + pub fn default_prompt_model(&mut self, model: impl Into) -> &mut Self { + if let Some(options) = self.default_options.as_mut() { + options.prompt_model = Some(model.into()); + } else { + self.default_options = Some(Options { + prompt_model: Some(model.into()), + }); + } + self + } +} + +fn default_client() -> Arc { + ollama_rs::Ollama::default().into() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_default_prompt_model() { + let openai = Ollama::builder() + .default_prompt_model("llama3.1") + .build() + .unwrap(); + assert_eq!( + openai.default_options.prompt_model, + Some("llama3.1".to_string()) + ); + } + + #[test] + fn test_building_via_default() { + let mut client = Ollama::default(); + + assert!(client.default_options.prompt_model.is_none()); + + client.with_default_prompt_model("llama3.1"); + assert_eq!( + client.default_options.prompt_model, + Some("llama3.1".to_string()) + ); + } +} diff --git a/swiftide-integrations/src/ollama/simple_prompt.rs b/swiftide-integrations/src/ollama/simple_prompt.rs new file mode 100644 index 00000000..1eeb2523 --- /dev/null +++ b/swiftide-integrations/src/ollama/simple_prompt.rs @@ -0,0 +1,60 @@ +//! This module provides an implementation of the `SimplePrompt` trait for the `Ollama` struct. +//! It defines an asynchronous function to interact with the `Ollama` API, allowing prompt processing +//! and generating responses as part of the Swiftide system. +use async_trait::async_trait; +use swiftide_core::{prompt::Prompt, SimplePrompt}; + +use super::Ollama; +use anyhow::{Context as _, Result}; + +/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a response. +#[async_trait] +impl SimplePrompt for Ollama { + /// Sends a prompt to the Ollama API and returns the response content. + /// + /// # Parameters + /// - `prompt`: A string slice that holds the prompt to be sent to the Ollama API. + /// + /// # Returns + /// - `Result`: On success, returns the content of the response as a `String`. + /// On failure, returns an error wrapped in a `Result`. + /// + /// # Errors + /// - Returns an error if the model is not set in the default options. + /// - Returns an error if the request to the Ollama API fails. + /// - Returns an error if the response does not contain the expected content. + #[tracing::instrument(skip_all, err)] + async fn prompt(&self, prompt: Prompt) -> Result { + // Retrieve the model from the default options, returning an error if not set. + let model = self + .default_options + .prompt_model + .as_ref() + .context("Model not set")?; + + // Build the request to be sent to the Ollama API. + let request = ollama_rs::generation::completion::request::GenerationRequest::new( + model.to_string(), + prompt.render().await?, + ); + + // Log the request for debugging purposes. + tracing::debug!( + messages = serde_json::to_string_pretty(&request)?, + "[SimplePrompt] Request to ollama" + ); + + // Send the request to the Ollama API and await the response. + // let mut response = self.client.chat().create(request).await?; + let response = self.client.generate(request).await?; + + // Log the response for debugging purposes. + tracing::debug!( + response = serde_json::to_string_pretty(&response.response)?, + "[SimplePrompt] Response from ollama" + ); + + // Extract and return the content of the response, returning an error if not found. + Ok(response.response) + } +} diff --git a/swiftide/Cargo.toml b/swiftide/Cargo.toml index 3de996ca..879f0b2a 100644 --- a/swiftide/Cargo.toml +++ b/swiftide/Cargo.toml @@ -22,14 +22,15 @@ swiftide-query = { path = "../swiftide-query", version = "0.7" } [features] default = [] all = [ - "qdrant", - "redis", - "tree-sitter", - "openai", - "fastembed", - "scraping", - "aws-bedrock", - "groq", + "qdrant", + "redis", + "tree-sitter", + "openai", + "fastembed", + "scraping", + "aws-bedrock", + "groq", + "ollama", ] # Qdrant for storage qdrant = ["swiftide-integrations/qdrant"] @@ -37,13 +38,15 @@ qdrant = ["swiftide-integrations/qdrant"] redis = ["swiftide-integrations/redis"] # Tree-sitter for code operations and chunking tree-sitter = [ - "swiftide-integrations/tree-sitter", - "swiftide-indexing/tree-sitter", + "swiftide-integrations/tree-sitter", + "swiftide-indexing/tree-sitter", ] # OpenAI for embedding and prompting openai = ["swiftide-integrations/openai"] # Groq prompting groq = ["swiftide-integrations/groq"] +# Ollama prompting +ollama = ["swiftide-integrations/ollama"] # FastEmbed (by qdrant) for fast, local embeddings fastembed = ["swiftide-integrations/fastembed"] # Scraping via spider as loader and a html to markdown transformer @@ -60,7 +63,7 @@ swiftide-test-utils = { path = "../swiftide-test-utils" } async-openai = { workspace = true } qdrant-client = { workspace = true, default-features = false, features = [ - "serde", + "serde", ] } anyhow = { workspace = true }