Skip to content

Commit

Permalink
feat(integrations): add ollama support (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
tinco authored Aug 2, 2024
1 parent f7accde commit b2d31e5
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 33 deletions.
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ edition = "2021"
[dependencies]
tokio = { version = "1.0", features = ["full"] }
swiftide = { path = "../swiftide/", features = [
"all",
"scraping",
"aws-bedrock",
"groq",
"all",
"scraping",
"aws-bedrock",
"groq",
"ollama",
] }
tracing-subscriber = "0.3"
serde_json = "1.0"
Expand Down Expand Up @@ -55,6 +56,10 @@ path = "store_multiple_vectors.rs"
name = "index-groq"
path = "index_groq.rs"

[[example]]
name = "index-ollama"
path = "index_ollama.rs"

[[example]]
name = "query-pipeline"
path = "query_pipeline.rs"
55 changes: 55 additions & 0 deletions examples/index_ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! # [Swiftide] Indexing with Ollama
//!
//! This example demonstrates how to index the Swiftide codebase itself.
//! Note that for it to work correctly you need to have ollama running on the default local port.
//!
//! The pipeline will:
//! - Loads the readme from the project
//! - Chunk the code into pieces of 10 to 2048 bytes
//! - Run metadata QA on each chunk with Ollama; generating questions and answers and adding metadata
//! - Embed the chunks in batches of 10, Metadata is embedded by default
//! - Store the nodes in Memory Storage
//!
//! [Swiftide]: https://github.com/bosun-ai/swiftide
//! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples
use swiftide::{
indexing,
indexing::loaders::FileLoader,
indexing::persist::MemoryStorage,
indexing::transformers::{ChunkMarkdown, Embed, MetadataQAText},
integrations,
};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();

let ollama_client = integrations::ollama::Ollama::default()
.with_default_prompt_model("llama3.1")
.to_owned();

let fastembed = integrations::fastembed::FastEmbed::try_default()?;
let memory_store = MemoryStorage::default();

indexing::Pipeline::from_loader(FileLoader::new("README.md"))
.then_chunk(ChunkMarkdown::from_chunk_range(10..2048))
.then(MetadataQAText::new(ollama_client.clone()))
.then_in_batch(10, Embed::new(fastembed))
.then_store_with(memory_store.clone())
.run()
.await?;

println!("Example results:");
println!(
"{}",
memory_store
.get_all_values()
.await
.into_iter()
.flat_map(|n| n.metadata.into_values().map(|v| v.to_string()))
.collect::<Vec<_>>()
.join("\n")
);
Ok(())
}
39 changes: 21 additions & 18 deletions swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ strum_macros = { workspace = true }
# Integrations
async-openai = { workspace = true, optional = true }
qdrant-client = { workspace = true, optional = true, default-features = false, features = [
"serde",
"serde",
] }
redis = { version = "0.26", features = [
"aio",
"tokio-comp",
"connection-manager",
"tokio-rustls-comp",
"aio",
"tokio-comp",
"connection-manager",
"tokio-rustls-comp",
], optional = true }
tree-sitter = { version = "0.22", optional = true }
tree-sitter-rust = { version = "0.21", optional = true }
Expand All @@ -47,16 +47,17 @@ fastembed = { version = "3.6", optional = true }
spider = { version = "1.98", optional = true }
htmd = { version = "0.1", optional = true }
aws-config = { version = "1.5", features = [
"behavior-version-latest",
"behavior-version-latest",
], optional = true }
aws-credential-types = { version = "1.2", features = [
"hardcoded-credentials",
"hardcoded-credentials",
], optional = true }
aws-sdk-bedrockruntime = { version = "1.37", features = [
"behavior-version-latest",
"behavior-version-latest",
], optional = true }
secrecy = { version = "0.8.0", optional = true }
reqwest = { version = "0.12.5", optional = true, default-features = false }
ollama-rs = { version = "0.2.0", optional = true }

[dev-dependencies]
swiftide-core = { path = "../swiftide-core", features = ["test-utils"] }
Expand All @@ -70,34 +71,36 @@ indoc = { workspace = true }

