Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(chunkcode): use correct chunksizes #122

Merged
merged 6 commits into from
Jul 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 129 additions & 21 deletions swiftide/src/integrations/treesitter/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,46 @@ impl CodeSplitter {
/// # Returns
///
/// * `Vec<String>` - A vector of code chunks as strings.
fn chunk_node(&self, node: Node, source: &str, mut last_end: usize) -> Vec<String> {
fn chunk_node(
&self,
node: Node,
source: &str,
mut last_end: usize,
current_chunk: Option<String>,
) -> Vec<String> {
let mut new_chunks: Vec<String> = Vec::new();
let mut current_chunk = String::new();
let mut current_chunk = current_chunk.unwrap_or(String::new());

for child in node.children(&mut node.walk()) {
if child.end_byte() - child.start_byte() > self.max_bytes() {
// Child is too big, recursively chunk the child
if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
new_chunks.push(current_chunk);
debug_assert!(
current_chunk.len() <= self.max_bytes(),
"Chunk too big: {} > {}",
current_chunk.len(),
self.max_bytes()
);
tinco marked this conversation as resolved.
Show resolved Hide resolved

// if the next child will make the chunk too big then there are two options:
// 1. if the next child is too big to fit in a whole chunk, then recursively chunk it one level down
// 2. if the next child is small enough to fit in a chunk, then add the current chunk to the list and start a new chunk

let next_child_size = child.end_byte() - last_end;
if current_chunk.len() + next_child_size >= self.max_bytes() {
if next_child_size > self.max_bytes() {
let mut sub_chunks =
self.chunk_node(child, source, last_end, Some(current_chunk));
current_chunk = sub_chunks.pop().unwrap_or(String::new());
new_chunks.extend(sub_chunks);
} else {
// NOTE: if the current chunk was smaller than then the min_bytes, then it is discarded here
if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
new_chunks.push(current_chunk);
tinco marked this conversation as resolved.
Show resolved Hide resolved
}
current_chunk = source[last_end..child.end_byte()].to_string();
}
current_chunk = String::new();
new_chunks.extend(self.chunk_node(child, source, last_end));
} else if current_chunk.len() + child.end_byte() - child.start_byte() > self.max_bytes()
{
// Child would make the current chunk too big, so start a new chunk
new_chunks.push(current_chunk.trim().to_string());
current_chunk = source[last_end..child.end_byte()].to_string();
} else {
current_chunk += &source[last_end..child.end_byte()];
}

last_end = child.end_byte();
}

Expand Down Expand Up @@ -157,7 +177,7 @@ impl CodeSplitter {
if root_node.has_error() {
anyhow::bail!("Root node has invalid syntax");
} else {
Ok(self.chunk_node(root_node, code, 0))
Ok(self.chunk_node(root_node, code, 0, None))
}
}

Expand Down Expand Up @@ -242,12 +262,15 @@ mod test {
"#};
let chunks = splitter.split(text).unwrap();

dbg!(&chunks);
assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
assert!(chunks
.windows(2)
.all(|pair| pair.iter().map(|chunk| chunk.len()).sum::<usize>() >= 50));

assert_eq!(
chunks,
vec![
"fn main()",
"{\n println!(\"Hello, World!\");",
"fn main() {\n println!(\"Hello, World!\");",
"\n println!(\"Goodbye, World!\");\n}",
]
)
Expand Down Expand Up @@ -288,8 +311,7 @@ mod test {
assert_eq!(
chunks,
vec![
"fn main()",
"{\n println!(\"Hello, World!\");",
"fn main() {\n println!(\"Hello, World!\");",
"\n println!(\"Goodbye, World!\");\n}",
]
)
Expand All @@ -310,12 +332,98 @@ mod test {
}
"#};
let chunks = splitter.split(text).unwrap();

assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
assert!(chunks
.windows(2)
.all(|pair| pair.iter().map(|chunk| chunk.len()).sum::<usize>() > 50));
assert!(chunks.iter().all(|chunk| chunk.len() >= 20));

assert_eq!(
chunks,
vec![
"{\n println!(\"Hello, World!\");",
"\n println!(\"Goodbye, World!\");\n}",
"fn main() {\n println!(\"Hello, World!\");",
"\n println!(\"Goodbye, World!\");\n}"
]
)
}

#[test]
fn test_on_self() {
// read the current file
let code = include_str!("splitter.rs");
// try chunking with varying ranges of bytes, give me ten with different min and max
let ranges = vec![
10..200,
50..100,
100..150,
150..200,
200..250,
250..300,
300..350,
350..400,
400..450,
450..500,
];

for range in ranges {
let min = range.start;
let max = range.end;
let splitter = CodeSplitter::builder()
.try_language("rust")
.unwrap()
.chunk_size(range)
.build()
.unwrap();

assert_eq!(splitter.min_bytes(), min);
assert_eq!(splitter.max_bytes(), max);

let chunks = splitter.split(code).unwrap();

assert!(chunks.iter().all(|chunk| chunk.len() <= max));
let chunk_pairs_that_are_smaller_than_max = chunks
.windows(2)
.filter(|pair| pair.iter().map(|chunk| chunk.len()).sum::<usize>() < max);
assert!(
chunk_pairs_that_are_smaller_than_max.clone().count() == 0,
"max: {}, {} + {}, {:?}",
max,
chunk_pairs_that_are_smaller_than_max
.clone()
.next()
.unwrap()[0]
.len(),
chunk_pairs_that_are_smaller_than_max
.clone()
.next()
.unwrap()[1]
.len(),
chunk_pairs_that_are_smaller_than_max
.collect::<Vec<_>>()
.first()
);
assert!(chunks.iter().all(|chunk| chunk.len() >= min));

assert!(
chunks.iter().all(|chunk| chunk.len() >= min),
"{:?}",
chunks
.iter()
.filter(|chunk| chunk.len() < min)
.collect::<Vec<_>>()
);
assert!(
chunks.iter().all(|chunk| chunk.len() <= max),
"max = {}, chunks = {:?}",
max,
chunks
.iter()
.filter(|chunk| chunk.len() > max)
.collect::<Vec<_>>()
);
}

// assert there are no nodes smaller than 10
}
}
Loading