Skip to content

Commit

Permalink
style: Move encoding functions into separate modules (#1111)
Browse files Browse the repository at this point in the history
* style: Move varint encoding to separate module

While the `encoding` module is undocumented and not stable, there are known users. Leave a alias in place to prevent breaking them.

* style: Move length delimiter encoding to separate module

Leave a alias in `lib.rs` to prevent a breaking change.

* style: Move wire type to separate module

While the `encoding` module is undocumented and not stable, there are known users. Leave a alias in place to prevent breaking them.
  • Loading branch information
caspermeijn authored Jul 31, 2024
1 parent ad5650b commit ef27f65
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 361 deletions.
4 changes: 2 additions & 2 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
fn merge_field(
&mut self,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down Expand Up @@ -472,7 +472,7 @@ fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
pub fn merge(
field: &mut ::core::option::Option<#ident #ty_generics>,
tag: u32,
wire_type: ::prost::encoding::WireType,
wire_type: ::prost::encoding::wire_type::WireType,
buf: &mut impl ::prost::bytes::Buf,
ctx: ::prost::encoding::DecodeContext,
) -> ::core::result::Result<(), ::prost::DecodeError>
Expand Down
2 changes: 1 addition & 1 deletion prost/benches/varint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::mem;

use bytes::Buf;
use criterion::{Criterion, Throughput};
use prost::encoding::{decode_varint, encode_varint, encoded_len_varint};
use prost::encoding::varint::{decode_varint, encode_varint, encoded_len_varint};
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};

fn benchmark_varint(criterion: &mut Criterion, name: &str, mut values: Vec<u64>) {
Expand Down
311 changes: 8 additions & 303 deletions prost/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use core::cmp::min;
use core::mem;
use core::str;

Expand All @@ -17,162 +16,16 @@ use ::bytes::{Buf, BufMut, Bytes};
use crate::DecodeError;
use crate::Message;

/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
/// The buffer must have enough remaining space (maximum 10 bytes).
#[inline]
pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) {
// Varints are never more than 10 bytes
for _ in 0..10 {
if value < 0x80 {
buf.put_u8(value as u8);
break;
} else {
buf.put_u8(((value & 0x7F) | 0x80) as u8);
value >>= 7;
}
}
}

/// Decodes a LEB128-encoded variable length integer from the buffer.
#[inline]
pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let bytes = buf.chunk();
let len = bytes.len();
if len == 0 {
return Err(DecodeError::new("invalid varint"));
}

let byte = bytes[0];
if byte < 0x80 {
buf.advance(1);
Ok(u64::from(byte))
} else if len > 10 || bytes[len - 1] < 0x80 {
let (value, advance) = decode_varint_slice(bytes)?;
buf.advance(advance);
Ok(value)
} else {
decode_varint_slow(buf)
}
}
pub mod varint;
pub use varint::{decode_varint, encode_varint, encoded_len_varint};

/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
/// number of bytes read.
///
/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
/// [`ConsumeVarint`][2].
///
/// ## Safety
///
/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
/// element in bytes is < `0x80`.
///
/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
#[inline]
fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
// Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
pub mod length_delimiter;
pub use length_delimiter::{
decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
};

// Use assertions to ensure memory safety, but it should always be optimized after inline.
assert!(!bytes.is_empty());
assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);

let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
let mut part0: u32 = u32::from(b);
if b < 0x80 {
return Ok((u64::from(part0), 1));
};
part0 -= 0x80;
b = unsafe { *bytes.get_unchecked(1) };
part0 += u32::from(b) << 7;
if b < 0x80 {
return Ok((u64::from(part0), 2));
};
part0 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(2) };
part0 += u32::from(b) << 14;
if b < 0x80 {
return Ok((u64::from(part0), 3));
};
part0 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(3) };
part0 += u32::from(b) << 21;
if b < 0x80 {
return Ok((u64::from(part0), 4));
};
part0 -= 0x80 << 21;
let value = u64::from(part0);

b = unsafe { *bytes.get_unchecked(4) };
let mut part1: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 5));
};
part1 -= 0x80;
b = unsafe { *bytes.get_unchecked(5) };
part1 += u32::from(b) << 7;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 6));
};
part1 -= 0x80 << 7;
b = unsafe { *bytes.get_unchecked(6) };
part1 += u32::from(b) << 14;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 7));
};
part1 -= 0x80 << 14;
b = unsafe { *bytes.get_unchecked(7) };
part1 += u32::from(b) << 21;
if b < 0x80 {
return Ok((value + (u64::from(part1) << 28), 8));
};
part1 -= 0x80 << 21;
let value = value + ((u64::from(part1)) << 28);

