From 5dd161c820371150b2ffcdba2b89bf11eceedf01 Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Fri, 6 Dec 2024 14:45:40 -0800 Subject: [PATCH] [ENH] Use tasks for concurrency. (#3257) This PR brings more concurrency to chroma-load for high-request-rate workloads. It does so by spawning a tokio task for each in-flight workload step. --- rust/load/src/lib.rs | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/rust/load/src/lib.rs b/rust/load/src/lib.rs index 27adfd04bd5..d554cc97c6f 100644 --- a/rust/load/src/lib.rs +++ b/rust/load/src/lib.rs @@ -33,6 +33,7 @@ pub enum Error { NotFound(String), InvalidRequest(String), InternalError(String), + FailWorkload(String), } impl std::fmt::Display for Error { @@ -41,6 +42,7 @@ impl std::fmt::Display for Error { Error::NotFound(msg) => write!(f, "not found: {}", msg), Error::InvalidRequest(msg) => write!(f, "invalid request: {}", msg), Error::InternalError(msg) => write!(f, "internal error: {}", msg), + Error::FailWorkload(msg) => write!(f, "workload failed: {}", msg), } } } @@ -53,6 +55,7 @@ impl axum::response::IntoResponse for Error { Error::NotFound(msg) => (StatusCode::NOT_FOUND, msg), Error::InvalidRequest(msg) => (StatusCode::BAD_REQUEST, msg), Error::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), + Error::FailWorkload(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), }; axum::http::Response::builder() .status(status) @@ -528,9 +531,20 @@ impl LoadService { inhibit: Arc, spec: RunningWorkload, ) { - let client = client(); + let client = Arc::new(client()); let mut guac = Guacamole::new(spec.expires.timestamp_millis() as u64); let mut next_op = Instant::now(); + let (tx, mut rx) = tokio::sync::mpsc::channel(spec.throughput.ceil() as usize); + let _ = tx + .send(tokio::spawn(async move { Ok::<(), Error>(()) })) + .await; + let reaper = tokio::spawn(async move { + while let Some(task) = rx.recv().await { + if let Err(err) = task.await.unwrap() { + tracing::error!("workload task failed: {err:?}"); + } + } + }); while !done.load(std::sync::atomic::Ordering::Relaxed) { let delay = interarrival_duration(spec.throughput)(&mut guac); next_op += delay; @@ -542,14 +556,25 @@ impl LoadService { tracing::info!("inhibited"); } else if !spec.workload.is_active() { tracing::debug!("workload inactive"); - } else if let Err(err) = spec - .workload - .step(&client, &*spec.data_set, &mut guac) - .await - { - tracing::error!("workload failed: {err:?}"); + } else { + let workload = spec.workload.clone(); + let client = Arc::clone(&client); + let data_set = Arc::clone(&spec.data_set); + let mut guacamole = Guacamole::new(any(&mut guac)); + let fut = async move { + workload + .step(&client, &*data_set, &mut guacamole) + .await + .map_err(|err| { + tracing::error!("workload failed: {err:?}"); + Error::FailWorkload(err.to_string()) + }) + }; + tx.send(tokio::spawn(fut)).await.unwrap(); } } + drop(tx); + reaper.await.unwrap(); } }