Skip to content

Commit

Permalink
Convert some panics that happen on invalid parquet files to error res…
Browse files Browse the repository at this point in the history
…ults (apache#6738)

* Reduce  panics

* t pushmove integer logical type from format.rs to schema type.rs

* remove some changes as per reviews

* use wrapping_shl

* fix typo in error message

* return error for invalid decimal length

---------

Co-authored-by: jp0317 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
3 people authored and CurtHagenlocher committed Jan 13, 2025
1 parent 30f46c7 commit 618d81c
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 28 deletions.
7 changes: 7 additions & 0 deletions parquet/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Common Parquet errors and macros.
use core::num::TryFromIntError;
use std::error::Error;
use std::{cell, io, result, str};

Expand Down Expand Up @@ -81,6 +82,12 @@ impl Error for ParquetError {
}
}

impl From<TryFromIntError> for ParquetError {
fn from(e: TryFromIntError) -> ParquetError {
ParquetError::General(format!("Integer overflow: {e}"))
}
}

impl From<io::Error> for ParquetError {
fn from(e: io::Error) -> ParquetError {
ParquetError::External(Box::new(e))
Expand Down
26 changes: 13 additions & 13 deletions parquet/src/file/metadata/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ impl ParquetMetaDataReader {
for rg in t_file_metadata.row_groups {
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?);
}
let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);
let column_orders =
Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?;

let file_metadata = FileMetaData::new(
t_file_metadata.version,
Expand All @@ -645,15 +646,13 @@ impl ParquetMetaDataReader {
fn parse_column_orders(
t_column_orders: Option<Vec<TColumnOrder>>,
schema_descr: &SchemaDescriptor,
) -> Option<Vec<ColumnOrder>> {
) -> Result<Option<Vec<ColumnOrder>>> {
match t_column_orders {
Some(orders) => {
// Should always be the case
assert_eq!(
orders.len(),
schema_descr.num_columns(),
"Column order length mismatch"
);
if orders.len() != schema_descr.num_columns() {
return Err(general_err!("Column order length mismatch"));
};
let mut res = Vec::new();
for (i, column) in schema_descr.columns().iter().enumerate() {
match orders[i] {
Expand All @@ -667,9 +666,9 @@ impl ParquetMetaDataReader {
}
}
}
Some(res)
Ok(Some(res))
}
None => None,
None => Ok(None),
}
}
}
Expand Down Expand Up @@ -741,7 +740,7 @@ mod tests {
]);

assert_eq!(
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr),
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(),
Some(vec![
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED),
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED)
Expand All @@ -750,20 +749,21 @@ mod tests {

// Test when no column orders are defined.
assert_eq!(
ParquetMetaDataReader::parse_column_orders(None, &schema_descr),
ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(),
None
);
}

#[test]
#[should_panic(expected = "Column order length mismatch")]
fn test_metadata_column_orders_len_mismatch() {
let schema = SchemaType::group_type_builder("schema").build().unwrap();
let schema_descr = SchemaDescriptor::new(Arc::new(schema));

let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]);

ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
assert!(res.is_err());
assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch"));
}

#[test]
Expand Down
53 changes: 46 additions & 7 deletions parquet/src/file/serialized_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ pub(crate) fn decode_page(
let is_sorted = dict_header.is_sorted.unwrap_or(false);
Page::DictionaryPage {
buf: buffer,
num_values: dict_header.num_values as u32,
num_values: dict_header.num_values.try_into()?,
encoding: Encoding::try_from(dict_header.encoding)?,
is_sorted,
}
Expand All @@ -446,7 +446,7 @@ pub(crate) fn decode_page(
.ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?;
Page::DataPage {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
def_level_encoding: Encoding::try_from(header.definition_level_encoding)?,
rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?,
Expand All @@ -460,12 +460,12 @@ pub(crate) fn decode_page(
let is_compressed = header.is_compressed.unwrap_or(true);
Page::DataPageV2 {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
num_nulls: header.num_nulls as u32,
num_rows: header.num_rows as u32,
def_levels_byte_len: header.definition_levels_byte_length as u32,
rep_levels_byte_len: header.repetition_levels_byte_length as u32,
num_nulls: header.num_nulls.try_into()?,
num_rows: header.num_rows.try_into()?,
def_levels_byte_len: header.definition_levels_byte_length.try_into()?,
rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?,
is_compressed,
statistics: statistics::from_thrift(physical_type, header.statistics)?,
}
Expand Down Expand Up @@ -578,6 +578,27 @@ impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
}
}

fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> {
if header_len > remaining_bytes {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

fn verify_page_size(
compressed_size: i32,
uncompressed_size: i32,
remaining_bytes: usize,
) -> Result<()> {
// The page's compressed size should not exceed the remaining bytes that are
// available to read. The page's uncompressed size is the expected size
// after decompression, which can never be negative.
if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
fn get_next_page(&mut self) -> Result<Option<Page>> {
loop {
Expand All @@ -596,10 +617,16 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
*header
} else {
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining)?;
*offset += header_len;
*remaining -= header_len;
header
};
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining,
)?;
let data_len = header.compressed_page_size as usize;
*offset += data_len;
*remaining -= data_len;
Expand Down Expand Up @@ -683,6 +710,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
*offset += header_len;
*remaining_bytes -= header_len;
let page_meta = if let Ok(page_meta) = (&header).try_into() {
Expand Down Expand Up @@ -733,12 +761,23 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
next_page_header,
} => {
if let Some(buffered_header) = next_page_header.take() {
verify_page_size(
buffered_header.compressed_page_size,
buffered_header.uncompressed_page_size,
*remaining_bytes,
)?;
// The next page header has already been peeked, so just advance the offset
*offset += buffered_header.compressed_page_size as usize;
*remaining_bytes -= buffered_header.compressed_page_size as usize;
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining_bytes,
)?;
let data_page_size = header.compressed_page_size as usize;
*offset += header_len + data_page_size;
*remaining_bytes -= header_len + data_page_size;
Expand Down
26 changes: 26 additions & 0 deletions parquet/src/file/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ pub fn from_thrift(
stats.max_value
};

fn check_len(min: &Option<Vec<u8>>, max: &Option<Vec<u8>>, len: usize) -> Result<()> {
if let Some(min) = min {
if min.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse min statistic".to_string(),
));
}
}
if let Some(max) = max {
if max.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse max statistic".to_string(),
));
}
}
Ok(())
}

