Skip to content

Commit

Permalink
fix: Bad embed error propagation (#293)
Browse files Browse the repository at this point in the history
- **fix(indexing): Limit logged chunk to max 100 chars**
- **fix: Embed transformers must correctly propagate errors**
  • Loading branch information
timonv authored Sep 12, 2024
1 parent c74f1e5 commit 9464ca1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
4 changes: 4 additions & 0 deletions swiftide-core/src/indexing_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ impl IndexingStream {
}
}

/// Creates an `IndexingStream` from an iterator of `Result<Node>`.
///
/// WARN: Also works with Err items directly, which will result
/// in an _incorrect_ stream
pub fn iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = Result<Node>> + Send + 'static,
Expand Down
21 changes: 20 additions & 1 deletion swiftide-indexing/src/transformers/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl BatchableTransformer for Embed {
// Embeddings vectors of every node stored in order of processed nodes.
let mut embeddings = match self.embed_model.embed(embeddables_data).await {
Ok(embeddngs) => VecDeque::from(embeddngs),
Err(err) => return IndexingStream::iter(Err(err)),
Err(err) => return err.into(),
};

// Iterator of nodes with embeddings vectors map.
Expand Down Expand Up @@ -281,4 +281,23 @@ mod tests {
debug_assert_eq!(ingested_node, expected_node);
}
}

#[tokio::test]
async fn test_returns_error_properly_if_embed_fails() {
let test_nodes = vec![Node::new("chunk")];
let mut model_mock = MockEmbeddingModel::new();
model_mock
.expect_embed()
.times(1)
.returning(|_| Err(anyhow::anyhow!("error")));
let embed = Embed::new(model_mock);
let mut stream = embed.batch_transform(test_nodes).await;
let error = stream
.next()
.await
.expect("IngestionStream has same length as expected_nodes")
.expect_err("Is Err");

assert_eq!(error.to_string(), "error");
}
}
21 changes: 20 additions & 1 deletion swiftide-indexing/src/transformers/sparse_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl BatchableTransformer for SparseEmbed {
// SparseEmbeddings vectors of every node stored in order of processed nodes.
let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
Ok(embeddngs) => VecDeque::from(embeddngs),
Err(err) => return IndexingStream::iter(Err(err)),
Err(err) => return err.into(),
};

// Iterator of nodes with embeddings vectors map.
Expand Down Expand Up @@ -309,4 +309,23 @@ mod tests {
debug_assert_eq!(ingested_node, expected_node);
}
}

#[tokio::test]
async fn test_returns_error_properly_if_sparse_embed_fails() {
let test_nodes = vec![Node::new("chunk")];
let mut model_mock = MockSparseEmbeddingModel::new();
model_mock
.expect_sparse_embed()
.times(1)
.returning(|_| Err(anyhow::anyhow!("error")));
let embed = SparseEmbed::new(model_mock);
let mut stream = embed.batch_transform(test_nodes).await;
let error = stream
.next()
.await
.expect("IngestionStream has same length as expected_nodes")
.expect_err("Is Err");

assert_eq!(error.to_string(), "error");
}
}
8 changes: 5 additions & 3 deletions swiftide-integrations/src/openai/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ impl EmbeddingModel for OpenAI {

let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input(input)
.input(&input)
.build()?;
tracing::debug!(
messages = serde_json::to_string_pretty(&request)?,
num_chunks = input.len(),
model = &model,
"[Embed] Request to openai"
);
let response = self
Expand All @@ -30,7 +31,8 @@ impl EmbeddingModel for OpenAI {
.await
.context("Request to OpenAI Failed")?;

tracing::debug!("[Embed] Response openai");
let num_embeddings = response.data.len();
tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai");

// WARN: Naively assumes that the order is preserved. Might not always be the case.
Ok(response.data.into_iter().map(|d| d.embedding).collect())
Expand Down

0 comments on commit 9464ca1

Please sign in to comment.