From 618d81ce1f3bd7591ae0c40be19065e13d3d68d6 Mon Sep 17 00:00:00 2001 From: Jinpeng Date: Mon, 6 Jan 2025 16:09:05 -0500 Subject: [PATCH] Convert some panics that happen on invalid parquet files to error results (#6738) * Reduce panics * t pushmove integer logical type from format.rs to schema type.rs * remove some changes as per reviews * use wrapping_shl * fix typo in error message * return error for invalid decimal length --------- Co-authored-by: jp0317 Co-authored-by: Andrew Lamb --- parquet/src/errors.rs | 7 ++++ parquet/src/file/metadata/reader.rs | 26 ++++++------- parquet/src/file/serialized_reader.rs | 53 ++++++++++++++++++++++---- parquet/src/file/statistics.rs | 26 +++++++++++++ parquet/src/schema/types.rs | 25 +++++++++++- parquet/src/thrift.rs | 35 ++++++++++++++--- parquet/tests/arrow_reader/bad_data.rs | 2 +- 7 files changed, 146 insertions(+), 28 deletions(-) diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index 8dc97f4ca2e6..d749287bba62 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -17,6 +17,7 @@ //! Common Parquet errors and macros. +use core::num::TryFromIntError; use std::error::Error; use std::{cell, io, result, str}; @@ -81,6 +82,12 @@ impl Error for ParquetError { } } +impl From for ParquetError { + fn from(e: TryFromIntError) -> ParquetError { + ParquetError::General(format!("Integer overflow: {e}")) + } +} + impl From for ParquetError { fn from(e: io::Error) -> ParquetError { ParquetError::External(Box::new(e)) diff --git a/parquet/src/file/metadata/reader.rs b/parquet/src/file/metadata/reader.rs index ec2cd1094d3a..c6715a33b5ae 100644 --- a/parquet/src/file/metadata/reader.rs +++ b/parquet/src/file/metadata/reader.rs @@ -627,7 +627,8 @@ impl ParquetMetaDataReader { for rg in t_file_metadata.row_groups { row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?); } - let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr); + let column_orders = + Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?; let file_metadata = FileMetaData::new( t_file_metadata.version, @@ -645,15 +646,13 @@ impl ParquetMetaDataReader { fn parse_column_orders( t_column_orders: Option>, schema_descr: &SchemaDescriptor, - ) -> Option> { + ) -> Result>> { match t_column_orders { Some(orders) => { // Should always be the case - assert_eq!( - orders.len(), - schema_descr.num_columns(), - "Column order length mismatch" - ); + if orders.len() != schema_descr.num_columns() { + return Err(general_err!("Column order length mismatch")); + }; let mut res = Vec::new(); for (i, column) in schema_descr.columns().iter().enumerate() { match orders[i] { @@ -667,9 +666,9 @@ impl ParquetMetaDataReader { } } } - Some(res) + Ok(Some(res)) } - None => None, + None => Ok(None), } } } @@ -741,7 +740,7 @@ mod tests { ]); assert_eq!( - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr), + ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(), Some(vec![ ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED), ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED) @@ -750,20 +749,21 @@ mod tests { // Test when no column orders are defined. assert_eq!( - ParquetMetaDataReader::parse_column_orders(None, &schema_descr), + ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(), None ); } #[test] - #[should_panic(expected = "Column order length mismatch")] fn test_metadata_column_orders_len_mismatch() { let schema = SchemaType::group_type_builder("schema").build().unwrap(); let schema_descr = SchemaDescriptor::new(Arc::new(schema)); let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]); - ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr); + assert!(res.is_err()); + assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch")); } #[test] diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 06f3cf9fb23f..a942481f7e4d 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -435,7 +435,7 @@ pub(crate) fn decode_page( let is_sorted = dict_header.is_sorted.unwrap_or(false); Page::DictionaryPage { buf: buffer, - num_values: dict_header.num_values as u32, + num_values: dict_header.num_values.try_into()?, encoding: Encoding::try_from(dict_header.encoding)?, is_sorted, } @@ -446,7 +446,7 @@ pub(crate) fn decode_page( .ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?; Page::DataPage { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, def_level_encoding: Encoding::try_from(header.definition_level_encoding)?, rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?, @@ -460,12 +460,12 @@ pub(crate) fn decode_page( let is_compressed = header.is_compressed.unwrap_or(true); Page::DataPageV2 { buf: buffer, - num_values: header.num_values as u32, + num_values: header.num_values.try_into()?, encoding: Encoding::try_from(header.encoding)?, - num_nulls: header.num_nulls as u32, - num_rows: header.num_rows as u32, - def_levels_byte_len: header.definition_levels_byte_length as u32, - rep_levels_byte_len: header.repetition_levels_byte_length as u32, + num_nulls: header.num_nulls.try_into()?, + num_rows: header.num_rows.try_into()?, + def_levels_byte_len: header.definition_levels_byte_length.try_into()?, + rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?, is_compressed, statistics: statistics::from_thrift(physical_type, header.statistics)?, } @@ -578,6 +578,27 @@ impl Iterator for SerializedPageReader { } } +fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> { + if header_len > remaining_bytes { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + +fn verify_page_size( + compressed_size: i32, + uncompressed_size: i32, + remaining_bytes: usize, +) -> Result<()> { + // The page's compressed size should not exceed the remaining bytes that are + // available to read. The page's uncompressed size is the expected size + // after decompression, which can never be negative. + if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 { + return Err(eof_err!("Invalid page header")); + } + Ok(()) +} + impl PageReader for SerializedPageReader { fn get_next_page(&mut self) -> Result> { loop { @@ -596,10 +617,16 @@ impl PageReader for SerializedPageReader { *header } else { let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining)?; *offset += header_len; *remaining -= header_len; header }; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining, + )?; let data_len = header.compressed_page_size as usize; *offset += data_len; *remaining -= data_len; @@ -683,6 +710,7 @@ impl PageReader for SerializedPageReader { } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; *offset += header_len; *remaining_bytes -= header_len; let page_meta = if let Ok(page_meta) = (&header).try_into() { @@ -733,12 +761,23 @@ impl PageReader for SerializedPageReader { next_page_header, } => { if let Some(buffered_header) = next_page_header.take() { + verify_page_size( + buffered_header.compressed_page_size, + buffered_header.uncompressed_page_size, + *remaining_bytes, + )?; // The next page header has already been peeked, so just advance the offset *offset += buffered_header.compressed_page_size as usize; *remaining_bytes -= buffered_header.compressed_page_size as usize; } else { let mut read = self.reader.get_read(*offset as u64)?; let (header_len, header) = read_page_header_len(&mut read)?; + verify_page_header_len(header_len, *remaining_bytes)?; + verify_page_size( + header.compressed_page_size, + header.uncompressed_page_size, + *remaining_bytes, + )?; let data_page_size = header.compressed_page_size as usize; *offset += header_len + data_page_size; *remaining_bytes -= header_len + data_page_size; diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index 2e05b83369cf..b7522a76f0fc 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -157,6 +157,32 @@ pub fn from_thrift( stats.max_value }; + fn check_len(min: &Option>, max: &Option>, len: usize) -> Result<()> { + if let Some(min) = min { + if min.len() < len { + return Err(ParquetError::General( + "Insufficient bytes to parse min statistic".to_string(), + )); + } + } + if let Some(max) = max { + if max.len() < len { + return Err(ParquetError::General( + "Insufficient bytes to parse max statistic".to_string(), + )); + } + } + Ok(()) + } + + match physical_type { + Type::BOOLEAN => check_len(&min, &max, 1), + Type::INT32 | Type::FLOAT => check_len(&min, &max, 4), + Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8), + Type::INT96 => check_len(&min, &max, 12), + _ => Ok(()), + }?; + // Values are encoded using PLAIN encoding definition, except that // variable-length byte arrays do not include a length prefix. // diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index d168e46de047..d9e9b22e809f 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -556,7 +556,11 @@ impl<'a> PrimitiveTypeBuilder<'a> { } } PhysicalType::FIXED_LEN_BYTE_ARRAY => { - let max_precision = (2f64.powi(8 * self.length - 1) - 1f64).log10().floor() as i32; + let length = self + .length + .checked_mul(8) + .ok_or(general_err!("Invalid length {} for Decimal", self.length))?; + let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32; if self.precision > max_precision { return Err(general_err!( @@ -1171,9 +1175,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result { )); } + if !schema_nodes[0].is_group() { + return Err(general_err!("Expected root node to be a group type")); + } + Ok(schema_nodes.remove(0)) } +/// Checks if the logical type is valid. +fn check_logical_type(logical_type: &Option) -> Result<()> { + if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type { + if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 { + return Err(general_err!( + "Bit width must be 8, 16, 32, or 64 for Integer logical type" + )); + } + } + Ok(()) +} + /// Constructs a new Type from the `elements`, starting at index `index`. /// The first result is the starting index for the next Type after this one. If it is /// equal to `elements.len()`, then this Type is the last one. @@ -1198,6 +1218,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize .logical_type .as_ref() .map(|value| LogicalType::from(value.clone())); + + check_logical_type(&logical_type)?; + let field_id = elements[index].field_id; match elements[index].num_children { // From parquet-format: diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs index ceb6b1c29fe8..b216fec6f3e7 100644 --- a/parquet/src/thrift.rs +++ b/parquet/src/thrift.rs @@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> { let mut shift = 0; loop { let byte = self.read_byte()?; - in_progress |= ((byte & 0x7F) as u64) << shift; + in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift); shift += 7; if byte & 0x80 == 0 { return Ok(in_progress); @@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> { } } +macro_rules! thrift_unimplemented { + () => { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::NotImplemented, + message: "not implemented".to_string(), + })) + }; +} + impl TInputProtocol for TCompactSliceInputProtocol<'_> { fn read_message_begin(&mut self) -> thrift::Result { unimplemented!() } fn read_message_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_struct_begin(&mut self) -> thrift::Result> { @@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { ), _ => { if field_delta != 0 { - self.last_read_field_id += field_delta as i16; + self.last_read_field_id = self + .last_read_field_id + .checked_add(field_delta as i16) + .map_or_else( + || { + Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!( + "cannot add {} to {}", + field_delta, self.last_read_field_id + ), + })) + }, + Ok, + )?; } else { self.last_read_field_id = self.read_i16()?; }; @@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> { } fn read_set_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_set_end(&mut self) -> thrift::Result<()> { - unimplemented!() + thrift_unimplemented!() } fn read_map_begin(&mut self) -> thrift::Result { - unimplemented!() + thrift_unimplemented!() } fn read_map_end(&mut self) -> thrift::Result<()> { diff --git a/parquet/tests/arrow_reader/bad_data.rs b/parquet/tests/arrow_reader/bad_data.rs index 74342031432a..cfd61e82d32b 100644 --- a/parquet/tests/arrow_reader/bad_data.rs +++ b/parquet/tests/arrow_reader/bad_data.rs @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() { let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err(); assert_eq!( err.to_string(), - "External: Parquet argument error: EOF: eof decoding byte array" + "External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted" ); }