Skip to content

Commit

Permalink
[ENH] Integrate e2e tracing for query vectors endpoint (#1991)
Browse files Browse the repository at this point in the history
## Description of changes

This PR integrates end-to-end tracing for the Query Service
(specifically the `query_vectors()` RPC). The main highlights are:
- Span inheritance by component - whenever a component is started the
executor executes the `run()` function inside a child span with the
current span as the parent. This works even across threads now for e.g.
if the component was spawned in `inherited` mode but it invoked the
receive msg handler on a different thread, etc.
- Span propagation across components. For e.g. when the HNSW Query
orchestrator submits the Pull logs request or brute force KNN request to
the worker (via the dispatcher), the worker invokes the task handler
inside a child span with parent span as the HNSW orchestrator. This is
implemented by adding a new field to the `Task` struct - the span id of
the parent and the worker then creates a child span with this id as the
parent.

## Test plan

- [+] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

Check
https://gist.github.com/sanketkedia/96f3b00a2037a0fb2e99c16c3a379e69 for
the exact logs for `query_vectors()` rpc

## Documentation Changes
None required

---------

Co-authored-by: skedia <[email protected]>
  • Loading branch information
sanketkedia and skedia authored Apr 23, 2024
1 parent 26b8ac6 commit adf011c
Show file tree
Hide file tree
Showing 17 changed files with 124 additions and 41 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ aws-config = { version = "1.1.2", features = ["behavior-version-latest"] }
arrow = "50.0.0"
roaring = "0.10.3"
tantivy = "0.21.1"
tracing = "0.1"
tracing-subscriber = "0.3"

[dev-dependencies]
proptest = "1.4.0"
Expand Down
3 changes: 3 additions & 0 deletions rust/worker/src/bin/query_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@ use worker::query_service_entrypoint;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.init();
query_service_entrypoint().await;
}
15 changes: 12 additions & 3 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
};
use async_trait::async_trait;
use std::fmt::Debug;
use tracing::Span;