b = unsafe { *bytes.get_unchecked(8) };
let mut part2: u32 = u32::from(b);
if b < 0x80 {
return Ok((value + (u64::from(part2) << 56), 9));
};
part2 -= 0x80;
b = unsafe { *bytes.get_unchecked(9) };
part2 += u32::from(b) << 7;
// Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
if b < 0x02 {
return Ok((value + (u64::from(part2) << 56), 10));
};

// We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
// Assume the data is corrupt.
Err(DecodeError::new("invalid varint"))
}

/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
/// necessary.
///
/// Contains a varint overflow check from [`ConsumeVarint`][1].
///
/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
#[inline(never)]
#[cold]
fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> {
let mut value = 0;
for count in 0..min(10, buf.remaining()) {
let byte = buf.get_u8();
value |= u64::from(byte & 0x7F) << (count * 7);
if byte <= 0x7F {
// Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
if count == 9 && byte >= 0x02 {
return Err(DecodeError::new("invalid varint"));
} else {
return Ok(value);
}
}
}

Err(DecodeError::new("invalid varint"))
}
pub mod wire_type;
pub use wire_type::{check_wire_type, WireType};

/// Additional information passed to every decode/merge function.
///
Expand Down Expand Up @@ -244,49 +97,9 @@ impl DecodeContext {
}
}

/// Returns the encoded length of the value in LEB128 variable length format.
/// The returned value will be between 1 and 10, inclusive.
#[inline]
pub fn encoded_len_varint(value: u64) -> usize {
// Based on [VarintSize64][1].
// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Varint = 0,
SixtyFourBit = 1,
LengthDelimited = 2,
StartGroup = 3,
EndGroup = 4,
ThirtyTwoBit = 5,
}

pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;

impl TryFrom<u64> for WireType {
type Error = DecodeError;

#[inline]
fn try_from(value: u64) -> Result<Self, Self::Error> {
match value {
0 => Ok(WireType::Varint),
1 => Ok(WireType::SixtyFourBit),
2 => Ok(WireType::LengthDelimited),
3 => Ok(WireType::StartGroup),
4 => Ok(WireType::EndGroup),
5 => Ok(WireType::ThirtyTwoBit),
_ => Err(DecodeError::new(format!(
"invalid wire type value: {}",
value
))),
}
}
}

/// Encodes a Protobuf field key, which consists of a wire type designator and
/// the field tag.
#[inline]
Expand Down Expand Up @@ -321,19 +134,6 @@ pub fn key_len(tag: u32) -> usize {
encoded_len_varint(u64::from(tag << 3))
}

/// Checks that the expected wire type matches the actual wire type,
/// or returns an error result.
#[inline]
pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
if expected != actual {
return Err(DecodeError::new(format!(
"invalid wire type: {:?} (expected {:?})",
actual, expected
)));
}
Ok(())
}

/// Helper function which abstracts reading a length delimiter prefix followed
/// by decoding values until the length of bytes is exhausted.
pub fn merge_loop<T, M, B>(
Expand Down Expand Up @@ -1522,101 +1322,6 @@ mod test {
assert!(s.is_empty());
}

#[test]
fn varint() {
fn check(value: u64, encoded: &[u8]) {
// Small buffer.
let mut buf = Vec::with_capacity(1);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);

// Large buffer.
let mut buf = Vec::with_capacity(100);
encode_varint(value, &mut buf);
assert_eq!(buf, encoded);

assert_eq!(encoded_len_varint(value), encoded.len());

// See: https://github.com/tokio-rs/prost/pull/1008 for copying reasoning.
let mut encoded_copy = encoded;
let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed");
assert_eq!(value, roundtrip_value);

let mut encoded_copy = encoded;
let roundtrip_value =
decode_varint_slow(&mut encoded_copy).expect("slow decoding failed");
assert_eq!(value, roundtrip_value);
}

check(2u64.pow(0) - 1, &[0x00]);
check(2u64.pow(0), &[0x01]);

check(2u64.pow(7) - 1, &[0x7F]);
check(2u64.pow(7), &[0x80, 0x01]);
check(300, &[0xAC, 0x02]);

check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
check(2u64.pow(14), &[0x80, 0x80, 0x01]);

check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);

check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);

check(
2u64.pow(49) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(49),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
2u64.pow(56) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(56),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
2u64.pow(63) - 1,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
check(
2u64.pow(63),
&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
);

check(
u64::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
);
}

const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];

#[test]
fn varint_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded");
}

#[test]
fn variant_slow_overflow() {
let mut copy = U64_MAX_PLUS_ONE;
decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded");
}

/// This big bowl o' macro soup generates an encoding property test for each combination of map
/// type, scalar map key, and value type.
/// TODO: these tests take a long time to compile, can this be improved?
Expand Down
Loading

0 comments on commit ef27f65

Please sign in to comment.