[features]
default = ["rustls"]
# Ensures rustls is used
# Ensures rustls is used
rustls = ["reqwest/rustls-tls-native-roots"]
# Qdrant for storage
qdrant = ["dep:qdrant-client"]
# Redis for caching and storage
redis = ["dep:redis"]
# Tree-sitter for code operations and chunking
tree-sitter = [
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-python",
"dep:tree-sitter-ruby",
"dep:tree-sitter-typescript",
"dep:tree-sitter-javascript",
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-python",
"dep:tree-sitter-ruby",
"dep:tree-sitter-typescript",
"dep:tree-sitter-javascript",
]
# OpenAI for embedding and prompting
openai = ["dep:async-openai"]
# Groq prompting
groq = ["dep:async-openai", "dep:secrecy", "dep:reqwest"]
# Ollama prompting
ollama = ["dep:ollama-rs"]
# FastEmbed (by qdrant) for fast, local embeddings
fastembed = ["dep:fastembed"]
# Scraping via spider as loader and a html to markdown transformer
scraping = ["dep:spider", "dep:htmd"]
# AWS Bedrock for prompting
aws-bedrock = [
"dep:aws-config",
"dep:aws-credential-types",
"dep:aws-sdk-bedrockruntime",
"dep:aws-config",
"dep:aws-credential-types",
"dep:aws-sdk-bedrockruntime",
]

[lints]
Expand Down
2 changes: 2 additions & 0 deletions swiftide-integrations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub mod aws_bedrock;
pub mod fastembed;
#[cfg(feature = "groq")]
pub mod groq;
#[cfg(feature = "ollama")]
pub mod ollama;
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "qdrant")]
Expand Down
136 changes: 136 additions & 0 deletions swiftide-integrations/src/ollama/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//! This module provides integration with `Ollama`'s API, enabling the use of language models within the Swiftide project.
//! It includes the `Ollama` struct for managing API clients and default options for prompt models.
//! The module is conditionally compiled based on the "ollama" feature flag.
use derive_builder::Builder;
use std::sync::Arc;

mod simple_prompt;

/// The `Ollama` struct encapsulates a `Ollama` client that implements [`swiftide::traits::SimplePrompt`]
///
/// There is also a builder available.
///
/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that a model
/// always needs to be set, either with [`Ollama::with_default_prompt_model`] or via the builder.
/// You can find available models in the Ollama documentation.
///
/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means
/// some features might not work as expected. See the Ollama documentation for details.
#[derive(Debug, Builder, Clone)]
#[builder(setter(into, strip_option))]
pub struct Ollama {
/// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting.
#[builder(default = "default_client()", setter(custom))]
client: Arc<ollama_rs::Ollama>,
/// Default options for prompt models.
#[builder(default)]
default_options: Options,
}

impl Default for Ollama {
fn default() -> Self {
Self {
client: default_client(),
default_options: Options::default(),
}
}
}

/// The `Options` struct holds configuration options for the `Ollama` client.
/// It includes optional fields for specifying the prompt model.
#[derive(Debug, Default, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct Options {
/// The default prompt model to use, if specified.
#[builder(default)]
pub prompt_model: Option<String>,
}

impl Options {
/// Creates a new `OptionsBuilder` for constructing `Options` instances.
pub fn builder() -> OptionsBuilder {
OptionsBuilder::default()
}
}

impl Ollama {
/// Creates a new `OllamaBuilder` for constructing `Ollama` instances.
pub fn builder() -> OllamaBuilder {
OllamaBuilder::default()
}

/// Sets a default prompt model to use when prompting
pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
self.default_options = Options {
prompt_model: Some(model.into()),
};
self
}
}

impl OllamaBuilder {
/// Sets the `Ollama` client for the `Ollama` instance.
///
/// # Parameters
/// - `client`: The `Ollama` client to set.
///
/// # Returns
/// A mutable reference to the `OllamaBuilder`.
pub fn client(&mut self, client: ollama_rs::Ollama) -> &mut Self {
self.client = Some(Arc::new(client));
self
}

/// Sets the default prompt model for the `Ollama` instance.
///
/// # Parameters
/// - `model`: The prompt model to set.
///
/// # Returns
/// A mutable reference to the `OllamaBuilder`.
pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
if let Some(options) = self.default_options.as_mut() {
options.prompt_model = Some(model.into());
} else {
self.default_options = Some(Options {
prompt_model: Some(model.into()),
});
}
self
}
}

fn default_client() -> Arc<ollama_rs::Ollama> {
ollama_rs::Ollama::default().into()
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_default_prompt_model() {
let openai = Ollama::builder()
.default_prompt_model("llama3.1")
.build()
.unwrap();
assert_eq!(
openai.default_options.prompt_model,
Some("llama3.1".to_string())
);
}

#[test]
fn test_building_via_default() {
let mut client = Ollama::default();

assert!(client.default_options.prompt_model.is_none());

client.with_default_prompt_model("llama3.1");
assert_eq!(
client.default_options.prompt_model,
Some("llama3.1".to_string())
);
}
}
Loading

0 comments on commit b2d31e5

Please sign in to comment.