match physical_type {
Type::BOOLEAN => check_len(&min, &max, 1),
Type::INT32 | Type::FLOAT => check_len(&min, &max, 4),
Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8),
Type::INT96 => check_len(&min, &max, 12),
_ => Ok(()),
}?;

// Values are encoded using PLAIN encoding definition, except that
// variable-length byte arrays do not include a length prefix.
//
Expand Down
25 changes: 24 additions & 1 deletion parquet/src/schema/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,11 @@ impl<'a> PrimitiveTypeBuilder<'a> {
}
}
PhysicalType::FIXED_LEN_BYTE_ARRAY => {
let max_precision = (2f64.powi(8 * self.length - 1) - 1f64).log10().floor() as i32;
let length = self
.length
.checked_mul(8)
.ok_or(general_err!("Invalid length {} for Decimal", self.length))?;
let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32;

if self.precision > max_precision {
return Err(general_err!(
Expand Down Expand Up @@ -1171,9 +1175,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result<TypePtr> {
));
}

if !schema_nodes[0].is_group() {
return Err(general_err!("Expected root node to be a group type"));
}

Ok(schema_nodes.remove(0))
}

/// Checks if the logical type is valid.
fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 {
return Err(general_err!(
"Bit width must be 8, 16, 32, or 64 for Integer logical type"
));
}
}
Ok(())
}

/// Constructs a new Type from the `elements`, starting at index `index`.
/// The first result is the starting index for the next Type after this one. If it is
/// equal to `elements.len()`, then this Type is the last one.
Expand All @@ -1198,6 +1218,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize
.logical_type
.as_ref()
.map(|value| LogicalType::from(value.clone()));

check_logical_type(&logical_type)?;

let field_id = elements[index].field_id;
match elements[index].num_children {
// From parquet-format:
Expand Down
35 changes: 29 additions & 6 deletions parquet/src/thrift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> {
let mut shift = 0;
loop {
let byte = self.read_byte()?;
in_progress |= ((byte & 0x7F) as u64) << shift;
in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
shift += 7;
if byte & 0x80 == 0 {
return Ok(in_progress);
Expand Down Expand Up @@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> {
}
}

macro_rules! thrift_unimplemented {
() => {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::NotImplemented,
message: "not implemented".to_string(),
}))
};
}

impl TInputProtocol for TCompactSliceInputProtocol<'_> {
fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
unimplemented!()
}

fn read_message_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_struct_begin(&mut self) -> thrift::Result<Option<TStructIdentifier>> {
Expand Down Expand Up @@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
),
_ => {
if field_delta != 0 {
self.last_read_field_id += field_delta as i16;
self.last_read_field_id = self
.last_read_field_id
.checked_add(field_delta as i16)
.map_or_else(
|| {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::InvalidData,
message: format!(
"cannot add {} to {}",
field_delta, self.last_read_field_id
),
}))
},
Ok,
)?;
} else {
self.last_read_field_id = self.read_i16()?;
};
Expand Down Expand Up @@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
}

fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_set_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_end(&mut self) -> thrift::Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion parquet/tests/arrow_reader/bad_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() {
let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err();
assert_eq!(
err.to_string(),
"External: Parquet argument error: EOF: eof decoding byte array"
"External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted"
);
}

Expand Down

0 comments on commit 618d81c

Please sign in to comment.