Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zstd compression for multipart payloads on the Rust side. #1401

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ resolver = "2"
chrono = "0.4.38"
flate2 = "1.0.34"
futures = "0.3.31"
http = "1.2.0"
rayon = "1.10.0"
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
Expand All @@ -20,6 +21,7 @@ tokio = { version = "1", features = ["full"] }
tokio-util = "0.7.12"
ureq = "2.10.1"
uuid = { version = "1.11.0", features = ["v4"] }
zstd = { version = "0.13.2", features = ["zstdmt"] }

# Use rustls instead of OpenSSL, because OpenSSL is a nightmare when compiling across platforms.
# OpenSSL is a default feature, so we have to disable all default features, then re-add
Expand Down
2 changes: 1 addition & 1 deletion rust/crates/langsmith-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "langsmith-pyo3"
version = "0.1.0-rc5"
version = "0.2.0-rc1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 2 additions & 0 deletions rust/crates/langsmith-pyo3/src/blocking_tracing_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ impl BlockingTracingClient {
batch_size: usize,
batch_timeout_millis: u64,
worker_threads: usize,
compression_level: i32,
) -> PyResult<Self> {
let config = langsmith_tracing_client::client::blocking::ClientConfig {
endpoint,
Expand All @@ -39,6 +40,7 @@ impl BlockingTracingClient {

headers: None, // TODO: support custom headers
num_worker_threads: worker_threads,
compression_level,
};

let client = RustTracingClient::new(config)
Expand Down
4 changes: 3 additions & 1 deletion rust/crates/langsmith-tracing-client/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "langsmith-tracing-client"
version = "0.1.0"
version = "0.2.0"
edition = "2021"

[dependencies]
Expand All @@ -17,6 +17,8 @@ futures = { workspace = true }
rayon = { workspace = true }
ureq = { workspace = true }
flate2 = { workspace = true }
zstd = { workspace = true }
http = { workspace = true }

[dev-dependencies]
multer = "3.1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ fn create_mock_client_config_sync(server_url: &str, batch_size: usize) -> Blocki
batch_timeout: Duration::from_secs(1),
headers: Default::default(),
num_worker_threads: 1,
compression_level: 1,
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::io::Write;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::{mpsc, Arc, Mutex};
use std::thread::available_parallelism;
use std::time::{Duration, Instant};

use rayon::iter::{IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -266,14 +268,18 @@ impl RunProcessor {
for (part_name, part) in json_parts.into_iter().chain(attachment_parts) {
form = form.part(part_name, part);
}
let content_type = format!("multipart/form-data; boundary={}", form.boundary());
let compressed_data = compress_multipart_body(form, self.config.compression_level)?;

// send the multipart POST request
let start_send_batch = Instant::now();
let response = self
.http_client
.post(format!("{}/runs/multipart", self.config.endpoint))
.multipart(form)
.headers(self.config.headers.as_ref().cloned().unwrap_or_default())
.header(http::header::CONTENT_TYPE, content_type)
.header(http::header::CONTENT_ENCODING, "zstd")
.body(compressed_data)
.send()?;
// println!("Sending batch took {:?}", start_send_batch.elapsed());

Expand Down Expand Up @@ -315,3 +321,63 @@ impl RunProcessor {
Ok(part)
}
}

fn compress_multipart_body(
form: Form,
compression_level: i32,
) -> Result<Vec<u8>, TracingClientError> {
// We want to use as many threads as available cores to compress data.
// However, we have to be mindful of special values in the zstd library:
// - A setting of `0` here means "use the current thread only."
// - A setting of `1` means "use a separate thread, but only one."
//
// `1` isn't a useful setting for us, so turn `1` into `0` while
// keeping higher numbers the same.
let n_workers = match available_parallelism() {
Ok(num) => {
if num.get() == 1 {
0
} else {
num.get() as u32
}
}
Err(_) => {
// We failed to query the available number of cores.
// Use only the current single thread, to be safe.
0
}
};

// `reqwest` doesn't have a public method for getting *just* the multipart form data:
// the `Form::reader()` method isn't public. Instead, we pretend to be preparing a request,
// place the multipart data, and ask for the body to be buffered and made available as `&[u8]`.
// *This does not send the request!* We merely prepare a request in memory as a workaround.
//
// Just in case, we use `example.com` as the URL here, which is explicitly reserved for
// example use in the standards and is guaranteed to not be taken. This means that
// under no circumstances will we send multipart data to any other place,
// even if `reqwest` were to suddenly change the API to send data on `.build()`.
let builder = reqwest::blocking::Client::new().get("http://example.com/").multipart(form);
let mut request = builder.build().expect("failed to construct request");
let multipart_form_bytes = request
.body_mut()
.as_mut()
.expect("multipart request had no body, this should never happen")
.buffer()
.map_err(|e| TracingClientError::IoError(e.to_string()))?;

let mut buffer = Vec::with_capacity(multipart_form_bytes.len());
let mut encoder = zstd::Encoder::new(&mut buffer, compression_level)
.and_then(|mut encoder| {
encoder.multithread(n_workers)?;
Ok(encoder)
})
.map_err(|e| TracingClientError::IoError(e.to_string()))?;

encoder
.write_all(multipart_form_bytes)
.map_err(|e| TracingClientError::IoError(e.to_string()))?;
encoder.finish().map_err(|e| TracingClientError::IoError(e.to_string()))?;
Comment on lines +351 to +380
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workaround will not be required anymore once the next reqwest version is released. This code will be replaced in #1403 which makes use of new reqwest functionality to do this properly.


Ok(buffer)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct ClientConfig {
pub batch_timeout: Duration,
pub headers: Option<HeaderMap>,
pub num_worker_threads: usize,
pub compression_level: i32,
}

pub struct TracingClient {
Expand Down
Loading