diff --git a/Cargo.lock b/Cargo.lock index e037dcac20b..46df5b64478 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4369,11 +4369,14 @@ dependencies = [ [[package]] name = "quick-protobuf-codec" -version = "0.3.0" +version = "0.3.1" dependencies = [ "asynchronous-codec", "bytes", + "criterion", + "futures", "quick-protobuf", + "quickcheck-ext", "thiserror", "unsigned-varint 0.8.0", ] diff --git a/Cargo.toml b/Cargo.toml index 0603a22629b..1dce34011c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ multiaddr = "0.18.1" multihash = "0.19.1" multistream-select = { version = "0.13.0", path = "misc/multistream-select" } prometheus-client = "0.22.0" -quick-protobuf-codec = { version = "0.3.0", path = "misc/quick-protobuf-codec" } +quick-protobuf-codec = { version = "0.3.1", path = "misc/quick-protobuf-codec" } quickcheck = { package = "quickcheck-ext", path = "misc/quickcheck-ext" } rw-stream-sink = { version = "0.4.0", path = "misc/rw-stream-sink" } unsigned-varint = { version = "0.8.0" } diff --git a/misc/quick-protobuf-codec/CHANGELOG.md b/misc/quick-protobuf-codec/CHANGELOG.md index 779dd750abd..a301293621f 100644 --- a/misc/quick-protobuf-codec/CHANGELOG.md +++ b/misc/quick-protobuf-codec/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.3.1 + +- Reduce allocations during encoding. + See [PR 4782](https://github.com/libp2p/rust-libp2p/pull/4782). + ## 0.3.0 - Update to `asynchronous-codec` `v0.7.0`. diff --git a/misc/quick-protobuf-codec/Cargo.toml b/misc/quick-protobuf-codec/Cargo.toml index fdc65cfa93c..484e2c9bc8b 100644 --- a/misc/quick-protobuf-codec/Cargo.toml +++ b/misc/quick-protobuf-codec/Cargo.toml @@ -3,7 +3,7 @@ name = "quick-protobuf-codec" edition = "2021" rust-version = { workspace = true } description = "Asynchronous de-/encoding of Protobuf structs using asynchronous-codec, unsigned-varint and quick-protobuf." -version = "0.3.0" +version = "0.3.1" authors = ["Max Inden "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" @@ -14,9 +14,18 @@ categories = ["asynchronous"] asynchronous-codec = { workspace = true } bytes = { version = "1" } thiserror = "1.0" -unsigned-varint = { workspace = true, features = ["asynchronous_codec"] } +unsigned-varint = { workspace = true, features = ["std"] } quick-protobuf = "0.8" +[dev-dependencies] +criterion = "0.5.1" +futures = "0.3.28" +quickcheck = { workspace = true } + +[[bench]] +name = "codec" +harness = false + # Passing arguments to the docsrs builder in order to properly document cfg's. # More information: https://docs.rs/about/builds#cross-compiling [package.metadata.docs.rs] diff --git a/misc/quick-protobuf-codec/benches/codec.rs b/misc/quick-protobuf-codec/benches/codec.rs new file mode 100644 index 00000000000..0f6ce9469c5 --- /dev/null +++ b/misc/quick-protobuf-codec/benches/codec.rs @@ -0,0 +1,28 @@ +use asynchronous_codec::Encoder; +use bytes::BytesMut; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use quick_protobuf_codec::{proto, Codec}; + +pub fn benchmark(c: &mut Criterion) { + for size in [1000, 10_000, 100_000, 1_000_000, 10_000_000] { + c.bench_with_input(BenchmarkId::new("encode", size), &size, |b, i| { + b.iter_batched( + || { + let mut out = BytesMut::new(); + out.reserve(i + 100); + let codec = Codec::::new(i + 100); + let msg = proto::Message { + data: vec![0; size], + }; + + (codec, out, msg) + }, + |(mut codec, mut out, msg)| codec.encode(msg, &mut out).unwrap(), + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group!(benches, benchmark); +criterion_main!(benches); diff --git a/misc/quick-protobuf-codec/src/generated/mod.rs b/misc/quick-protobuf-codec/src/generated/mod.rs new file mode 100644 index 00000000000..b9f982f8dfd --- /dev/null +++ b/misc/quick-protobuf-codec/src/generated/mod.rs @@ -0,0 +1,2 @@ +// Automatically generated mod.rs +pub mod test; diff --git a/misc/quick-protobuf-codec/src/generated/test.proto b/misc/quick-protobuf-codec/src/generated/test.proto new file mode 100644 index 00000000000..5b1f46c0bfa --- /dev/null +++ b/misc/quick-protobuf-codec/src/generated/test.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package test; + +message Message { + bytes data = 1; +} diff --git a/misc/quick-protobuf-codec/src/generated/test.rs b/misc/quick-protobuf-codec/src/generated/test.rs new file mode 100644 index 00000000000..b353e6d9183 --- /dev/null +++ b/misc/quick-protobuf-codec/src/generated/test.rs @@ -0,0 +1,47 @@ +// Automatically generated rust module for 'test.proto' file + +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +#![allow(unknown_lints)] +#![allow(clippy::all)] +#![cfg_attr(rustfmt, rustfmt_skip)] + + +use quick_protobuf::{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result}; +use quick_protobuf::sizeofs::*; +use super::*; + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Debug, Default, PartialEq, Clone)] +pub struct Message { + pub data: Vec, +} + +impl<'a> MessageRead<'a> for Message { + fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result { + let mut msg = Self::default(); + while !r.is_eof() { + match r.next_tag(bytes) { + Ok(10) => msg.data = r.read_bytes(bytes)?.to_owned(), + Ok(t) => { r.read_unknown(bytes, t)?; } + Err(e) => return Err(e), + } + } + Ok(msg) + } +} + +impl MessageWrite for Message { + fn get_size(&self) -> usize { + 0 + + if self.data.is_empty() { 0 } else { 1 + sizeof_len((&self.data).len()) } + } + + fn write_message(&self, w: &mut Writer) -> Result<()> { + if !self.data.is_empty() { w.write_with_tag(10, |w| w.write_bytes(&**&self.data))?; } + Ok(()) + } +} + diff --git a/misc/quick-protobuf-codec/src/lib.rs b/misc/quick-protobuf-codec/src/lib.rs index 2d1fda99a70..c50b1264af6 100644 --- a/misc/quick-protobuf-codec/src/lib.rs +++ b/misc/quick-protobuf-codec/src/lib.rs @@ -1,16 +1,21 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] use asynchronous_codec::{Decoder, Encoder}; -use bytes::{Bytes, BytesMut}; -use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer}; +use bytes::{Buf, BufMut, BytesMut}; +use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend}; +use std::io; use std::marker::PhantomData; -use unsigned_varint::codec::UviBytes; + +mod generated; + +#[doc(hidden)] // NOT public API. Do not use. +pub use generated::test as proto; /// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`] /// to prefix messages with their length and uses [`quick_protobuf`] and a provided /// `struct` implementing [`MessageRead`] and [`MessageWrite`] to do the encoding. pub struct Codec { - uvi: UviBytes, + max_message_len_bytes: usize, phantom: PhantomData<(In, Out)>, } @@ -21,10 +26,8 @@ impl Codec { /// Protobuf message. The limit does not include the bytes needed for the /// [`unsigned_varint`]. pub fn new(max_message_len_bytes: usize) -> Self { - let mut uvi = UviBytes::default(); - uvi.set_max_len(max_message_len_bytes); Self { - uvi, + max_message_len_bytes, phantom: PhantomData, } } @@ -35,16 +38,32 @@ impl Encoder for Codec { type Error = Error; fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> { - let mut encoded_msg = Vec::new(); - let mut writer = Writer::new(&mut encoded_msg); - item.write_message(&mut writer) - .expect("Encoding to succeed"); - self.uvi.encode(Bytes::from(encoded_msg), dst)?; + write_length(&item, dst); + write_message(&item, dst)?; Ok(()) } } +/// Write the message's length (i.e. `size`) to `dst` as a variable-length integer. +fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) { + let message_length = message.get_size(); + + let mut uvi_buf = unsigned_varint::encode::usize_buffer(); + let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf); + + dst.extend_from_slice(encoded_length); +} + +/// Write the message itself to `dst`. +fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> { + let mut writer = Writer::new(BytesMutWriterBackend::new(dst)); + item.write_message(&mut writer) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + Ok(()) +} + impl Decoder for Codec where Out: for<'a> MessageRead<'a>, @@ -53,24 +72,203 @@ where type Error = Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - let msg = match self.uvi.decode(src)? { - None => return Ok(None), - Some(msg) => msg, + let (message_length, remaining) = match unsigned_varint::decode::usize(src) { + Ok((len, remaining)) => (len, remaining), + Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None), + Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))), }; - let mut reader = BytesReader::from_bytes(&msg); - let message = Self::Item::from_reader(&mut reader, &msg) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + if message_length > self.max_message_len_bytes { + return Err(Error(io::Error::new( + io::ErrorKind::PermissionDenied, + format!( + "message with {message_length}b exceeds maximum of {}b", + self.max_message_len_bytes + ), + ))); + } + + // Compute how many bytes the varint itself consumed. + let varint_length = src.len() - remaining.len(); + + // Ensure we can read an entire message. + if src.len() < (message_length + varint_length) { + return Ok(None); + } + + // Safe to advance buffer now. + src.advance(varint_length); + + let message = src.split_to(message_length); + + let mut reader = BytesReader::from_bytes(&message); + let message = Self::Item::from_reader(&mut reader, &message) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(Some(message)) } } +struct BytesMutWriterBackend<'a> { + dst: &'a mut BytesMut, +} + +impl<'a> BytesMutWriterBackend<'a> { + fn new(dst: &'a mut BytesMut) -> Self { + Self { dst } + } +} + +impl<'a> WriterBackend for BytesMutWriterBackend<'a> { + fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> { + self.dst.put_u8(x); + + Ok(()) + } + + fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> { + self.dst.put_u32_le(x); + + Ok(()) + } + + fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> { + self.dst.put_i32_le(x); + + Ok(()) + } + + fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> { + self.dst.put_f32_le(x); + + Ok(()) + } + + fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> { + self.dst.put_u64_le(x); + + Ok(()) + } + + fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> { + self.dst.put_i64_le(x); + + Ok(()) + } + + fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> { + self.dst.put_f64_le(x); + + Ok(()) + } + + fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> { + self.dst.put_slice(buf); + + Ok(()) + } +} + #[derive(thiserror::Error, Debug)] #[error("Failed to encode/decode message")] -pub struct Error(#[from] std::io::Error); +pub struct Error(#[from] io::Error); -impl From for std::io::Error { +impl From for io::Error { fn from(e: Error) -> Self { e.0 } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto; + use asynchronous_codec::FramedRead; + use futures::io::Cursor; + use futures::{FutureExt, StreamExt}; + use quickcheck::{Arbitrary, Gen, QuickCheck}; + use std::error::Error; + + #[test] + fn honors_max_message_length() { + let codec = Codec::::new(1); + let mut src = varint_zeroes(100); + + let mut read = FramedRead::new(Cursor::new(&mut src), codec); + let err = read.next().now_or_never().unwrap().unwrap().unwrap_err(); + + assert_eq!( + err.source().unwrap().to_string(), + "message with 100b exceeds maximum of 1b" + ) + } + + #[test] + fn empty_bytes_mut_does_not_panic() { + let mut codec = Codec::::new(100); + + let mut src = varint_zeroes(100); + src.truncate(50); + + let result = codec.decode(&mut src); + + assert!(result.unwrap().is_none()); + assert_eq!( + src.len(), + 50, + "to not modify `src` if we cannot read a full message" + ) + } + + #[test] + fn only_partial_message_in_bytes_mut_does_not_panic() { + let mut codec = Codec::::new(100); + + let result = codec.decode(&mut BytesMut::new()); + + assert!(result.unwrap().is_none()); + } + + #[test] + fn handles_arbitrary_initial_capacity() { + fn prop(message: proto::Message, initial_capacity: u16) { + let mut buffer = BytesMut::with_capacity(initial_capacity as usize); + let mut codec = Codec::::new(u32::MAX as usize); + + codec.encode(message.clone(), &mut buffer).unwrap(); + let decoded = codec.decode(&mut buffer).unwrap().unwrap(); + + assert_eq!(message, decoded); + } + + QuickCheck::new().quickcheck(prop as fn(_, _) -> _) + } + + /// Constructs a [`BytesMut`] of the provided length where the message is all zeros. + fn varint_zeroes(length: usize) -> BytesMut { + let mut buf = unsigned_varint::encode::usize_buffer(); + let encoded_length = unsigned_varint::encode::usize(length, &mut buf); + + let mut src = BytesMut::new(); + src.extend_from_slice(encoded_length); + src.extend(std::iter::repeat(0).take(length)); + src + } + + impl Arbitrary for proto::Message { + fn arbitrary(g: &mut Gen) -> Self { + Self { + data: Vec::arbitrary(g), + } + } + } + + #[derive(Debug)] + struct Dummy; + + impl<'a> MessageRead<'a> for Dummy { + fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result { + todo!() + } + } +} diff --git a/misc/quick-protobuf-codec/tests/large_message.rs b/misc/quick-protobuf-codec/tests/large_message.rs new file mode 100644 index 00000000000..65dafe065d1 --- /dev/null +++ b/misc/quick-protobuf-codec/tests/large_message.rs @@ -0,0 +1,16 @@ +use asynchronous_codec::Encoder; +use bytes::BytesMut; +use quick_protobuf_codec::proto; +use quick_protobuf_codec::Codec; + +#[test] +fn encode_large_message() { + let mut codec = Codec::::new(1_001_000); + let mut dst = BytesMut::new(); + dst.reserve(1_001_000); + let message = proto::Message { + data: vec![0; 1_000_000], + }; + + codec.encode(message, &mut dst).unwrap(); +}