Skip to content

Commit

Permalink
feat: Code outlines in chunk metadata (#137)
Browse files Browse the repository at this point in the history
Added a transformer that generates outlines for code files using tree sitter. And another that compresses the outline to be more relevant to chunks. Additionally added a step to the metadata QA tool that uses the outline to improve the contextual awareness during QA generation.
  • Loading branch information
tinco authored Aug 6, 2024
1 parent 1ff2855 commit e728a7c
Show file tree
Hide file tree
Showing 23 changed files with 748 additions and 131 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ doc-scrape-examples = true
name = "index-codebase"
path = "index_codebase.rs"

[[example]]
name = "index-codebase-reduced-context"
path = "index_codebase_reduced_context.rs"

[[example]]
doc-scrape-examples = true
name = "fastembed"
Expand Down
75 changes: 75 additions & 0 deletions examples/index_codebase_reduced_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//! # [Swiftide] Indexing the Swiftide itself example with reduced context size
//!
//! This example demonstrates how to index the Swiftide codebase itself, optimizing for a smaller context size.
//! Note that for it to work correctly you need to have OPENAI_API_KEY set, redis and qdrant
//! running.
//!
//! The pipeline will:
//! - Load all `.rs` files from the current directory
//! - Skip any nodes previously processed; hashes are based on the path and chunk (not the
//! metadata!)
//! - Generate an outline of the symbols defined in each file to be used as context in a later step and store it in the metadata
//! - Chunk the code into pieces of 10 to 2048 bytes
//! - For each chunk, generate a condensed subset of the symbols outline tailored for that specific chunk and store that in the metadata
//! - Run metadata QA on each chunk; generating questions and answers and adding metadata
//! - Embed the chunks in batches of 10, Metadata is embedded by default
//! - Store the nodes in Qdrant
//!
//! Note that metadata is copied over to smaller chunks when chunking. When making LLM requests
//! with lots of small chunks, consider the rate limits of the API.
//!
//! [Swiftide]: https://github.com/bosun-ai/swiftide
//! [examples]: https://github.com/bosun-ai/swiftide/blob/master/examples
use swiftide::indexing;
use swiftide::indexing::loaders::FileLoader;
use swiftide::indexing::transformers::{ChunkCode, Embed, MetadataQACode};
use swiftide::integrations::{self, qdrant::Qdrant, redis::Redis};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();

let openai_client = integrations::openai::OpenAI::builder()
.default_embed_model("text-embedding-3-small")
.default_prompt_model("gpt-3.5-turbo")
.build()?;

let redis_url = std::env::var("REDIS_URL")
.as_deref()
.unwrap_or("redis://localhost:6379")
.to_owned();

let chunk_size = 2048;

indexing::Pipeline::from_loader(FileLoader::new(".").with_extensions(&["rs"]))
.filter_cached(Redis::try_from_url(
redis_url,
"swiftide-examples-codebase-reduced-context",
)?)
.then(
indexing::transformers::OutlineCodeTreeSitter::try_for_language(
"rust",
Some(chunk_size),
)?,
)
.then(MetadataQACode::new(openai_client.clone()))
.then_chunk(ChunkCode::try_for_language_and_chunk_size(
"rust",
10..chunk_size,
)?)
.then(indexing::transformers::CompressCodeOutline::new(
openai_client.clone(),
))
.then_in_batch(10, Embed::new(openai_client.clone()))
.then_store_with(
Qdrant::builder()
.batch_size(50)
.vector_size(1536)
.collection_name("swiftide-examples-codebase-reduced-context")
.build()?,
)
.run()
.await?;
Ok(())
}
9 changes: 8 additions & 1 deletion swiftide-core/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub struct Node {
pub metadata: Metadata,
/// Mode of embedding data Chunk and Metadata
pub embed_mode: EmbedMode,
/// Size of the input this node was originally derived from in bytes
pub original_size: usize,
/// Offset of the chunk relative to the start of the input this node was originally derived from in bytes
pub offset: usize,
}