/// The dispatcher is responsible for distributing tasks to worker threads.
/// It is a component that receives tasks and distributes them to worker threads.
Expand Down Expand Up @@ -97,7 +98,11 @@ impl Dispatcher {
// If a worker is waiting for a task, send it to the worker in FIFO order
// Otherwise, add it to the task queue
match self.waiters.pop() {
Some(channel) => match channel.reply_to.send(task).await {
Some(channel) => match channel
.reply_to
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
Expand All @@ -116,7 +121,11 @@ impl Dispatcher {
/// when one is available
async fn handle_work_request(&mut self, request: TaskRequestMessage) {
match self.task_queue.pop() {
Some(task) => match request.reply_to.send(task).await {
Some(task) => match request
.reply_to
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
Expand Down Expand Up @@ -265,7 +274,7 @@ mod tests {
impl Handler<()> for MockDispatchUser {
async fn handle(&mut self, _message: (), ctx: &ComponentContext<MockDispatchUser>) {
let task = wrap(Box::new(MockOperator {}), 42.0, ctx.sender.as_receiver());
let res = self.dispatcher.send(task).await;
let res = self.dispatcher.send(task, None).await;
}
}

Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ where
{
async fn run(&self) {
let output = self.operator.run(&self.input).await;
let res = self.reply_channel.send(output).await;
let res = self.reply_channel.send(output, None).await;
// TODO: if this errors, it means the caller was dropped
}
}
Expand Down
7 changes: 7 additions & 0 deletions rust/worker/src/execution/operators/brute_force_knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{distance::DistanceFunction, execution::operator::Operator};
use async_trait::async_trait;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use tracing::{debug, trace};

/// The brute force k-nearest neighbors operator is responsible for computing the k-nearest neighbors
/// of a given query vector against a set of vectors using brute force calculation.
Expand Down Expand Up @@ -114,6 +115,12 @@ impl Operator<BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput> for Brute
}
let mut data_chunk = data_chunk.clone();
data_chunk.set_visibility(visibility);
trace!(
"Brute force Knn result. data: {:?}, indices: {:?}, distances: {:?}",
data_chunk,
sorted_indices,
sorted_distances
);
Ok(BruteForceKnnOperatorOutput {
data: data_chunk,
indices: sorted_indices,
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/src/execution/operators/pull_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::execution::operator::Operator;
use crate::log::log::Log;
use crate::log::log::PullLogsError;
use async_trait::async_trait;
use tracing::debug;
use tracing::trace;
use uuid::Uuid;

/// The pull logs operator is responsible for reading logs from the log service.
Expand Down Expand Up @@ -130,8 +132,10 @@ impl Operator<PullLogsInput, PullLogsOutput> for PullLogsOperator {
break;
}
}
trace!("Log records {:?}", result);
if input.num_records.is_some() && result.len() > input.num_records.unwrap() as usize {
result.truncate(input.num_records.unwrap() as usize);
trace!("Truncated log records {:?}", result);
}
// Convert to DataChunk
let data_chunk = DataChunk::new(result.into());
Expand Down
6 changes: 3 additions & 3 deletions rust/worker/src/execution/orchestration/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl CompactOrchestrator {
};
let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp));
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task).await {
match self.dispatcher.send(task, None).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand All @@ -139,7 +139,7 @@ impl CompactOrchestrator {
let operator = PartitionOperator::new();
let input = PartitionInput::new(records, max_partition_size);
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task).await {
match self.dispatcher.send(task, None).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down Expand Up @@ -179,7 +179,7 @@ impl CompactOrchestrator {
);

let task = wrap(operator, input, self_address);
match self.dispatcher.send(task).await {
match self.dispatcher.send(task, None).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down
25 changes: 21 additions & 4 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
use async_trait::async_trait;
use std::fmt::Debug;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{trace, trace_span, Instrument, Span};
use uuid::Uuid;

/** The state of the orchestrator.
Expand Down Expand Up @@ -97,7 +98,10 @@ impl HnswQueryOrchestrator {
.await;
match segments {
Ok(segments) => match segments.get(0) {
Some(segment) => segment.collection,
Some(segment) => {
trace!("Collection Id {:?}", segment.collection);
segment.collection
}
None => None,
},
Err(e) => {
Expand All @@ -110,7 +114,13 @@ impl HnswQueryOrchestrator {
async fn pull_logs(&mut self, self_address: Box<dyn Receiver<PullLogsResult>>) {
self.state = ExecutionState::PullLogs;
let operator = PullLogsOperator::new(self.log.clone());
let collection_id = match self.get_collection_id_for_segment_id(self.segment_id).await {
let child_span: tracing::Span =
trace_span!(parent: Span::current(), "get collection id for segment id");
let get_collection_id_future = self.get_collection_id_for_segment_id(self.segment_id);
let collection_id = match get_collection_id_future
.instrument(child_span.clone())
.await
{
Some(collection_id) => collection_id,
None => {
// Log an error and reply + return
Expand All @@ -128,7 +138,9 @@ impl HnswQueryOrchestrator {
};
let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp));
let task = wrap(operator, input, self_address);
match self.dispatcher.send(task).await {
// Wrap the task with current span as the parent. The worker then executes it
// inside a child span with this parent.
match self.dispatcher.send(task, Some(child_span.clone())).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down Expand Up @@ -186,7 +198,11 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
};
let operator = Box::new(BruteForceKnnOperator {});
let task = wrap(operator, bf_input, ctx.sender.as_receiver());
match self.dispatcher.send(task).await {
match self
.dispatcher
.send(task, Some(Span::current().clone()))
.await
{
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
Expand Down Expand Up @@ -230,6 +246,7 @@ impl Handler<BruteForceKnnOperatorResult> for HnswQueryOrchestrator {
query_results.push(query_result);
}
result.push(query_results);
trace!("Merged results: {:?}", result);

match result_channel.send(Ok(result)) {
Ok(_) => (),
Expand Down
4 changes: 2 additions & 2 deletions rust/worker/src/execution/worker_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Component for WorkerThread {

async fn on_start(&mut self, ctx: &ComponentContext<Self>) -> () {
let req = TaskRequestMessage::new(ctx.sender.as_receiver());
let res = self.dispatcher.send(req).await;
let res = self.dispatcher.send(req, None).await;
// TODO: what to do with resp?
}
}
Expand All @@ -52,7 +52,7 @@ impl Handler<TaskMessage> for WorkerThread {
async fn handle(&mut self, task: TaskMessage, ctx: &ComponentContext<WorkerThread>) {
task.run().await;
let req: TaskRequestMessage = TaskRequestMessage::new(ctx.sender.as_receiver());
let res = self.dispatcher.send(req).await;
let res = self.dispatcher.send(req, None).await;
// TODO: task run should be able to error and we should send it as part of the result
}
}
2 changes: 1 addition & 1 deletion rust/worker/src/memberlist/memberlist_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl CustomResourceMemberlistProvider {
};

for subscriber in self.subscribers.iter() {
let _ = subscriber.send(curr_memberlist.clone()).await;
let _ = subscriber.send(curr_memberlist.clone(), None).await;
}
}
}
Expand Down
24 changes: 15 additions & 9 deletions rust/worker/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::system::{Receiver, System};
use crate::types::ScalarEncoding;
use async_trait::async_trait;
use tonic::{transport::Server, Request, Response, Status};
use tracing::{debug, trace, trace_span};
use uuid::Uuid;

pub struct WorkerServer {
Expand Down Expand Up @@ -93,6 +94,7 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer {
Err(Status::unimplemented("Not yet implemented"))
}

#[tracing::instrument(skip(self, request), fields(request_metadata = ?request.metadata(), k = request.get_ref().k, segment_id = request.get_ref().segment_id, include_embeddings = request.get_ref().include_embeddings, allowed_ids = ?request.get_ref().allowed_ids))]
async fn query_vectors(
&self,
request: Request<QueryVectorsRequest>,
Expand All @@ -108,15 +110,19 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer {
let mut proto_results_for_all = Vec::new();

let mut query_vectors = Vec::new();
for proto_query_vector in request.vectors {
let (query_vector, _encoding) = match proto_query_vector.try_into() {
Ok((vector, encoding)) => (vector, encoding),
Err(e) => {
return Err(Status::internal(format!("Error converting vector: {}", e)));
}
};
query_vectors.push(query_vector);
}
trace_span!("Input vectors parsing").in_scope(|| {
for proto_query_vector in request.vectors {
let (query_vector, _encoding) = match proto_query_vector.try_into() {
Ok((vector, encoding)) => (vector, encoding),
Err(e) => {
return Err(Status::internal(format!("Error converting vector: {}", e)));
}
};
query_vectors.push(query_vector);
}
trace!("Parsed vectors {:?}", query_vectors);
Ok(())
});

let dispatcher = match self.dispatcher {
Some(ref dispatcher) => dispatcher,
Expand Down
19 changes: 15 additions & 4 deletions rust/worker/src/system/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::{
use crate::system::ComponentContext;
use std::sync::Arc;
use tokio::select;
use tracing::{trace_span, Instrument, Span};

struct Inner<C>
where
Expand Down Expand Up @@ -69,14 +70,24 @@ where
message = channel.recv() => {
match message {
Some(mut message) => {
message.handle(&mut self.handler,
&ComponentContext{
let parent_span: tracing::Span;
match message.get_tracing_context() {
Some(spn) => {
parent_span = spn;
},
None => {
parent_span = Span::current().clone();
}
}
let child_span = trace_span!(parent: parent_span, "task handler");
let component_context = ComponentContext {
system: self.inner.system.clone(),
sender: self.inner.sender.clone(),
cancellation_token: self.inner.cancellation_token.clone(),
scheduler: self.inner.scheduler.clone(),
}
).await;
};
let task_future = message.handle(&mut self.handler, &component_context);
task_future.instrument(child_span).await;
}
None => {
// TODO: Log error
Expand Down
4 changes: 2 additions & 2 deletions rust/worker/src/system/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Scheduler {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message).await {
match sender.send(message, None).await {
Ok(_) => {
return;
},
Expand Down Expand Up @@ -83,7 +83,7 @@ impl Scheduler {
return;
}
_ = tokio::time::sleep(duration) => {
match sender.send(message.clone()).await {
match sender.send(message.clone(), None).await {
Ok(_) => {
},
Err(e) => {
Expand Down
Loading

0 comments on commit adf011c

Please sign in to comment.