From 9dd640c8e93644e73fa39c78ea88e92b4166f2e5 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Wed, 21 Aug 2024 21:19:04 +1000 Subject: [PATCH] Initial write support (#122) * Zigzag and varint encoders * Short repeat and direct RLEv2 writers * Minor refactoring * Signed msb encoder * Refacftor RLE delta+patched base to be more functional * Minor refactoring * Byte RLE writer * Remove unused error types * Initial version of ORC writer, supporting only float * Int8 array write support * Integer RLEv2 Delta writing * Minor optimization * Abstract bits_used functionality to common NInt function * Remove overcomplicated AbsVarint code, replace with i64/u64 in delta encoding * Minor fixes * Initial RLEv2 encoder base * Remove u64 impl of NInt in favour of generic to determine sign * Simplify getting RLE reader * Fix percentile bit calculation encoding/decoding * Patched base writing * Support writing int arrays * Multi stripe write test case * Reduce duplication for primitive ColumnStripeEncoders * Introduce EstimateMemory trait to standardize * Comment * Remove need for seek from writer + minor PR comments * Rename s_type to kind * Deduplicate get_closest_fixed_bits * Fix comments * Switch arrow writer tests to be in-memory instead of writing to disk * Fix writing arrays with nulls * Add license to new files --- Cargo.toml | 3 +- src/array_decoder/list.rs | 8 +- src/array_decoder/map.rs | 8 +- src/array_decoder/mod.rs | 10 +- src/array_decoder/string.rs | 25 +- src/array_decoder/timestamp.rs | 4 +- src/array_decoder/union.rs | 4 +- src/arrow_writer.rs | 386 ++++++++++++++ src/lib.rs | 3 + src/reader/decode/boolean_rle.rs | 8 +- src/reader/decode/byte_rle.rs | 282 ++++++++-- src/reader/decode/decimal.rs | 4 +- src/reader/decode/mod.rs | 290 ++++++----- src/reader/decode/rle_v1.rs | 24 +- src/reader/decode/rle_v2/delta.rs | 325 ++++++++---- src/reader/decode/rle_v2/direct.rs | 134 ++++- src/reader/decode/rle_v2/mod.rs | 535 +++++++++++++++++-- src/reader/decode/rle_v2/patched_base.rs | 494 +++++++++++++----- src/reader/decode/rle_v2/short_repeat.rs | 141 +++-- src/reader/decode/timestamp.rs | 9 +- src/reader/decode/util.rs | 638 ++++++++++++++++++++--- src/reader/metadata.rs | 29 +- src/reader/mod.rs | 16 +- src/stripe.rs | 15 + src/writer/column.rs | 242 +++++++++ src/writer/mod.rs | 143 +++++ src/writer/stripe.rs | 200 +++++++ 27 files changed, 3404 insertions(+), 576 deletions(-) create mode 100644 src/arrow_writer.rs create mode 100644 src/writer/column.rs create mode 100644 src/writer/mod.rs create mode 100644 src/writer/stripe.rs diff --git a/Cargo.toml b/Cargo.toml index 988b42dc..48bacfb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" description = "Implementation of Apache ORC file format using Apache Arrow in-memory format" keywords = ["arrow", "orc", "arrow-rs", "datafusion"] include = ["src/**/*.rs", "Cargo.toml"] -rust-version = "1.70" +rust-version = "1.73" [dependencies] arrow = { version = "52", features = ["prettyprint", "chrono-tz"] } @@ -57,6 +57,7 @@ arrow-json = "52.0.0" criterion = { version = "0.5", default-features = false, features = ["async_tokio"] } opendal = { version = "0.48", default-features = false, features = ["services-memory"] } pretty_assertions = "1.3.0" +proptest = "1.0.0" serde_json = { version = "1.0", default-features = false, features = ["std"] } [features] diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index a77ecc98..39f84e96 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -25,9 +25,9 @@ use snafu::ResultExt; use crate::array_decoder::{derive_present_vec, populate_lengths_with_nulls}; use crate::column::{get_present_vec, Column}; use crate::proto::stream::Kind; -use crate::reader::decode::get_rle_reader; use crate::error::{ArrowSnafu, Result}; +use crate::reader::decode::get_unsigned_rle_reader; use crate::stripe::Stripe; use super::{array_decoder_factory, ArrayBatchDecoder}; @@ -35,7 +35,7 @@ use super::{array_decoder_factory, ArrayBatchDecoder}; pub struct ListArrayDecoder { inner: Box, present: Option + Send>>, - lengths: Box> + Send>, + lengths: Box> + Send>, field: FieldRef, } @@ -48,7 +48,7 @@ impl ListArrayDecoder { let inner = array_decoder_factory(child, field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_rle_reader(column, reader)?; + let lengths = get_unsigned_rle_reader(column, reader); Ok(Self { inner, @@ -83,7 +83,7 @@ impl ArrayBatchDecoder for ListArrayDecoder { elements_to_fetch, "less lengths than expected in ListArray" ); - let total_length: u64 = lengths.iter().sum(); + let total_length: i64 = lengths.iter().sum(); // Fetch child array as one Array with total_length elements let child_array = self.inner.next_batch(total_length as usize, None)?; let lengths = populate_lengths_with_nulls(lengths, batch_size, &present); diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index 64197086..b9c874e5 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -26,7 +26,7 @@ use crate::array_decoder::{derive_present_vec, populate_lengths_with_nulls}; use crate::column::{get_present_vec, Column}; use crate::error::{ArrowSnafu, Result}; use crate::proto::stream::Kind; -use crate::reader::decode::get_rle_reader; +use crate::reader::decode::get_unsigned_rle_reader; use crate::stripe::Stripe; use super::{array_decoder_factory, ArrayBatchDecoder}; @@ -35,7 +35,7 @@ pub struct MapArrayDecoder { keys: Box, values: Box, present: Option + Send>>, - lengths: Box> + Send>, + lengths: Box> + Send>, fields: Fields, } @@ -56,7 +56,7 @@ impl MapArrayDecoder { let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); - let lengths = get_rle_reader(column, reader)?; + let lengths = get_unsigned_rle_reader(column, reader); let fields = Fields::from(vec![keys_field, values_field]); @@ -94,7 +94,7 @@ impl ArrayBatchDecoder for MapArrayDecoder { elements_to_fetch, "less lengths than expected in MapArray" ); - let total_length: u64 = lengths.iter().sum(); + let total_length: i64 = lengths.iter().sum(); // Fetch key and value arrays, each with total_length elements // Fetch child array as one Array with total_length elements let keys_array = self.keys.next_batch(total_length as usize, None)?; diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index db6011ec..b3351604 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, PrimitiveBuilder}; use arrow::buffer::NullBuffer; -use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type, UInt64Type}; +use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type}; use arrow::datatypes::{DataType as ArrowDataType, Field}; use arrow::datatypes::{ Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, @@ -33,7 +33,7 @@ use crate::error::{ }; use crate::proto::stream::Kind; use crate::reader::decode::boolean_rle::BooleanIter; -use crate::reader::decode::byte_rle::ByteRleIter; +use crate::reader::decode::byte_rle::ByteRleReader; use crate::reader::decode::float::FloatIter; use crate::reader::decode::get_rle_reader; use crate::schema::DataType; @@ -119,7 +119,6 @@ impl ArrayBatchDecoder for PrimitiveArrayDecoder { } } -type UInt64ArrayDecoder = PrimitiveArrayDecoder; type Int64ArrayDecoder = PrimitiveArrayDecoder; type Int32ArrayDecoder = PrimitiveArrayDecoder; type Int16ArrayDecoder = PrimitiveArrayDecoder; @@ -260,7 +259,7 @@ fn derive_present_vec( /// Fix the lengths to account for nulls (represented as 0 length) fn populate_lengths_with_nulls( - lengths: Vec, + lengths: Vec, batch_size: usize, present: &Option>, ) -> Vec { @@ -365,7 +364,8 @@ pub fn array_decoder_factory( } ); let iter = stripe.stream_map().get(column, Kind::Data); - let iter = Box::new(ByteRleIter::new(iter).map(|value| value.map(|value| value as i8))); + let iter = + Box::new(ByteRleReader::new(iter).map(|value| value.map(|value| value as i8))); let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); Box::new(Int8ArrayDecoder::new(iter, present)) diff --git a/src/array_decoder/string.rs b/src/array_decoder/string.rs index efe58b92..6712f7ad 100644 --- a/src/array_decoder/string.rs +++ b/src/array_decoder/string.rs @@ -30,11 +30,11 @@ use crate::column::{get_present_vec, Column}; use crate::error::{ArrowSnafu, IoSnafu, Result}; use crate::proto::column_encoding::Kind as ColumnEncodingKind; use crate::proto::stream::Kind; -use crate::reader::decode::{get_rle_reader, RleVersion}; +use crate::reader::decode::get_unsigned_rle_reader; use crate::reader::decompress::Decompressor; use crate::stripe::Stripe; -use super::{ArrayBatchDecoder, UInt64ArrayDecoder}; +use super::{ArrayBatchDecoder, Int64ArrayDecoder}; // TODO: reduce duplication with string below pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result> { @@ -42,7 +42,7 @@ pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result + Send>); let lengths = stripe.stream_map().get(column, Kind::Length); - let lengths = get_rle_reader::(column, lengths)?; + let lengths = get_unsigned_rle_reader(column, lengths); let bytes = Box::new(stripe.stream_map().get(column, Kind::Data)); Ok(Box::new(BinaryArrayDecoder::new(bytes, lengths, present))) @@ -50,12 +50,11 @@ pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result Result> { let kind = column.encoding().kind(); - let rle_version = RleVersion::from(kind); let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); let lengths = stripe.stream_map().get(column, Kind::Length); - let lengths = rle_version.get_unsigned_rle_reader(lengths); + let lengths = get_unsigned_rle_reader(column, lengths); match kind { ColumnEncodingKind::Direct | ColumnEncodingKind::DirectV2 => { @@ -74,8 +73,8 @@ pub fn new_string_decoder(column: &Column, stripe: &Stripe) -> Result>; pub struct GenericByteArrayDecoder { bytes: Box, - lengths: Box> + Send>, + lengths: Box> + Send>, present: Option + Send>>, phantom: PhantomData, } @@ -99,7 +98,7 @@ pub struct GenericByteArrayDecoder { impl GenericByteArrayDecoder { fn new( bytes: Box, - lengths: Box> + Send>, + lengths: Box> + Send>, present: Option + Send>>, ) -> Self { Self { @@ -133,12 +132,12 @@ impl GenericByteArrayDecoder { elements_to_fetch, "less lengths than expected in ByteArray" ); - let total_length: u64 = lengths.iter().sum(); + let total_length: i64 = lengths.iter().sum(); // Fetch all data bytes at once let mut bytes = Vec::with_capacity(total_length as usize); self.bytes .by_ref() - .take(total_length) + .take(total_length as u64) .read_to_end(&mut bytes) .context(IoSnafu)?; let bytes = Buffer::from(bytes); @@ -165,12 +164,12 @@ impl ArrayBatchDecoder for GenericByteArrayDecoder { } pub struct DictionaryStringArrayDecoder { - indexes: UInt64ArrayDecoder, + indexes: Int64ArrayDecoder, dictionary: Arc, } impl DictionaryStringArrayDecoder { - fn new(indexes: UInt64ArrayDecoder, dictionary: Arc) -> Result { + fn new(indexes: Int64ArrayDecoder, dictionary: Arc) -> Result { Ok(Self { indexes, dictionary, diff --git a/src/array_decoder/timestamp.rs b/src/array_decoder/timestamp.rs index abb3fc7a..430cabb8 100644 --- a/src/array_decoder/timestamp.rs +++ b/src/array_decoder/timestamp.rs @@ -22,7 +22,7 @@ use crate::{ column::{get_present_vec, Column}, error::{MismatchedSchemaSnafu, Result}, proto::stream::Kind, - reader::decode::{get_rle_reader, timestamp::TimestampIterator}, + reader::decode::{get_rle_reader, get_unsigned_rle_reader, timestamp::TimestampIterator}, stripe::Stripe, }; use arrow::array::ArrayRef; @@ -54,7 +54,7 @@ macro_rules! decoder_for_time_unit { let data = get_rle_reader(column, data)?; let secondary = stripe.stream_map().get(column, Kind::Secondary); - let secondary = get_rle_reader(column, secondary)?; + let secondary = get_unsigned_rle_reader(column, secondary); let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs index df88afaa..53b1f4ab 100644 --- a/src/array_decoder/union.rs +++ b/src/array_decoder/union.rs @@ -26,7 +26,7 @@ use crate::column::{get_present_vec, Column}; use crate::error::ArrowSnafu; use crate::error::Result; use crate::proto::stream::Kind; -use crate::reader::decode::byte_rle::ByteRleIter; +use crate::reader::decode::byte_rle::ByteRleReader; use crate::stripe::Stripe; use super::{array_decoder_factory, derive_present_vec, ArrayBatchDecoder}; @@ -47,7 +47,7 @@ impl UnionArrayDecoder { .map(|iter| Box::new(iter.into_iter()) as Box + Send>); let tags = stripe.stream_map().get(column, Kind::Data); - let tags = Box::new(ByteRleIter::new(tags)); + let tags = Box::new(ByteRleReader::new(tags)); let variants = column .children() diff --git a/src/arrow_writer.rs b/src/arrow_writer.rs new file mode 100644 index 00000000..c364ed34 --- /dev/null +++ b/src/arrow_writer.rs @@ -0,0 +1,386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Write; + +use arrow::{ + array::RecordBatch, + datatypes::{DataType as ArrowDataType, SchemaRef}, +}; +use prost::Message; +use snafu::{ensure, ResultExt}; + +use crate::{ + error::{IoSnafu, Result, UnexpectedSnafu}, + proto, + writer::{ + column::EstimateMemory, + stripe::{StripeInformation, StripeWriter}, + }, +}; + +/// Construct an [`ArrowWriter`] to encode [`RecordBatch`]es into a single +/// ORC file. +pub struct ArrowWriterBuilder { + writer: W, + schema: SchemaRef, + batch_size: usize, + stripe_byte_size: usize, +} + +impl ArrowWriterBuilder { + /// Create a new [`ArrowWriterBuilder`], which will write an ORC file to + /// the provided writer, with the expected Arrow schema. + pub fn new(writer: W, schema: SchemaRef) -> Self { + Self { + writer, + schema, + batch_size: 1024, + // 64 MiB + stripe_byte_size: 64 * 1024 * 1024, + } + } + + /// Batch size controls the encoding behaviour, where `batch_size` values + /// are encoded at a time. Default is `1024`. + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// The approximate size of stripes. Default is `64MiB`. + pub fn with_stripe_byte_size(mut self, stripe_byte_size: usize) -> Self { + self.stripe_byte_size = stripe_byte_size; + self + } + + /// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into + /// an ORC file. + pub fn try_build(mut self) -> Result> { + // Required magic "ORC" bytes at start of file + self.writer.write_all(b"ORC").context(IoSnafu)?; + let writer = StripeWriter::new(self.writer, &self.schema); + Ok(ArrowWriter { + writer, + schema: self.schema, + batch_size: self.batch_size, + stripe_byte_size: self.stripe_byte_size, + written_stripes: vec![], + // Accounting for the 3 magic bytes above + total_bytes_written: 3, + }) + } +} + +/// Encodes [`RecordBatch`]es into an ORC file. Will encode `batch_size` rows +/// at a time into a stripe, flushing the stripe to the underlying writer when +/// it's estimated memory footprint exceeds the configures `stripe_byte_size`. +pub struct ArrowWriter { + writer: StripeWriter, + schema: SchemaRef, + batch_size: usize, + stripe_byte_size: usize, + written_stripes: Vec, + /// Used to keep track of progress in file so far (instead of needing Seek on the writer) + total_bytes_written: u64, +} + +impl ArrowWriter { + /// Encode the provided batch at `batch_size` rows at a time, flushing any + /// stripes that exceed the configured stripe size. + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + ensure!( + batch.schema() == self.schema, + UnexpectedSnafu { + msg: "RecordBatch doesn't match expected schema" + } + ); + + for offset in (0..batch.num_rows()).step_by(self.batch_size) { + let length = self.batch_size.min(batch.num_rows() - offset); + let batch = batch.slice(offset, length); + self.writer.encode_batch(&batch)?; + + // Flush stripe when it exceeds estimated configured size + if self.writer.estimate_memory_size() > self.stripe_byte_size { + self.flush_stripe()?; + } + } + Ok(()) + } + + /// Flush any buffered data that hasn't been written, and write the stripe + /// footer metadata. + pub fn flush_stripe(&mut self) -> Result<()> { + let info = self.writer.finish_stripe(self.total_bytes_written)?; + self.total_bytes_written += info.total_byte_size(); + self.written_stripes.push(info); + Ok(()) + } + + /// Flush the current stripe if it is still in progress, and write the tail + /// metadata and close the writer. + pub fn close(mut self) -> Result<()> { + // Flush in-progress stripe + if self.writer.row_count > 0 { + self.flush_stripe()?; + } + let footer = serialize_footer(&self.written_stripes, &self.schema); + let footer = footer.encode_to_vec(); + let postscript = serialize_postscript(footer.len() as u64); + let postscript = postscript.encode_to_vec(); + let postscript_len = postscript.len() as u8; + + let mut writer = self.writer.finish(); + writer.write_all(&footer).context(IoSnafu)?; + writer.write_all(&postscript).context(IoSnafu)?; + // Postscript length as last byte + writer.write_all(&[postscript_len]).context(IoSnafu)?; + + // TODO: return file metadata + Ok(()) + } +} + +fn serialize_schema(schema: &SchemaRef) -> Vec { + let mut types = vec![]; + + let field_names = schema + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect(); + // TODO: consider nested types + let subtypes = (1..(schema.fields().len() as u32 + 1)).collect(); + let root_type = proto::Type { + kind: Some(proto::r#type::Kind::Struct.into()), + subtypes, + field_names, + maximum_length: None, + precision: None, + scale: None, + attributes: vec![], + }; + types.push(root_type); + for field in schema.fields() { + let t = match field.data_type() { + ArrowDataType::Float32 => proto::Type { + kind: Some(proto::r#type::Kind::Float.into()), + ..Default::default() + }, + ArrowDataType::Float64 => proto::Type { + kind: Some(proto::r#type::Kind::Double.into()), + ..Default::default() + }, + ArrowDataType::Int8 => proto::Type { + kind: Some(proto::r#type::Kind::Byte.into()), + ..Default::default() + }, + ArrowDataType::Int16 => proto::Type { + kind: Some(proto::r#type::Kind::Short.into()), + ..Default::default() + }, + ArrowDataType::Int32 => proto::Type { + kind: Some(proto::r#type::Kind::Int.into()), + ..Default::default() + }, + ArrowDataType::Int64 => proto::Type { + kind: Some(proto::r#type::Kind::Long.into()), + ..Default::default() + }, + // TODO: support more types + _ => unimplemented!("unsupported datatype"), + }; + types.push(t); + } + types +} + +fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer { + let body_length = stripes + .iter() + .map(|s| s.index_length + s.data_length + s.footer_length) + .sum::(); + let number_of_rows = stripes.iter().map(|s| s.row_count as u64).sum::(); + let stripes = stripes.iter().map(From::from).collect(); + let types = serialize_schema(schema); + proto::Footer { + header_length: Some(3), + content_length: Some(body_length + 3), + stripes, + types, + metadata: vec![], + number_of_rows: Some(number_of_rows), + statistics: vec![], + row_index_stride: None, + writer: Some(u32::MAX), + encryption: None, + calendar: None, + software_version: None, + } +} + +fn serialize_postscript(footer_length: u64) -> proto::PostScript { + proto::PostScript { + footer_length: Some(footer_length), + compression: Some(proto::CompressionKind::None.into()), // TODO: support compression + compression_block_size: None, + version: vec![0, 12], + metadata_length: Some(0), // TODO: statistics + writer_version: Some(u32::MAX), // TODO: check which version to use + stripe_statistics_length: None, + magic: Some("ORC".to_string()), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{ + Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + RecordBatchReader, + }, + compute::concat_batches, + datatypes::{DataType as ArrowDataType, Field, Schema}, + }; + use bytes::Bytes; + + use crate::ArrowReaderBuilder; + + use super::*; + + #[test] + fn test_roundtrip_write() { + let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let f64_array = Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let int8_array = Arc::new(Int8Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int16_array = Arc::new(Int16Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int32_array = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int64_array = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let schema = Schema::new(vec![ + Field::new("f32", ArrowDataType::Float32, false), + Field::new("f64", ArrowDataType::Float64, false), + Field::new("int8", ArrowDataType::Int8, false), + Field::new("int16", ArrowDataType::Int16, false), + Field::new("int32", ArrowDataType::Int32, false), + Field::new("int64", ArrowDataType::Int64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + f32_array, + f64_array, + int8_array, + int16_array, + int32_array, + int64_array, + ], + ) + .unwrap(); + + let mut f = vec![]; + let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema()) + .try_build() + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let f = Bytes::from(f); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let rows = reader.collect::, _>>().unwrap(); + assert_eq!(batch, rows[0]); + } + + #[test] + fn test_write_small_stripes() { + // Set small stripe size to ensure writing across multiple stripes works + let data: Vec = (0..1_000_000).collect(); + let int64_array = Arc::new(Int64Array::from(data)); + let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![int64_array]).unwrap(); + + let mut f = vec![]; + let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema()) + .with_stripe_byte_size(256) + .try_build() + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let f = Bytes::from(f); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let schema = reader.schema(); + // Current reader doesn't read a batch across stripe boundaries, so we expect + // more than one batch to prove multiple stripes are being written here + let rows = reader.collect::, _>>().unwrap(); + assert!( + rows.len() > 1, + "must have written more than 1 stripe (each stripe read as separate recordbatch)" + ); + let actual = concat_batches(&schema, rows.iter()).unwrap(); + assert_eq!(batch, actual); + } + + #[test] + fn test_write_inconsistent_null_buffers() { + // When writing arrays where null buffer can appear/disappear between writes + let schema = Arc::new(Schema::new(vec![Field::new( + "int64", + ArrowDataType::Int64, + true, + )])); + + // Ensure first batch has array with no null buffer + let array_no_nulls = Arc::new(Int64Array::from(vec![1, 2, 3])); + assert!(array_no_nulls.nulls().is_none()); + // But subsequent batch has array with null buffer + let array_with_nulls = Arc::new(Int64Array::from(vec![None, Some(4), None])); + assert!(array_with_nulls.nulls().is_some()); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![array_no_nulls]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![array_with_nulls]).unwrap(); + + let mut f = vec![]; + let mut writer = ArrowWriterBuilder::new(&mut f, schema.clone()) + .with_stripe_byte_size(256) + .try_build() + .unwrap(); + writer.write(&batch1).unwrap(); + writer.write(&batch2).unwrap(); + writer.close().unwrap(); + + // ORC writer should be able to handle this gracefully + let expected_array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(4), + None, + ])); + let expected_batch = RecordBatch::try_new(schema, vec![expected_array]).unwrap(); + + let f = Bytes::from(f); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let rows = reader.collect::, _>>().unwrap(); + assert_eq!(expected_batch, rows[0]); + } +} diff --git a/src/lib.rs b/src/lib.rs index 2e61c26c..122af9bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,6 +33,7 @@ mod array_decoder; pub mod arrow_reader; +pub mod arrow_writer; #[cfg(feature = "async")] pub mod async_arrow_reader; mod column; @@ -43,8 +44,10 @@ pub mod reader; pub mod schema; pub mod statistics; pub mod stripe; +mod writer; pub use arrow_reader::{ArrowReader, ArrowReaderBuilder}; +pub use arrow_writer::{ArrowWriter, ArrowWriterBuilder}; #[cfg(feature = "async")] pub use async_arrow_reader::ArrowStreamReader; diff --git a/src/reader/decode/boolean_rle.rs b/src/reader/decode/boolean_rle.rs index 47590e08..a6ccd1fe 100644 --- a/src/reader/decode/boolean_rle.rs +++ b/src/reader/decode/boolean_rle.rs @@ -19,10 +19,10 @@ use std::io::Read; use crate::error::Result; -use super::byte_rle::ByteRleIter; +use super::byte_rle::ByteRleReader; pub struct BooleanIter { - iter: ByteRleIter, + iter: ByteRleReader, data: u8, bits_in_data: usize, } @@ -30,7 +30,7 @@ pub struct BooleanIter { impl BooleanIter { pub fn new(reader: R) -> Self { Self { - iter: ByteRleIter::new(reader), + iter: ByteRleReader::new(reader), bits_in_data: 0, data: 0, } @@ -68,7 +68,7 @@ impl Iterator for BooleanIter { } #[cfg(test)] -mod test { +mod tests { use super::*; #[test] diff --git a/src/reader/decode/byte_rle.rs b/src/reader/decode/byte_rle.rs index 6a713308..5b133640 100644 --- a/src/reader/decode/byte_rle.rs +++ b/src/reader/decode/byte_rle.rs @@ -15,50 +15,214 @@ // specific language governing permissions and limitations // under the License. -use crate::error::Result; +use bytes::{BufMut, BytesMut}; + +use crate::{ + error::Result, + writer::column::{EstimateMemory, PrimitiveValueEncoder}, +}; use std::io::Read; use super::util::read_u8; -const MAX_LITERAL_SIZE: usize = 128; -const MIN_REPEAT_SIZE: usize = 3; +const MAX_LITERAL_LENGTH: usize = 128; +const MIN_REPEAT_LENGTH: usize = 3; +const MAX_REPEAT_LENGTH: usize = 130; + +pub struct ByteRleWriter { + writer: BytesMut, + /// Literal values to encode. + literals: [u8; MAX_LITERAL_LENGTH], + /// Represents the number of elements currently in `literals` if Literals, + /// otherwise represents the length of the Run. + num_literals: usize, + /// Tracks if current Literal sequence will turn into a Run sequence due to + /// repeated values at the end of the value sequence. + tail_run_length: usize, + /// If in Run sequence or not, and keeps the corresponding value. + run_value: Option, +} + +impl ByteRleWriter { + /// Incrementally encode bytes using Run Length Encoding, where the subencodings are: + /// - Run: at least 3 repeated values in sequence (up to `MAX_REPEAT_LENGTH`) + /// - Literals: disparate values (up to `MAX_LITERAL_LENGTH` length) + /// + /// How the relevant encodings are chosen: + /// - Keep of track of values as they come, starting off assuming Literal sequence + /// - Keep track of latest value, to see if we are encountering a sequence of repeated + /// values (Run sequence) + /// - If this tail end exceeds the required minimum length, flush the current Literal + /// sequence (or switch to Run if entire current sequence is the repeated value) + /// - Whether in Literal or Run mode, keep buffering values and flushing when max length + /// reached or encoding is broken (e.g. non-repeated value found in Run mode) + fn process_value(&mut self, value: u8) { + // Adapted from https://github.com/apache/orc/blob/main/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java + if self.num_literals == 0 { + // Start off in Literal mode + self.run_value = None; + self.literals[0] = value; + self.num_literals = 1; + self.tail_run_length = 1; + } else if let Some(run_value) = self.run_value { + // Run mode + + if value == run_value { + // Continue buffering for Run sequence, flushing if reaching max length + self.num_literals += 1; + if self.num_literals == MAX_REPEAT_LENGTH { + write_run(&mut self.writer, run_value, MAX_REPEAT_LENGTH); + self.clear_state(); + } + } else { + // Run is broken, flush then start again in Literal mode + write_run(&mut self.writer, run_value, self.num_literals); + self.run_value = None; + self.literals[0] = value; + self.num_literals = 1; + self.tail_run_length = 1; + } + } else { + // Literal mode + + // tail_run_length tracks length of repetition of last value + if value == self.literals[self.num_literals - 1] { + self.tail_run_length += 1; + } else { + self.tail_run_length = 1; + } + + if self.tail_run_length == MIN_REPEAT_LENGTH { + // When the tail end of the current sequence is enough for a Run sequence + + if self.num_literals + 1 == MIN_REPEAT_LENGTH { + // If current values are enough for a Run sequence, switch to Run encoding + self.run_value = Some(value); + self.num_literals += 1; + } else { + // Flush the current Literal sequence, then switch to Run encoding + // We don't flush the tail end which is a Run sequence + let len = self.num_literals - (MIN_REPEAT_LENGTH - 1); + let literals = &self.literals[..len]; + write_literals(&mut self.writer, literals); + self.run_value = Some(value); + self.num_literals = MIN_REPEAT_LENGTH; + } + } else { + // Continue buffering for Literal sequence, flushing if reaching max length + self.literals[self.num_literals] = value; + self.num_literals += 1; + if self.num_literals == MAX_LITERAL_LENGTH { + // Entire literals is filled, pass in as is + write_literals(&mut self.writer, &self.literals); + self.clear_state(); + } + } + } + } + + fn clear_state(&mut self) { + self.run_value = None; + self.tail_run_length = 0; + self.num_literals = 0; + } -pub struct ByteRleIter { + /// Flush any buffered values to writer in appropriate sequence. + fn flush(&mut self) { + if self.num_literals != 0 { + if let Some(value) = self.run_value { + write_run(&mut self.writer, value, self.num_literals); + } else { + let literals = &self.literals[..self.num_literals]; + write_literals(&mut self.writer, literals); + } + self.clear_state(); + } + } +} + +impl EstimateMemory for ByteRleWriter { + fn estimate_memory_size(&self) -> usize { + self.writer.len() + self.num_literals + } +} + +/// i8 to match with Arrow Int8 type. +impl PrimitiveValueEncoder for ByteRleWriter { + fn new() -> Self { + Self { + writer: BytesMut::new(), + literals: [0; MAX_LITERAL_LENGTH], + num_literals: 0, + tail_run_length: 0, + run_value: None, + } + } + + fn write_one(&mut self, value: i8) { + self.process_value(value as u8); + } + + fn take_inner(&mut self) -> bytes::Bytes { + self.flush(); + std::mem::take(&mut self.writer).into() + } +} + +fn write_run(writer: &mut BytesMut, value: u8, run_length: usize) { + debug_assert!( + (MIN_REPEAT_LENGTH..=MAX_REPEAT_LENGTH).contains(&run_length), + "Byte RLE Run sequence must be in range 3..=130" + ); + // [3, 130] to [0, 127] + let header = run_length - MIN_REPEAT_LENGTH; + writer.put_u8(header as u8); + writer.put_u8(value); +} + +fn write_literals(writer: &mut BytesMut, literals: &[u8]) { + debug_assert!( + (1..=MAX_LITERAL_LENGTH).contains(&literals.len()), + "Byte RLE Literal sequence must be in range 1..=128" + ); + // [1, 128] to [-1, -128], then writing as a byte + let header = -(literals.len() as i32); + writer.put_u8(header as u8); + writer.put_slice(literals); +} + +pub struct ByteRleReader { reader: R, - literals: [u8; MAX_LITERAL_SIZE], + literals: [u8; MAX_LITERAL_LENGTH], num_literals: usize, used: usize, repeat: bool, } -impl ByteRleIter { +impl ByteRleReader { pub fn new(reader: R) -> Self { Self { reader, - literals: [0u8; MAX_LITERAL_SIZE], + literals: [0; MAX_LITERAL_LENGTH], num_literals: 0, used: 0, repeat: false, } } - fn read_byte(&mut self) -> Result { - read_u8(&mut self.reader) - } - fn read_values(&mut self) -> Result<()> { - let control = self.read_byte()?; + let control = read_u8(&mut self.reader)?; self.used = 0; if control < 0x80 { self.repeat = true; - self.num_literals = control as usize + MIN_REPEAT_SIZE; - let val = self.read_byte()?; + self.num_literals = control as usize + MIN_REPEAT_LENGTH; + let val = read_u8(&mut self.reader)?; self.literals[0] = val; } else { self.repeat = false; self.num_literals = 0x100 - control as usize; for i in 0..self.num_literals { - let result = self.read_byte()?; + let result = read_u8(&mut self.reader)?; self.literals[i] = result; } } @@ -66,15 +230,12 @@ impl ByteRleIter { } } -impl Iterator for ByteRleIter { +impl Iterator for ByteRleReader { type Item = Result; fn next(&mut self) -> Option { if self.used == self.num_literals { - match self.read_values() { - Ok(_) => {} - Err(_err) => return None, - } + self.read_values().ok()?; } let result = if self.repeat { @@ -88,33 +249,92 @@ impl Iterator for ByteRleIter { } #[cfg(test)] -mod test { +mod tests { + use std::io::Cursor; + use super::*; + use proptest::prelude::*; + #[test] fn reader_test() { let data = [0x61u8, 0x00]; - let data = &mut data.as_ref(); - - let iter = ByteRleIter::new(data).collect::>>().unwrap(); - + let iter = ByteRleReader::new(data) + .collect::>>() + .unwrap(); assert_eq!(iter, vec![0; 100]); let data = [0x01, 0x01]; + let data = &mut data.as_ref(); + let iter = ByteRleReader::new(data) + .collect::>>() + .unwrap(); + assert_eq!(iter, vec![1; 4]); + let data = [0xfe, 0x44, 0x45]; let data = &mut data.as_ref(); + let iter = ByteRleReader::new(data) + .collect::>>() + .unwrap(); + assert_eq!(iter, vec![0x44, 0x45]); + } - let iter = ByteRleIter::new(data).collect::>>().unwrap(); + fn roundtrip_byte_rle_helper(values: &[u8]) -> Result> { + let mut writer = ByteRleWriter::new(); + let values = values.iter().map(|&b| b as i8).collect::>(); + writer.write_slice(&values); + writer.flush(); - assert_eq!(iter, vec![1; 4]); + let buf = writer.take_inner(); + let mut cursor = Cursor::new(&buf); + let reader = ByteRleReader::new(&mut cursor); + reader.into_iter().collect::>>() + } - let data = [0xfe, 0x44, 0x45]; + #[derive(Debug, Clone)] + enum ByteSequence { + Run(u8, usize), + Literals(Vec), + } - let data = &mut data.as_ref(); + fn byte_sequence_strategy() -> impl Strategy { + // We limit the max length of the sequences to 140 to try get more interleaving + prop_oneof![ + (any::(), 1..140_usize).prop_map(|(a, b)| ByteSequence::Run(a, b)), + prop::collection::vec(any::(), 1..140).prop_map(ByteSequence::Literals) + ] + } + + fn generate_bytes_from_sequences(sequences: Vec) -> Vec { + let mut bytes = vec![]; + for sequence in sequences { + match sequence { + ByteSequence::Run(value, length) => { + bytes.extend(std::iter::repeat(value).take(length)) + } + ByteSequence::Literals(literals) => bytes.extend(literals), + } + } + bytes + } - let iter = ByteRleIter::new(data).collect::>>().unwrap(); + proptest! { + #[test] + fn roundtrip_byte_rle_pure_random(values: Vec) { + // Biased towards literal sequences due to purely random values + let out = roundtrip_byte_rle_helper(&values).unwrap(); + prop_assert_eq!(out, values); + } - assert_eq!(iter, vec![0x44, 0x45]); + #[test] + fn roundtrip_byte_rle_biased( + sequences in prop::collection::vec(byte_sequence_strategy(), 1..200) + ) { + // Intentionally introduce run sequences to not be entirely random literals + let values = generate_bytes_from_sequences(sequences); + let out = roundtrip_byte_rle_helper(&values).unwrap(); + prop_assert_eq!(out, values); + } } } diff --git a/src/reader/decode/decimal.rs b/src/reader/decode/decimal.rs index 4544d56c..375a401f 100644 --- a/src/reader/decode/decimal.rs +++ b/src/reader/decode/decimal.rs @@ -19,6 +19,8 @@ use std::io::Read; use crate::{error::Result, reader::decode::util::read_varint_zigzagged}; +use super::SignedEncoding; + /// Read stream of zigzag encoded varints as i128 (unbound). pub struct UnboundedVarintStreamDecoder { reader: R, @@ -40,7 +42,7 @@ impl Iterator for UnboundedVarintStreamDecoder { fn next(&mut self) -> Option { (self.remaining > 0).then(|| { self.remaining -= 1; - read_varint_zigzagged::(&mut self.reader) + read_varint_zigzagged::(&mut self.reader) }) } } diff --git a/src/reader/decode/mod.rs b/src/reader/decode/mod.rs index 431866b4..9d4d8664 100644 --- a/src/reader/decode/mod.rs +++ b/src/reader/decode/mod.rs @@ -20,7 +20,7 @@ use std::io::Read; use std::ops::{BitOrAssign, ShlAssign}; use num::traits::CheckedShl; -use num::PrimInt; +use num::{PrimInt, Signed}; use crate::column::Column; use crate::error::{InvalidColumnEncodingSnafu, Result}; @@ -28,40 +28,31 @@ use crate::proto::column_encoding::Kind as ProtoColumnKind; use self::rle_v1::RleReaderV1; use self::rle_v2::RleReaderV2; -use self::util::{signed_msb_decode, signed_zigzag_decode}; +use self::util::{ + get_closest_aligned_bit_width, signed_msb_decode, signed_zigzag_decode, signed_zigzag_encode, +}; + +// TODO: rename mod to encoding pub mod boolean_rle; pub mod byte_rle; pub mod decimal; pub mod float; -pub mod rle_v1; +mod rle_v1; pub mod rle_v2; pub mod timestamp; mod util; -#[derive(Clone, Copy, Debug)] -pub enum RleVersion { - V1, - V2, -} - -impl RleVersion { - pub fn get_unsigned_rle_reader( - &self, - reader: R, - ) -> Box> + Send> { - match self { - RleVersion::V1 => Box::new(RleReaderV1::new(reader)), - RleVersion::V2 => Box::new(RleReaderV2::new(reader)), +pub fn get_unsigned_rle_reader( + column: &Column, + reader: R, +) -> Box> + Send> { + match column.encoding().kind() { + ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => { + Box::new(RleReaderV1::::new(reader)) } - } -} - -impl From for RleVersion { - fn from(value: ProtoColumnKind) -> Self { - match value { - ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => Self::V1, - ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => Self::V2, + ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => { + Box::new(RleReaderV2::::new(reader)) } } } @@ -71,8 +62,8 @@ pub fn get_rle_reader( reader: R, ) -> Result> + Send>> { match column.encoding().kind() { - ProtoColumnKind::Direct => Ok(Box::new(RleReaderV1::::new(reader))), - ProtoColumnKind::DirectV2 => Ok(Box::new(RleReaderV2::::new(reader))), + ProtoColumnKind::Direct => Ok(Box::new(RleReaderV1::::new(reader))), + ProtoColumnKind::DirectV2 => Ok(Box::new(RleReaderV2::::new(reader))), k => InvalidColumnEncodingSnafu { name: column.name(), encoding: k, @@ -81,98 +72,145 @@ pub fn get_rle_reader( } } -/// Helps generalise the decoder efforts to be specific to supported integers. -/// (Instead of decoding to u64/i64 for all then downcasting). -pub trait NInt: - PrimInt - + CheckedShl - + BitOrAssign - + ShlAssign - + fmt::Debug - + fmt::Display - + fmt::Binary - + Send - + Sync - + 'static -{ - type Bytes: AsRef<[u8]> + AsMut<[u8]> + Default + Clone + Copy; - const BYTE_SIZE: usize; +pub trait EncodingSign: Send + 'static { + // TODO: have separate type/trait to represent Zigzag encoded NInt? + fn zigzag_decode(v: N) -> N; + fn zigzag_encode(v: N) -> N; + fn decode_signed_msb(v: N, encoded_byte_size: usize) -> N; +} + +pub struct SignedEncoding; + +impl EncodingSign for SignedEncoding { #[inline] - fn empty_byte_array() -> Self::Bytes { - Self::Bytes::default() + fn zigzag_decode(v: N) -> N { + signed_zigzag_decode(v) } - /// Should truncate any extra bits. - fn from_u64(u: u64) -> Self; + #[inline] + fn zigzag_encode(v: N) -> N { + signed_zigzag_encode(v) + } - fn from_u8(u: u8) -> Self; + #[inline] + fn decode_signed_msb(v: N, encoded_byte_size: usize) -> N { + signed_msb_decode(v, encoded_byte_size) + } +} - fn from_be_bytes(b: Self::Bytes) -> Self; +pub struct UnsignedEncoding; +impl EncodingSign for UnsignedEncoding { #[inline] - fn zigzag_decode(self) -> Self { - // Default noop for unsigned (signed should override this) - self + fn zigzag_decode(v: N) -> N { + v } #[inline] - fn decode_signed_from_msb(self, _encoded_byte_size: usize) -> Self { - // Default noop for unsigned (signed should override this) - // Used when decoding Patched Base in RLEv2 - // TODO: is this correct for unsigned? Spec doesn't state, but seems logical. - // Add a test for this to check. - self + fn zigzag_encode(v: N) -> N { + v } -} -// We only implement for i16, i32, i64 and u64. -// ORC supports only signed Short, Integer and Long types for its integer types, -// and i8 is encoded as bytes. u64 is used for other encodings such as Strings -// (to encode length, etc.). + #[inline] + fn decode_signed_msb(v: N, _encoded_byte_size: usize) -> N { + v + } +} -impl NInt for i16 { - type Bytes = [u8; 2]; - const BYTE_SIZE: usize = 2; +pub trait VarintSerde: PrimInt + CheckedShl + BitOrAssign + Signed { + const BYTE_SIZE: usize; + /// Calculate the minimum bit size required to represent this value, by truncating + /// the leading zeros. #[inline] - fn from_u64(u: u64) -> Self { - u as Self + fn bits_used(self) -> usize { + Self::BYTE_SIZE * 8 - self.leading_zeros() as usize + } + + /// Feeds [`Self::bits_used`] into a mapping to get an aligned bit width. + fn closest_aligned_bit_width(self) -> usize { + get_closest_aligned_bit_width(self.bits_used()) } + fn from_u8(b: u8) -> Self; +} + +/// Helps generalise the decoder efforts to be specific to supported integers. +/// (Instead of decoding to u64/i64 for all then downcasting). +pub trait NInt: + VarintSerde + ShlAssign + fmt::Debug + fmt::Display + fmt::Binary + Send + Sync + 'static +{ + type Bytes: AsRef<[u8]> + AsMut<[u8]> + Default + Clone + Copy + fmt::Debug; + #[inline] - fn from_u8(u: u8) -> Self { - u as Self + fn empty_byte_array() -> Self::Bytes { + Self::Bytes::default() } + /// Should truncate any extra bits. + fn from_i64(u: i64) -> Self; + + fn from_be_bytes(b: Self::Bytes) -> Self; + + // TODO: use num_traits::ToBytes instead + fn to_be_bytes(self) -> Self::Bytes; + + fn add_i64(self, i: i64) -> Option; + + fn sub_i64(self, i: i64) -> Option; + + // TODO: use Into instead? + fn as_i64(self) -> i64; +} + +impl VarintSerde for i16 { + const BYTE_SIZE: usize = 2; + #[inline] - fn from_be_bytes(b: Self::Bytes) -> Self { - Self::from_be_bytes(b) + fn from_u8(b: u8) -> Self { + b as Self } +} + +impl VarintSerde for i32 { + const BYTE_SIZE: usize = 4; #[inline] - fn zigzag_decode(self) -> Self { - signed_zigzag_decode(self) + fn from_u8(b: u8) -> Self { + b as Self } +} + +impl VarintSerde for i64 { + const BYTE_SIZE: usize = 8; #[inline] - fn decode_signed_from_msb(self, encoded_byte_size: usize) -> Self { - signed_msb_decode(self, encoded_byte_size) + fn from_u8(b: u8) -> Self { + b as Self } } -impl NInt for i32 { - type Bytes = [u8; 4]; - const BYTE_SIZE: usize = 4; +impl VarintSerde for i128 { + const BYTE_SIZE: usize = 16; #[inline] - fn from_u64(u: u64) -> Self { - u as Self + fn from_u8(b: u8) -> Self { + b as Self } +} + +// We only implement for i16, i32, i64 and u64. +// ORC supports only signed Short, Integer and Long types for its integer types, +// and i8 is encoded as bytes. u64 is used for other encodings such as Strings +// (to encode length, etc.). + +impl NInt for i16 { + type Bytes = [u8; 2]; #[inline] - fn from_u8(u: u8) -> Self { - u as Self + fn from_i64(i: i64) -> Self { + i as Self } #[inline] @@ -181,28 +219,32 @@ impl NInt for i32 { } #[inline] - fn zigzag_decode(self) -> Self { - signed_zigzag_decode(self) + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() } #[inline] - fn decode_signed_from_msb(self, encoded_byte_size: usize) -> Self { - signed_msb_decode(self, encoded_byte_size) + fn add_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_add(i)) } -} -impl NInt for i64 { - type Bytes = [u8; 8]; - const BYTE_SIZE: usize = 8; + #[inline] + fn sub_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_sub(i)) + } #[inline] - fn from_u64(u: u64) -> Self { - u as Self + fn as_i64(self) -> i64 { + self as i64 } +} + +impl NInt for i32 { + type Bytes = [u8; 4]; #[inline] - fn from_u8(u: u8) -> Self { - u as Self + fn from_i64(i: i64) -> Self { + i as Self } #[inline] @@ -211,60 +253,56 @@ impl NInt for i64 { } #[inline] - fn zigzag_decode(self) -> Self { - signed_zigzag_decode(self) + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() } #[inline] - fn decode_signed_from_msb(self, encoded_byte_size: usize) -> Self { - signed_msb_decode(self, encoded_byte_size) + fn add_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_add(i)) } -} -impl NInt for u64 { - type Bytes = [u8; 8]; - const BYTE_SIZE: usize = 8; + #[inline] + fn sub_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_sub(i)) + } #[inline] - fn from_u64(u: u64) -> Self { - u as Self + fn as_i64(self) -> i64 { + self as i64 } +} + +impl NInt for i64 { + type Bytes = [u8; 8]; #[inline] - fn from_u8(u: u8) -> Self { - u as Self + fn from_i64(i: i64) -> Self { + i as Self } #[inline] fn from_be_bytes(b: Self::Bytes) -> Self { Self::from_be_bytes(b) } -} -// This impl is used only for varint decoding. -// Hence some methods are left unimplemented since they are not used. -// TODO: maybe split NInt into traits for the specific use case -// - patched base decoding -// - varint decoding -// - etc. -impl NInt for i128 { - type Bytes = [u8; 16]; - const BYTE_SIZE: usize = 16; - - fn from_u64(_u: u64) -> Self { - unimplemented!() + #[inline] + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() } - fn from_u8(u: u8) -> Self { - u as Self + #[inline] + fn add_i64(self, i: i64) -> Option { + self.checked_add(i) } - fn from_be_bytes(_b: Self::Bytes) -> Self { - unimplemented!() + #[inline] + fn sub_i64(self, i: i64) -> Option { + self.checked_sub(i) } #[inline] - fn zigzag_decode(self) -> Self { - signed_zigzag_decode(self) + fn as_i64(self) -> i64 { + self } } diff --git a/src/reader/decode/rle_v1.rs b/src/reader/decode/rle_v1.rs index 043ff8bc..1a435388 100644 --- a/src/reader/decode/rle_v1.rs +++ b/src/reader/decode/rle_v1.rs @@ -17,7 +17,7 @@ //! Handling decoding of Integer Run Length Encoded V1 data in ORC files -use std::io::Read; +use std::{io::Read, marker::PhantomData}; use snafu::OptionExt; @@ -25,24 +25,26 @@ use crate::error::{OutOfSpecSnafu, Result}; use super::{ util::{read_u8, read_varint_zigzagged, try_read_u8}, - NInt, + EncodingSign, NInt, }; const MAX_RUN_LENGTH: usize = 130; /// Decodes a stream of Integer Run Length Encoded version 1 bytes. -pub struct RleReaderV1 { +pub struct RleReaderV1 { reader: R, decoded_ints: Vec, current_head: usize, + phantom: PhantomData, } -impl RleReaderV1 { +impl RleReaderV1 { pub fn new(reader: R) -> Self { Self { reader, decoded_ints: Vec::with_capacity(MAX_RUN_LENGTH), current_head: 0, + phantom: Default::default(), } } @@ -55,7 +57,7 @@ impl RleReaderV1 { Some(byte) if byte < 0 => { let length = byte.unsigned_abs(); for _ in 0..length { - let lit = read_varint_zigzagged(&mut self.reader)?; + let lit = read_varint_zigzagged::<_, _, S>(&mut self.reader)?; self.decoded_ints.push(lit); } Ok(true) @@ -65,7 +67,7 @@ impl RleReaderV1 { let byte = byte as u8; let length = byte + 2; // Technically +3, but we subtract 1 for the base let delta = read_u8(&mut self.reader)? as i8; - let mut base = read_varint_zigzagged(&mut self.reader)?; + let mut base = read_varint_zigzagged::<_, _, S>(&mut self.reader)?; self.decoded_ints.push(base); if delta < 0 { let delta = delta.unsigned_abs(); @@ -93,7 +95,7 @@ impl RleReaderV1 { } } -impl Iterator for RleReaderV1 { +impl Iterator for RleReaderV1 { type Item = Result; fn next(&mut self) -> Option { @@ -121,18 +123,20 @@ impl Iterator for RleReaderV1 { mod tests { use std::io::Cursor; + use crate::reader::decode::UnsignedEncoding; + use super::*; #[test] fn test_run() -> Result<()> { let input = [0x61, 0x00, 0x07]; - let decoder = RleReaderV1::::new(Cursor::new(&input)); + let decoder = RleReaderV1::::new(Cursor::new(&input)); let expected = vec![7; 100]; let actual = decoder.collect::>>()?; assert_eq!(actual, expected); let input = [0x61, 0xff, 0x64]; - let decoder = RleReaderV1::::new(Cursor::new(&input)); + let decoder = RleReaderV1::::new(Cursor::new(&input)); let expected = (1..=100).rev().collect::>(); let actual = decoder.collect::>>()?; assert_eq!(actual, expected); @@ -143,7 +147,7 @@ mod tests { #[test] fn test_literal() -> Result<()> { let input = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; - let decoder = RleReaderV1::::new(Cursor::new(&input)); + let decoder = RleReaderV1::::new(Cursor::new(&input)); let expected = vec![2, 3, 6, 7, 11]; let actual = decoder.collect::>>()?; assert_eq!(actual, expected); diff --git a/src/reader/decode/rle_v2/delta.rs b/src/reader/decode/rle_v2/delta.rs index f66b88a9..3a5d1072 100644 --- a/src/reader/decode/rle_v2/delta.rs +++ b/src/reader/decode/rle_v2/delta.rs @@ -17,119 +17,266 @@ use std::io::Read; +use bytes::{BufMut, BytesMut}; use snafu::OptionExt; use crate::error::{OrcError, OutOfSpecSnafu, Result}; +use crate::reader::decode::rle_v2::{EncodingType, MAX_RUN_LENGTH}; use crate::reader::decode::util::{ - extract_run_length_from_header, read_abs_varint, read_ints, read_u8, read_varint_zigzagged, - rle_v2_decode_bit_width, AbsVarint, AccumulateOp, AddOp, SubOp, + extract_run_length_from_header, read_ints, read_u8, read_varint_zigzagged, + rle_v2_decode_bit_width, rle_v2_encode_bit_width, write_aligned_packed_ints, + write_varint_zigzagged, }; +use crate::reader::decode::{EncodingSign, SignedEncoding, VarintSerde}; -use super::{NInt, RleReaderV2}; +use super::NInt; -impl RleReaderV2 { - fn fixed_delta( - &mut self, - length: usize, - base_value: N, - delta: N, - ) -> Result<()> { +/// We use i64 and u64 for delta to make things easier and to avoid edge cases, +/// as for example for i16, the delta may be too large to represent in an i16. +// TODO: expand on the above +pub fn read_delta_values( + reader: &mut R, + out_ints: &mut Vec, + deltas: &mut Vec, + header: u8, +) -> Result<()> { + // Encoding format: + // 2 bytes header + // - 2 bits for encoding type (constant 3) + // - 5 bits for encoded delta bitwidth (0 to 64) + // - 9 bits for run length (1 to 512) + // Base value (signed or unsigned) varint + // Delta value signed varint + // Sequence of delta values + + let encoded_delta_bit_width = (header >> 1) & 0x1f; + // Uses same encoding table as for direct & patched base, + // but special case where 0 indicates 0 width (for fixed delta) + let delta_bit_width = if encoded_delta_bit_width == 0 { + encoded_delta_bit_width as usize + } else { + rle_v2_decode_bit_width(encoded_delta_bit_width) + }; + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + let base_value = read_varint_zigzagged::(reader)?; + out_ints.push(base_value); + + // Always signed since can be decreasing sequence + let delta_base = read_varint_zigzagged::(reader)?; + // TODO: does this get inlined? + let op: fn(N, i64) -> Option = if delta_base.is_positive() { + |acc, delta| acc.add_i64(delta) + } else { + |acc, delta| acc.sub_i64(delta) + }; + let delta_base = delta_base.abs(); // TODO: i64::MIN? + + if delta_bit_width == 0 { + // If width is 0 then all values have fixed delta of delta_base // Skip first value since that's base_value (1..length).try_fold(base_value, |acc, _| { - let acc = A::acc(acc, delta).context(OutOfSpecSnafu { + let acc = op(acc, delta_base).context(OutOfSpecSnafu { msg: "over/underflow when decoding delta integer", })?; - self.decoded_ints.push(acc); + out_ints.push(acc); Ok::<_, OrcError>(acc) })?; - Ok(()) - } - - fn varied_deltas( - &mut self, - length: usize, - base_value: N, - delta: N, - delta_bit_width: usize, - ) -> Result<()> { + } else { + deltas.clear(); // Add delta base and first value - let second_value = A::acc(base_value, delta).context(OutOfSpecSnafu { + let second_value = op(base_value, delta_base).context(OutOfSpecSnafu { msg: "over/underflow when decoding delta integer", })?; - self.decoded_ints.push(second_value); + out_ints.push(second_value); // Run length includes base value and first delta, so skip them let length = length - 2; // Unpack the delta values - read_ints( - &mut self.decoded_ints, - length, - delta_bit_width, - &mut self.reader, - )?; - self.decoded_ints - .iter_mut() - // Ignore base_value and second_value - .skip(2) - // Each element is the delta, so find actual value using running accumulator - .try_fold(second_value, |acc, delta| { - let acc = A::acc(acc, *delta).context(OutOfSpecSnafu { - msg: "over/underflow when decoding delta integer", - })?; - *delta = acc; - Ok::<_, OrcError>(acc) + read_ints(deltas, length, delta_bit_width, reader)?; + let mut acc = second_value; + // Each element is the delta, so find actual value using running accumulator + for delta in deltas { + acc = op(acc, *delta).context(OutOfSpecSnafu { + msg: "over/underflow when decoding delta integer", })?; - Ok(()) + out_ints.push(acc); + } } + Ok(()) +} - pub fn read_delta_values(&mut self, header: u8) -> Result<()> { - // Encoding format: - // 2 bytes header - // - 2 bits for encoding type (constant 3) - // - 5 bits for encoded delta bitwidth (0 to 64) - // - 9 bits for run length (1 to 512) - // Base value (signed or unsigned) varint - // Delta value signed varint - // Sequence of delta values - - let encoded_delta_bit_width = (header >> 1) & 0x1f; - // Uses same encoding table as for direct & patched base, - // but special case where 0 indicates 0 width (for fixed delta) - let delta_bit_width = if encoded_delta_bit_width == 0 { - encoded_delta_bit_width as usize - } else { - rle_v2_decode_bit_width(encoded_delta_bit_width) - }; - - let second_byte = read_u8(&mut self.reader)?; - let length = extract_run_length_from_header(header, second_byte); - - let base_value = read_varint_zigzagged::(&mut self.reader)?; - self.decoded_ints.push(base_value); - - // Always signed since can be decreasing sequence - let delta_base = read_abs_varint::(&mut self.reader)?; +pub fn write_varying_delta( + writer: &mut BytesMut, + base_value: N, + first_delta: i64, + max_delta: i64, + subsequent_deltas: &[i64], +) { + debug_assert!( + max_delta > 0, + "varying deltas must have at least one non-zero delta" + ); + let bit_width = max_delta.closest_aligned_bit_width(); + // We can't have bit width of 1 for delta as that would get decoded as + // 0 bit width on reader, which indicates fixed delta, so bump 1 to 2 + // in this case. + let bit_width = if bit_width == 1 { 2 } else { bit_width }; + // Add 2 to len for the base_value and first_delta + let header = derive_delta_header(bit_width, subsequent_deltas.len() + 2); + writer.put_slice(&header); - // If width is 0 then all values have fixed delta of delta_base - if delta_bit_width == 0 { - match delta_base { - AbsVarint::Negative(delta) => { - self.fixed_delta::(length, base_value, delta)?; - } - AbsVarint::Positive(delta) => { - self.fixed_delta::(length, base_value, delta)?; - } - }; - } else { - match delta_base { - AbsVarint::Negative(delta) => { - self.varied_deltas::(length, base_value, delta, delta_bit_width)?; - } - AbsVarint::Positive(delta) => { - self.varied_deltas::(length, base_value, delta, delta_bit_width)?; - } - }; + write_varint_zigzagged::<_, S>(writer, base_value); + // First delta always signed to indicate increasing/decreasing sequence + write_varint_zigzagged::<_, SignedEncoding>(writer, first_delta); + + // Bitpacked deltas + write_aligned_packed_ints(writer, bit_width, subsequent_deltas); +} + +pub fn write_fixed_delta( + writer: &mut BytesMut, + base_value: N, + fixed_delta: i64, + subsequent_deltas_len: usize, +) { + // Assuming len excludes base_value and first delta + let header = derive_delta_header(0, subsequent_deltas_len + 2); + writer.put_slice(&header); + + write_varint_zigzagged::<_, S>(writer, base_value); + // First delta always signed to indicate increasing/decreasing sequence + write_varint_zigzagged::<_, SignedEncoding>(writer, fixed_delta); +} + +fn derive_delta_header(delta_width: usize, run_length: usize) -> [u8; 2] { + debug_assert!( + (1..=MAX_RUN_LENGTH).contains(&run_length), + "delta run length cannot exceed 512 values" + ); + // [1, 512] to [0, 511] + let run_length = run_length as u16 - 1; + // 0 is special value to indicate fixed delta + let delta_width = if delta_width == 0 { + 0 + } else { + rle_v2_encode_bit_width(delta_width) + }; + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (run_length >> 8) as u8; + let encoded_length_low_bits = (run_length & 0xFF) as u8; + + let header1 = EncodingType::Delta.to_header() | delta_width << 1 | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + + [header1, header2] +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use crate::reader::decode::UnsignedEncoding; + + use super::*; + + // TODO: figure out how to write proptests for these + + #[test] + fn test_fixed_delta_positive() { + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_fixed_delta::(&mut buf, 0, 10, 100 - 2); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let expected = (0..100).map(|i| i * 10).collect::>(); + assert_eq!(expected, out); + } + + #[test] + fn test_fixed_delta_negative() { + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_fixed_delta::(&mut buf, 10_000, -63, 150 - 2); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let expected = (0..150).map(|i| 10_000 - i * 63).collect::>(); + assert_eq!(expected, out); + } + + #[test] + fn test_varying_delta_positive() { + let deltas = [ + 1, 6, 98, 12, 65, 9, 0, 0, 1, 128, 643, 129, 469, 123, 4572, 124, + ]; + let max = *deltas.iter().max().unwrap(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_varying_delta::(&mut buf, 0, 10, max, &deltas); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let mut expected = vec![0, 10]; + let mut i = 1; + for d in deltas { + expected.push(d + expected[i]); + i += 1; + } + assert_eq!(expected, out); + } + + #[test] + fn test_varying_delta_negative() { + let deltas = [ + 1, 6, 98, 12, 65, 9, 0, 0, 1, 128, 643, 129, 469, 123, 4572, 124, + ]; + let max = *deltas.iter().max().unwrap(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_varying_delta::(&mut buf, 10_000, -1, max, &deltas); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let mut expected = vec![10_000, 9_999]; + let mut i = 1; + for d in deltas { + expected.push(expected[i] - d); + i += 1; } - Ok(()) + assert_eq!(expected, out); } } diff --git a/src/reader/decode/rle_v2/direct.rs b/src/reader/decode/rle_v2/direct.rs index e20cc38f..20a6f534 100644 --- a/src/reader/decode/rle_v2/direct.rs +++ b/src/reader/decode/rle_v2/direct.rs @@ -17,35 +17,135 @@ use std::io::Read; +use bytes::{BufMut, BytesMut}; + use crate::error::{OutOfSpecSnafu, Result}; +use crate::reader::decode::rle_v2::{EncodingType, MAX_RUN_LENGTH}; use crate::reader::decode::util::{ extract_run_length_from_header, read_ints, read_u8, rle_v2_decode_bit_width, + rle_v2_encode_bit_width, write_aligned_packed_ints, }; +use crate::reader::decode::EncodingSign; -use super::{NInt, RleReaderV2}; +use super::NInt; -impl RleReaderV2 { - pub fn read_direct_values(&mut self, header: u8) -> Result<()> { - let encoded_bit_width = (header >> 1) & 0x1F; - let bit_width = rle_v2_decode_bit_width(encoded_bit_width); +pub fn read_direct_values( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + let encoded_bit_width = (header >> 1) & 0x1F; + let bit_width = rle_v2_decode_bit_width(encoded_bit_width); - if (N::BYTE_SIZE * 8) < bit_width { - return OutOfSpecSnafu { - msg: "byte width of direct encoding exceeds byte size of integer being decoded to", - } - .fail(); + if (N::BYTE_SIZE * 8) < bit_width { + return OutOfSpecSnafu { + msg: "byte width of direct encoding exceeds byte size of integer being decoded to", } + .fail(); + } + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + // Write the unpacked values and zigzag decode to result buffer + read_ints(out_ints, length, bit_width, reader)?; + + for lit in out_ints.iter_mut() { + *lit = S::zigzag_decode(*lit); + } + + Ok(()) +} + +/// `values` and `max` must be zigzag encoded. If `max` is not provided, it is derived +/// by iterating over `values`. +pub fn write_direct(writer: &mut BytesMut, values: &[N], max: Option) { + debug_assert!( + (1..=MAX_RUN_LENGTH).contains(&values.len()), + "direct run length cannot exceed 512 values" + ); + + let max = max.unwrap_or_else(|| { + // Assert guards that values is non-empty + *values.iter().max_by_key(|x| x.bits_used()).unwrap() + }); + + let bit_width = max.closest_aligned_bit_width(); + let encoded_bit_width = rle_v2_encode_bit_width(bit_width); + // From [1, 512] to [0, 511] + let encoded_length = values.len() as u16 - 1; + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (encoded_length >> 8) as u8; + let encoded_length_low_bits = (encoded_length & 0xFF) as u8; + + let header1 = + EncodingType::Direct.to_header() | (encoded_bit_width << 1) | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + + writer.put_u8(header1); + writer.put_u8(header2); + write_aligned_packed_ints(writer, bit_width, values); +} - let second_byte = read_u8(&mut self.reader)?; - let length = extract_run_length_from_header(header, second_byte); +#[cfg(test)] +mod tests { + use std::io::Cursor; - // Write the unpacked values and zigzag decode to result buffer - read_ints(&mut self.decoded_ints, length, bit_width, &mut self.reader)?; + use proptest::prelude::*; - for lit in self.decoded_ints.iter_mut() { - *lit = lit.zigzag_decode(); + use crate::reader::decode::{SignedEncoding, UnsignedEncoding}; + + use super::*; + + fn roundtrip_direct_helper(values: &[N]) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_direct(&mut buf, values, None); + let header = buf[0]; + read_direct_values::<_, _, S>(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + #[test] + fn test_direct_edge_case() { + let values: Vec = vec![109, -17809, -29946, -17285]; + let encoded = values + .iter() + .map(|&v| SignedEncoding::zigzag_encode(v)) + .collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded).unwrap(); + assert_eq!(out, values); + } + + proptest! { + #[test] + fn roundtrip_direct_i16(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_direct_i32(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); } - Ok(()) + #[test] + fn roundtrip_direct_i64(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_direct_i64_unsigned(values in prop::collection::vec(0..=i64::MAX, 1..=512)) { + let encoded = values.iter().map(|v| UnsignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, UnsignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } } } diff --git a/src/reader/decode/rle_v2/mod.rs b/src/reader/decode/rle_v2/mod.rs index 652800ef..1cdb66a5 100644 --- a/src/reader/decode/rle_v2/mod.rs +++ b/src/reader/decode/rle_v2/mod.rs @@ -15,30 +15,56 @@ // specific language governing permissions and limitations // under the License. +use std::{io::Read, marker::PhantomData}; + +use bytes::BytesMut; + +use crate::{ + error::Result, + writer::column::{EstimateMemory, PrimitiveValueEncoder}, +}; + +use self::{ + delta::{read_delta_values, write_fixed_delta, write_varying_delta}, + direct::{read_direct_values, write_direct}, + patched_base::{read_patched_base, write_patched_base}, + short_repeat::{read_short_repeat_values, write_short_repeat}, +}; + +use super::{ + util::{calculate_percentile_bits, try_read_u8}, + EncodingSign, NInt, VarintSerde, +}; + pub mod delta; pub mod direct; pub mod patched_base; pub mod short_repeat; -use std::io::Read; - -use crate::error::Result; - -use super::{util::try_read_u8, NInt}; const MAX_RUN_LENGTH: usize = 512; +/// Minimum number of repeated values required to use Short Repeat sub-encoding +const SHORT_REPEAT_MIN_LENGTH: usize = 3; +const SHORT_REPEAT_MAX_LENGTH: usize = 10; +const BASE_VALUE_LIMIT: i64 = 1 << 56; -pub struct RleReaderV2 { +// TODO: switch to read from Bytes directly? +pub struct RleReaderV2 { reader: R, decoded_ints: Vec, + /// Indexes into decoded_ints to make it act like a queue current_head: usize, + deltas: Vec, + phantom: PhantomData, } -impl RleReaderV2 { +impl RleReaderV2 { pub fn new(reader: R) -> Self { Self { reader, decoded_ints: Vec::with_capacity(MAX_RUN_LENGTH), current_head: 0, + deltas: Vec::with_capacity(MAX_RUN_LENGTH), + phantom: Default::default(), } } @@ -50,17 +76,30 @@ impl RleReaderV2 { }; match EncodingType::from_header(header) { - EncodingType::ShortRepeat => self.read_short_repeat_values(header)?, - EncodingType::Direct => self.read_direct_values(header)?, - EncodingType::PatchedBase => self.read_patched_base(header)?, - EncodingType::Delta => self.read_delta_values(header)?, + EncodingType::ShortRepeat => read_short_repeat_values::<_, _, S>( + &mut self.reader, + &mut self.decoded_ints, + header, + )?, + EncodingType::Direct => { + read_direct_values::<_, _, S>(&mut self.reader, &mut self.decoded_ints, header)? + } + EncodingType::PatchedBase => { + read_patched_base::<_, _, S>(&mut self.reader, &mut self.decoded_ints, header)? + } + EncodingType::Delta => read_delta_values::<_, _, S>( + &mut self.reader, + &mut self.decoded_ints, + &mut self.deltas, + header, + )?, } Ok(true) } } -impl Iterator for RleReaderV2 { +impl Iterator for RleReaderV2 { type Item = Result; fn next(&mut self) -> Option { @@ -84,6 +123,366 @@ impl Iterator for RleReaderV2 { } } +struct DeltaEncodingCheckResult { + base_value: N, + min: N, + max: N, + first_delta: i64, + max_delta: i64, + is_monotonic: bool, + is_fixed_delta: bool, + adjacent_deltas: Vec, +} + +/// Calculate the necessary values to determine if sequence can be delta encoded. +fn delta_encoding_check(literals: &[N]) -> DeltaEncodingCheckResult { + let base_value = literals[0]; + let mut min = base_value.min(literals[1]); + let mut max = base_value.max(literals[1]); + // Saturating should be fine here (and below) as we later check the + // difference between min & max and defer to direct encoding if it + // is too large (so the corrupt delta here won't actually be used). + // TODO: is there a more explicit way of ensuring this behaviour? + let first_delta = literals[1].as_i64().saturating_sub(base_value.as_i64()); + let mut current_delta; + let mut max_delta = 0; + + let mut is_increasing = first_delta.is_positive(); + let mut is_decreasing = first_delta.is_negative(); + let mut is_fixed_delta = true; + + let mut adjacent_deltas = vec![]; + + // We've already preprocessed the first step above + for i in 2..literals.len() { + let l1 = literals[i]; + let l0 = literals[i - 1]; + + min = min.min(l1); + max = max.max(l1); + + current_delta = l1.as_i64().saturating_sub(l0.as_i64()); + + is_increasing &= current_delta >= 0; + is_decreasing &= current_delta <= 0; + + is_fixed_delta &= current_delta == first_delta; + let current_delta = current_delta.saturating_abs(); + adjacent_deltas.push(current_delta); + max_delta = max_delta.max(current_delta); + } + let is_monotonic = is_increasing || is_decreasing; + + DeltaEncodingCheckResult { + base_value, + min, + max, + first_delta, + max_delta, + is_monotonic, + is_fixed_delta, + adjacent_deltas, + } +} + +/// Runs are guaranteed to have length > 1. +#[derive(Debug, Clone, Eq, PartialEq)] +enum RleV2EncodingState { + /// When buffer is empty and no values to encode. + Empty, + /// Special state for first value as we determine after the first + /// value whether to go fixed or variable run. + One(N), + /// Run of identical value of specified count. + FixedRun { value: N, count: usize }, + /// Run of variable values. + VariableRun { literals: Vec }, +} + +impl Default for RleV2EncodingState { + fn default() -> Self { + Self::Empty + } +} + +pub struct RleWriterV2 { + /// Stores the run length encoded sequences. + data: BytesMut, + /// Used in state machine for determining which sub-encoding + /// for a sequence to use. + state: RleV2EncodingState, + phantom: PhantomData, +} + +impl RleWriterV2 { + // Algorithm adapted from: + // https://github.com/apache/orc/blob/main/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java + + /// Process each value to build up knowledge to determine which encoding to use. We attempt + /// to identify runs of identical values (fixed runs), otherwise falling back to variable + /// runs (varying values). + /// + /// When in a fixed run state, as long as identical values are found, we keep incrementing + /// the run length up to a maximum of 512, flushing to fixed delta run if so. If we encounter + /// a differing value, we flush to short repeat or fixed delta depending on the length and + /// reset the state (if the current run is small enough, we switch direct to variable run). + /// + /// When in a variable run state, if we find 3 identical values in a row as the latest values, + /// we flush the variable run to a sub-encoding then switch to fixed run, otherwise continue + /// incrementing the run length up to a max length of 512, before flushing and resetting the + /// state. For a variable run, extra logic must take place to determine which sub-encoding to + /// use when flushing, see [`Self::determine_variable_run_encoding`] for more details. + fn process_value(&mut self, value: N) { + match &mut self.state { + // When we start, or when a run was flushed to a sub-encoding + RleV2EncodingState::Empty => { + self.state = RleV2EncodingState::One(value); + } + // Here we determine if we look like we're in a fixed run or variable run + RleV2EncodingState::One(one_value) => { + if value == *one_value { + self.state = RleV2EncodingState::FixedRun { value, count: 2 }; + } else { + // TODO: alloc here + let mut literals = Vec::with_capacity(MAX_RUN_LENGTH); + literals.push(*one_value); + literals.push(value); + self.state = RleV2EncodingState::VariableRun { literals }; + } + } + // When we're in a run of identical values + RleV2EncodingState::FixedRun { + value: fixed_value, + count, + } => { + if value == *fixed_value { + // Continue fixed run, flushing to delta when max length reached + *count += 1; + if *count == MAX_RUN_LENGTH { + write_fixed_delta::<_, S>(&mut self.data, value, 0, *count - 2); + self.state = RleV2EncodingState::Empty; + } + } else { + // If fixed run is broken by a different value. + match count { + // Note that count cannot be 0 or 1 here as that is encoded + // by Empty and One states in self.state + 2 => { + // If fixed run is smaller than short repeat then just include + // it at the start of the variable run we're switching to. + // TODO: alloc here + let mut literals = Vec::with_capacity(MAX_RUN_LENGTH); + literals.push(*fixed_value); + literals.push(*fixed_value); + literals.push(value); + self.state = RleV2EncodingState::VariableRun { literals }; + } + SHORT_REPEAT_MIN_LENGTH..=SHORT_REPEAT_MAX_LENGTH => { + // If we have enough values for a Short Repeat, then encode as + // such. + write_short_repeat::<_, S>(&mut self.data, *fixed_value, *count); + self.state = RleV2EncodingState::One(value); + } + _ => { + // Otherwise if too large, use Delta encoding. + write_fixed_delta::<_, S>(&mut self.data, *fixed_value, 0, *count - 2); + self.state = RleV2EncodingState::One(value); + } + } + } + } + // When we're in a run of varying values + RleV2EncodingState::VariableRun { literals } => { + let length = literals.len(); + let last_value = literals[length - 1]; + let second_last_value = literals[length - 2]; + if value == last_value && value == second_last_value { + // Last 3 values (including current new one) are identical. Break the current + // variable run, flushing it to a sub-encoding, then switch to a fixed run + // state. + + // Pop off the last two values (which are identical to value) and flush + // the variable run to writer + literals.truncate(literals.len() - 2); + determine_variable_run_encoding::<_, S>(&mut self.data, literals); + + self.state = RleV2EncodingState::FixedRun { value, count: 3 }; + } else { + // Continue variable run, flushing sub-encoding if max length reached + literals.push(value); + if literals.len() == MAX_RUN_LENGTH { + determine_variable_run_encoding::<_, S>(&mut self.data, literals); + self.state = RleV2EncodingState::Empty; + } + } + } + } + } + + /// Flush any buffered values to the writer. + fn flush(&mut self) { + let state = std::mem::take(&mut self.state); + match state { + RleV2EncodingState::Empty => {} + RleV2EncodingState::One(value) => { + let value = S::zigzag_encode(value); + write_direct(&mut self.data, &[value], Some(value)); + } + RleV2EncodingState::FixedRun { value, count: 2 } => { + // Direct has smallest overhead + let value = S::zigzag_encode(value); + write_direct(&mut self.data, &[value, value], Some(value)); + } + RleV2EncodingState::FixedRun { value, count } if count <= SHORT_REPEAT_MAX_LENGTH => { + // Short repeat must have length [3, 10] + write_short_repeat::<_, S>(&mut self.data, value, count); + } + RleV2EncodingState::FixedRun { value, count } => { + write_fixed_delta::<_, S>(&mut self.data, value, 0, count - 2); + } + RleV2EncodingState::VariableRun { mut literals } => { + determine_variable_run_encoding::<_, S>(&mut self.data, &mut literals); + } + } + } +} + +impl EstimateMemory for RleWriterV2 { + fn estimate_memory_size(&self) -> usize { + self.data.len() + } +} + +impl PrimitiveValueEncoder for RleWriterV2 { + fn new() -> Self { + Self { + data: BytesMut::new(), + state: RleV2EncodingState::Empty, + phantom: Default::default(), + } + } + + fn write_one(&mut self, value: N) { + self.process_value(value); + } + + fn take_inner(&mut self) -> bytes::Bytes { + self.flush(); + std::mem::take(&mut self.data).into() + } +} + +fn determine_variable_run_encoding( + writer: &mut BytesMut, + literals: &mut [N], +) { + // Direct will have smallest overhead for tiny runs + if literals.len() <= SHORT_REPEAT_MIN_LENGTH { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // Invariant: literals.len() > 3 + let DeltaEncodingCheckResult { + base_value, + min, + max, + first_delta, + max_delta, + is_monotonic, + is_fixed_delta, + adjacent_deltas, + } = delta_encoding_check(literals); + + // Quick check for delta overflow, if so just move to Direct as it has less + // overhead than Patched Base. + // TODO: should min/max be N or i64 here? + if max.checked_sub(&min).is_none() { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // Any subtractions here on are safe due to above check + + if is_fixed_delta { + write_fixed_delta::<_, S>(writer, literals[0], first_delta, literals.len() - 2); + return; + } + + // First delta used to indicate if increasing or decreasing, so must be non-zero + if first_delta != 0 && is_monotonic { + write_varying_delta::<_, S>(writer, base_value, first_delta, max_delta, &adjacent_deltas); + return; + } + + // In Java implementation, Patched Base encoding base value cannot exceed 56 + // bits in value otherwise it can overflow the maximum 8 bytes used to encode + // the value when signed MSB encoding is used (adds an extra bit). + let min = min.as_i64(); + if min.abs() >= BASE_VALUE_LIMIT && min != i64::MIN { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // TODO: another allocation here + let zigzag_literals = literals + .iter() + .map(|&v| S::zigzag_encode(v)) + .collect::>(); + let zigzagged_90_percentile_bit_width = calculate_percentile_bits(&zigzag_literals, 0.90); + // TODO: can derive from min/max? + let zigzagged_100_percentile_bit_width = calculate_percentile_bits(&zigzag_literals, 1.00); + // If variation of bit width between largest value and lower 90% of values isn't + // significant enough, just use direct encoding as patched base wouldn't be as + // efficient. + if (zigzagged_100_percentile_bit_width.saturating_sub(zigzagged_90_percentile_bit_width)) <= 1 { + // TODO: pass through the 100p here + write_direct(writer, &zigzag_literals, None); + return; + } + + // Base value for patched base is the minimum value + // Patch data values are the literals with the base value subtracted + // We use base_reduced_literals to store these base reduced literals + let mut max_data_value = 0; + let mut base_reduced_literals = vec![]; + for l in literals.iter() { + // All base reduced literals become positive here + let base_reduced_literal = l.as_i64() - min; + base_reduced_literals.push(base_reduced_literal); + max_data_value = max_data_value.max(base_reduced_literal); + } + + // Aka 100th percentile + let base_reduced_literals_max_bit_width = max_data_value.closest_aligned_bit_width(); + // 95th percentile width is used to find the 5% of values to encode with patches + let base_reduced_literals_95th_percentile_bit_width = + calculate_percentile_bits(&base_reduced_literals, 0.95); + + // Patch only if we have outliers, based on bit width + if base_reduced_literals_max_bit_width != base_reduced_literals_95th_percentile_bit_width { + write_patched_base( + writer, + &mut base_reduced_literals, + min, + base_reduced_literals_max_bit_width, + base_reduced_literals_95th_percentile_bit_width, + ); + } else { + // TODO: pass through the 100p here + write_direct(writer, &zigzag_literals, None); + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum EncodingType { ShortRepeat, @@ -104,6 +503,17 @@ impl EncodingType { _ => unreachable!(), } } + + /// Return byte with highest two bits set according to variant. + #[inline] + fn to_header(self) -> u8 { + match self { + EncodingType::Delta => 0b_1100_0000, + EncodingType::PatchedBase => 0b_1000_0000, + EncodingType::Direct => 0b_0100_0000, + EncodingType::ShortRepeat => 0b_0000_0000, + } + } } #[cfg(test)] @@ -111,63 +521,67 @@ mod tests { use std::io::Cursor; + use proptest::prelude::*; + + use crate::reader::decode::{SignedEncoding, UnsignedEncoding}; + use super::*; #[test] fn reader_test() { - let data = [2u8, 1, 64, 5, 80, 1, 1]; + let data = [2, 1, 64, 5, 80, 1, 1]; let expected = [1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); // direct - let data = [0x5eu8, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; + let data = [0x5e, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; let expected = [23713, 43806, 57005, 48879]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); // patched base let data = [ - 102u8, 9, 0, 126, 224, 7, 208, 0, 126, 79, 66, 64, 0, 127, 128, 8, 2, 0, 128, 192, 8, - 22, 0, 130, 0, 8, 42, + 102, 9, 0, 126, 224, 7, 208, 0, 126, 79, 66, 64, 0, 127, 128, 8, 2, 0, 128, 192, 8, 22, + 0, 130, 0, 8, 42, ]; let expected = [ - 2030u64, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, + 2030, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, ]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); - let data = [196u8, 9, 2, 2, 74, 40, 166]; - let expected = [2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + let data = [196, 9, 2, 2, 74, 40, 166]; + let expected = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); - let data = [0xc6u8, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; - let expected = [2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + let data = [0xc6, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; + let expected = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); - let data = [7u8, 1]; - let expected = [1u64, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + let data = [7, 1]; + let expected = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); } @@ -178,7 +592,7 @@ mod tests { let data: [u8; 3] = [0x0a, 0x27, 0x10]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, vec![10000, 10000, 10000, 10000, 10000]); @@ -190,7 +604,7 @@ mod tests { let data: [u8; 10] = [0x5e, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, vec![23713, 43806, 57005, 48879]); @@ -199,10 +613,10 @@ mod tests { #[test] fn direct_signed() { // [23713, 43806, 57005, 48879] - let data = [110u8, 3, 0, 185, 66, 1, 86, 60, 1, 189, 90, 1, 125, 222]; + let data = [110, 3, 0, 185, 66, 1, 86, 60, 1, 189, 90, 1, 125, 222]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, vec![23713, 43806, 57005, 48879]); @@ -217,7 +631,7 @@ mod tests { let data: [u8; 8] = [0xc6, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29]); @@ -226,16 +640,16 @@ mod tests { #[test] fn patched_base() { let data = vec![ - 0x8eu8, 0x09, 0x2b, 0x21, 0x07, 0xd0, 0x1e, 0x00, 0x14, 0x70, 0x28, 0x32, 0x3c, 0x46, + 0x8e, 0x09, 0x2b, 0x21, 0x07, 0xd0, 0x1e, 0x00, 0x14, 0x70, 0x28, 0x32, 0x3c, 0x46, 0x50, 0x5a, 0xfc, 0xe8, ]; let expected = vec![ - 2030u64, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, + 2030, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, ]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader .collect::>>() .unwrap() @@ -248,7 +662,7 @@ mod tests { #[test] fn patched_base_1() { let data = vec![ - 144u8, 109, 4, 164, 141, 16, 131, 194, 0, 240, 112, 64, 60, 84, 24, 3, 193, 201, 128, + 144, 109, 4, 164, 141, 16, 131, 194, 0, 240, 112, 64, 60, 84, 24, 3, 193, 201, 128, 120, 60, 33, 4, 244, 3, 193, 192, 224, 128, 56, 32, 15, 22, 131, 129, 225, 0, 112, 84, 86, 14, 8, 106, 193, 192, 228, 160, 64, 32, 14, 213, 131, 193, 192, 240, 121, 124, 30, 18, 9, 132, 67, 0, 224, 120, 60, 28, 14, 32, 132, 65, 192, 240, 160, 56, 61, 91, 7, 3, @@ -278,9 +692,52 @@ mod tests { ]; let cursor = Cursor::new(data); - let reader = RleReaderV2::::new(cursor); + let reader = RleReaderV2::::new(cursor); let a = reader.collect::>>().unwrap(); assert_eq!(a, expected); } + + // TODO: be smarter about prop test here, generate different patterns of ints + // - e.g. increasing/decreasing sequences, outliers, repeated + // - to ensure all different subencodings are being used (and might make shrinking better) + // currently 99% of the time here the subencoding will be Direct due to random generation + + fn roundtrip_helper(values: &[N]) -> Result> { + let mut writer = RleWriterV2::::new(); + writer.write_slice(values); + let data = writer.take_inner(); + + let cursor = Cursor::new(data); + let reader = RleReaderV2::::new(cursor); + let out = reader.collect::>>()?; + + Ok(out) + } + + proptest! { + #[test] + fn roundtrip_i16(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i32(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i64(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i64_unsigned(values in prop::collection::vec(0..=i64::MAX, 1..1_000)) { + let out = roundtrip_helper::<_, UnsignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + } } diff --git a/src/reader/decode/rle_v2/patched_base.rs b/src/reader/decode/rle_v2/patched_base.rs index 157ce849..0f26471d 100644 --- a/src/reader/decode/rle_v2/patched_base.rs +++ b/src/reader/decode/rle_v2/patched_base.rs @@ -17,149 +17,401 @@ use std::io::Read; +use bytes::{BufMut, BytesMut}; use snafu::{OptionExt, ResultExt}; -use super::{NInt, RleReaderV2}; +use super::{EncodingType, NInt}; use crate::error::{IoSnafu, OutOfSpecSnafu, Result}; use crate::reader::decode::util::{ - extract_run_length_from_header, read_ints, read_u8, rle_v2_decode_bit_width, + encode_bit_width, extract_run_length_from_header, get_closest_fixed_bits, read_ints, read_u8, + rle_v2_decode_bit_width, signed_msb_encode, write_packed_ints, }; +use crate::reader::decode::{EncodingSign, VarintSerde}; -/// Patches (gap + actual patch bits) width are ceil'd here. -/// -/// Not mentioned in ORC specification, but happens in their implementation. -fn get_closest_fixed_bits(width: usize) -> usize { - match width { - 1..=24 => width, - 25..=26 => 26, - 27..=28 => 28, - 29..=30 => 30, - 31..=32 => 32, - 33..=40 => 40, - 41..=48 => 48, - 49..=56 => 56, - 57..=64 => 64, - _ => unreachable!(), +pub fn read_patched_base( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + let encoded_bit_width = (header >> 1) & 0x1F; + let value_bit_width = rle_v2_decode_bit_width(encoded_bit_width); + let value_bit_width_u32 = u32::try_from(value_bit_width).or_else(|_| { + OutOfSpecSnafu { + msg: "value_bit_width overflows u32", + } + .fail() + })?; + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + let third_byte = read_u8(reader)?; + let fourth_byte = read_u8(reader)?; + + // Base width is one off + let base_byte_width = ((third_byte >> 5) & 0x07) as usize + 1; + + let patch_bit_width = rle_v2_decode_bit_width(third_byte & 0x1f); + + // Patch gap width is one off + let patch_gap_bit_width = ((fourth_byte >> 5) & 0x07) as usize + 1; + + let patch_total_bit_width = patch_bit_width + patch_gap_bit_width; + if patch_total_bit_width > 64 { + return OutOfSpecSnafu { + msg: "combined patch width and patch gap width cannot be greater than 64 bits", + } + .fail(); } -} -impl RleReaderV2 { - pub fn read_patched_base(&mut self, header: u8) -> Result<()> { - let encoded_bit_width = (header >> 1) & 0x1F; - let value_bit_width = rle_v2_decode_bit_width(encoded_bit_width); - let value_bit_width_u32 = u32::try_from(value_bit_width).or_else(|_| { - OutOfSpecSnafu { - msg: "value_bit_width overflows u32", - } - .fail() - })?; + let patch_list_length = (fourth_byte & 0x1f) as usize; - let second_byte = read_u8(&mut self.reader)?; - let length = extract_run_length_from_header(header, second_byte); + let mut buffer = N::empty_byte_array(); + // Read into back part of buffer since is big endian. + // So if smaller than N::BYTE_SIZE bytes, most significant bytes will be 0. + reader + .read_exact(&mut buffer.as_mut()[N::BYTE_SIZE - base_byte_width..]) + .context(IoSnafu)?; + let base = N::from_be_bytes(buffer); + let base = S::decode_signed_msb(base, base_byte_width); - let third_byte = read_u8(&mut self.reader)?; - let fourth_byte = read_u8(&mut self.reader)?; + // Get data values + // TODO: this should read into Vec + // as base reduced values can exceed N::max() + // (e.g. if base is N::min() and this is signed type) + read_ints(out_ints, length, value_bit_width, reader)?; - // Base width is one off - let base_byte_width = ((third_byte >> 5) & 0x07) as usize + 1; + // Get patches that will be applied to base values. + // At most they might be u64 in width (because of check above). + let ceil_patch_total_bit_width = get_closest_fixed_bits(patch_total_bit_width); + let mut patches: Vec = Vec::with_capacity(patch_list_length); + read_ints( + &mut patches, + patch_list_length, + ceil_patch_total_bit_width, + reader, + )?; - let patch_bit_width = rle_v2_decode_bit_width(third_byte & 0x1f); + // TODO: document and explain below logic + let mut patch_index = 0; + let patch_mask = (1 << patch_bit_width) - 1; + let mut current_gap = patches[patch_index] >> patch_bit_width; + let mut current_patch = patches[patch_index] & patch_mask; + let mut actual_gap = 0; - // Patch gap width is one off - let patch_gap_bit_width = ((fourth_byte >> 5) & 0x07) as usize + 1; + while current_gap == 255 && current_patch == 0 { + actual_gap += 255; + patch_index += 1; + current_gap = patches[patch_index] >> patch_bit_width; + current_patch = patches[patch_index] & patch_mask; + } + actual_gap += current_gap; - let patch_total_bit_width = patch_bit_width + patch_gap_bit_width; - if patch_total_bit_width > 64 { - return OutOfSpecSnafu { - msg: "combined patch width and patch gap width cannot be greater than 64 bits", - } - .fail(); - } + for (idx, value) in out_ints.iter_mut().enumerate() { + if idx == actual_gap as usize { + let patch_bits = + current_patch + .checked_shl(value_bit_width_u32) + .context(OutOfSpecSnafu { + msg: "Overflow while shifting patch bits by value_bit_width", + })?; + // Safe conversion without loss as we check the bit width prior + let patch_bits = N::from_i64(patch_bits); + let patched_value = *value | patch_bits; + + *value = patched_value.checked_add(&base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; - let patch_list_length = (fourth_byte & 0x1f) as usize; - - let mut buffer = N::empty_byte_array(); - // Read into back part of buffer since is big endian. - // So if smaller than N::BYTE_SIZE bytes, most significant bytes will be 0. - self.reader - .read_exact(&mut buffer.as_mut()[N::BYTE_SIZE - base_byte_width..]) - .context(IoSnafu)?; - let base = N::from_be_bytes(buffer).decode_signed_from_msb(base_byte_width); - - // Get data values - read_ints( - &mut self.decoded_ints, - length, - value_bit_width, - &mut self.reader, - )?; - - // Get patches that will be applied to base values. - // At most they might be u64 in width (because of check above). - let ceil_patch_total_bit_width = get_closest_fixed_bits(patch_total_bit_width); - let mut patches: Vec = Vec::with_capacity(patch_list_length); - read_ints( - &mut patches, - patch_list_length, - ceil_patch_total_bit_width, - &mut self.reader, - )?; - - // TODO: document and explain below logic - let mut patch_index = 0; - let patch_mask = (1 << patch_bit_width) - 1; - let mut current_gap = patches[patch_index] >> patch_bit_width; - let mut current_patch = patches[patch_index] & patch_mask; - let mut actual_gap = 0; - - while current_gap == 255 && current_patch == 0 { - actual_gap += 255; patch_index += 1; - current_gap = patches[patch_index] >> patch_bit_width; - current_patch = patches[patch_index] & patch_mask; - } - actual_gap += current_gap; - - for (idx, value) in self.decoded_ints.iter_mut().enumerate() { - if idx == actual_gap as usize { - let patch_bits = - current_patch - .checked_shl(value_bit_width_u32) - .context(OutOfSpecSnafu { - msg: "Overflow while shifting patch bits by value_bit_width", - })?; - // Safe conversion without loss as we check the bit width prior - let patch_bits = N::from_u64(patch_bits); - let patched_value = *value | patch_bits; - - *value = patched_value.checked_add(&base).context(OutOfSpecSnafu { - msg: "over/underflow when decoding patched base integer", - })?; - - patch_index += 1; - - if patch_index < patches.len() { + + if patch_index < patches.len() { + current_gap = patches[patch_index] >> patch_bit_width; + current_patch = patches[patch_index] & patch_mask; + actual_gap = 0; + + while current_gap == 255 && current_patch == 0 { + actual_gap += 255; + patch_index += 1; current_gap = patches[patch_index] >> patch_bit_width; current_patch = patches[patch_index] & patch_mask; - actual_gap = 0; + } + + actual_gap += current_gap; + actual_gap += idx as i64; + } + } else { + *value = value.checked_add(&base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; + } + } + + Ok(()) +} + +fn derive_patches( + base_reduced_literals: &mut [i64], + patch_bits_width: usize, + max_base_value_bit_width: usize, +) -> (Vec, usize) { + // Values with bits exceeding this mask will be patched. + let max_base_value_mask = (1 << max_base_value_bit_width) - 1; + // Used to encode gaps greater than 255 (no patch bits, just used for gap). + let jump_patch = 255 << patch_bits_width; + + // At most 5% of values that must be patched. + // (Since max buffer length is 512, at most this can be 26) + let mut patches: Vec = Vec::with_capacity(26); + let mut last_patch_index = 0; + // Needed to determine bit width of patch gaps to encode in header. + let mut max_gap = 0; + for (idx, lit) in base_reduced_literals + .iter_mut() + .enumerate() + // Find all values which need to be patched (the 5% of values larger than the others) + .filter(|(_, &mut lit)| lit > max_base_value_mask) + { + // Convert to unsigned to ensure leftmost bits are 0 + let patch_bits = (*lit as u64) >> max_base_value_bit_width; + + // Gaps can at most be 255 (since gap bit width cannot exceed 8; in spec it states + // the header has only 3 bits to encode the size of the patch gap, so 8 is the largest + // value). + // + // Therefore if gap is found greater than 255 then we insert an empty patch with gap of 255 + // (and the empty patch will have no effect when reading as patching using empty bits will + // be a no-op). + // + // Extra special case if gap is 511, we unroll into inserting two empty patches (instead of + // relying on a loop). Max buffer size cannot exceed 512 so this is the largest possible gap. + let gap = idx - last_patch_index; + let gap = if gap == 511 { + max_gap = 255; + patches.push(jump_patch); + patches.push(jump_patch); + 1 + } else if gap > 255 { + max_gap = 255; + patches.push(jump_patch); + gap - 255 + } else { + max_gap = max_gap.max(gap); + gap + }; + let patch = patch_bits | (gap << patch_bits_width) as u64; + patches.push(patch as i64); + + last_patch_index = idx; + + // Stripping patch bits + *lit &= max_base_value_mask; + } - while current_gap == 255 && current_patch == 0 { - actual_gap += 255; - patch_index += 1; - current_gap = patches[patch_index] >> patch_bit_width; - current_patch = patches[patch_index] & patch_mask; - } + // If only one element to be patched, and is the very first one. + // Patch gap width minimum is 1. + let patch_gap_width = if max_gap == 0 { + 1 + } else { + (max_gap as i16).bits_used() + }; - actual_gap += current_gap; - actual_gap += idx as u64; + (patches, patch_gap_width) +} + +pub fn write_patched_base( + writer: &mut BytesMut, + base_reduced_literals: &mut [i64], + base: i64, + brl_100p_bit_width: usize, + brl_95p_bit_width: usize, +) { + let patch_bits_width = brl_100p_bit_width - brl_95p_bit_width; + let patch_bits_width = get_closest_fixed_bits(patch_bits_width); + // According to spec, each patch (patch bits + gap) must be <= 64 bits. + // So we adjust accordingly here if we hit this edge case where patch_width + // is 64 bits (which would have no space for gap). + let (patch_bits_width, brl_95p_bit_width) = if patch_bits_width == 64 { + (56, 8) + } else { + (patch_bits_width, brl_95p_bit_width) + }; + + let (patches, patch_gap_width) = + derive_patches(base_reduced_literals, patch_bits_width, brl_95p_bit_width); + + let encoded_bit_width = encode_bit_width(brl_95p_bit_width) as u8; + + // [1, 512] to [0, 511] + let run_length = base_reduced_literals.len() as u16 - 1; + + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (run_length >> 8) as u8; + let encoded_length_low_bits = (run_length & 0xFF) as u8; + + // +1 to account for sign bit + let base_bit_width = get_closest_fixed_bits(base.abs().bits_used() + 1); + let base_byte_width = base_bit_width.div_ceil(8).max(1); + let msb_encoded_min = signed_msb_encode(base, base_byte_width); + // [1, 8] to [0, 7] + let encoded_base_width = base_byte_width - 1; + let encoded_patch_bits_width = encode_bit_width(patch_bits_width); + let encoded_patch_gap_width = patch_gap_width - 1; + + let header1 = + EncodingType::PatchedBase.to_header() | encoded_bit_width << 1 | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + let header3 = (encoded_base_width as u8) << 5 | encoded_patch_bits_width as u8; + let header4 = (encoded_patch_gap_width as u8) << 5 | patches.len() as u8; + writer.put_slice(&[header1, header2, header3, header4]); + + // Write out base value as big endian bytes + let base_bytes = msb_encoded_min.to_be_bytes(); + // 8 since i64 + let base_bytes = &base_bytes.as_ref()[8 - base_byte_width..]; + writer.put_slice(base_bytes); + + // Writing base reduced literals followed by patch list + let bit_width = get_closest_fixed_bits(brl_95p_bit_width); + write_packed_ints(writer, bit_width, base_reduced_literals); + let bit_width = get_closest_fixed_bits(patch_gap_width + patch_bits_width); + write_packed_ints(writer, bit_width, &patches); +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use proptest::prelude::*; + + use crate::reader::decode::{util::calculate_percentile_bits, SignedEncoding}; + + use super::*; + + #[derive(Debug)] + struct PatchesStrategy { + base: i64, + base_reduced_values: Vec, + patches: Vec, + patch_indices: Vec, + base_index: usize, + } + + fn patches_strategy() -> impl Strategy { + // TODO: clean this up a bit + prop::collection::vec(0..1_000_000_i64, 20..=512) + .prop_flat_map(|base_reduced_values| { + let base_strategy = -1_000_000_000..1_000_000_000_i64; + let max_patches_length = (base_reduced_values.len() as f32 * 0.05).ceil() as usize; + let base_reduced_values_strategy = Just(base_reduced_values); + let patches_strategy = prop::collection::vec( + 1_000_000_000_000_000..1_000_000_000_000_000_000_i64, + 1..=max_patches_length, + ); + ( + base_strategy, + base_reduced_values_strategy, + patches_strategy, + ) + }) + .prop_flat_map(|(base, base_reduced_values, patches)| { + let base_strategy = Just(base); + // +1 for the base index, so we don't have to deduplicate separately + let patch_indices_strategy = + prop::collection::hash_set(0..base_reduced_values.len(), patches.len() + 1); + let base_reduced_values_strategy = Just(base_reduced_values); + let patches_strategy = Just(patches); + ( + base_strategy, + base_reduced_values_strategy, + patches_strategy, + patch_indices_strategy, + ) + }) + .prop_map(|(base, base_reduced_values, patches, patch_indices)| { + let mut patch_indices = patch_indices.into_iter().collect::>(); + let base_index = patch_indices.pop().unwrap(); + PatchesStrategy { + base, + base_reduced_values, + patches, + patch_indices, + base_index, } - } else { - *value = value.checked_add(&base).context(OutOfSpecSnafu { - msg: "over/underflow when decoding patched base integer", - })?; - } + }) + } + + fn roundtrip_patched_base_helper( + base_reduced_literals: &[i64], + base: i64, + brl_95p_bit_width: usize, + brl_100p_bit_width: usize, + ) -> Result> { + let mut base_reduced_literals = base_reduced_literals.to_vec(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_patched_base( + &mut buf, + &mut base_reduced_literals, + base, + brl_100p_bit_width, + brl_95p_bit_width, + ); + let header = buf[0]; + read_patched_base::(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + fn form_patched_base_values( + base_reduced_values: &[i64], + patches: &[i64], + patch_indices: &[usize], + base_index: usize, + ) -> Vec { + let mut base_reduced_values = base_reduced_values.to_vec(); + for (&patch, &index) in patches.iter().zip(patch_indices) { + base_reduced_values[index] = patch; } + // Need at least one zero to represent the base + base_reduced_values[base_index] = 0; + base_reduced_values + } - Ok(()) + fn form_expected_values(base: i64, base_reduced_values: &[i64]) -> Vec { + base_reduced_values.iter().map(|&v| base + v).collect() + } + + proptest! { + #[test] + fn roundtrip_patched_base_i64(patches_strategy in patches_strategy()) { + let PatchesStrategy { + base, + base_reduced_values, + patches, + patch_indices, + base_index + } = patches_strategy; + let base_reduced_values = form_patched_base_values( + &base_reduced_values, + &patches, + &patch_indices, + base_index + ); + let expected = form_expected_values(base, &base_reduced_values); + let brl_95p_bit_width = calculate_percentile_bits(&base_reduced_values, 0.95); + let brl_100p_bit_width = calculate_percentile_bits(&base_reduced_values, 1.0); + // Need enough outliers to require patching + prop_assume!(brl_95p_bit_width != brl_100p_bit_width); + let actual = roundtrip_patched_base_helper( + &base_reduced_values, + base, + brl_95p_bit_width, + brl_100p_bit_width + )?; + prop_assert_eq!(actual, expected); + } } } diff --git a/src/reader/decode/rle_v2/short_repeat.rs b/src/reader/decode/rle_v2/short_repeat.rs index 6d63e82e..ce7b09cb 100644 --- a/src/reader/decode/rle_v2/short_repeat.rs +++ b/src/reader/decode/rle_v2/short_repeat.rs @@ -17,50 +17,123 @@ use std::io::Read; +use bytes::{BufMut, BytesMut}; use snafu::ResultExt; -use crate::error::{IoSnafu, OutOfSpecSnafu, Result}; +use crate::{ + error::{IoSnafu, OutOfSpecSnafu, Result}, + reader::decode::{rle_v2::EncodingType, EncodingSign}, +}; -use super::{NInt, RleReaderV2}; +use super::{NInt, SHORT_REPEAT_MIN_LENGTH}; -/// Minimum number of repeated values required to use this sub-encoding -const MIN_REPEAT_SIZE: usize = 3; +pub fn read_short_repeat_values( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + // Header byte: + // + // eeww_wccc + // 7 0 LSB + // + // ee = Sub-encoding bits, always 00 + // www = Value width bits + // ccc = Repeat count bits -impl RleReaderV2 { - pub fn read_short_repeat_values(&mut self, header: u8) -> Result<()> { - // Header byte: - // - // eeww_wccc - // 7 0 LSB - // - // ee = Sub-encoding bits, always 00 - // www = Value width bits - // ccc = Repeat count bits + let byte_width = (header >> 3) & 0x07; // Encoded as 0 to 7 + let byte_width = byte_width as usize + 1; // Decode to 1 to 8 bytes - let byte_width = (header >> 3) & 0x07; // Encoded as 0 to 7 - let byte_width = byte_width as usize + 1; // Decode to 1 to 8 bytes - - if N::BYTE_SIZE < byte_width { - return OutOfSpecSnafu { - msg: "byte width of short repeat encoding exceeds byte size of integer being decoded to", - } - .fail(); + if N::BYTE_SIZE < byte_width { + return OutOfSpecSnafu { + msg: + "byte width of short repeat encoding exceeds byte size of integer being decoded to", } + .fail(); + } + + let run_length = (header & 0x07) as usize + SHORT_REPEAT_MIN_LENGTH; + + // Value that is being repeated is encoded as value_byte_width bytes in big endian format + let mut buffer = N::empty_byte_array(); + // Read into back part of buffer since is big endian. + // So if smaller than N::BYTE_SIZE bytes, most significant bytes will be 0. + reader + .read_exact(&mut buffer.as_mut()[N::BYTE_SIZE - byte_width..]) + .context(IoSnafu)?; + let val = N::from_be_bytes(buffer); + let val = S::zigzag_decode(val); + + out_ints.extend(std::iter::repeat(val).take(run_length)); + + Ok(()) +} + +pub fn write_short_repeat(writer: &mut BytesMut, value: N, count: usize) { + debug_assert!((SHORT_REPEAT_MIN_LENGTH..=10).contains(&count)); + + let value = S::zigzag_encode(value); + + // Take max in case value = 0 + let byte_size = value.bits_used().div_ceil(8).max(1) as u8; + let encoded_byte_size = byte_size - 1; + let encoded_count = (count - SHORT_REPEAT_MIN_LENGTH) as u8; + + let header = EncodingType::ShortRepeat.to_header() | (encoded_byte_size << 3) | encoded_count; + let bytes = value.to_be_bytes(); + let bytes = &bytes.as_ref()[N::BYTE_SIZE - byte_size as usize..]; + + writer.put_u8(header); + writer.put_slice(bytes); +} - let run_length = (header & 0x07) as usize + MIN_REPEAT_SIZE; +#[cfg(test)] +mod tests { + use std::io::Cursor; - // Value that is being repeated is encoded as value_byte_width bytes in big endian format - let mut buffer = N::empty_byte_array(); - // Read into back part of buffer since is big endian. - // So if smaller than N::BYTE_SIZE bytes, most significant bytes will be 0. - self.reader - .read_exact(&mut buffer.as_mut()[N::BYTE_SIZE - byte_width..]) - .context(IoSnafu)?; - let val = N::from_be_bytes(buffer).zigzag_decode(); + use proptest::prelude::*; - self.decoded_ints - .extend(std::iter::repeat(val).take(run_length)); + use crate::reader::decode::{SignedEncoding, UnsignedEncoding}; - Ok(()) + use super::*; + + fn roundtrip_short_repeat_helper( + value: N, + count: usize, + ) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_short_repeat::<_, S>(&mut buf, value, count); + let header = buf[0]; + read_short_repeat_values::<_, _, S>(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + proptest! { + #[test] + fn roundtrip_short_repeat_i16(value: i16, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i32(value: i32, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i64(value: i64, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i64_unsigned(value in 0..=i64::MAX, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, UnsignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } } } diff --git a/src/reader/decode/timestamp.rs b/src/reader/decode/timestamp.rs index 8e3dcf5a..c65e4f03 100644 --- a/src/reader/decode/timestamp.rs +++ b/src/reader/decode/timestamp.rs @@ -27,7 +27,7 @@ const NANOSECONDS_IN_SECOND: i64 = 1_000_000_000; pub struct TimestampIterator> { base_from_epoch: i64, data: Box> + Send>, - secondary: Box> + Send>, + secondary: Box> + Send>, _marker: PhantomData<(T, Item)>, } @@ -35,7 +35,7 @@ impl> TimestampIterator { pub fn new( base_from_epoch: i64, data: Box> + Send>, - secondary: Box> + Send>, + secondary: Box> + Send>, ) -> Self { Self { base_from_epoch, @@ -61,10 +61,11 @@ impl> Iterator for TimestampIterator< fn decode_timestamp>( base: i64, seconds_since_orc_base: Result, - nanoseconds: Result, + nanoseconds: Result, ) -> Result> { let data = seconds_since_orc_base?; - let mut nanoseconds = nanoseconds?; + // TODO + let mut nanoseconds = nanoseconds? as u64; // Last 3 bits indicate how many trailing zeros were truncated let zeros = nanoseconds & 0x7; nanoseconds >>= 3; diff --git a/src/reader/decode/util.rs b/src/reader/decode/util.rs index 67d65942..cfd0a0ee 100644 --- a/src/reader/decode/util.rs +++ b/src/reader/decode/util.rs @@ -17,12 +17,13 @@ use std::io::Read; +use bytes::{BufMut, BytesMut}; use num::Signed; use snafu::{OptionExt, ResultExt}; use crate::error::{self, IoSnafu, Result, VarintTooLargeSnafu}; -use super::NInt; +use super::{EncodingSign, NInt, VarintSerde}; /// Read single byte #[inline] @@ -233,7 +234,151 @@ fn unrolled_unpack_byte_aligned( Ok(()) } -/// Encoding table for RLEv2 sub-encodings bit width. +/// Write bit packed integers, where we expect the `bit_width` to be aligned +/// by [`get_closest_aligned_bit_width`], and we write the bytes as big endian. +pub fn write_aligned_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + bit_width == 1 || bit_width == 2 || bit_width == 4 || bit_width % 8 == 0, + "bit_width must be 1, 2, 4 or a multiple of 8" + ); + match bit_width { + 1 => unrolled_pack_1(writer, values), + 2 => unrolled_pack_2(writer, values), + 4 => unrolled_pack_4(writer, values), + n => unrolled_pack_bytes(writer, n / 8, values), + } +} + +/// Similar to [`write_aligned_packed_ints`] but the `bit_width` allows any value +/// in the range `[1, 64]`. +pub fn write_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + (1..=64).contains(&bit_width), + "bit_width must be in the range [1, 64]" + ); + if bit_width == 1 || bit_width == 2 || bit_width == 4 || bit_width % 8 == 0 { + write_aligned_packed_ints(writer, bit_width, values); + } else { + write_unaligned_packed_ints(writer, bit_width, values) + } +} + +fn write_unaligned_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + (1..=64).contains(&bit_width), + "bit_width must be in the range [1, 64]" + ); + let mut bits_left = 8; + let mut current_byte = 0; + for &value in values { + let mut bits_to_write = bit_width; + // This loop will write 8 bits at a time into current_byte, except for the + // first iteration after a previous value has been written. The previous + // value may have bits left over, still in current_byte, which is represented + // by 8 - bits_left (aka bits_left is the amount of space left in current_byte). + while bits_to_write > bits_left { + // Writing from most significant bits first. + let shift = bits_to_write - bits_left; + // Shift so bits to write are in least significant 8 bits. + // Masking out higher bits so conversion to u8 is safe. + let bits = value.unsigned_shr(shift as u32) & N::from_u8(0xFF); + current_byte |= bits.to_u8().unwrap(); + bits_to_write -= bits_left; + + writer.put_u8(current_byte); + current_byte = 0; + bits_left = 8; + } + + // If there are trailing bits then include these into current_byte. + bits_left -= bits_to_write; + let bits = (value << bits_left) & N::from_u8(0xFF); + current_byte |= bits.to_u8().unwrap(); + + if bits_left == 0 { + writer.put_u8(current_byte); + current_byte = 0; + bits_left = 8; + } + } + // Flush any remaining bits + if bits_left != 8 { + writer.put_u8(current_byte); + } +} + +fn unrolled_pack_1(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(8); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x01; + let n2 = chunk[1].to_u8().unwrap() & 0x01; + let n3 = chunk[2].to_u8().unwrap() & 0x01; + let n4 = chunk[3].to_u8().unwrap() & 0x01; + let n5 = chunk[4].to_u8().unwrap() & 0x01; + let n6 = chunk[5].to_u8().unwrap() & 0x01; + let n7 = chunk[6].to_u8().unwrap() & 0x01; + let n8 = chunk[7].to_u8().unwrap() & 0x01; + let byte = + (n1 << 7) | (n2 << 6) | (n3 << 5) | (n4 << 4) | (n5 << 3) | (n6 << 2) | (n7 << 1) | n8; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let mut byte = 0; + for (i, n) in remainder.iter().enumerate() { + let n = n.to_u8().unwrap(); + byte |= (n & 0x03) << (7 - i); + } + writer.put_u8(byte); + } +} + +fn unrolled_pack_2(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(4); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x03; + let n2 = chunk[1].to_u8().unwrap() & 0x03; + let n3 = chunk[2].to_u8().unwrap() & 0x03; + let n4 = chunk[3].to_u8().unwrap() & 0x03; + let byte = (n1 << 6) | (n2 << 4) | (n3 << 2) | n4; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let mut byte = 0; + for (i, n) in remainder.iter().enumerate() { + let n = n.to_u8().unwrap(); + byte |= (n & 0x03) << (6 - i * 2); + } + writer.put_u8(byte); + } +} + +fn unrolled_pack_4(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(2); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x0F; + let n2 = chunk[1].to_u8().unwrap() & 0x0F; + let byte = (n1 << 4) | n2; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let byte = remainder[0].to_u8().unwrap() & 0x0F; + let byte = byte << 4; + writer.put_u8(byte); + } +} + +fn unrolled_pack_bytes(writer: &mut BytesMut, byte_size: usize, values: &[N]) { + for num in values { + let bytes = num.to_be_bytes(); + let bytes = &bytes.as_ref()[N::BYTE_SIZE - byte_size..]; + writer.put_slice(bytes); + } +} + +/// Decoding table for RLEv2 sub-encodings bit width. /// /// Used by Direct, Patched Base and Delta. By default this assumes non-delta /// (0 maps to 1), so Delta handles this discrepancy at the caller side. @@ -255,8 +400,96 @@ pub fn rle_v2_decode_bit_width(encoded: u8) -> usize { } } +/// Inverse of [`rle_v2_decode_bit_width`]. +/// +/// Assumes supported bit width is passed in. Will panic on invalid +/// inputs that aren't defined in the ORC bit width encoding table +/// (such as 50). +pub fn rle_v2_encode_bit_width(width: usize) -> u8 { + debug_assert!(width <= 64, "bit width cannot exceed 64"); + match width { + 64 => 31, + 56 => 30, + 48 => 29, + 40 => 28, + 32 => 27, + 30 => 26, + 28 => 25, + 26 => 24, + 1..=24 => width as u8 - 1, + _ => unreachable!(), + } +} + +pub fn get_closest_fixed_bits(n: usize) -> usize { + match n { + 0 => 1, + 1..=24 => n, + 25..=26 => 26, + 27..=28 => 28, + 29..=30 => 30, + 31..=32 => 32, + 33..=40 => 40, + 41..=48 => 48, + 49..=56 => 56, + 57..=64 => 64, + _ => unreachable!(), + } +} + +pub fn encode_bit_width(n: usize) -> usize { + let n = get_closest_fixed_bits(n); + match n { + 1..=24 => n - 1, + 25..=26 => 24, + 27..=28 => 25, + 29..=30 => 26, + 31..=32 => 27, + 33..=40 => 28, + 41..=48 => 29, + 49..=56 => 30, + 57..=64 => 31, + _ => unreachable!(), + } +} + +fn decode_bit_width(n: usize) -> usize { + match n { + 0..=23 => n + 1, + 24 => 26, + 25 => 28, + 26 => 30, + 27 => 32, + 28 => 40, + 29 => 48, + 30 => 56, + 31 => 64, + _ => unreachable!(), + } +} + +/// Converts width of 64 bits or less to an aligned width, either rounding +/// up to the nearest multiple of 8, or rounding up to 1, 2 or 4. +pub fn get_closest_aligned_bit_width(width: usize) -> usize { + debug_assert!(width <= 64, "bit width cannot exceed 64"); + match width { + 0..=1 => 1, + 2 => 2, + 3..=4 => 4, + 5..=8 => 8, + 9..=16 => 16, + 17..=24 => 24, + 25..=32 => 32, + 33..=40 => 40, + 41..=48 => 48, + 49..=54 => 56, + 55..=64 => 64, + _ => unreachable!(), + } +} + /// Decode Base 128 Unsigned Varint -fn read_varint_n(r: &mut R) -> Result { +fn read_varint(reader: &mut R) -> Result { // Varints are encoded as sequence of bytes. // Where the high bit of a byte is set to 1 if the varint // continues into the next byte. Eventually it should terminate @@ -264,7 +497,7 @@ fn read_varint_n(r: &mut R) -> Result { let mut num = N::zero(); let mut offset = 0; loop { - let byte = read_u8(r)?; + let byte = read_u8(reader)?; let is_last_byte = byte & 0x80 == 0; let without_continuation_bit = byte & 0x7F; num |= N::from_u8(without_continuation_bit) @@ -281,58 +514,43 @@ fn read_varint_n(r: &mut R) -> Result { Ok(num) } -pub fn read_varint_zigzagged(r: &mut R) -> Result { - Ok(read_varint_n::(r)?.zigzag_decode()) -} +/// Encode Base 128 Unsigned Varint +fn write_varint(writer: &mut BytesMut, value: N) { + // Take max in case value = 0. + // Divide by 7 as high bit is always used as continuation flag. + let byte_size = value.bits_used().div_ceil(7).max(1); + // By default we'll have continuation bit set + // TODO: can probably do without Vec allocation? + let mut bytes = vec![0x80; byte_size]; + // Then just clear for the last one + let i = bytes.len() - 1; + bytes[i] = 0; + + // Encoding 7 bits at a time into bytes + let mask = N::from_u8(0x7F); + for (i, b) in bytes.iter_mut().enumerate() { + let shift = i * 7; + *b |= ((value >> shift) & mask).to_u8().unwrap(); + } -#[derive(Clone, Copy, Debug)] -pub(crate) enum AbsVarint { - Negative(N), - Positive(N), + writer.put_slice(&bytes); } -/// This trait (and its implementations) is intended to make generic the -/// behaviour of addition/subtraction over NInt. -// TODO: probably can be done in a cleaner/better way -pub trait AccumulateOp { - fn acc(a: N, b: N) -> Option; +pub fn read_varint_zigzagged( + reader: &mut R, +) -> Result { + let unsigned = read_varint::(reader)?; + Ok(S::zigzag_decode(unsigned)) } -pub struct AddOp; - -impl AccumulateOp for AddOp { - fn acc(a: N, b: N) -> Option { - a.checked_add(&b) - } -} - -pub struct SubOp; - -impl AccumulateOp for SubOp { - fn acc(a: N, b: N) -> Option { - a.checked_sub(&b) - } -} - -/// Special case for delta where we need to parse as NInt, but it's signed. -/// So we calculate the absolute value and return the sign via enum variants. -pub fn read_abs_varint(r: &mut R) -> Result> { - let num = read_varint_n::(r)?; - let is_negative = (num & N::one()) == N::one(); - // Unsigned >> to ensure new MSB is always 0 and not 1 - let num = num.unsigned_shr(1); - if is_negative { - // Because of two's complement - let num = num + N::one(); - Ok(AbsVarint::Negative(num)) - } else { - Ok(AbsVarint::Positive(num)) - } +pub fn write_varint_zigzagged(writer: &mut BytesMut, value: N) { + let value = S::zigzag_encode(value); + write_varint(writer, value) } /// Zigzag encoding stores the sign bit in the least significant bit. #[inline] -pub fn signed_zigzag_decode(encoded: N) -> N { +pub fn signed_zigzag_decode(encoded: N) -> N { let without_sign_bit = encoded.unsigned_shr(1); let sign_bit = encoded & N::one(); // If positive, sign_bit is 0 @@ -344,13 +562,20 @@ pub fn signed_zigzag_decode(encoded: N) -> N { without_sign_bit ^ -sign_bit } +/// Opposite of [`signed_zigzag_decode`]. +#[inline] +pub fn signed_zigzag_encode(value: N) -> N { + let l = N::BYTE_SIZE * 8 - 1; + (value << 1_usize) ^ (value >> l) +} + /// MSB indicates if value is negated (1 if negative, else positive). Note we /// take the MSB of the encoded number which might be smaller than N, hence /// we need the encoded number byte size to find this MSB. #[inline] pub fn signed_msb_decode(encoded: N, encoded_byte_size: usize) -> N { let msb_mask = N::one() << (encoded_byte_size * 8 - 1); - let is_positive = msb_mask & encoded == N::zero(); + let is_positive = (encoded & msb_mask) == N::zero(); let clean_sign_bit_mask = !msb_mask; let encoded = encoded & clean_sign_bit_mask; if is_positive { @@ -360,13 +585,58 @@ pub fn signed_msb_decode(encoded: N, encoded_byte_size: usize) } } +/// Inverse of [`signed_msb_decode`]. +#[inline] +// TODO: bound this to only allow i64 input? might mess up for i32::MIN? +pub fn signed_msb_encode(value: N, encoded_byte_size: usize) -> N { + let is_signed = value.is_negative(); + // 0 if unsigned, 1 if signed + let sign_bit = N::from_u8(is_signed as u8); + let value = value.abs(); + let encoded_msb = sign_bit << (encoded_byte_size * 8 - 1); + encoded_msb | value +} + +/// Get the nth percentile, where input percentile must be in range (0.0, 1.0]. +pub fn calculate_percentile_bits(values: &[N], percentile: f32) -> usize { + debug_assert!( + percentile > 0.0 && percentile <= 1.0, + "percentile must be in range (0.0, 1.0]" + ); + + let mut histogram = [0; 32]; + // Fill out histogram + for n in values { + // Map into range [0, 31] + let encoded_bit_width = encode_bit_width(n.bits_used()); + histogram[encoded_bit_width] += 1; + } + + // Then calculate the percentile here + let count = values.len() as f32; + let mut per_len = ((1.0 - percentile) * count) as usize; + for i in (0..32).rev() { + if let Some(a) = per_len.checked_sub(histogram[i]) { + per_len = a; + } else { + return decode_bit_width(i); + } + } + + // If percentile is in correct input range then we should always return above + unreachable!() +} + #[cfg(test)] mod tests { + use super::*; + use crate::{ + error::Result, + reader::decode::{SignedEncoding, UnsignedEncoding}, + }; + use proptest::prelude::*; use std::io::Cursor; - use crate::error::Result; - use crate::reader::decode::util::{read_varint_zigzagged, signed_zigzag_decode}; - #[test] fn test_zigzag_decode() { assert_eq!(0, signed_zigzag_decode(0)); @@ -385,10 +655,143 @@ mod tests { } #[test] - fn test_read_vulong() -> Result<()> { - fn test_assert(serialized: &[u8], expected: u64) -> Result<()> { + fn test_zigzag_encode() { + assert_eq!(0, signed_zigzag_encode(0)); + assert_eq!(1, signed_zigzag_encode(-1)); + assert_eq!(2, signed_zigzag_encode(1)); + assert_eq!(3, signed_zigzag_encode(-2)); + assert_eq!(4, signed_zigzag_encode(2)); + assert_eq!(5, signed_zigzag_encode(-3)); + assert_eq!(6, signed_zigzag_encode(3)); + assert_eq!(7, signed_zigzag_encode(-4)); + assert_eq!(8, signed_zigzag_encode(4)); + assert_eq!(9, signed_zigzag_encode(-5)); + + assert_eq!(-2_i64, signed_zigzag_encode(9_223_372_036_854_775_807)); + assert_eq!(-1_i64, signed_zigzag_encode(-9_223_372_036_854_775_808)); + } + + #[test] + fn roundtrip_zigzag_edge_cases() { + let value = 0_i16; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i16::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + + let value = 0_i32; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i32::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i32::MIN; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + + let value = 0_i64; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i64::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i64::MIN; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + } + + proptest! { + #[test] + fn roundtrip_zigzag_i16(value: i16) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_zigzag_i32(value: i32) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_zigzag_i64(value: i64) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + } + + fn generate_msb_test_value( + seed_value: N, + byte_size: usize, + signed: bool, + ) -> N { + // We mask out to values that can fit within the specified byte_size. + let shift = (N::BYTE_SIZE - byte_size) * 8; + let mask = N::max_value().unsigned_shr(shift as u32); + // And remove the msb since we manually set a value to signed based on the signed parameter. + let mask = mask >> 1; + let value = seed_value & mask; + // This guarantees values that can fit within byte_size when they are msb encoded, both + // signed and unsigned. + if signed { + -value + } else { + value + } + } + + #[test] + fn roundtrip_msb_edge_cases() { + // Testing all cases of max values for byte_size + signed combinations + for byte_size in 1..=2 { + for signed in [true, false] { + let value = generate_msb_test_value(i16::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + + for byte_size in 1..=4 { + for signed in [true, false] { + let value = generate_msb_test_value(i32::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + + for byte_size in 1..=8 { + for signed in [true, false] { + let value = generate_msb_test_value(i64::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + } + + proptest! { + #[test] + fn roundtrip_msb_i16(value: i16, byte_size in 1..=2_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_msb_i32(value: i32, byte_size in 1..=4_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_msb_i64(value: i64, byte_size in 1..=8_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + } + + #[test] + fn test_read_varint() -> Result<()> { + fn test_assert(serialized: &[u8], expected: i64) -> Result<()> { let mut reader = Cursor::new(serialized); - assert_eq!(expected, read_varint_zigzagged::(&mut reader)?); + assert_eq!( + expected, + read_varint_zigzagged::(&mut reader)? + ); Ok(()) } @@ -400,13 +803,9 @@ mod tests { test_assert(&[0xff, 0x7f], 16_383)?; test_assert(&[0x80, 0x80, 0x01], 16_384)?; test_assert(&[0x81, 0x80, 0x01], 16_385)?; - test_assert( - &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], - u64::MAX, - )?; // when too large - let err = read_varint_zigzagged::(&mut Cursor::new(&[ + let err = read_varint_zigzagged::(&mut Cursor::new(&[ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, ])); assert!(err.is_err()); @@ -416,7 +815,8 @@ mod tests { ); // when unexpected end to stream - let err = read_varint_zigzagged::(&mut Cursor::new(&[0x80, 0x80])); + let err = + read_varint_zigzagged::(&mut Cursor::new(&[0x80, 0x80])); assert!(err.is_err()); assert_eq!( "Failed to read, source: failed to fill whole buffer", @@ -425,4 +825,122 @@ mod tests { Ok(()) } + + fn roundtrip_varint(value: N) -> N { + let mut buf = BytesMut::new(); + write_varint_zigzagged::(&mut buf, value); + read_varint_zigzagged::(&mut Cursor::new(&buf)).unwrap() + } + + proptest! { + #[test] + fn roundtrip_varint_i16(value: i16) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i32(value: i32) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i64(value: i64) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i128(value: i128) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_u64(value in 0..=i64::MAX) { + let out = roundtrip_varint::<_, UnsignedEncoding>(value); + prop_assert_eq!(out, value); + } + } + + #[test] + fn roundtrip_varint_edge_cases() { + let value = 0_i16; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i16::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i16::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i32; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i32::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i32::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i64; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i64::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i64::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i128; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i128::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i128::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + } + + /// Easier to generate values then bound them, instead of generating correctly bounded + /// values. In this case, bounds are that no value will exceed the `bit_width` in terms + /// of bit size. + fn mask_to_bit_width(values: &[N], bit_width: usize) -> Vec { + let shift = N::BYTE_SIZE * 8 - bit_width; + let mask = N::max_value().unsigned_shr(shift as u32); + values.iter().map(|&v| v & mask).collect() + } + + fn roundtrip_packed_ints_serde(values: &[N], bit_width: usize) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + write_packed_ints(&mut buf, bit_width, values); + read_ints(&mut out, values.len(), bit_width, &mut Cursor::new(buf))?; + Ok(out) + } + + proptest! { + #[test] + fn roundtrip_packed_ints_serde_i64( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=64_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_packed_ints_serde_i32( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=32_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_packed_ints_serde_i16( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=16_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + } } diff --git a/src/reader/metadata.rs b/src/reader/metadata.rs index 6c747e65..c9219b83 100644 --- a/src/reader/metadata.rs +++ b/src/reader/metadata.rs @@ -46,7 +46,7 @@ use std::io::Read; use bytes::{Bytes, BytesMut}; use prost::Message; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{self, EmptyFileSnafu, OutOfSpecSnafu, Result}; use crate::proto::{self, Footer, Metadata, PostScript}; @@ -88,12 +88,27 @@ impl FileMetadata { .iter() .map(TryFrom::try_from) .collect::>>()?; - let stripes = footer - .stripes - .iter() - .zip(metadata.stripe_stats.iter()) - .map(TryFrom::try_from) - .collect::>>()?; + ensure!( + metadata.stripe_stats.is_empty() || metadata.stripe_stats.len() == footer.stripes.len(), + OutOfSpecSnafu { + msg: "stripe stats length must equal the number of stripes" + } + ); + // TODO: confirm if this is valid + let stripes = if metadata.stripe_stats.is_empty() { + footer + .stripes + .iter() + .map(TryFrom::try_from) + .collect::>>()? + } else { + footer + .stripes + .iter() + .zip(metadata.stripe_stats.iter()) + .map(TryFrom::try_from) + .collect::>>()? + }; let user_custom_metadata = footer .metadata .iter() diff --git a/src/reader/mod.rs b/src/reader/mod.rs index d4f7021e..ade121d5 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -22,7 +22,7 @@ pub mod metadata; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom}; -use bytes::Bytes; +use bytes::{Buf, Bytes}; /// Primary source used for reading required bytes for operations. #[allow(clippy::len_without_is_empty)] @@ -31,6 +31,7 @@ pub trait ChunkReader { /// Get total length of bytes. Useful for parsing the metadata located at /// the end of the file. + // TODO: this is only used for file tail, so replace with load_metadata? fn len(&self) -> u64; /// Get a reader starting at a specific offset. @@ -49,7 +50,6 @@ pub trait ChunkReader { impl ChunkReader for File { type T = BufReader; - // TODO: this is only used for file tail, so replace with load_metadata? fn len(&self) -> u64 { self.metadata().map(|m| m.len()).unwrap_or(0u64) } @@ -65,6 +65,18 @@ impl ChunkReader for File { } } +impl ChunkReader for Bytes { + type T = bytes::buf::Reader; + + fn len(&self) -> u64 { + self.len() as u64 + } + + fn get_read(&self, offset_from_start: u64) -> std::io::Result { + Ok(self.slice(offset_from_start as usize..).reader()) + } +} + #[cfg(feature = "async")] mod async_chunk_reader { use super::*; diff --git a/src/stripe.rs b/src/stripe.rs index 63fd94be..e9557c60 100644 --- a/src/stripe.rs +++ b/src/stripe.rs @@ -104,6 +104,21 @@ impl TryFrom<(&proto::StripeInformation, &proto::StripeStatistics)> for StripeMe } } +impl TryFrom<&proto::StripeInformation> for StripeMetadata { + type Error = error::OrcError; + + fn try_from(value: &proto::StripeInformation) -> Result { + Ok(Self { + column_statistics: vec![], + offset: value.offset(), + index_length: value.index_length(), + data_length: value.data_length(), + footer_length: value.footer_length(), + number_of_rows: value.number_of_rows(), + }) + } +} + #[derive(Debug)] pub struct Stripe { columns: Vec, diff --git a/src/writer/column.rs b/src/writer/column.rs new file mode 100644 index 00000000..40cfa368 --- /dev/null +++ b/src/writer/column.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::marker::PhantomData; + +use arrow::{ + array::{Array, ArrayRef, AsArray}, + datatypes::{ + ArrowPrimitiveType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + ToByteSlice, + }, +}; +use bytes::{Bytes, BytesMut}; + +use crate::{ + error::Result, + reader::decode::{byte_rle::ByteRleWriter, float::Float, rle_v2::RleWriterV2, SignedEncoding}, + writer::StreamType, +}; + +use super::{ColumnEncoding, PresentStreamEncoder, Stream}; + +/// Used to help determine when to finish writing a stripe once a certain +/// size threshold has been reached. +pub trait EstimateMemory { + /// Approximate current memory usage in bytes. + fn estimate_memory_size(&self) -> usize; +} + +/// Encodes a specific column for a stripe. Will encode to an internal memory +/// buffer until it is finished, in which case it returns the stream bytes to +/// be serialized to a writer. +pub trait ColumnStripeEncoder: EstimateMemory { + /// Encode entire provided [`ArrayRef`] to internal buffer. + fn encode_array(&mut self, array: &ArrayRef) -> Result<()>; + + /// Column encoding used for streams. + fn column_encoding(&self) -> ColumnEncoding; + + /// Emit buffered streams to be written to the writer, and reset state + /// in preparation for next stripe. + fn finish(&mut self) -> Vec; +} + +/// Encodes primitive values into an internal buffer, usually with a specialized run length +/// encoding for better compression. +pub trait PrimitiveValueEncoder: EstimateMemory +where + V: Copy, +{ + fn new() -> Self; + + fn write_one(&mut self, value: V); + + fn write_slice(&mut self, values: &[V]) { + for &value in values { + self.write_one(value); + } + } + + /// Take the encoded bytes, replacing it with an empty buffer. + // TODO: Figure out how to retain the allocation instead of handing + // it off each time. + fn take_inner(&mut self) -> Bytes; +} + +// TODO: simplify these generics, probably overcomplicating things here + +/// Encoder for primitive ORC types (e.g. int, float). Uses a specific [`PrimitiveValueEncoder`] to +/// encode the primitive values into internal memory. When finished, outputs a DATA stream and +/// optionally a PRESENT stream. +pub struct PrimitiveStripeEncoder> { + encoder: E, + column_encoding: ColumnEncoding, + /// Lazily initialized once we encounter an [`Array`] with a [`NullBuffer`]. + present: Option, + encoded_count: usize, + _phantom: PhantomData, +} + +impl> PrimitiveStripeEncoder { + // TODO: encode knowledge of the ColumnEncoding as part of the type, instead of requiring it + // to be passed at runtime + pub fn new(column_encoding: ColumnEncoding) -> Self { + Self { + encoder: E::new(), + column_encoding, + present: None, + encoded_count: 0, + _phantom: Default::default(), + } + } +} + +impl> EstimateMemory + for PrimitiveStripeEncoder +{ + fn estimate_memory_size(&self) -> usize { + self.encoder.estimate_memory_size() + + self + .present + .as_ref() + .map(|p| p.estimate_memory_size()) + .unwrap_or(0) + } +} + +impl> ColumnStripeEncoder + for PrimitiveStripeEncoder +{ + fn encode_array(&mut self, array: &ArrayRef) -> Result<()> { + // TODO: return as result instead of panicking here? + let array = array.as_primitive::(); + // Handling case where if encoding across RecordBatch boundaries, arrays + // might introduce a NullBuffer + match (array.nulls(), &mut self.present) { + // Need to copy only the valid values as indicated by null_buffer + (Some(null_buffer), Some(present)) => { + present.extend(null_buffer); + for index in null_buffer.valid_indices() { + let v = array.value(index); + self.encoder.write_one(v); + } + } + (Some(null_buffer), None) => { + // Lazily initiate present buffer and ensure backfill the already encoded values + let mut present = PresentStreamEncoder::new(); + present.extend_present(self.encoded_count); + present.extend(null_buffer); + self.present = Some(present); + for index in null_buffer.valid_indices() { + let v = array.value(index); + self.encoder.write_one(v); + } + } + // Simple direct copy from values buffer, extending present if needed + (None, Some(present)) => { + let values = array.values(); + self.encoder.write_slice(values); + present.extend_present(array.len()); + } + (None, None) => { + let values = array.values(); + self.encoder.write_slice(values); + } + } + self.encoded_count += array.len() - array.null_count(); + Ok(()) + } + + fn column_encoding(&self) -> ColumnEncoding { + self.column_encoding + } + + fn finish(&mut self) -> Vec { + let bytes = self.encoder.take_inner(); + // Return mandatory Data stream and optional Present stream + let data = Stream { + kind: StreamType::Data, + bytes, + }; + self.encoded_count = 0; + match &mut self.present { + Some(present) => { + let bytes = present.finish(); + let present = Stream { + kind: StreamType::Present, + bytes, + }; + vec![data, present] + } + None => vec![data], + } + } +} + +/// No special run encoding for floats/doubles, they are stored as their IEEE 754 floating +/// point bit layout. This encoder simply copies incoming floats/doubles to its internal +/// byte buffer. +pub struct FloatValueEncoder +where + T::Native: Float, +{ + data: BytesMut, + _phantom: PhantomData, +} + +impl EstimateMemory for FloatValueEncoder +where + T::Native: Float, +{ + fn estimate_memory_size(&self) -> usize { + self.data.len() + } +} + +impl PrimitiveValueEncoder for FloatValueEncoder +where + T::Native: Float, +{ + fn new() -> Self { + Self { + data: BytesMut::new(), + _phantom: Default::default(), + } + } + + fn write_one(&mut self, value: T::Native) { + let bytes = value.to_byte_slice(); + self.data.extend_from_slice(bytes); + } + + fn write_slice(&mut self, values: &[T::Native]) { + let bytes = values.to_byte_slice(); + self.data.extend_from_slice(bytes) + } + + fn take_inner(&mut self) -> Bytes { + std::mem::take(&mut self.data).into() + } +} + +pub type FloatStripeEncoder = PrimitiveStripeEncoder>; +pub type DoubleStripeEncoder = PrimitiveStripeEncoder>; +pub type ByteStripeEncoder = PrimitiveStripeEncoder; +pub type Int16StripeEncoder = PrimitiveStripeEncoder>; +pub type Int32StripeEncoder = PrimitiveStripeEncoder>; +pub type Int64StripeEncoder = PrimitiveStripeEncoder>; diff --git a/src/writer/mod.rs b/src/writer/mod.rs new file mode 100644 index 00000000..6f0ec840 --- /dev/null +++ b/src/writer/mod.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; + +use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer}; +use bytes::Bytes; + +use crate::{proto, reader::decode::byte_rle::ByteRleWriter}; + +use self::column::{EstimateMemory, PrimitiveValueEncoder}; + +pub mod column; +pub mod stripe; + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum StreamType { + Present, + Data, + Length, + DictionaryData, + Secondary, +} + +impl From for proto::stream::Kind { + fn from(value: StreamType) -> Self { + match value { + StreamType::Present => proto::stream::Kind::Present, + StreamType::Data => proto::stream::Kind::Data, + StreamType::Length => proto::stream::Kind::Length, + StreamType::DictionaryData => proto::stream::Kind::DictionaryData, + StreamType::Secondary => proto::stream::Kind::Secondary, + } + } +} + +#[derive(Debug, Clone)] +pub struct Stream { + kind: StreamType, + bytes: Bytes, +} + +impl Stream { + pub fn into_parts(self) -> (StreamType, Bytes) { + (self.kind, self.bytes) + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ColumnEncoding { + Direct, + DirectV2, + Dictionary { size: usize }, + DictionaryV2 { size: usize }, +} + +impl From<&ColumnEncoding> for proto::ColumnEncoding { + fn from(value: &ColumnEncoding) -> Self { + match value { + ColumnEncoding::Direct => proto::ColumnEncoding { + kind: Some(proto::column_encoding::Kind::Direct.into()), + dictionary_size: None, + bloom_encoding: None, + }, + ColumnEncoding::DirectV2 => proto::ColumnEncoding { + kind: Some(proto::column_encoding::Kind::DirectV2.into()), + dictionary_size: None, + bloom_encoding: None, + }, + ColumnEncoding::Dictionary { size } => proto::ColumnEncoding { + kind: Some(proto::column_encoding::Kind::Dictionary.into()), + dictionary_size: Some(*size as u32), + bloom_encoding: None, + }, + ColumnEncoding::DictionaryV2 { size } => proto::ColumnEncoding { + kind: Some(proto::column_encoding::Kind::DictionaryV2.into()), + dictionary_size: Some(*size as u32), + bloom_encoding: None, + }, + } + } +} + +/// ORC encodes validity starting from MSB, whilst Arrow encodes it +/// from LSB. +struct PresentStreamEncoder { + builder: BooleanBufferBuilder, +} + +impl EstimateMemory for PresentStreamEncoder { + fn estimate_memory_size(&self) -> usize { + self.builder.len() / 8 + } +} + +impl PresentStreamEncoder { + pub fn new() -> Self { + Self { + builder: BooleanBufferBuilder::new(8), + } + } + + pub fn extend(&mut self, null_buffer: &NullBuffer) { + let bb = null_buffer.inner(); + self.builder.append_buffer(bb); + } + + /// Extend with n true bits. + pub fn extend_present(&mut self, n: usize) { + self.builder.append_n(n, true); + } + + /// Produce ORC present stream bytes and reset internal builder. + pub fn finish(&mut self) -> Bytes { + let bb = self.builder.finish(); + // We use BooleanBufferBuilder so offset is 0 + let bytes = bb.values(); + // Reverse bits as ORC stores from MSB + let bytes = bytes.iter().map(|b| b.reverse_bits()).collect::>(); + // Bytes are then further encoded via Byte RLE + // TODO: refactor; this is a hack to ensure writing nulls works for now + // figure a better way than throwing away this writer everytime + let mut encoder = ByteRleWriter::new(); + for &b in bytes.as_slice() { + encoder.write_one(b as i8); + } + encoder.take_inner() + } +} diff --git a/src/writer/stripe.rs b/src/writer/stripe.rs new file mode 100644 index 00000000..edc58806 --- /dev/null +++ b/src/writer/stripe.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Write; + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType as ArrowDataType, FieldRef, SchemaRef}; +use prost::Message; +use snafu::ResultExt; + +use crate::error::{IoSnafu, Result}; +use crate::proto; + +use super::column::{ + ByteStripeEncoder, ColumnStripeEncoder, DoubleStripeEncoder, EstimateMemory, + FloatStripeEncoder, Int16StripeEncoder, Int32StripeEncoder, Int64StripeEncoder, +}; +use super::{ColumnEncoding, StreamType}; + +#[derive(Copy, Clone, Eq, Debug, PartialEq)] +pub struct StripeInformation { + pub start_offset: u64, + pub index_length: u64, + pub data_length: u64, + pub footer_length: u64, + pub row_count: usize, +} + +impl StripeInformation { + pub fn total_byte_size(&self) -> u64 { + self.index_length + self.data_length + self.footer_length + } +} + +impl From<&StripeInformation> for proto::StripeInformation { + fn from(value: &StripeInformation) -> Self { + proto::StripeInformation { + offset: Some(value.start_offset), + index_length: Some(value.index_length), + data_length: Some(value.data_length), + footer_length: Some(value.footer_length), + number_of_rows: Some(value.row_count as u64), + encrypt_stripe_id: None, + encrypted_local_keys: vec![], + } + } +} + +/// Encode a stripe. Will encode columns into an in-memory buffer before flushing +/// entire stripe to the underlying writer. +pub struct StripeWriter { + writer: W, + /// Flattened columns, in order of their column ID. + columns: Vec>, + pub row_count: usize, +} + +impl EstimateMemory for StripeWriter { + /// Used to estimate when stripe size is over threshold and should be flushed + /// to the writer and a new stripe started. + fn estimate_memory_size(&self) -> usize { + self.columns.iter().map(|c| c.estimate_memory_size()).sum() + } +} + +impl StripeWriter { + pub fn new(writer: W, schema: &SchemaRef) -> Self { + let columns = schema.fields().iter().map(create_encoder).collect(); + Self { + writer, + columns, + row_count: 0, + } + } + + /// Attempt to encode entire [`RecordBatch`]. Relies on caller slicing the batch + /// to required batch size. + pub fn encode_batch(&mut self, batch: &RecordBatch) -> Result<()> { + // TODO: consider how to handle nested types (including parent nullability) + for (array, encoder) in batch.columns().iter().zip(self.columns.iter_mut()) { + encoder.encode_array(array)?; + } + self.row_count += batch.num_rows(); + Ok(()) + } + + /// Flush streams to the writer, and write the stripe footer to finish + /// the stripe. After this, the [`StripeWriter`] will be reset and ready + /// to write a fresh new stripe. + /// + /// `start_offset` is used to manually keep track of position in the writer (instead + /// of relying on Seek). + pub fn finish_stripe(&mut self, start_offset: u64) -> Result { + // Order of column_encodings needs to match final type vec order. + // (see arrow_writer::serialize_schema()) + // Direct encoding to represent root struct + let mut column_encodings = vec![ColumnEncoding::Direct]; + let child_column_encodings = self + .columns + .iter() + .map(|c| c.column_encoding()) + .collect::>(); + column_encodings.extend(child_column_encodings); + let column_encodings = column_encodings.iter().map(From::from).collect(); + + // Root type won't have any streams + let mut written_streams = vec![]; + let mut data_length = 0; + for (index, c) in self.columns.iter_mut().enumerate() { + // Offset by 1 to account for root of 0 + let column = index + 1; + let streams = c.finish(); + // Flush the streams to the writer + for s in streams { + let (kind, bytes) = s.into_parts(); + let length = bytes.len(); + self.writer.write_all(&bytes).context(IoSnafu)?; + data_length += length as u64; + written_streams.push(WrittenStream { + kind, + column, + length, + }); + } + } + let streams = written_streams.iter().map(From::from).collect(); + let stripe_footer = proto::StripeFooter { + streams, + columns: column_encodings, + writer_timezone: None, + encryption: vec![], + }; + + let footer_bytes = stripe_footer.encode_to_vec(); + let footer_length = footer_bytes.len() as u64; + let row_count = self.row_count; + self.writer.write_all(&footer_bytes).context(IoSnafu)?; + + // Reset state for next stripe + self.row_count = 0; + + Ok(StripeInformation { + start_offset, + index_length: 0, + data_length, + footer_length, + row_count, + }) + } + + /// When finished writing all stripes, return the inner writer. + pub fn finish(self) -> W { + self.writer + } +} + +fn create_encoder(field: &FieldRef) -> Box { + match field.data_type() { + ArrowDataType::Float32 => Box::new(FloatStripeEncoder::new(ColumnEncoding::Direct)), + ArrowDataType::Float64 => Box::new(DoubleStripeEncoder::new(ColumnEncoding::Direct)), + ArrowDataType::Int8 => Box::new(ByteStripeEncoder::new(ColumnEncoding::Direct)), + ArrowDataType::Int16 => Box::new(Int16StripeEncoder::new(ColumnEncoding::DirectV2)), + ArrowDataType::Int32 => Box::new(Int32StripeEncoder::new(ColumnEncoding::DirectV2)), + ArrowDataType::Int64 => Box::new(Int64StripeEncoder::new(ColumnEncoding::DirectV2)), + // TODO: support more datatypes + _ => unimplemented!("unsupported datatype"), + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct WrittenStream { + kind: StreamType, + column: usize, + length: usize, +} + +impl From<&WrittenStream> for proto::Stream { + fn from(value: &WrittenStream) -> Self { + let kind = proto::stream::Kind::from(value.kind); + proto::Stream { + kind: Some(kind.into()), + column: Some(value.column as u32), + length: Some(value.length as u64), + } + } +}