impl Debug for Node {
Expand Down Expand Up @@ -80,8 +84,11 @@ impl Node {
///
/// The other fields are set to their default values.
pub fn new(chunk: impl Into<String>) -> Node {
let chunk = chunk.into();
let original_size = chunk.len();
Node {
chunk: chunk.into(),
chunk,
original_size,
..Default::default()
}
}
Expand Down
5 changes: 3 additions & 2 deletions swiftide-indexing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ strum = { workspace = true }
strum_macros = { workspace = true }
indoc = { workspace = true }

regex = "1.10.5"
ignore = "0.4"
text-splitter = { version = "0.14", features = ["markdown"] }

Expand All @@ -42,8 +43,8 @@ test-case = { workspace = true }
[features]
# TODO: Should not depend on integrations, transformers that use them should be in integrations instead and re-exported from root for convencience
tree-sitter = [
"swiftide-integrations?/tree-sitter",
"dep:swiftide-integrations",
"swiftide-integrations?/tree-sitter",
"dep:swiftide-integrations",
]

[lints]
Expand Down
4 changes: 4 additions & 0 deletions swiftide-indexing/src/loaders/file_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ impl FileLoader {
.map(|entry| {
tracing::debug!("Reading file: {:?}", entry);
let content = std::fs::read_to_string(&entry).unwrap();
let original_size = content.len();
Node {
path: entry,
chunk: content,
original_size,
..Default::default()
}
})
Expand Down Expand Up @@ -99,9 +101,11 @@ impl Loader for FileLoader {
tracing::debug!("Reading file: {:?}", entry);
let content =
std::fs::read_to_string(entry.path()).context("Failed to read file")?;
let original_size = content.len();
Ok(Node {
path: entry.path().into(),
chunk: content,
original_size,
..Default::default()
})
});
Expand Down
10 changes: 8 additions & 2 deletions swiftide-indexing/src/transformers/chunk_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,17 @@ impl ChunkerTransformer for ChunkCode {
let split_result = self.chunker.split(&node.chunk);

if let Ok(split) = split_result {
let mut offset = 0;

IndexingStream::iter(split.into_iter().map(move |chunk| {
Ok(Node {
let chunk_size = chunk.len();
let mut node = Node {
chunk,
..node.clone()
})
};
node.offset = offset;
offset += chunk_size;
Ok(node)
}))
} else {
// Send the error downstream
Expand Down
171 changes: 171 additions & 0 deletions swiftide-indexing/src/transformers/compress_code_outline.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//! `CompressCodeOutline` is a transformer that reduces the size of the outline of a the parent file of a chunk to make it more relevant to the chunk.
use derive_builder::Builder;
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use swiftide_core::{indexing::Node, prompt::PromptTemplate, SimplePrompt, Transformer};

/// `CompressCodeChunk` rewrites the "Outline" metadata field of a chunk to
/// condense it and make it more relevant to the chunk in question. It is useful as a
/// step after chunking a file that has had outline generated for it with `FileToOutlineTreeSitter`.
#[derive(Debug, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct CompressCodeOutline {
#[builder(setter(custom))]
client: Arc<dyn SimplePrompt>,
#[builder(default = "default_prompt()")]
prompt_template: PromptTemplate,
#[builder(default)]
concurrency: Option<usize>,
}

fn extract_markdown_codeblock(text: String) -> String {
let re = regex::Regex::new(r"(?sm)```\w*\n(.*?)```").unwrap();
let captures = re.captures(text.as_str());
captures
.map(|c| c.get(1).unwrap().as_str().to_string())
.unwrap_or(text)
}

impl CompressCodeOutline {
pub fn builder() -> CompressCodeOutlineBuilder {
CompressCodeOutlineBuilder::default()
}

pub fn from_client(client: impl SimplePrompt + 'static) -> CompressCodeOutlineBuilder {
CompressCodeOutlineBuilder::default()
.client(client)
.to_owned()
}
/// Creates a new instance of `CompressCodeOutline`.
///
/// # Arguments
///
/// * `client` - An implementation of the `SimplePrompt` trait used to generate questions and answers.
///
/// # Returns
///
/// A new instance of `CompressCodeOutline` with a default prompt and a default number of questions.
pub fn new(client: impl SimplePrompt + 'static) -> Self {
Self {
client: Arc::new(client),
prompt_template: default_prompt(),
concurrency: None,
}
}

#[must_use]
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}

/// Returns the default prompt template for generating questions and answers.
///
/// This template includes placeholders for the number of questions and the code chunk.
///
/// # Returns
///
/// A string representing the default prompt template.
fn default_prompt() -> PromptTemplate {
include_str!("prompts/compress_code_outline.prompt.md").into()
}

impl CompressCodeOutlineBuilder {
pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
self.client = Some(Arc::new(client));
self
}
}

#[async_trait]
impl Transformer for CompressCodeOutline {
/// Asynchronously transforms an `Node` by reducing the size of the outline to make it more relevant to the chunk.
///
/// This method uses the `SimplePrompt` client to compress the outline of the `Node` and updates the `Node` with the compressed outline.
///
/// # Arguments
///
/// * `node` - The `Node` to be transformed.
///
/// # Returns
///
/// A result containing the transformed `Node` or an error if the transformation fails.
///
/// # Errors
///
/// This function will return an error if the `SimplePrompt` client fails to generate a response.
#[tracing::instrument(skip_all, name = "transformers.compress_code_outline")]
async fn transform_node(&self, mut node: Node) -> Result<Node> {
let maybe_outline = node.metadata.get("Outline");

let Some(outline) = maybe_outline else {
return Ok(node);
};

let prompt = self
.prompt_template
.to_prompt()
.with_context_value("outline", outline.as_str())
.with_context_value("code", node.chunk.as_str());

let response = extract_markdown_codeblock(self.client.prompt(prompt).await?);

node.metadata.insert("Outline".to_string(), response);

Ok(node)
}

fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}

#[cfg(test)]
mod test {
use swiftide_core::MockSimplePrompt;

use super::*;

#[test_log::test(tokio::test)]
async fn test_compress_code_template() {
let template = default_prompt();

let outline = "Relevant Outline";
let code = "Code using outline";

let prompt = template
.to_prompt()
.with_context_value("outline", outline)
.with_context_value("code", code);

insta::assert_snapshot!(prompt.render().await.unwrap());
}

#[tokio::test]
async fn test_compress_code_outline() {
let mut client = MockSimplePrompt::new();

client
.expect_prompt()
.returning(|_| Ok("RelevantOutline".to_string()));

let transformer = CompressCodeOutline::builder()
.client(client)
.build()
.unwrap();
let mut node = Node::new("Some text");
node.offset = 0;
node.original_size = 100;

node.metadata
.insert("Outline".to_string(), "Some outline".to_string());

let result = transformer.transform_node(node).await.unwrap();

assert_eq!(result.chunk, "Some text");
assert_eq!(result.metadata.get("Outline").unwrap(), "RelevantOutline");
}
}
18 changes: 17 additions & 1 deletion swiftide-indexing/src/transformers/metadata_qa_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,16 @@ impl Transformer for MetadataQACode {
/// This function will return an error if the `SimplePrompt` client fails to generate a response.
#[tracing::instrument(skip_all, name = "transformers.metadata_qa_code")]
async fn transform_node(&self, mut node: Node) -> Result<Node> {
let prompt = self
let mut prompt = self
.prompt_template
.to_prompt()
.with_node(&node)
.with_context_value("questions", self.num_questions);

if let Some(outline) = node.metadata.get("Outline") {
prompt = prompt.with_context_value("outline", outline.as_str());
}

let response = self.client.prompt(prompt).await?;

node.metadata.insert(NAME, response);
Expand Down Expand Up @@ -128,6 +132,18 @@ mod test {
insta::assert_snapshot!(prompt.render().await.unwrap());
}

#[tokio::test]
async fn test_template_with_outline() {
let template = default_prompt();

let prompt = template
.to_prompt()
.with_node(&Node::new("test"))
.with_context_value("questions", 5)
.with_context_value("outline", "Test outline");
insta::assert_snapshot!(prompt.render().await.unwrap());
}

#[tokio::test]
async fn test_metadata_qacode() {
let mut client = MockSimplePrompt::new();
Expand Down
Loading

0 comments on commit e728a7c

Please sign in to comment.