From 36d56a9b36bef27f6a1fc67a8cb34ded459f36ca Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 04:18:32 -0600 Subject: [PATCH 1/8] Added codec support + tests for: 1. Namespaces 2. Enums 3. Maps 4. Decimals --- arrow-avro/src/codec.rs | 777 ++++++++++++++++++++++++++++---- arrow-avro/src/reader/record.rs | 3 + 2 files changed, 692 insertions(+), 88 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 25a790fa476a..aab38a45e444 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -17,11 +17,14 @@ use crate::schema::{ Attributes, ComplexType, PrimitiveType, Schema, TypeName, Array, Fixed, Map, Record, - Field as AvroFieldDef -}; -use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + Field as AvroFieldDef, + Fixed as AvroFixed, + Enum as AvroEnum, + Map as AvroMap }; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, + SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE}; use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; @@ -49,7 +52,6 @@ pub struct AvroDataType { } impl AvroDataType { - /// Create a new AvroDataType with the given parts. /// This helps you construct it from outside `codec.rs` without exposing internals. pub fn new( @@ -64,6 +66,7 @@ impl AvroDataType { } } + /// Create a new AvroDataType from a `Codec`, with default (no) nullability and empty metadata. pub fn from_codec(codec: Codec) -> Self { Self::new(codec, None, Default::default()) } @@ -74,30 +77,57 @@ impl AvroDataType { Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) } + /// Return a reference to the inner `Codec`. pub fn codec(&self) -> &Codec { &self.codec } + /// Return the nullability for this Avro type, if any. pub fn nullability(&self) -> Option { self.nullability } /// Convert this `AvroDataType`, which encapsulates an Arrow data type (`codec`) - /// plus nullability, back into an Avro `Schema<'a>`. + /// plus nullability and metadata, back into an Avro `Schema<'a>`. + /// + /// - If `metadata["namespace"]` is present, we'll store it in the resulting schema for named types + /// (record, enum, fixed). pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { let inner_schema = self.codec.to_avro_schema(name); - // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. - // Otherwise, return the schema as-is. if let Some(_) = self.nullability { Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - inner_schema, + maybe_add_namespace(inner_schema, self), ]) } else { - inner_schema + maybe_add_namespace(inner_schema, self) + } + } +} + +/// If this is a named complex type (Record, Enum, Fixed), attach `namespace` +/// from `dt.metadata["namespace"]` if present. Otherwise, return as-is. +fn maybe_add_namespace<'a>(mut schema: Schema<'a>, dt: &'a AvroDataType) -> Schema<'a> { + let ns = dt.metadata.get("namespace"); + if let Some(ns_str) = ns { + if let Schema::Complex(ref mut c) = schema { + match c { + ComplexType::Record(r) => { + r.namespace = Some(ns_str); + } + ComplexType::Enum(e) => { + e.namespace = Some(ns_str); + } + ComplexType::Fixed(f) => { + f.namespace = Some(ns_str); + } + // Arrays and Maps do not have a namespace field, so do nothing + _ => {} + } } } + schema } /// A named [`AvroDataType`] @@ -118,6 +148,7 @@ impl AvroField { &self.data_type } + /// Returns the name of this field pub fn name(&self) -> &str { &self.name } @@ -167,9 +198,14 @@ pub enum Codec { List(Arc), Struct(Arc<[AvroField]>), Interval, + /// In Arrow, use Dictionary(Int32, Utf8) for Enum. + Enum(Vec), + Map(Arc), + Decimal(usize, Option, Option), } impl Codec { + /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { Self::Null => DataType::Null, @@ -195,11 +231,50 @@ impl Codec { DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) } Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + Self::Enum(_symbols) => { + // Produce a Dictionary type with index = Int32, value = Utf8 + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ) + } + Self::Map(values) => { + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + values.field_with_name("value"), + ]) + ), + false, + )), + false, + ) + } + Self::Decimal(precision, scale, size) => match size { + Some(s) if *s > 16 && *s <= 32 => { + DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + }, + Some(s) if *s <= 16 => { + DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + }, + _ => { + // Infer based on precision when size is None + if *precision <= DECIMAL128_MAX_PRECISION as usize + && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize + { + DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + } else { + DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + } + } + }, } } /// Convert this `Codec` variant to an Avro `Schema<'a>`. - /// More work needed to handle `decimal`, `enum`, `map`, etc. pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { match self { Codec::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -210,7 +285,6 @@ impl Codec { Codec::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), Codec::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), Codec::Utf8 => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), - // date32 => Avro int + logicalType=date Codec::Date32 => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Int), @@ -219,7 +293,6 @@ impl Codec { additional: Default::default(), }, }), - // time-millis => Avro int with logicalType=time-millis Codec::TimeMillis => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Int), @@ -228,7 +301,6 @@ impl Codec { additional: Default::default(), }, }), - // time-micros => Avro long with logicalType=time-micros Codec::TimeMicros => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), @@ -237,52 +309,53 @@ impl Codec { additional: Default::default(), }, }), - - // timestamp-millis => Avro long with logicalType=timestamp-millis + // timestamp-millis => Avro long with logicalType=timestamp-millis or local-timestamp-millis Codec::TimestampMillis(is_utc) => { - // TODO `is_utc` or store it in metadata + let lt = if *is_utc { + Some("timestamp-millis") + } else { + Some("local-timestamp-millis") + }; Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: Some("timestamp-millis"), + logical_type: lt, additional: Default::default(), }, }) } - - // timestamp-micros => Avro long with logicalType=timestamp-micros + // timestamp-micros => Avro long with logicalType=timestamp-micros or local-timestamp-micros Codec::TimestampMicros(is_utc) => { + let lt = if *is_utc { + Some("timestamp-micros") + } else { + Some("local-timestamp-micros") + }; Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: Some("timestamp-micros"), - additional: Default::default(), - }, - }) - } - - Codec::Interval => { - Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Bytes), - attributes: Attributes { - logical_type: Some("duration"), + logical_type: lt, additional: Default::default(), }, }) } - + Codec::Interval => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("duration"), + additional: Default::default(), + }, + }), Codec::Fixed(size) => { - // Convert Arrow FixedSizeBinary => Avro fixed with a known name & size - // TODO namespace/aliases. + // Convert Arrow FixedSizeBinary => Avro fixed with name & size Schema::Complex(ComplexType::Fixed(Fixed { name, - namespace: None, // TODO namespace implementation - aliases: vec![], // TODO alias implementation + namespace: None, + aliases: vec![], size: *size as usize, attributes: Attributes::default(), })) } - Codec::List(item_type) => { // Avro array with "items" recursively derived let items_schema = item_type.to_avro_schema("items"); @@ -291,32 +364,80 @@ impl Codec { attributes: Attributes::default(), })) } - Codec::Struct(fields) => { // Avro record with nested fields let record_fields = fields .iter() .map(|f| { - // For each `AvroField`, get its Avro schema let child_schema = f.data_type().to_avro_schema(f.name()); AvroFieldDef { - name: f.name(), // Avro field name + name: f.name(), doc: None, r#type: child_schema, default: None, } }) .collect(); - Schema::Complex(ComplexType::Record(Record { name, - namespace: None, // TODO follow up for namespace implementation + namespace: None, doc: None, - aliases: vec![], // TODO follow up for alias implementation + aliases: vec![], fields: record_fields, attributes: Attributes::default(), })) } + Codec::Enum(symbols) => { + // If there's a namespace in metadata, we will apply it later in maybe_add_namespace. + Schema::Complex(ComplexType::Enum(AvroEnum { + name, + namespace: None, + doc: None, + aliases: vec![], + symbols: symbols.iter().map(|s| s.as_str()).collect(), + default: None, + attributes: Attributes::default(), + })) + } + Codec::Map(values) => { + let val_schema = values.to_avro_schema("values"); + Schema::Complex(ComplexType::Map(AvroMap { + values: Box::new(val_schema), + attributes: Attributes::default(), + })) + } + Codec::Decimal(precision, scale, size) => { + // If size is Some(n), produce Avro "fixed", else "bytes". + if let Some(n) = size { + // fixed with logicalType=decimal, plus precision/scale + Schema::Complex(ComplexType::Fixed(AvroFixed { + name, + namespace: None, + aliases: vec![], + size: *n, + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ("size", serde_json::json!(*n)), + ]), + }, + })) + } else { + // "type":"bytes", "logicalType":"decimal" + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ]), + }, + }) + } + } } } } @@ -365,8 +486,6 @@ impl<'a> Resolver<'a> { /// /// `name`: is name used to refer to `schema` in its parent /// `namespace`: an optional qualifier used as part of a type hierarchy -/// -/// See [`Resolver`] for more information fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, @@ -380,7 +499,7 @@ fn make_data_type<'a>( }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), Schema::Union(f) => { - // Special case the common case of nullable primitives + // Special case the common case of nullable primitives or single-type let null = f .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); @@ -431,50 +550,132 @@ fn make_data_type<'a>( }) } ComplexType::Fixed(f) => { + // Possibly decimal with logicalType=decimal let size = f.size.try_into().map_err(|e| { ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; + if let Some("decimal") = f.attributes.logical_type { + let precision = f + .attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires precision".to_string()) + })?; + let size_val = f + .attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires size".to_string()) + })?; + let scale = f + .attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .or_else(|| Some(0)); + + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Decimal( + precision as usize, + Some(scale.unwrap_or(0) as usize), + Some(size_val as usize), + ), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } else { + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + } + ComplexType::Enum(e) => { + let symbols = e.symbols.iter().map(|sym| sym.to_string()).collect::>(); let field = AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), + metadata: e.attributes.field_metadata(), + codec: Codec::Enum(symbols), + }; + resolver.register(e.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Map(m) => { + let values_data_type = make_data_type(m.values.as_ref(), namespace, resolver)?; + let field = AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(values_data_type)), }; - resolver.register(f.name, namespace, field.clone()); Ok(field) } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( - "Map of {m:?} not currently supported" - ))), }, Schema::Type(t) => { + // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - // https://avro.apache.org/docs/1.11.1/specification/#logical-types match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + Some( + t.attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + ); + } + (Some("decimal"), c @ Codec::Binary) => { + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + None, + ); } (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) - } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) - } + (Some("local-timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(false), + (Some("local-timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(false), (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { - // Insert unrecognized logical type into metadata map + // Insert unrecognized logical type into metadata field.metadata.insert("logicalType".into(), logical.into()); } (None, _) => {} @@ -490,20 +691,20 @@ fn make_data_type<'a>( } } - /// Convert an Arrow `Field` into an `AvroField`. -pub(crate) fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { - // TODO advanced metadata logic here +pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { + // Basic metadata logic: + // If arrow_field.metadata().get("namespace") is present, we store it below in AvroDataType let codec = arrow_type_to_codec(arrow_field.data_type()); - // Set nullability if the Arrow field is nullable let nullability = if arrow_field.is_nullable() { Some(Nullability::NullFirst) } else { None }; + let mut metadata = arrow_field.metadata().clone(); let avro_data_type = AvroDataType { nullability, - metadata: arrow_field.metadata().clone(), + metadata, codec, }; AvroField { @@ -512,7 +713,7 @@ pub(crate) fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { } } -/// Maps an Arrow `DataType` to a `Codec`: +/// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { use arrow_schema::DataType::*; match dt { @@ -527,29 +728,429 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Date32 => Codec::Date32, Time32(TimeUnit::Millisecond) => Codec::TimeMillis, Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, _) => Codec::TimestampMillis(true), - Timestamp(TimeUnit::Microsecond, _) => Codec::TimestampMicros(true), - FixedSizeBinary(n) => Codec::Fixed(*n as i32), - - List(field) => { - // Recursively create Codec for the child item - let child_codec = arrow_type_to_codec(field.data_type()); - Codec::List(Arc::new(AvroDataType { - nullability: None, - metadata: Default::default(), - codec: child_codec, - })) + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), + Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMillis(true) + } + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMicros(true) + } + FixedSizeBinary(n) => Codec::Fixed(*n), + Decimal128(prec, scale) => Codec::Decimal( + *prec as usize, + Some(*scale as usize), + Some(16), + ), + Decimal256(prec, scale) => Codec::Decimal( + *prec as usize, + Some(*scale as usize), + Some(32), + ),Dictionary(index_type, value_type) => { + let mut md = HashMap::new(); + md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); + if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { + let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); + dt.metadata.extend(md); + Codec::Enum(vec![]) + } else { + // fallback + Codec::Utf8 + } + } + // For map => "type":"map" => in Arrow: DataType::Map + Map(field, _keys_sorted) => { + if let Struct(child_fields) = field.data_type() { + let value_field = &child_fields[1]; // name="value" + let sub_codec = arrow_type_to_codec(value_field.data_type()); + Codec::Map(Arc::new(AvroDataType { + nullability: value_field.is_nullable().then(|| Nullability::NullFirst), + metadata: value_field.metadata().clone(), + codec: sub_codec, + })) + } else { + Codec::Map(Arc::new(AvroDataType::from_codec(Codec::Utf8))) + } } Struct(child_fields) => { let avro_fields: Vec = child_fields .iter() - .map(|fref| arrow_field_to_avro_field(fref.as_ref())) + .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) .collect(); Codec::Struct(Arc::from(avro_fields)) } - _ => { - // TODO handle more arrow types (e.g. decimal, map, union, etc.) - Codec::Utf8 + _ => Codec::Utf8, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{DataType, Field}; + use std::sync::Arc; + use serde_json::json; + + #[test] + fn test_decimal256_tuple_variant_fixed() { + // Arrow decimal(60,3) => Codec::Decimal(60,3,Some(32)) + let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); + match c { + Codec::Decimal(p, s, Some(32)) => { + assert_eq!(p, 60); + assert_eq!(s, Some(3)); + } + _ => panic!("Expected decimal(60,3,Some(32))"), + } + let avro_dt = AvroDataType::from_codec(c); + let avro_schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 32, + "logicalType": "decimal", + "precision": 60, + "scale": 3 + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal128_tuple_variant_fixed() { + // Avro "fixed" => decimal(6,2,Some(4)) + // arrow => decimal(6,2) + let c = Codec::Decimal(6, Some(2), Some(4)); + let dt = c.data_type(); + match dt { + DataType::Decimal128(p, s) => { + assert_eq!(p, 6); + assert_eq!(s, 2); + } + _ => panic!("Expected decimal(6,2) arrow type"), + } + + // Convert back to Avro schema => "fixed" + let avro_dt = AvroDataType::from_codec(c); + let schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 4, + "logicalType": "decimal", + "precision": 6, + "scale": 2, + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal_size_decision() { + // Decimal128 (size <= 16) + let codec = Codec::Decimal(10, Some(3), Some(16)); + let dt = codec.data_type(); + match dt { + DataType::Decimal128(precision, scale) => { + assert_eq!(precision, 10); + assert_eq!(scale, 3); + } + _ => panic!("Expected Decimal128"), + } + + // Decimal256 (size > 16) + let codec = Codec::Decimal(18, Some(4), Some(32)); + let dt = codec.data_type(); + match dt { + DataType::Decimal256(precision, scale) => { + assert_eq!(precision, 18); + assert_eq!(scale, 4); + } + _ => panic!("Expected Decimal256"), + } + + // Default to Decimal128 (size not specified) + let codec = Codec::Decimal(8, Some(2), None); + let dt = codec.data_type(); + match dt { + DataType::Decimal128(precision, scale) => { + assert_eq!(precision, 8); + assert_eq!(scale, 2); + } + _ => panic!("Expected Decimal128"), + } + } + + #[test] + fn test_avro_data_type_new_and_from_codec() { + let dt1 = AvroDataType::new( + Codec::Int32, + Some(Nullability::NullFirst), + HashMap::from([("namespace".into(), "my.ns".into())]), + ); + + let actual_str = format!("{:?}", dt1.nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + + let actual_str2 = format!("{:?}", dt1.codec()); + let expected_str2 = format!("{:?}", &Codec::Int32); + assert_eq!(actual_str2, expected_str2); + assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); + + let dt2 = AvroDataType::from_codec(Codec::Float64); + let actual_str4 = format!("{:?}", dt2.codec()); + let expected_str4 = format!("{:?}", &Codec::Float64); + assert_eq!(actual_str4, expected_str4); + assert!(dt2.metadata.is_empty()); + } + + #[test] + fn test_avro_data_type_field_with_name() { + let dt = AvroDataType::new( + Codec::Binary, + None, + HashMap::from([("something".into(), "else".into())]), + ); + let f = dt.field_with_name("bin_col"); + assert_eq!(f.name(), "bin_col"); + assert_eq!(f.data_type(), &DataType::Binary); + assert!(!f.is_nullable()); + assert_eq!(f.metadata().get("something"), Some(&"else".to_string())); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_record() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example".to_string()); + let fields = Arc::from(vec![ + AvroField { + name: "id".to_string(), + data_type: AvroDataType::from_codec(Codec::Int32), + }, + AvroField { + name: "label".to_string(), + data_type: AvroDataType::new(Codec::Utf8, Some(Nullability::NullFirst), Default::default()), + } + ]); + let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); + let avro_schema = top_level.to_avro_schema("TopRecord"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + + let expected = json!({ + "type": "record", + "name": "TopRecord", + "namespace": "com.example", + "doc": null, + "logicalType": null, + "aliases": [], + "fields": [ + { "name": "id", "doc": null, "type": "int" }, + { "name": "label", "doc": null, "type": ["null","string"] } + ], + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_enum() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.enum".to_string()); + + let enum_dt = AvroDataType::new( + Codec::Enum(vec!["A".to_string(), "B".to_string(), "C".to_string()]), + None, + meta, + ); + let avro_schema = enum_dt.to_avro_schema("MyEnum"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "enum", + "name": "MyEnum", + "logicalType": null, + "namespace": "com.example.enum", + "doc": null, + "aliases": [], + "symbols": ["A","B","C"] + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.fixed".to_string()); + + let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); + + let avro_schema = fixed_dt.to_avro_schema("MyFixed"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + + let expected = json!({ + "type": "fixed", + "name": "MyFixed", + "logicalType": null, + "namespace": "com.example.fixed", + "aliases": [], + "size": 8 + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_field() { + let field_codec = AvroDataType::from_codec(Codec::Int64); + let avro_field = AvroField { + name: "long_col".to_string(), + data_type: field_codec.clone(), + }; + + assert_eq!(avro_field.name(), "long_col"); + + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Int64); + assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); + + let arrow_field = avro_field.field(); + assert_eq!(arrow_field.name(), "long_col"); + assert_eq!(arrow_field.data_type(), &DataType::Int64); + assert!(!arrow_field.is_nullable()); + } + + #[test] + fn test_arrow_field_to_avro_field() { + let arrow_field = Field::new( + "test_meta", + DataType::Utf8, + true, + ) + .with_metadata(HashMap::from([ + ("namespace".to_string(), "arrow_meta_ns".to_string()) + ])); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert_eq!(avro_field.name(), "test_meta"); + + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Utf8); + assert_eq!(actual_str, expected_str); + + let actual_str = format!("{:?}", avro_field.data_type().nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + + // Confirm we kept the metadata + assert_eq!( + avro_field.data_type().metadata.get("namespace"), + Some(&"arrow_meta_ns".to_string()) + ); + } + + #[test] + fn test_codec_struct() { + let fields = Arc::from(vec![ + AvroField { + name: "a".to_string(), + data_type: AvroDataType::from_codec(Codec::Boolean), + }, + AvroField { + name: "b".to_string(), + data_type: AvroDataType::from_codec(Codec::Float64), + }, + ]); + let codec = Codec::Struct(fields); + let dt = codec.data_type(); + match dt { + DataType::Struct(fields) => { + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "a"); + assert_eq!(fields[0].data_type(), &DataType::Boolean); + assert_eq!(fields[1].name(), "b"); + assert_eq!(fields[1].data_type(), &DataType::Float64); + } + _ => panic!("Expected Struct data type"), + } + } + + #[test] + fn test_codec_fixedsizebinary() { + let codec = Codec::Fixed(12); + let dt = codec.data_type(); + match dt { + DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + _ => panic!("Expected FixedSizeBinary(12)"), } } + + #[test] + fn test_utc_timestamp_millis() { + let arrow_field = Field::new( + "utc_ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMillis(true)), + "Expected Codec::TimestampMillis(true), got: {:?}", + codec + ); + } + + #[test] + fn test_utc_timestamp_micros() { + let arrow_field = Field::new( + "utc_ts_us", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMicros(true)), + "Expected Codec::TimestampMicros(true), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_millis() { + let arrow_field = Field::new( + "local_ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMillis(false)), + "Expected Codec::TimestampMillis(false), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_micros() { + let arrow_field = Field::new( + "local_ts_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMicros(false)), + "Expected Codec::TimestampMicros(false), got: {:?}", + codec + ); + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 52a58cf63303..97ccc1032b76 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -144,6 +144,9 @@ impl Decoder { } Self::Record(arrow_fields.into(), encodings) } + _ => { + Self::Null(0) // TODO: Add decoders for Enum, Map, and Decimal + } }; Ok(match data_type.nullability() { From 36b4b734cea5d9a8b7c28044c76ee076988acebd Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 22:48:29 -0600 Subject: [PATCH 2/8] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 47 +-- arrow-avro/src/reader/cursor.rs | 28 +- arrow-avro/src/reader/record.rs | 500 ++++++++++++++++++++++++++++++-- 3 files changed, 500 insertions(+), 75 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index aab38a45e444..01a2732e99bc 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -53,7 +53,6 @@ pub struct AvroDataType { impl AvroDataType { /// Create a new AvroDataType with the given parts. - /// This helps you construct it from outside `codec.rs` without exposing internals. pub fn new( codec: Codec, nullability: Option, @@ -261,7 +260,7 @@ impl Codec { DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) }, _ => { - // Infer based on precision when size is None + // Note: Infer based on precision when size is None if *precision <= DECIMAL128_MAX_PRECISION as usize && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize { @@ -409,7 +408,6 @@ impl Codec { Codec::Decimal(precision, scale, size) => { // If size is Some(n), produce Avro "fixed", else "bytes". if let Some(n) = size { - // fixed with logicalType=decimal, plus precision/scale Schema::Complex(ComplexType::Fixed(AvroFixed { name, namespace: None, @@ -532,7 +530,6 @@ fn make_data_type<'a>( }) }) .collect::>()?; - let field = AvroDataType { nullability: None, codec: Codec::Struct(fields), @@ -554,7 +551,6 @@ fn make_data_type<'a>( let size = f.size.try_into().map_err(|e| { ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; - if let Some("decimal") = f.attributes.logical_type { let precision = f .attributes @@ -578,7 +574,6 @@ fn make_data_type<'a>( .get("scale") .and_then(|v| v.as_u64()) .or_else(|| Some(0)); - let field = AvroDataType { nullability: None, metadata: f.attributes.field_metadata(), @@ -624,7 +619,6 @@ fn make_data_type<'a>( // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { *c = Codec::Decimal( @@ -792,7 +786,6 @@ mod tests { #[test] fn test_decimal256_tuple_variant_fixed() { - // Arrow decimal(60,3) => Codec::Decimal(60,3,Some(32)) let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); match c { Codec::Decimal(p, s, Some(32)) => { @@ -818,8 +811,6 @@ mod tests { #[test] fn test_decimal128_tuple_variant_fixed() { - // Avro "fixed" => decimal(6,2,Some(4)) - // arrow => decimal(6,2) let c = Codec::Decimal(6, Some(2), Some(4)); let dt = c.data_type(); match dt { @@ -829,8 +820,6 @@ mod tests { } _ => panic!("Expected decimal(6,2) arrow type"), } - - // Convert back to Avro schema => "fixed" let avro_dt = AvroDataType::from_codec(c); let schema = avro_dt.to_avro_schema("FixedDec"); let j = serde_json::to_value(&schema).unwrap(); @@ -848,7 +837,6 @@ mod tests { #[test] fn test_decimal_size_decision() { - // Decimal128 (size <= 16) let codec = Codec::Decimal(10, Some(3), Some(16)); let dt = codec.data_type(); match dt { @@ -858,8 +846,6 @@ mod tests { } _ => panic!("Expected Decimal128"), } - - // Decimal256 (size > 16) let codec = Codec::Decimal(18, Some(4), Some(32)); let dt = codec.data_type(); match dt { @@ -869,8 +855,6 @@ mod tests { } _ => panic!("Expected Decimal256"), } - - // Default to Decimal128 (size not specified) let codec = Codec::Decimal(8, Some(2), None); let dt = codec.data_type(); match dt { @@ -889,16 +873,13 @@ mod tests { Some(Nullability::NullFirst), HashMap::from([("namespace".into(), "my.ns".into())]), ); - let actual_str = format!("{:?}", dt1.nullability()); let expected_str = format!("{:?}", Some(Nullability::NullFirst)); assert_eq!(actual_str, expected_str); - let actual_str2 = format!("{:?}", dt1.codec()); let expected_str2 = format!("{:?}", &Codec::Int32); assert_eq!(actual_str2, expected_str2); assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); - let dt2 = AvroDataType::from_codec(Codec::Float64); let actual_str4 = format!("{:?}", dt2.codec()); let expected_str4 = format!("{:?}", &Codec::Float64); @@ -937,7 +918,6 @@ mod tests { let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); let avro_schema = top_level.to_avro_schema("TopRecord"); let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ "type": "record", "name": "TopRecord", @@ -981,12 +961,9 @@ mod tests { fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { let mut meta = HashMap::new(); meta.insert("namespace".to_string(), "com.example.fixed".to_string()); - let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); - let avro_schema = fixed_dt.to_avro_schema("MyFixed"); let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ "type": "fixed", "name": "MyFixed", @@ -1005,13 +982,10 @@ mod tests { name: "long_col".to_string(), data_type: field_codec.clone(), }; - assert_eq!(avro_field.name(), "long_col"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); let expected_str = format!("{:?}", &Codec::Int64); assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); - let arrow_field = avro_field.field(); assert_eq!(arrow_field.name(), "long_col"); assert_eq!(arrow_field.data_type(), &DataType::Int64); @@ -1024,22 +998,17 @@ mod tests { "test_meta", DataType::Utf8, true, - ) - .with_metadata(HashMap::from([ - ("namespace".to_string(), "arrow_meta_ns".to_string()) - ])); + ).with_metadata(HashMap::from([ + ("namespace".to_string(), "arrow_meta_ns".to_string()) + ])); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); let expected_str = format!("{:?}", &Codec::Utf8); assert_eq!(actual_str, expected_str); - let actual_str = format!("{:?}", avro_field.data_type().nullability()); let expected_str = format!("{:?}", Some(Nullability::NullFirst)); assert_eq!(actual_str, expected_str); - - // Confirm we kept the metadata assert_eq!( avro_field.data_type().metadata.get("namespace"), Some(&"arrow_meta_ns".to_string()) @@ -1089,10 +1058,8 @@ mod tests { DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMillis(true)), "Expected Codec::TimestampMillis(true), got: {:?}", @@ -1107,10 +1074,8 @@ mod tests { DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMicros(true)), "Expected Codec::TimestampMicros(true), got: {:?}", @@ -1125,10 +1090,8 @@ mod tests { DataType::Timestamp(TimeUnit::Millisecond, None), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMillis(false)), "Expected Codec::TimestampMillis(false), got: {:?}", @@ -1143,10 +1106,8 @@ mod tests { DataType::Timestamp(TimeUnit::Microsecond, None), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMicros(false)), "Expected Codec::TimestampMicros(false), got: {:?}", diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..ba1d01f72d7e 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; @@ -65,27 +64,32 @@ impl<'a> AvroCursor<'a> { Ok(val) } + /// Decode a zig-zag encoded Avro int (32-bit). #[inline] pub(crate) fn get_int(&mut self) -> Result { let varint = self.read_vlq()?; let val: u32 = varint .try_into() .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; + // Zig-zag decode Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } + /// Decode a zig-zag encoded Avro long (64-bit). #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; + // Zig-zag decode Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } + /// Read a variable-length byte array from Avro (where the length is stored as an Avro long). pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { let len: usize = self.get_long()?.try_into().map_err(|_| { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -95,9 +99,10 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 32-bit float #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -107,15 +112,28 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 64-bit float #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( - "Unexpected EOF reading float".to_string(), + "Unexpected EOF reading double".to_string(), )); } let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 97ccc1032b76..4c57a3426bd6 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -23,12 +23,12 @@ use crate::schema::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; -use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, -}; +use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit}; use std::collections::HashMap; use std::io::Read; +use std::ptr::null; use std::sync::Arc; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; /// Decodes avro encoded data into [`RecordBatch`] pub struct RecordDecoder { @@ -94,9 +94,17 @@ enum Decoder { List(FieldRef, OffsetBufferBuilder, Box), Record(Fields, Vec), Nullable(Nullability, NullBufferBuilder, Box), + Enum(Vec, Vec), + Map(FieldRef, OffsetBufferBuilder, OffsetBufferBuilder, Vec, Box, usize), + Decimal(usize, usize, Option, Vec>), } impl Decoder { + /// Checks if the Decoder is nullable + fn is_nullable(&self) -> bool { + matches!(self, Decoder::Nullable(_, _, _)) + } + fn try_new(data_type: &AvroDataType) -> Result { let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); @@ -144,11 +152,39 @@ impl Decoder { } Self::Record(arrow_fields.into(), encodings) } - _ => { - Self::Null(0) // TODO: Add decoders for Enum, Map, and Decimal + Codec::Enum(symbols) => { + Decoder::Enum( + symbols.clone(), + Vec::with_capacity(DEFAULT_CAPACITY), + ) + } + Codec::Map(value_type) => { + let map_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(value_type.field_with_name("value")), + ])), + false, + )); + Decoder::Map( + map_field, + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets + Vec::with_capacity(DEFAULT_CAPACITY), // key_data + Box::new(Self::try_new(value_type)?), // values_decoder_inner + 0, // current_entry_count + ) + } + Codec::Decimal(precision, scale, size) => { + Decoder::Decimal( + *precision, + scale.unwrap_or(0), + *size, + Vec::with_capacity(DEFAULT_CAPACITY), + ) } }; - Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( nullability, @@ -178,6 +214,23 @@ impl Decoder { } Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + Self::Enum(_, _) => { + // For Enum, appending a null is not straightforward. Handle accordingly if needed. + } + Self::Map( + _, + key_offsets, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + key_offsets.push_length(0); + map_offsets_builder.push_length(*current_entry_count); + } + Self::Decimal(_, _, _, _) => { + // For Decimal, appending a null doesn't make sense as per current implementation + } } } @@ -218,59 +271,256 @@ impl Decoder { false => e.append_null(), } } + Self::Enum(symbols, indices) => { + // Encodes enum by writing its zero-based index as an int + let index = buf.get_int()?; + indices.push(index); + } + Self::Map( + field, + key_offsets, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + let block_count = buf.get_long()?; + if block_count <= 0 { + // Push the current_entry_count without changes + map_offsets_builder.push_length(*current_entry_count); + } else { + let n = block_count as usize; + for _ in 0..n { + let key_bytes = buf.get_bytes()?; + key_offsets.push_length(key_bytes.len()); + key_data.extend_from_slice(key_bytes); + values_decoder_inner.decode(buf)?; + } + // Update the current_entry_count and push to map_offsets_builder + *current_entry_count += n; + map_offsets_builder.push_length(*current_entry_count); + } + } + Self::Decimal( + precision, + scale, + size, + data + ) => { + let raw = if let Some(fixed_len) = size { + // get_fixed used to get exactly fixed_len bytes + buf.get_fixed(*fixed_len)? + } else { + // get_bytes used for variable-length + buf.get_bytes()? + }; + data.push(raw.to_vec()); + } } Ok(()) } /// Flush decoded records to an [`ArrayRef`] fn flush(&mut self, nulls: Option) -> Result { - Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, - Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), - Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), - Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), + match self { + Self::Nullable(_, n, e) => e.flush(n.finish()), + Self::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), + Self::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), + Self::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), Self::TimeMillis(values) => { - Arc::new(flush_primitive::(values, nulls)) + Ok(Arc::new(flush_primitive::(values, nulls))) } Self::TimeMicros(values) => { - Arc::new(flush_primitive::(values, nulls)) + Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimestampMillis(is_utc, values) => Arc::new( + Self::TimestampMillis(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::TimestampMicros(is_utc, values) => Arc::new( + )), + Self::TimestampMicros(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - + )), + Self::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), Self::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) + Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } Self::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) + Ok(Arc::new(StringArray::new(offsets, values, nulls))) } Self::List(field, offsets, values) => { let values = values.flush(None)?; let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) } Self::Record(fields, encodings) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } - }) + Self::Enum(symbols, indices) => { + let dict_values = StringArray::from_iter_values(symbols.iter()); + let flushed_indices = flush_values(indices); // Vec + let indices_array: Int32Array = match nulls { + Some(buf) => { + let buffer = Buffer::from_slice_ref(&flushed_indices); + PrimitiveArray::::try_new(ScalarBuffer::from(buffer), Some(buf.clone()))? + }, + None => { + Int32Array::from_iter_values(flushed_indices) + } + }; + let dict_array = DictionaryArray::::try_new( + indices_array, + Arc::new(dict_values), + )?; + Ok(Arc::new(dict_array)) + } + Self::Map( + field, + key_offsets_builder, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + let map_offsets = flush_offsets(map_offsets_builder); + let key_offsets = flush_offsets(key_offsets_builder); + let key_data = flush_values(key_data).into(); + let key_array = StringArray::new(key_offsets, key_data, None); + let val_array = values_decoder_inner.flush(None)?; + let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); + let struct_fields = vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new("value", val_array.data_type().clone(), is_nullable)), + ]; + let struct_array = StructArray::new( + Fields::from(struct_fields), + vec![Arc::new(key_array), val_array], + None, + ); + let map_array = MapArray::new(field.clone(), map_offsets.clone(), struct_array.clone(), nulls, false); + Ok(Arc::new(map_array)) + } + Self::Decimal( + precision, + scale, + size, + data, + ) => { + let mut array_builder = DecimalBuilder::new(*precision, *scale, *size)?; + for raw in data.drain(..) { + if let Some(s) = size { + if raw.len() < *s { + let extended = sign_extend(&raw, *s); + array_builder.append_bytes(&extended)?; + continue; + } + } + array_builder.append_bytes(&raw)?; + } + let arr = array_builder.finish()?; + Ok(Arc::new(arr)) + } + } + } +} + +/// Helper to build a field with a given type +fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(ArrowField::new(name, dt, nullable)) +} + +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut extended = Vec::with_capacity(target_len); + if sign_bit != 0 { // negative + extended.resize(target_len - raw.len(), 0xFF); + } else { // positive + extended.resize(target_len - raw.len(), 0x00); + } + extended.extend_from_slice(raw); + extended +} + +/// Extend raw bytes to 16 bytes (for Decimal128) +fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let extended = sign_extend(raw, 16); + if extended.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend bytes to 16 bytes: got {} bytes", + extended.len() + ))); + } + Ok(extended.try_into().unwrap()) +} + +/// Extend raw bytes to 32 bytes (for Decimal256) +fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let extended = sign_extend(raw, 32); + if extended.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend bytes to 32 bytes: got {} bytes", + extended.len() + ))); + } + Ok(extended.try_into().unwrap()) +} + +/// Trait for building decimal arrays +enum DecimalBuilder { + Decimal128(Decimal128Builder), + Decimal256(Decimal256Builder), +} + +impl DecimalBuilder { + + fn new(precision: usize, scale: usize, size: Option) -> Result { + match size { + Some(s) if s > 16 => { + // decimal256 + Ok(Self::Decimal256(Decimal256Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + } + _ => { + // decimal128 + Ok(Self::Decimal128(Decimal128Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + } + } + } + + fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { + match self { + DecimalBuilder::Decimal128(b) => { + let padded = extend_to_16_bytes(bytes)?; + let value = i128::from_be_bytes(padded); + b.append_value(value); + } + DecimalBuilder::Decimal256(b) => { + let padded = extend_to_32_bytes(bytes)?; + let value = i256::from_be_bytes(padded); + b.append_value(value); + } + } + Ok(()) + } + + fn finish(self) -> Result { + match self { + DecimalBuilder::Decimal128(mut b) => Ok(Arc::new(b.finish())), + DecimalBuilder::Decimal256(mut b) => Ok(Arc::new(b.finish())), + } } } @@ -293,3 +543,199 @@ fn flush_primitive( } const DEFAULT_CAPACITY: usize = 1024; + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ + Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, + Decimal128Array, Decimal256Array, DictionaryArray, + }; + use arrow_array::cast::AsArray; + use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + + /// Helper functions for encoding test data + fn encode_avro_int(value: i32) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 31); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_long(value: i64) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 63); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_bytes(bytes: &[u8]) -> Vec { + let mut buf = encode_avro_long(bytes.len() as i64); + buf.extend_from_slice(bytes); + buf + } + + #[test] + fn test_enum_decoding() { + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + // Encode the indices [1, 0, 2] using zigzag encoding + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_map_decoding_one_entry() { + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Avro encoding for a map: + // - block_count: 1 (number of entries) + // - keys: "hello" (5 bytes) + // - values: "world" (5 bytes) + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" + data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // Verify 1 map + assert_eq!(map_arr.value_length(0), 1); // Verify 1 entry in the map + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 1); // Verify 1 entry in StructArray + let key = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key.value(0), "hello"); // Verify Key + assert_eq!(value.value(0), "world"); // Verify Value + } + + #[test] + fn test_map_decoding_empty() { + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Avro encoding for an empty map: + // - block_count: 0 (no entries) + let data = encode_avro_long(0); // block_count = 0 + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // Verify 1 map + assert_eq!(map_arr.value_length(0), 0); // Verify 0 entries in the map + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 0); // // Verify 0 entries StructArray + let key = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key.len(), 0); + assert_eq!(value.len(), 0); + } + + #[test] + fn test_decimal_decoding_fixed128() { + let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + // Row1: 123.45 => unscaled: 12345 => i128: 0x00000000000000000000000000003039 + // Row2: -1.23 => unscaled: -123 => i128: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF85 + let row1 = [ + 0x00, 0x00, 0x00, 0x00, // First 8 bytes + 0x00, 0x00, 0x00, 0x00, // Next 8 bytes + 0x00, 0x00, 0x00, 0x00, // Next 8 bytes + 0x00, 0x00, 0x30, 0x39, // Last 8 bytes: 0x3039 = 12345 + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, // First 8 bytes (two's complement) + 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes + 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes + 0xFF, 0xFF, 0xFF, 0x85, // Last 8 bytes: 0xFFFFFF85 = -123 + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + decoder.decode(&mut AvroCursor::new(&data[16..])).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 2); + assert_eq!(dec_arr.value_as_string(0), "123.45"); + assert_eq!(dec_arr.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_bytes() { + let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let unscaled_row1: i128 = 1234; // 123.4 + let unscaled_row2: i128 = -1234; // -123.4 + // Note: convert unscaled values to big-endian bytes + let bytes_row1 = unscaled_row1.to_be_bytes(); + let bytes_row2 = unscaled_row2.to_be_bytes(); + // Row1: 1234 => 0x04D2 (2 bytes) + // Row2: -1234 => two's complement of 0x04D2 = 0xFB2E (2 bytes) + let row1_bytes = &bytes_row1[14..16]; // Last 2 bytes + let row2_bytes = &bytes_row2[14..16]; // Last 2 bytes + let mut data = Vec::new(); + // Encode row1 + data.extend_from_slice(&encode_avro_long(2)); // Length=2 + data.extend_from_slice(row1_bytes); // 0x04, 0xD2 + // Encode row2 + data.extend_from_slice(&encode_avro_long(2)); // Length=2 + data.extend_from_slice(row2_bytes); // 0xFB, 0x2E + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 2); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(1), "-123.4"); + } +} From 9d0bf4cc7e490e34aa409d62e8a4d3eee0783e70 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 23:01:18 -0600 Subject: [PATCH 3/8] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 01a2732e99bc..92274a167de6 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -740,7 +740,8 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { *prec as usize, Some(*scale as usize), Some(32), - ),Dictionary(index_type, value_type) => { + ), + Dictionary(index_type, value_type) => { let mut md = HashMap::new(); md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { From 082581a4d98a6931718a8491ee524f6b86fc7dd5 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 23:02:15 -0600 Subject: [PATCH 4/8] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 92274a167de6..d58390a57bf2 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -748,15 +748,13 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); dt.metadata.extend(md); Codec::Enum(vec![]) - } else { - // fallback + } else { // fallback Codec::Utf8 } } - // For map => "type":"map" => in Arrow: DataType::Map Map(field, _keys_sorted) => { if let Struct(child_fields) = field.data_type() { - let value_field = &child_fields[1]; // name="value" + let value_field = &child_fields[1]; let sub_codec = arrow_type_to_codec(value_field.data_type()); Codec::Map(Arc::new(AvroDataType { nullability: value_field.is_nullable().then(|| Nullability::NullFirst), From 6647b250054b6a9681c976e265d643ffcbc1d85c Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 12:39:55 -0600 Subject: [PATCH 5/8] Added null support Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 86 ++--- arrow-avro/src/reader/record.rs | 651 ++++++++++++++++++++++---------- 2 files changed, 493 insertions(+), 244 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index d58390a57bf2..4e57d4d186bc 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -29,6 +29,7 @@ use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; +use arrow_schema::DataType::*; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. @@ -207,43 +208,43 @@ impl Codec { /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { - Self::Null => DataType::Null, - Self::Boolean => DataType::Boolean, - Self::Int32 => DataType::Int32, - Self::Int64 => DataType::Int64, - Self::Float32 => DataType::Float32, - Self::Float64 => DataType::Float64, - Self::Binary => DataType::Binary, - Self::Utf8 => DataType::Utf8, - Self::Date32 => DataType::Date32, - Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), - Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::Null => Null, + Self::Boolean => Boolean, + Self::Int32 => Int32, + Self::Int64 => Int64, + Self::Float32 => Float32, + Self::Float64 => Float64, + Self::Binary => Binary, + Self::Utf8 => Utf8, + Self::Date32 => Date32, + Self::TimeMillis => Time32(TimeUnit::Millisecond), + Self::TimeMicros => Time64(TimeUnit::Microsecond), Self::TimestampMillis(is_utc) => { - DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) } Self::TimestampMicros(is_utc) => { - DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => DataType::FixedSizeBinary(*size), + Self::Interval => Interval(IntervalUnit::MonthDayNano), + Self::Fixed(size) => FixedSizeBinary(*size), Self::List(f) => { - DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) } - Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), Self::Enum(_symbols) => { // Produce a Dictionary type with index = Int32, value = Utf8 - DataType::Dictionary( + Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ) } Self::Map(values) => { - DataType::Map( + Map( Arc::new(Field::new( "entries", - DataType::Struct( + Struct( Fields::from(vec![ - Field::new("key", DataType::Utf8, false), + Field::new("key", Utf8, false), values.field_with_name("value"), ]) ), @@ -254,19 +255,19 @@ impl Codec { } Self::Decimal(precision, scale, size) => match size { Some(s) if *s > 16 && *s <= 32 => { - DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + Decimal256(*precision as u8, scale.unwrap_or(0) as i8) }, Some(s) if *s <= 16 => { - DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) }, _ => { // Note: Infer based on precision when size is None if *precision <= DECIMAL128_MAX_PRECISION as usize && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize { - DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) } else { - DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + Decimal256(*precision as u8, scale.unwrap_or(0) as i8) } } }, @@ -687,8 +688,6 @@ fn make_data_type<'a>( /// Convert an Arrow `Field` into an `AvroField`. pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { - // Basic metadata logic: - // If arrow_field.metadata().get("namespace") is present, we store it below in AvroDataType let codec = arrow_type_to_codec(arrow_field.data_type()); let nullability = if arrow_field.is_nullable() { Some(Nullability::NullFirst) @@ -709,7 +708,6 @@ pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { /// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { - use arrow_schema::DataType::*; match dt { Null => Codec::Null, Boolean => Codec::Boolean, @@ -742,13 +740,9 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Some(32), ), Dictionary(index_type, value_type) => { - let mut md = HashMap::new(); - md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); - if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { - let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); - dt.metadata.extend(md); + if let Utf8 = **value_type { Codec::Enum(vec![]) - } else { // fallback + } else { // Fallback to Utf8 Codec::Utf8 } } @@ -785,7 +779,7 @@ mod tests { #[test] fn test_decimal256_tuple_variant_fixed() { - let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); + let c = arrow_type_to_codec(&Decimal256(60, 3)); match c { Codec::Decimal(p, s, Some(32)) => { assert_eq!(p, 60); @@ -813,7 +807,7 @@ mod tests { let c = Codec::Decimal(6, Some(2), Some(4)); let dt = c.data_type(); match dt { - DataType::Decimal128(p, s) => { + Decimal128(p, s) => { assert_eq!(p, 6); assert_eq!(s, 2); } @@ -839,7 +833,7 @@ mod tests { let codec = Codec::Decimal(10, Some(3), Some(16)); let dt = codec.data_type(); match dt { - DataType::Decimal128(precision, scale) => { + Decimal128(precision, scale) => { assert_eq!(precision, 10); assert_eq!(scale, 3); } @@ -848,7 +842,7 @@ mod tests { let codec = Codec::Decimal(18, Some(4), Some(32)); let dt = codec.data_type(); match dt { - DataType::Decimal256(precision, scale) => { + Decimal256(precision, scale) => { assert_eq!(precision, 18); assert_eq!(scale, 4); } @@ -857,7 +851,7 @@ mod tests { let codec = Codec::Decimal(8, Some(2), None); let dt = codec.data_type(); match dt { - DataType::Decimal128(precision, scale) => { + Decimal128(precision, scale) => { assert_eq!(precision, 8); assert_eq!(scale, 2); } @@ -995,7 +989,7 @@ mod tests { fn test_arrow_field_to_avro_field() { let arrow_field = Field::new( "test_meta", - DataType::Utf8, + Utf8, true, ).with_metadata(HashMap::from([ ("namespace".to_string(), "arrow_meta_ns".to_string()) @@ -1029,7 +1023,7 @@ mod tests { let codec = Codec::Struct(fields); let dt = codec.data_type(); match dt { - DataType::Struct(fields) => { + Struct(fields) => { assert_eq!(fields.len(), 2); assert_eq!(fields[0].name(), "a"); assert_eq!(fields[0].data_type(), &DataType::Boolean); @@ -1045,7 +1039,7 @@ mod tests { let codec = Codec::Fixed(12); let dt = codec.data_type(); match dt { - DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + FixedSizeBinary(n) => assert_eq!(n, 12), _ => panic!("Expected FixedSizeBinary(12)"), } } @@ -1054,7 +1048,7 @@ mod tests { fn test_utc_timestamp_millis() { let arrow_field = Field::new( "utc_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1070,7 +1064,7 @@ mod tests { fn test_utc_timestamp_micros() { let arrow_field = Field::new( "utc_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1086,7 +1080,7 @@ mod tests { fn test_local_timestamp_millis() { let arrow_field = Field::new( "local_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Millisecond, None), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1102,7 +1096,7 @@ mod tests { fn test_local_timestamp_micros() { let arrow_field = Field::new( "local_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Microsecond, None), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 4c57a3426bd6..a29f293107c0 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -23,7 +23,7 @@ use crate::schema::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; -use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit}; +use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION}; use std::collections::HashMap; use std::io::Read; use std::ptr::null; @@ -76,6 +76,7 @@ impl RecordDecoder { } } +/// Enum representing different decoders for various data types. #[derive(Debug)] enum Decoder { Null(usize), @@ -95,51 +96,59 @@ enum Decoder { Record(Fields, Vec), Nullable(Nullability, NullBufferBuilder, Box), Enum(Vec, Vec), - Map(FieldRef, OffsetBufferBuilder, OffsetBufferBuilder, Vec, Box, usize), - Decimal(usize, usize, Option, Vec>), + Map( + FieldRef, + OffsetBufferBuilder, // key_offsets + OffsetBufferBuilder, // map_offsets + Vec, // key_data + Box, // values_decoder_inner + usize, // current_entry_count + ), + Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable + /// Checks if the Decoder is nullable. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } + /// Creates a new `Decoder` based on the provided `AvroDataType`. fn try_new(data_type: &AvroDataType) -> Result { let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { - Codec::Null => Self::Null(0), - Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Self::Binary( + Codec::Null => Decoder::Null(0), + Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + Codec::Int32 => Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Int64 => Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float32 => Decoder::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float64 => Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Binary => Decoder::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + Codec::Utf8 => Decoder::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Decoder::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Decoder::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Decoder::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Decoder::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Fixed(_) => return nyi("decoding fixed"), Codec::Interval => return nyi("decoding interval"), Codec::List(item) => { - let decoder = Self::try_new(item)?; - Self::List( + let decoder = Box::new(Self::try_new(item)?); + Decoder::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + decoder, ) } Codec::Struct(fields) => { @@ -150,14 +159,9 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - Self::Record(arrow_fields.into(), encodings) - } - Codec::Enum(symbols) => { - Decoder::Enum( - symbols.clone(), - Vec::with_capacity(DEFAULT_CAPACITY), - ) + Decoder::Record(arrow_fields.into(), encodings) } + Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( "entries", @@ -170,54 +174,58 @@ impl Decoder { Decoder::Map( map_field, OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets - Vec::with_capacity(DEFAULT_CAPACITY), // key_data - Box::new(Self::try_new(value_type)?), // values_decoder_inner + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets + Vec::with_capacity(DEFAULT_CAPACITY), // key_data + Box::new(Self::try_new(value_type)?), // values_decoder_inner 0, // current_entry_count ) } Codec::Decimal(precision, scale, size) => { - Decoder::Decimal( - *precision, - scale.unwrap_or(0), - *size, - Vec::with_capacity(DEFAULT_CAPACITY), - ) + let builder = DecimalBuilder::new(*precision, *scale, *size)?; + Decoder::Decimal(*precision, *scale, *size, builder) } }; - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( + + // Wrap the decoder in Nullable if necessary + match data_type.nullability() { + Some(nullability) => Ok(Decoder::Nullable( nullability, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), - ), - None => decoder, - }) + )), + None => Ok(decoder), + } } - /// Append a null record + /// Appends a null value to the decoder. fn append_null(&mut self) { match self { - Self::Null(count) => *count += 1, - Self::Boolean(b) => b.append(false), - Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), - Self::Int64(v) - | Self::TimeMicros(v) - | Self::TimestampMillis(_, v) - | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), - Self::List(_, offsets, e) => { + Decoder::Null(count) => *count += 1, + Decoder::Boolean(b) => b.append(false), + Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => v.push(0), + Decoder::Int64(v) + | Decoder::TimeMicros(v) + | Decoder::TimestampMillis(_, v) + | Decoder::TimestampMicros(_, v) => v.push(0), + Decoder::Float32(v) => v.push(0.0), + Decoder::Float64(v) => v.push(0.0), + Decoder::Binary(offsets, _) | Decoder::String(offsets, _) => { + offsets.push_length(0); + } + Decoder::List(_, offsets, e) => { offsets.push_length(0); e.append_null(); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), - Self::Enum(_, _) => { - // For Enum, appending a null is not straightforward. Handle accordingly if needed. + Decoder::Record(_, encodings) => { + for encoding in encodings.iter_mut() { + encoding.append_null(); + } + } + Decoder::Enum(_, indices) => { + // Append a placeholder index for null entries + indices.push(0); } - Self::Map( + Decoder::Map( _, key_offsets, map_offsets_builder, @@ -228,55 +236,60 @@ impl Decoder { key_offsets.push_length(0); map_offsets_builder.push_length(*current_entry_count); } - Self::Decimal(_, _, _, _) => { - // For Decimal, appending a null doesn't make sense as per current implementation + Decoder::Decimal(_, _, _, builder) => { + builder.append_null(); } + Decoder::Nullable(_, _, _) => { /* Nulls are handled by the Nullable variant */ } } } - /// Decode a single record from `buf` + /// Decodes a single record from the provided buffer `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Self::Null(x) => *x += 1, - Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { - values.push(buf.get_int()?) - } - Self::Int64(values) - | Self::TimeMicros(values) - | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), - Self::Float32(values) => values.push(buf.get_float()?), - Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) | Self::String(offsets, values) => { + Decoder::Null(x) => *x += 1, + Decoder::Boolean(values) => values.append(buf.get_bool()?), + Decoder::Int32(values) => values.push(buf.get_int()?), + Decoder::Date32(values) => values.push(buf.get_int()?), + Decoder::Int64(values) => values.push(buf.get_long()?), + Decoder::TimeMillis(values) => values.push(buf.get_int()?), + Decoder::TimeMicros(values) => values.push(buf.get_long()?), + Decoder::TimestampMillis(is_utc, values) => { + values.push(buf.get_long()?); + } + Decoder::TimestampMicros(is_utc, values) => { + values.push(buf.get_long()?); + } + Decoder::Float32(values) => values.push(buf.get_float()?), + Decoder::Float64(values) => values.push(buf.get_double()?), + Decoder::Binary(offsets, values) | Decoder::String(offsets, values) => { let data = buf.get_bytes()?; offsets.push_length(data.len()); values.extend_from_slice(data); } - Self::List(_, _, _) => { + Decoder::List(_, _, _) => { return Err(ArrowError::NotYetImplemented( "Decoding ListArray".to_string(), - )) + )); } - Self::Record(_, encodings) => { - for encoding in encodings { + Decoder::Record(fields, encodings) => { + for encoding in encodings.iter_mut() { encoding.decode(buf)?; } } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); + Decoder::Nullable(_, nulls, e) => { + let is_valid = buf.get_bool()?; nulls.append(is_valid); match is_valid { true => e.decode(buf)?, false => e.append_null(), } } - Self::Enum(symbols, indices) => { - // Encodes enum by writing its zero-based index as an int + Decoder::Enum(symbols, indices) => { + // Enums are encoded as zero-based indices using zigzag encoding let index = buf.get_int()?; indices.push(index); } - Self::Map( + Decoder::Map( field, key_offsets, map_offsets_builder, @@ -301,83 +314,79 @@ impl Decoder { map_offsets_builder.push_length(*current_entry_count); } } - Self::Decimal( - precision, - scale, - size, - data - ) => { - let raw = if let Some(fixed_len) = size { - // get_fixed used to get exactly fixed_len bytes - buf.get_fixed(*fixed_len)? + Decoder::Decimal(_precision, _scale, _size, builder) => { + if let Some(size) = _size { + // Fixed-size decimal + let raw = buf.get_fixed(*size)?; + builder.append_bytes(raw)?; } else { - // get_bytes used for variable-length - buf.get_bytes()? - }; - data.push(raw.to_vec()); + // Variable-size decimal + let bytes = buf.get_bytes()?; + builder.append_bytes(bytes)?; + } } } Ok(()) } - /// Flush decoded records to an [`ArrayRef`] + /// Flushes decoded records to an [`ArrayRef`]. fn flush(&mut self, nulls: Option) -> Result { match self { - Self::Nullable(_, n, e) => e.flush(n.finish()), - Self::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), - Self::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), - Self::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::TimeMillis(values) => { + Decoder::Nullable(_, n, e) => e.flush(n.finish()), + Decoder::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), + Decoder::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), + Decoder::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::TimeMillis(values) => { Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimeMicros(values) => { + Decoder::TimeMicros(values) => { Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimestampMillis(is_utc, values) => Ok(Arc::new( + Decoder::TimestampMillis(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), )), - Self::TimestampMicros(is_utc, values) => Ok(Arc::new( + Decoder::TimestampMicros(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), )), - Self::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Binary(offsets, values) => { + Decoder::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } - Self::String(offsets, values) => { + Decoder::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - Self::List(field, offsets, values) => { + Decoder::List(field, offsets, values) => { let values = values.flush(None)?; let offsets = flush_offsets(offsets); Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) } - Self::Record(fields, encodings) => { + Decoder::Record(fields, encodings) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } - Self::Enum(symbols, indices) => { + Decoder::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); - let flushed_indices = flush_values(indices); // Vec let indices_array: Int32Array = match nulls { Some(buf) => { - let buffer = Buffer::from_slice_ref(&flushed_indices); - PrimitiveArray::::try_new(ScalarBuffer::from(buffer), Some(buf.clone()))? - }, - None => { - Int32Array::from_iter_values(flushed_indices) + let buffer = arrow_buffer::Buffer::from_slice_ref(&indices); + PrimitiveArray::::try_new( + arrow_buffer::ScalarBuffer::from(buffer), + Some(buf.clone()), + )? } + None => Int32Array::from_iter_values(indices.iter().cloned()), }; let dict_array = DictionaryArray::::try_new( indices_array, @@ -385,7 +394,7 @@ impl Decoder { )?; Ok(Arc::new(dict_array)) } - Self::Map( + Decoder::Map( field, key_offsets_builder, map_offsets_builder, @@ -401,36 +410,37 @@ impl Decoder { let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_array.data_type().clone(), is_nullable)), + Arc::new(ArrowField::new( + "value", + val_array.data_type().clone(), + is_nullable, + )), ]; let struct_array = StructArray::new( Fields::from(struct_fields), vec![Arc::new(key_array), val_array], None, ); - let map_array = MapArray::new(field.clone(), map_offsets.clone(), struct_array.clone(), nulls, false); + let map_array = MapArray::new( + field.clone(), + map_offsets.clone(), + struct_array.clone(), + nulls, + false, + ); Ok(Arc::new(map_array)) } - Self::Decimal( - precision, - scale, - size, - data, - ) => { - let mut array_builder = DecimalBuilder::new(*precision, *scale, *size)?; - for raw in data.drain(..) { - if let Some(s) = size { - if raw.len() < *s { - let extended = sign_extend(&raw, *s); - array_builder.append_bytes(&extended)?; - continue; - } - } - array_builder.append_bytes(&raw)?; - } - let arr = array_builder.finish()?; - Ok(Arc::new(arr)) + Decoder::Decimal(_precision, _scale, _size, builder) => { + let precision = *_precision; + let scale = _scale.unwrap_or(0); // Default scale if None + let size = _size.clone(); + let builder = std::mem::replace( + builder, + DecimalBuilder::new(precision, *_scale, *_size)?, + ); + Ok(builder.finish(nulls, precision, scale)?) // Pass precision and scale } + } } } @@ -440,22 +450,23 @@ fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { Arc::new(ArrowField::new(name, dt, nullable)) } +/// Extends raw bytes to the target length with sign extension. fn sign_extend(raw: &[u8], target_len: usize) -> Vec { if raw.is_empty() { return vec![0; target_len]; } let sign_bit = raw[0] & 0x80; let mut extended = Vec::with_capacity(target_len); - if sign_bit != 0 { // negative + if sign_bit != 0 { extended.resize(target_len - raw.len(), 0xFF); - } else { // positive + } else { extended.resize(target_len - raw.len(), 0x00); } extended.extend_from_slice(raw); extended } -/// Extend raw bytes to 16 bytes (for Decimal128) +/// Extends raw bytes to 16 bytes (for Decimal128). fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { let extended = sign_extend(raw, 16); if extended.len() != 16 { @@ -464,10 +475,12 @@ fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { extended.len() ))); } - Ok(extended.try_into().unwrap()) + let mut arr = [0u8; 16]; + arr.copy_from_slice(&extended); + Ok(arr) } -/// Extend raw bytes to 32 bytes (for Decimal256) +/// Extends raw bytes to 32 bytes (for Decimal256). fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { let extended = sign_extend(raw, 32); if extended.len() != 32 { @@ -476,30 +489,67 @@ fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { extended.len() ))); } - Ok(extended.try_into().unwrap()) + let mut arr = [0u8; 32]; + arr.copy_from_slice(&extended); + Ok(arr) } -/// Trait for building decimal arrays +/// Enum representing the builder for Decimal arrays. +#[derive(Debug)] enum DecimalBuilder { Decimal128(Decimal128Builder), Decimal256(Decimal256Builder), } impl DecimalBuilder { - - fn new(precision: usize, scale: usize, size: Option) -> Result { + /// Initializes a new `DecimalBuilder` based on precision, scale, and size. + fn new( + precision: usize, + scale: Option, + size: Option, + ) -> Result { match size { - Some(s) if s > 16 => { - // decimal256 - Ok(Self::Decimal256(Decimal256Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + Some(s) if s > 16 && s <= 32 => { + // Decimal256 + Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) } - _ => { - // decimal128 - Ok(Self::Decimal128(Decimal128Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + Some(s) if s <= 16 => { + // Decimal128 + Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) } + None => { + // Infer based on precision + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) + } + } + _ => Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {:?}", + size + ))), } } + /// Appends bytes to the decimal builder. fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { match self { DecimalBuilder::Decimal128(b) => { @@ -516,10 +566,46 @@ impl DecimalBuilder { Ok(()) } - fn finish(self) -> Result { + /// Appends a null value to the decimal builder by appending placeholder bytes. + fn append_null(&mut self) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(mut b) => Ok(Arc::new(b.finish())), - DecimalBuilder::Decimal256(mut b) => Ok(Arc::new(b.finish())), + DecimalBuilder::Decimal128(b) => { + // Append zeroed bytes as placeholder + let placeholder = [0u8; 16]; + let value = i128::from_be_bytes(placeholder); + b.append_value(value); + } + DecimalBuilder::Decimal256(b) => { + // Append zeroed bytes as placeholder + let placeholder = [0u8; 32]; + let value = i256::from_be_bytes(placeholder); + b.append_value(value); + } + } + Ok(()) + } + + /// Finalizes the decimal array and returns it as an `ArrayRef`. + fn finish(self, nulls: Option, precision: usize, scale: usize) -> Result { + match self { + DecimalBuilder::Decimal128(mut b) => { + let array = b.finish(); + let values = array.values().clone(); + let decimal_array = Decimal128Array::new( + values, + nulls, + ).with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(decimal_array)) + } + DecimalBuilder::Decimal256(mut b) => { + let array = b.finish(); + let values = array.values().clone(); + let decimal_array = Decimal256Array::new( + values, + nulls, + ).with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(decimal_array)) + } } } } @@ -551,10 +637,12 @@ mod tests { Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, Decimal128Array, Decimal256Array, DictionaryArray, }; - use arrow_array::cast::AsArray; + use arrow_buffer::Buffer; use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + use serde_json::json; + use arrow_array::cast::AsArray; - /// Helper functions for encoding test data + /// Helper functions for encoding test data. fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -615,22 +703,23 @@ mod tests { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode a single map with one entry: {"hello": "world"} // Avro encoding for a map: - // - block_count: 1 (number of entries) - // - keys: "hello" (5 bytes) - // - values: "world" (5 bytes) + // - block_count: 1 (encoded as [2] due to ZigZag) + // - keys: "hello" (encoded with length prefix) + // - values: "world" (encoded with length prefix) let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 + data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // Verify 1 map - assert_eq!(map_arr.value_length(0), 1); // Verify 1 entry in the map + assert_eq!(map_arr.len(), 1); // One map + assert_eq!(map_arr.value_length(0), 1); // One entry in the map let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); // Verify 1 entry in StructArray + assert_eq!(struct_entries.len(), 1); // One entry in StructArray let key = struct_entries .column_by_name("key") .unwrap() @@ -643,7 +732,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert_eq!(key.value(0), "hello"); // Verify Key + assert_eq!(key.value(0), "hello"); // Verify Key assert_eq!(value.value(0), "world"); // Verify Value } @@ -652,17 +741,18 @@ mod tests { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode an empty map // Avro encoding for an empty map: - // - block_count: 0 (no entries) + // - block_count: 0 (encoded as [0] due to ZigZag) let data = encode_avro_long(0); // block_count = 0 decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // Verify 1 map - assert_eq!(map_arr.value_length(0), 0); // Verify 0 entries in the map + assert_eq!(map_arr.len(), 1); // One map + assert_eq!(map_arr.value_length(0), 0); // Zero entries in the map let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 0); // // Verify 0 entries StructArray + assert_eq!(struct_entries.len(), 0); // Zero entries in StructArray let key = struct_entries .column_by_name("key") .unwrap() @@ -710,32 +800,197 @@ mod tests { } #[test] - fn test_decimal_decoding_bytes() { + fn test_decimal_decoding_bytes_with_nulls() { let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); let mut decoder = Decoder::try_new(&dt).unwrap(); - let unscaled_row1: i128 = 1234; // 123.4 - let unscaled_row2: i128 = -1234; // -123.4 - // Note: convert unscaled values to big-endian bytes - let bytes_row1 = unscaled_row1.to_be_bytes(); - let bytes_row2 = unscaled_row2.to_be_bytes(); - // Row1: 1234 => 0x04D2 (2 bytes) - // Row2: -1234 => two's complement of 0x04D2 = 0xFB2E (2 bytes) - let row1_bytes = &bytes_row1[14..16]; // Last 2 bytes - let row2_bytes = &bytes_row2[14..16]; // Last 2 bytes + // Wrap the decimal in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + // Row1: 123.4 => unscaled: 1234 => bytes: [0x04, 0xD2] + // Row2: null + // Row3: -123.4 => unscaled: -1234 => bytes: [0xFB, 0x2E] let mut data = Vec::new(); - // Encode row1 - data.extend_from_slice(&encode_avro_long(2)); // Length=2 - data.extend_from_slice(row1_bytes); // 0x04, 0xD2 - // Encode row2 - data.extend_from_slice(&encode_avro_long(2)); // Length=2 - data.extend_from_slice(row2_bytes); // 0xFB, 0x2E + // Row1: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // 0x04D2 = 1234 + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // 0xFB2E = -1234 let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 2); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 123.4 + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -123.4 + let array = nullable_decoder.flush(None).unwrap(); + let dec_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(1), "-123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); } -} + + #[test] + fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + // Wrap the decimal in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + // Correct Byte Encoding: + // Row1: 1234.56 => unscaled: 123456 => bytes: [0x00; 12] + [0x00, 0x01, 0xE2, 0x40] + // Row2: null + // Row3: -1234.56 => unscaled: -123456 => bytes: [0xFF; 12] + [0xFE, 0x1D, 0xC0, 0x00] + let row1_bytes = &[ + 0x00, 0x00, 0x00, 0x00, // First 4 bytes + 0x00, 0x00, 0x00, 0x00, // Next 4 bytes + 0x00, 0x00, 0x00, 0x01, // Next 4 bytes + 0xE2, 0x40, 0x00, 0x00, // Last 4 bytes + ]; + let row3_bytes = &[ + 0xFF, 0xFF, 0xFF, 0xFF, // First 4 bytes (two's complement) + 0xFF, 0xFF, 0xFF, 0xFF, // Next 4 bytes + 0xFF, 0xFF, 0xFE, 0x1D, // Next 4 bytes + 0xC0, 0x00, 0x00, 0x00, // Last 4 bytes + ]; + + let mut data = Vec::new(); + // Row1: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(row1_bytes); // 1234.56 + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(row3_bytes); // -1234.56 + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 1234.56 + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -1234.56 + + let array = nullable_decoder.flush(None).unwrap(); + let dec_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + + #[test] + fn test_enum_decoding_with_nulls() { + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + + // Wrap the enum in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + + // Encode the indices [1, null, 2] using ZigZag encoding + // Indices: 1 -> [2], null -> no index, 2 -> [4] + let mut data = Vec::new(); + // Row1: valid (1) + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid (2) + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: RED + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: BLUE + + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + let validity = dict_arr.is_valid(0); // Correctly access the null buffer + + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); // Placeholder index for null + assert_eq!(keys.value(2), 2); + + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null + assert!(dict_arr.is_valid(2)); + + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_enum_with_nullable_entries() { + let symbols = vec!["APPLE".to_string(), "BANANA".to_string(), "CHERRY".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + + // Wrap the enum in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + + // Encode the indices [0, null, 2, 1] using ZigZag encoding + let mut data = Vec::new(); + // Row1: valid (0) -> "APPLE" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid (2) -> "CHERRY" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + // Row4: valid (1) -> "BANANA" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: APPLE + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: CHERRY + nullable_decoder.decode(&mut cursor).unwrap(); // Row4: BANANA + + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + + assert_eq!(dict_arr.len(), 4); + let keys = dict_arr.keys(); + let validity = dict_arr.is_valid(0); // Correctly access the null buffer + + assert_eq!(keys.value(0), 0); + assert_eq!(keys.value(1), 0); // Placeholder index for null + assert_eq!(keys.value(2), 2); + assert_eq!(keys.value(3), 1); + + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null + assert!(dict_arr.is_valid(2)); + assert!(dict_arr.is_valid(3)); + + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "APPLE"); + assert_eq!(dict_values.value(1), "BANANA"); + assert_eq!(dict_values.value(2), "CHERRY"); + } +} \ No newline at end of file From 9cfda09a48a73b0c3ae8fa540bb552eaf02da8b1 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 15:05:40 -0600 Subject: [PATCH 6/8] * Reader decoder Support for nullable types. * Implemented reader decoder for Avro Lists * Cleaned up reader/record.rs and added comments for readability. Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 1240 +++++++++++++++++-------------- 1 file changed, 691 insertions(+), 549 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index a29f293107c0..500fe27fd53b 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -22,38 +22,48 @@ use crate::reader::header::Header; use crate::schema::*; use arrow_array::types::*; use arrow_array::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_buffer::*; -use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION}; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; use std::collections::HashMap; use std::io::Read; -use std::ptr::null; use std::sync::Arc; -use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; -/// Decodes avro encoded data into [`RecordBatch`] +/// The default capacity used for internal buffers +const DEFAULT_CAPACITY: usize = 1024; + +/// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. pub struct RecordDecoder { schema: SchemaRef, fields: Vec, } impl RecordDecoder { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] expected to be a `Record`. pub fn try_new(data_type: &AvroDataType) -> Result { match Decoder::try_new(data_type)? { Decoder::Record(fields, encodings) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), fields: encodings, }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } + /// Return the [`SchemaRef`] describing the Arrow schema of rows produced by this decoder. pub fn schema(&self) -> &SchemaRef { &self.schema } - /// Decode `count` records from `buf` + /// Decode `count` Avro records from `buf`. + /// + /// This accumulates data in internal buffers. Once done reading, call + /// [`Self::flush`] to yield an Arrow [`RecordBatch`]. pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { @@ -64,7 +74,7 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the decoded records into a [`RecordBatch`] + /// Flush the accumulated data into a [`RecordBatch`], clearing internal state. pub fn flush(&mut self) -> Result { let arrays = self .fields @@ -76,47 +86,78 @@ impl RecordDecoder { } } -/// Enum representing different decoders for various data types. +/// Decoder for Avro data of various shapes. +/// +/// This is the “internal” representation used by [`RecordDecoder`]. #[derive(Debug)] enum Decoder { + /// Avro `null` Null(usize), + /// Avro `boolean` Boolean(BooleanBufferBuilder), + /// Avro `int` => i32 Int32(Vec), + /// Avro `long` => i64 Int64(Vec), + /// Avro `float` => f32 Float32(Vec), + /// Avro `double` => f64 Float64(Vec), + /// Avro `date` => Date32 Date32(Vec), + /// Avro `time-millis` => Time32(Millisecond) TimeMillis(Vec), + /// Avro `time-micros` => Time64(Microsecond) TimeMicros(Vec), + /// Avro `timestamp-millis` (bool = UTC?) TimestampMillis(bool, Vec), + /// Avro `timestamp-micros` (bool = UTC?) TimestampMicros(bool, Vec), + /// Avro `bytes` => Arrow Binary Binary(OffsetBufferBuilder, Vec), + /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), + /// Avro `array` + /// * `FieldRef` is the arrow field for the list + /// * `OffsetBufferBuilder` holds offsets into the child array + /// * The boxed `Decoder` decodes T itself List(FieldRef, OffsetBufferBuilder, Box), + /// Avro `record` + /// * `Fields` is the Arrow schema of the record + /// * The `Vec` is one decoder per child field Record(Fields, Vec), + /// Avro union that includes `null` => decodes as a single arrow field + a null bit mask Nullable(Nullability, NullBufferBuilder, Box), + /// Avro `enum` => Dictionary(int32 -> string) Enum(Vec, Vec), + /// Avro `map` + /// * The `FieldRef` is the arrow field for the map + /// * `key_offsets`, `map_offsets`: offset builders + /// * `key_data` accumulates the raw UTF8 for keys + /// * `values_decoder_inner` decodes the map’s value type + /// * `current_entry_count` how many (key,value) pairs total seen so far Map( FieldRef, - OffsetBufferBuilder, // key_offsets - OffsetBufferBuilder, // map_offsets - Vec, // key_data - Box, // values_decoder_inner - usize, // current_entry_count + OffsetBufferBuilder, + OffsetBufferBuilder, + Vec, + Box, + usize, ), + /// Avro decimal => Arrow decimal + /// (precision, scale, size, builder) Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable. + /// Checks if the Decoder is nullable, i.e. wrapped in [`Decoder::Nullable`]. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } - /// Creates a new `Decoder` based on the provided `AvroDataType`. + /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); - + let not_implemented = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { Codec::Null => Decoder::Null(0), Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -141,25 +182,25 @@ impl Decoder { Codec::TimestampMicros(is_utc) => { Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), + Codec::Fixed(_) => return not_implemented("decoding Avro fixed-typed data"), + Codec::Interval => return not_implemented("decoding Avro interval"), Codec::List(item) => { - let decoder = Box::new(Self::try_new(item)?); + let item_decoder = Box::new(Self::try_new(item)?); Decoder::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - decoder, + item_decoder, ) } Codec::Struct(fields) => { let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut encodings = Vec::with_capacity(fields.len()); + let mut decoders = Vec::with_capacity(fields.len()); for avro_field in fields.iter() { - let encoding = Self::try_new(avro_field.data_type())?; + let d = Self::try_new(avro_field.data_type())?; arrow_fields.push(avro_field.field()); - encodings.push(encoding); + decoders.push(d); } - Decoder::Record(arrow_fields.into(), encodings) + Decoder::Record(arrow_fields.into(), decoders) } Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Map(value_type) => { @@ -173,11 +214,11 @@ impl Decoder { )); Decoder::Map( map_field, - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets - Vec::with_capacity(DEFAULT_CAPACITY), // key_data - Box::new(Self::try_new(value_type)?), // values_decoder_inner - 0, // current_entry_count + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + Box::new(Self::try_new(value_type)?), + 0, ) } Codec::Decimal(precision, scale, size) => { @@ -185,11 +226,9 @@ impl Decoder { Decoder::Decimal(*precision, *scale, *size, builder) } }; - - // Wrap the decoder in Nullable if necessary match data_type.nullability() { - Some(nullability) => Ok(Decoder::Nullable( - nullability, + Some(nb) => Ok(Decoder::Nullable( + nb, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), )), @@ -197,304 +236,375 @@ impl Decoder { } } - /// Appends a null value to the decoder. + /// Append a null to this decoder. + /// + /// This must keep the “row counts” in sync across child buffers, etc. fn append_null(&mut self) { match self { - Decoder::Null(count) => *count += 1, - Decoder::Boolean(b) => b.append(false), - Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => v.push(0), + Decoder::Null(n) => { + *n += 1; + } + Decoder::Boolean(b) => { + b.append(false); + } + Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => { + v.push(0); + } Decoder::Int64(v) | Decoder::TimeMicros(v) | Decoder::TimestampMillis(_, v) - | Decoder::TimestampMicros(_, v) => v.push(0), - Decoder::Float32(v) => v.push(0.0), - Decoder::Float64(v) => v.push(0.0), - Decoder::Binary(offsets, _) | Decoder::String(offsets, _) => { - offsets.push_length(0); - } - Decoder::List(_, offsets, e) => { - offsets.push_length(0); - e.append_null(); - } - Decoder::Record(_, encodings) => { - for encoding in encodings.iter_mut() { - encoding.append_null(); + | Decoder::TimestampMicros(_, v) => { + v.push(0); + } + Decoder::Float32(v) => { + v.push(0.0); + } + Decoder::Float64(v) => { + v.push(0.0); + } + Decoder::Binary(off, _) | Decoder::String(off, _) => { + off.push_length(0); + } + Decoder::List(_, off, child) => { + off.push_length(0); + child.append_null(); + } + Decoder::Record(_, children) => { + for c in children.iter_mut() { + c.append_null(); } } Decoder::Enum(_, indices) => { - // Append a placeholder index for null entries indices.push(0); } - Decoder::Map( - _, - key_offsets, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - key_offsets.push_length(0); - map_offsets_builder.push_length(*current_entry_count); + Decoder::Map(_, key_off, map_off, _, _, entry_count) => { + key_off.push_length(0); + map_off.push_length(*entry_count); } Decoder::Decimal(_, _, _, builder) => { - builder.append_null(); + let _ = builder.append_null(); } - Decoder::Nullable(_, _, _) => { /* Nulls are handled by the Nullable variant */ } + Decoder::Nullable(_, _, _) => { /* The null mask is handled by the outer decoder */ } } } - /// Decodes a single record from the provided buffer `buf`. + /// Decode a single “row” of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Decoder::Null(x) => *x += 1, - Decoder::Boolean(values) => values.append(buf.get_bool()?), - Decoder::Int32(values) => values.push(buf.get_int()?), - Decoder::Date32(values) => values.push(buf.get_int()?), - Decoder::Int64(values) => values.push(buf.get_long()?), - Decoder::TimeMillis(values) => values.push(buf.get_int()?), - Decoder::TimeMicros(values) => values.push(buf.get_long()?), - Decoder::TimestampMillis(is_utc, values) => { - values.push(buf.get_long()?); - } - Decoder::TimestampMicros(is_utc, values) => { - values.push(buf.get_long()?); - } - Decoder::Float32(values) => values.push(buf.get_float()?), - Decoder::Float64(values) => values.push(buf.get_double()?), - Decoder::Binary(offsets, values) | Decoder::String(offsets, values) => { - let data = buf.get_bytes()?; - offsets.push_length(data.len()); - values.extend_from_slice(data); - } - Decoder::List(_, _, _) => { - return Err(ArrowError::NotYetImplemented( - "Decoding ListArray".to_string(), - )); + Decoder::Null(n) => { + *n += 1; } - Decoder::Record(fields, encodings) => { - for encoding in encodings.iter_mut() { - encoding.decode(buf)?; - } + Decoder::Boolean(vals) => { + vals.append(buf.get_bool()?); } - Decoder::Nullable(_, nulls, e) => { - let is_valid = buf.get_bool()?; - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Decoder::Int32(vals) => { + vals.push(buf.get_int()?); + } + Decoder::Date32(vals) => { + vals.push(buf.get_int()?); + } + Decoder::Int64(vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimeMillis(vals) => { + vals.push(buf.get_int()?); + } + Decoder::TimeMicros(vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimestampMillis(_, vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimestampMicros(_, vals) => { + vals.push(buf.get_long()?); + } + Decoder::Float32(vals) => { + vals.push(buf.get_float()?); + } + Decoder::Float64(vals) => { + vals.push(buf.get_double()?); + } + Decoder::Binary(off, data) | Decoder::String(off, data) => { + let bytes = buf.get_bytes()?; + off.push_length(bytes.len()); + data.extend_from_slice(bytes); + } + Decoder::List(_, off, child) => { + let total_items = read_array_blocks(buf, |b| child.decode(b))?; + off.push_length(total_items); + } + Decoder::Record(_, children) => { + for c in children.iter_mut() { + c.decode(buf)?; } } - Decoder::Enum(symbols, indices) => { - // Enums are encoded as zero-based indices using zigzag encoding - let index = buf.get_int()?; - indices.push(index); - } - Decoder::Map( - field, - key_offsets, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - let block_count = buf.get_long()?; - if block_count <= 0 { - // Push the current_entry_count without changes - map_offsets_builder.push_length(*current_entry_count); - } else { - let n = block_count as usize; - for _ in 0..n { - let key_bytes = buf.get_bytes()?; - key_offsets.push_length(key_bytes.len()); - key_data.extend_from_slice(key_bytes); - values_decoder_inner.decode(buf)?; + Decoder::Nullable(_, null_buf, child) => { + let branch_index = buf.get_int()?; + match branch_index { + 0 => { + // child + null_buf.append(true); + child.decode(buf)?; + } + 1 => { + // null + null_buf.append(false); + child.append_null(); + } + other => { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {other} for Nullable" + ))); } - // Update the current_entry_count and push to map_offsets_builder - *current_entry_count += n; - map_offsets_builder.push_length(*current_entry_count); } } - Decoder::Decimal(_precision, _scale, _size, builder) => { - if let Some(size) = _size { - // Fixed-size decimal - let raw = buf.get_fixed(*size)?; + Decoder::Enum(_, indices) => { + let idx = buf.get_int()?; + indices.push(idx); + } + Decoder::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + let newly_added = read_map_blocks(buf, |b| { + let kb = b.get_bytes()?; + key_off.push_length(kb.len()); + key_data.extend_from_slice(kb); + val_decoder.decode(b) + })?; + *entry_count += newly_added; + map_off.push_length(*entry_count); + } + Decoder::Decimal(_, _, size, builder) => { + if let Some(sz) = *size { + let raw = buf.get_fixed(sz)?; builder.append_bytes(raw)?; } else { - // Variable-size decimal - let bytes = buf.get_bytes()?; - builder.append_bytes(bytes)?; + let variable = buf.get_bytes()?; + builder.append_bytes(variable)?; } } } Ok(()) } - /// Flushes decoded records to an [`ArrayRef`]. + /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { - Decoder::Nullable(_, n, e) => e.flush(n.finish()), - Decoder::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), - Decoder::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), - Decoder::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::TimeMillis(values) => { - Ok(Arc::new(flush_primitive::(values, nulls))) - } - Decoder::TimeMicros(values) => { - Ok(Arc::new(flush_primitive::(values, nulls))) - } - Decoder::TimestampMillis(is_utc, values) => Ok(Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), - )), - Decoder::TimestampMicros(is_utc, values) => Ok(Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), - )), - Decoder::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Binary(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); + Decoder::Nullable(_, nb, child) => { + let mask = nb.finish(); + child.flush(mask) + } + // Null => produce NullArray + Decoder::Null(len) => { + let count = std::mem::replace(len, 0); + Ok(Arc::new(NullArray::new(count))) + } + // boolean => flush to BooleanArray + Decoder::Boolean(b) => { + let bits = b.finish(); + Ok(Arc::new(BooleanArray::new(bits, nulls))) + } + // int32 => flush to Int32Array + Decoder::Int32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // date32 => flush to Date32Array + Decoder::Date32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // int64 => flush to Int64Array + Decoder::Int64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-millis => Time32Millisecond + Decoder::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-micros => Time64Microsecond + Decoder::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // timestamp-millis => TimestampMillisecond + Decoder::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // timestamp-micros => TimestampMicrosecond + Decoder::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // float32 => flush to Float32Array + Decoder::Float32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // float64 => flush to Float64Array + Decoder::Float64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // Avro bytes => BinaryArray + Decoder::Binary(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } - Decoder::String(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); + // Avro string => StringArray + Decoder::String(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - Decoder::List(field, offsets, values) => { - let values = values.flush(None)?; - let offsets = flush_offsets(offsets); - Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) + // Avro array => ListArray + Decoder::List(field, off, item_dec) => { + let child_arr = item_dec.flush(None)?; + let offsets = flush_offsets(off); + let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); + Ok(Arc::new(arr)) } - Decoder::Record(fields, encodings) => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; + // Avro record => StructArray + Decoder::Record(fields, children) => { + let mut arrays = Vec::with_capacity(children.len()); + for c in children.iter_mut() { + let a = c.flush(None)?; + arrays.push(a); + } Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } + // Avro enum => DictionaryArray utf8> Decoder::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); - let indices_array: Int32Array = match nulls { - Some(buf) => { - let buffer = arrow_buffer::Buffer::from_slice_ref(&indices); + let idxs: Int32Array = match nulls { + Some(b) => { + let buff = Buffer::from_slice_ref(&indices); PrimitiveArray::::try_new( - arrow_buffer::ScalarBuffer::from(buffer), - Some(buf.clone()), + arrow_buffer::ScalarBuffer::from(buff), + Some(b), )? } None => Int32Array::from_iter_values(indices.iter().cloned()), }; - let dict_array = DictionaryArray::::try_new( - indices_array, - Arc::new(dict_values), - )?; - Ok(Arc::new(dict_array)) - } - Decoder::Map( - field, - key_offsets_builder, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - let map_offsets = flush_offsets(map_offsets_builder); - let key_offsets = flush_offsets(key_offsets_builder); - let key_data = flush_values(key_data).into(); - let key_array = StringArray::new(key_offsets, key_data, None); - let val_array = values_decoder_inner.flush(None)?; - let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); + let dict = DictionaryArray::::try_new(idxs, Arc::new(dict_values))?; + indices.clear(); // reset + Ok(Arc::new(dict)) + } + // Avro map => MapArray + Decoder::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { + let moff = flush_offsets(map_off); + let koff = flush_offsets(key_off); + let kd = flush_values(key_data).into(); + let val_arr = val_dec.flush(None)?; + let is_nullable = matches!(**val_dec, Decoder::Nullable(_, _, _)); + let key_arr = StringArray::new(koff, kd, None); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), Arc::new(ArrowField::new( "value", - val_array.data_type().clone(), + val_arr.data_type().clone(), is_nullable, )), ]; - let struct_array = StructArray::new( + let entries = StructArray::new( Fields::from(struct_fields), - vec![Arc::new(key_array), val_array], + vec![Arc::new(key_arr), val_arr], None, ); - let map_array = MapArray::new( - field.clone(), - map_offsets.clone(), - struct_array.clone(), - nulls, - false, - ); - Ok(Arc::new(map_array)) - } - Decoder::Decimal(_precision, _scale, _size, builder) => { - let precision = *_precision; - let scale = _scale.unwrap_or(0); // Default scale if None - let size = _size.clone(); - let builder = std::mem::replace( - builder, - DecimalBuilder::new(precision, *_scale, *_size)?, - ); - Ok(builder.finish(nulls, precision, scale)?) // Pass precision and scale + let map_arr = MapArray::new(field.clone(), moff, entries, nulls, false); + *entry_count = 0; + Ok(Arc::new(map_arr)) + } + // Avro decimal => Arrow decimal + Decoder::Decimal(prec, sc, sz, builder) => { + let precision = *prec; + let scale = sc.unwrap_or(0); + let new_builder = DecimalBuilder::new(precision, *sc, *sz)?; + let old_builder = std::mem::replace(builder, new_builder); + let arr = old_builder.finish(nulls, precision, scale)?; + Ok(arr) } - } } } -/// Helper to build a field with a given type -fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { - Arc::new(ArrowField::new(name, dt, nullable)) +/// Helper to decode an Avro array in blocks until a 0 block_count signals end. +/// +/// Each block may be negative, in which case we read an extra “block size” `long`, +/// but typically ignore it unless we want to skip. This function invokes `decode_item` once per item. +fn read_array_blocks( + buf: &mut AvroCursor, + mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let mut total_items = 0usize; + loop { + let block_count = buf.get_long()?; + if block_count == 0 { + break; + } else if block_count < 0 { + let item_count = (-block_count) as usize; + let _block_size = buf.get_long()?; // read but ignore + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } else { + let item_count = block_count as usize; + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } + } + Ok(total_items) } -/// Extends raw bytes to the target length with sign extension. -fn sign_extend(raw: &[u8], target_len: usize) -> Vec { - if raw.is_empty() { - return vec![0; target_len]; - } - let sign_bit = raw[0] & 0x80; - let mut extended = Vec::with_capacity(target_len); - if sign_bit != 0 { - extended.resize(target_len - raw.len(), 0xFF); +/// Helper to decode an Avro map in blocks until a 0 block_count signals end. +/// +/// For each entry in a block, we decode a key (bytes) + a value (`decode_value`). +/// Returns how many map entries were decoded. +fn read_map_blocks( + buf: &mut AvroCursor, + mut decode_value: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let block_count = buf.get_long()?; + if block_count <= 0 { + Ok(0) } else { - extended.resize(target_len - raw.len(), 0x00); + let n = block_count as usize; + for _ in 0..n { + decode_value(buf)?; + } + Ok(n) } - extended.extend_from_slice(raw); - extended } -/// Extends raw bytes to 16 bytes (for Decimal128). -fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { - let extended = sign_extend(raw, 16); - if extended.len() != 16 { - return Err(ArrowError::ParseError(format!( - "Failed to extend bytes to 16 bytes: got {} bytes", - extended.len() - ))); - } - let mut arr = [0u8; 16]; - arr.copy_from_slice(&extended); - Ok(arr) +/// Flush a [`Vec`] of primitive values to a [`PrimitiveArray`], applying optional `nulls`. +#[inline] +fn flush_primitive( + values: &mut Vec, + nulls: Option, +) -> PrimitiveArray { + PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Extends raw bytes to 32 bytes (for Decimal256). -fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { - let extended = sign_extend(raw, 32); - if extended.len() != 32 { - return Err(ArrowError::ParseError(format!( - "Failed to extend bytes to 32 bytes: got {} bytes", - extended.len() - ))); - } - let mut arr = [0u8; 32]; - arr.copy_from_slice(&extended); - Ok(arr) +/// Flush an [`OffsetBufferBuilder`], returning its completed offsets. +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +/// Remove and return the contents of `values`, replacing it with an empty buffer. +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) } -/// Enum representing the builder for Decimal arrays. +/// A builder for Avro decimal, either 128-bit or 256-bit. #[derive(Debug)] enum DecimalBuilder { Decimal128(Decimal128Builder), @@ -502,7 +612,7 @@ enum DecimalBuilder { } impl DecimalBuilder { - /// Initializes a new `DecimalBuilder` based on precision, scale, and size. + /// Create a new DecimalBuilder given precision, scale, and optional byte-size (`fixed`). fn new( precision: usize, scale: Option, @@ -510,30 +620,38 @@ impl DecimalBuilder { ) -> Result { match size { Some(s) if s > 16 && s <= 32 => { - // Decimal256 + // decimal256 Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal256Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } Some(s) if s <= 16 => { - // Decimal128 + // decimal128 Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal128Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } None => { - // Infer based on precision + // infer from precision when fixed size is None if precision <= DECIMAL128_MAX_PRECISION as usize { Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal128Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } else if precision <= DECIMAL256_MAX_PRECISION as usize { Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal256Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } else { Err(ArrowError::ParseError(format!( @@ -549,100 +667,127 @@ impl DecimalBuilder { } } - /// Appends bytes to the decimal builder. - fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { + /// Append sign-extended bytes to this decimal builder + fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(b) => { - let padded = extend_to_16_bytes(bytes)?; - let value = i128::from_be_bytes(padded); - b.append_value(value); + Self::Decimal128(b) => { + let padded = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(padded); + b.append_value(val); } - DecimalBuilder::Decimal256(b) => { - let padded = extend_to_32_bytes(bytes)?; - let value = i256::from_be_bytes(padded); - b.append_value(value); + Self::Decimal256(b) => { + let padded = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(padded); + b.append_value(val); } } Ok(()) } - /// Appends a null value to the decimal builder by appending placeholder bytes. + /// Append a null decimal value (0) fn append_null(&mut self) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(b) => { - // Append zeroed bytes as placeholder - let placeholder = [0u8; 16]; - let value = i128::from_be_bytes(placeholder); - b.append_value(value); - } - DecimalBuilder::Decimal256(b) => { - // Append zeroed bytes as placeholder - let placeholder = [0u8; 32]; - let value = i256::from_be_bytes(placeholder); - b.append_value(value); + Self::Decimal128(b) => { + let zero = [0u8; 16]; + b.append_value(i128::from_be_bytes(zero)); + } + Self::Decimal256(b) => { + let zero = [0u8; 32]; + b.append_value(i256::from_be_bytes(zero)); } } Ok(()) } - /// Finalizes the decimal array and returns it as an `ArrayRef`. - fn finish(self, nulls: Option, precision: usize, scale: usize) -> Result { + /// Finish building this decimal array, returning an [`ArrayRef`]. + fn finish( + self, + nulls: Option, + precision: usize, + scale: usize, + ) -> Result { match self { - DecimalBuilder::Decimal128(mut b) => { - let array = b.finish(); - let values = array.values().clone(); - let decimal_array = Decimal128Array::new( - values, - nulls, - ).with_precision_and_scale(precision as u8, scale as i8)?; - Ok(Arc::new(decimal_array)) - } - DecimalBuilder::Decimal256(mut b) => { - let array = b.finish(); - let values = array.values().clone(); - let decimal_array = Decimal256Array::new( - values, - nulls, - ).with_precision_and_scale(precision as u8, scale as i8)?; - Ok(Arc::new(decimal_array)) + Self::Decimal128(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal128Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + Self::Decimal256(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal256Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) } } } } -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +/// Sign-extend `raw` to 16 bytes. +fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let extended = sign_extend(raw, 16); + if extended.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 16 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 16]; + arr.copy_from_slice(&extended); + Ok(arr) } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +/// Sign-extend `raw` to 32 bytes. +fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let extended = sign_extend(raw, 32); + if extended.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 32 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&extended); + Ok(arr) } -#[inline] -fn flush_primitive( - values: &mut Vec, - nulls: Option, -) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) +/// Sign-extend the first byte to produce `target_len` bytes total. +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut out = Vec::with_capacity(target_len); + if sign_bit != 0 { + out.resize(target_len - raw.len(), 0xFF); + } else { + out.resize(target_len - raw.len(), 0x00); + } + out.extend_from_slice(raw); + out } -const DEFAULT_CAPACITY: usize = 1024; +/// Convenience helper to build a field with `name`, `DataType` and `nullable`. +fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(ArrowField::new(name, dt, nullable)) +} #[cfg(test)] mod tests { use super::*; use arrow_array::{ - Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, - Decimal128Array, Decimal256Array, DictionaryArray, + cast::AsArray, Array, ArrayRef, Decimal128Array, Decimal256Array, DictionaryArray, + Int32Array, ListArray, MapArray, StringArray, StructArray, }; use arrow_buffer::Buffer; - use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; use serde_json::json; - use arrow_array::cast::AsArray; - /// Helper functions for encoding test data. + // ------------------- + // Zig-Zag Encoding Helper Functions + // ------------------- fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -671,20 +816,23 @@ mod tests { buf } + // ------------------- + // Tests for Enum + // ------------------- #[test] fn test_enum_decoding() { let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - // Encode the indices [1, 0, 2] using zigzag encoding + // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + data.extend_from_slice(&encode_avro_int(1)); // => [2] + data.extend_from_slice(&encode_avro_int(0)); // => [0] + data.extend_from_slice(&encode_avro_int(2)); // => [4] let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); // => GREEN + decoder.decode(&mut cursor).unwrap(); // => RED + decoder.decode(&mut cursor).unwrap(); // => BLUE let array = decoder.flush(None).unwrap(); let dict_arr = array.as_any().downcast_ref::>().unwrap(); assert_eq!(dict_arr.len(), 3); @@ -698,187 +846,208 @@ mod tests { assert_eq!(dict_values.value(2), "BLUE"); } + #[test] + fn test_enum_decoding_with_nulls() { + // Union => [Enum(...), null] + // "child" => branch_index=0 => [0x00], "null" => 1 => [0x02] + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner_decoder), + ); + // Indices: [1, null, 2] => in Avro union + let mut data = Vec::new(); + // Row1 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=1 => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row2 => union branch=1 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=2 => [0x04] + data.extend_from_slice(&encode_avro_int(2)); + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // => GREEN + nullable_decoder.decode(&mut cursor).unwrap(); // => null + nullable_decoder.decode(&mut cursor).unwrap(); // => BLUE + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + assert_eq!(dict_arr.len(), 3); + // [GREEN, null, BLUE] + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); + assert!(dict_arr.is_valid(2)); + let keys = dict_arr.keys(); + // keys.value(0) => 1 => GREEN + // keys.value(2) => 2 => BLUE + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + // ------------------- + // Tests for Map + // ------------------- #[test] fn test_map_decoding_one_entry() { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); // Encode a single map with one entry: {"hello": "world"} - // Avro encoding for a map: - // - block_count: 1 (encoded as [2] due to ZigZag) - // - keys: "hello" (encoded with length prefix) - // - values: "world" (encoded with length prefix) let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 - data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" - data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" - decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + // block_count=1 => zigzag => [0x02] + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key + data.extend_from_slice(&encode_avro_bytes(b"world")); // value + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // One map - assert_eq!(map_arr.value_length(0), 1); // One entry in the map + assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.value_length(0), 1); let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); // One entry in StructArray - let key = struct_entries + assert_eq!(struct_entries.len(), 1); + let key_arr = struct_entries .column_by_name("key") .unwrap() .as_any() .downcast_ref::() .unwrap(); - let value = struct_entries + let val_arr = struct_entries .column_by_name("value") .unwrap() .as_any() .downcast_ref::() .unwrap(); - assert_eq!(key.value(0), "hello"); // Verify Key - assert_eq!(value.value(0), "world"); // Verify Value + assert_eq!(key_arr.value(0), "hello"); + assert_eq!(val_arr.value(0), "world"); } #[test] fn test_map_decoding_empty() { + // block_count=0 => empty map let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode an empty map - // Avro encoding for an empty map: - // - block_count: 0 (encoded as [0] due to ZigZag) - let data = encode_avro_long(0); // block_count = 0 + // Encode an empty map => block_count=0 => [0x00] + let data = encode_avro_long(0); decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // One map - assert_eq!(map_arr.value_length(0), 0); // Zero entries in the map - let entries = map_arr.value(0); - let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 0); // Zero entries in StructArray - let key = struct_entries - .column_by_name("key") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let value = struct_entries - .column_by_name("value") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(key.len(), 0); - assert_eq!(value.len(), 0); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 0); } + // ------------------- + // Tests for Decimal + // ------------------- #[test] fn test_decimal_decoding_fixed128() { let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); let mut decoder = Decoder::try_new(&dt).unwrap(); - // Row1: 123.45 => unscaled: 12345 => i128: 0x00000000000000000000000000003039 - // Row2: -1.23 => unscaled: -123 => i128: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF85 + // Row1 => 123.45 => unscaled=12345 => i128 0x000...3039 + // Row2 => -1.23 => unscaled=-123 => i128 0xFFFF...FF85 let row1 = [ - 0x00, 0x00, 0x00, 0x00, // First 8 bytes - 0x00, 0x00, 0x00, 0x00, // Next 8 bytes - 0x00, 0x00, 0x00, 0x00, // Next 8 bytes - 0x00, 0x00, 0x30, 0x39, // Last 8 bytes: 0x3039 = 12345 + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, ]; let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, // First 8 bytes (two's complement) - 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes - 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes - 0xFF, 0xFF, 0xFF, 0x85, // Last 8 bytes: 0xFFFFFF85 = -123 + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, ]; + let mut data = Vec::new(); data.extend_from_slice(&row1); data.extend_from_slice(&row2); - decoder.decode(&mut AvroCursor::new(&data)).unwrap(); - decoder.decode(&mut AvroCursor::new(&data[16..])).unwrap(); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 2); - assert_eq!(dec_arr.value_as_string(0), "123.45"); - assert_eq!(dec_arr.value_as_string(1), "-1.23"); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); } #[test] fn test_decimal_decoding_bytes_with_nulls() { + // Avro union => [ Decimal(4,1), null ] + // child => index=0 => [0x00], null => index=1 => [0x02] let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); - let mut decoder = Decoder::try_new(&dt).unwrap(); - // Wrap the decimal in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( Nullability::NullFirst, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(inner), ); - // Row1: 123.4 => unscaled: 1234 => bytes: [0x04, 0xD2] - // Row2: null - // Row3: -123.4 => unscaled: -1234 => bytes: [0xFB, 0x2E] + // Decode three rows: [123.4, null, -123.4] let mut data = Vec::new(); - // Row1: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // 0x04D2 = 1234 - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // 0xFB2E = -1234 + // Row1 => child => [0x00], then decimal => e.g. 0x04D2 => 1234 => "123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00], then decimal => 0xFB2E => -1234 => "-123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 123.4 - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -123.4 - let array = nullable_decoder.flush(None).unwrap(); - let dec_arr = array.as_any().downcast_ref::().unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.is_valid(0), true); + assert_eq!(dec_arr.is_valid(1), false); + assert_eq!(dec_arr.is_valid(2), true); assert_eq!(dec_arr.value_as_string(0), "123.4"); assert_eq!(dec_arr.value_as_string(2), "-123.4"); } #[test] fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + // Avro union => [Decimal(6,2,16), null] let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); - let mut decoder = Decoder::try_new(&dt).unwrap(); - // Wrap the decimal in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( Nullability::NullFirst, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(inner), ); - // Correct Byte Encoding: - // Row1: 1234.56 => unscaled: 123456 => bytes: [0x00; 12] + [0x00, 0x01, 0xE2, 0x40] - // Row2: null - // Row3: -1234.56 => unscaled: -123456 => bytes: [0xFF; 12] + [0xFE, 0x1D, 0xC0, 0x00] - let row1_bytes = &[ - 0x00, 0x00, 0x00, 0x00, // First 4 bytes - 0x00, 0x00, 0x00, 0x00, // Next 4 bytes - 0x00, 0x00, 0x00, 0x01, // Next 4 bytes - 0xE2, 0x40, 0x00, 0x00, // Last 4 bytes + // Decode [1234.56, null, -1234.56] + let row1 = [ + 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, + 0x00,0x00,0x00,0x00, 0x00,0x01,0xE2,0x40 ]; - let row3_bytes = &[ - 0xFF, 0xFF, 0xFF, 0xFF, // First 4 bytes (two's complement) - 0xFF, 0xFF, 0xFF, 0xFF, // Next 4 bytes - 0xFF, 0xFF, 0xFE, 0x1D, // Next 4 bytes - 0xC0, 0x00, 0x00, 0x00, // Last 4 bytes + let row3 = [ + 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFF,0xFF,0xFF, + 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFE,0x1D,0xC0 ]; - let mut data = Vec::new(); - // Row1: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(row1_bytes); // 1234.56 - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(row3_bytes); // -1234.56 - + // Row1 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 1234.56 - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -1234.56 - - let array = nullable_decoder.flush(None).unwrap(); - let dec_arr = array.as_any().downcast_ref::().unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); assert_eq!(dec_arr.len(), 3); assert!(dec_arr.is_valid(0)); assert!(!dec_arr.is_valid(1)); @@ -887,110 +1056,83 @@ mod tests { assert_eq!(dec_arr.value_as_string(2), "-1234.56"); } + // ------------------- + // Tests for List + // ------------------- #[test] - fn test_enum_decoding_with_nulls() { - let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); - let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - - // Wrap the enum in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( - Nullability::NullFirst, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), - ); - - // Encode the indices [1, null, 2] using ZigZag encoding - // Indices: 1 -> [2], null -> no index, 2 -> [4] - let mut data = Vec::new(); - // Row1: valid (1) - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid (2) - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] - - let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: RED - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: BLUE - - let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); - - assert_eq!(dict_arr.len(), 3); - let keys = dict_arr.keys(); - let validity = dict_arr.is_valid(0); // Correctly access the null buffer - - assert_eq!(keys.value(0), 1); - assert_eq!(keys.value(1), 0); // Placeholder index for null - assert_eq!(keys.value(2), 2); - - assert!(dict_arr.is_valid(0)); - assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null - assert!(dict_arr.is_valid(2)); - - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "RED"); - assert_eq!(dict_values.value(1), "GREEN"); - assert_eq!(dict_values.value(2), "BLUE"); + fn test_list_decoding() { + // Avro array => block1(count=2), item1, item2, block2(count=0 => end) + // + // 1. Create 2 rows: + // Row1 => [10, 20] + // Row2 => [ ] + // + // 2. flush => should yield 2-element array => first row has 2 items, second row has 0 items + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // Row1 => block_count=2 => item=10 => item=20 => block_count=0 => end + // - 2 => zigzag => [0x04] + // - item=10 => zigzag => [0x14] + // - item=20 => zigzag => [0x28] + // - 0 => [0x00] + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(2)); // block_count=2 + row1.extend_from_slice(&encode_avro_int(10)); // item=10 + row1.extend_from_slice(&encode_avro_int(20)); // item=20 + row1.extend_from_slice(&encode_avro_long(0)); // end of array + // Row2 => block_count=0 => empty array + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&row1); + decoder.decode(&mut cursor).unwrap(); + let mut cursor2 = AvroCursor::new(&row2); + decoder.decode(&mut cursor2).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 2); + // row0 => 2 items => [10, 20] + // row1 => 0 items + let offsets = list_arr.value_offsets(); + assert_eq!(offsets, &[0, 2, 2]); + let values = list_arr.values(); + let int_arr = values.as_primitive::(); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 10); + assert_eq!(int_arr.value(1), 20); } #[test] - fn test_enum_with_nullable_entries() { - let symbols = vec!["APPLE".to_string(), "BANANA".to_string(), "CHERRY".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); - let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - - // Wrap the enum in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( - Nullability::NullFirst, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), - ); - - // Encode the indices [0, null, 2, 1] using ZigZag encoding - let mut data = Vec::new(); - // Row1: valid (0) -> "APPLE" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid (2) -> "CHERRY" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] - // Row4: valid (1) -> "BANANA" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - + fn test_list_decoding_with_negative_block_count() { + // Start with single row => [1, 2, 3] + // We'll store them in a single negative block => block_count=-3 => #items=3 + // Then read block_size => let's pretend it's 9 bytes, etc. Then the items. + // Then a block_count=0 => done + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // block_count=-3 => zigzag => (-3 << 1) ^ (-3 >> 63) + // => -6 ^ -1 => ... + // Encode directly with `encode_avro_long(-3)`. + let mut data = encode_avro_long(-3); + // Next => block_size => let's pretend 12 => encode_avro_long(12) + data.extend_from_slice(&encode_avro_long(12)); + // Then 3 items => [1, 2, 3] + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(3)); + // Then block_count=0 => done + data.extend_from_slice(&encode_avro_long(0)); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: APPLE - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: CHERRY - nullable_decoder.decode(&mut cursor).unwrap(); // Row4: BANANA - - let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); - - assert_eq!(dict_arr.len(), 4); - let keys = dict_arr.keys(); - let validity = dict_arr.is_valid(0); // Correctly access the null buffer - - assert_eq!(keys.value(0), 0); - assert_eq!(keys.value(1), 0); // Placeholder index for null - assert_eq!(keys.value(2), 2); - assert_eq!(keys.value(3), 1); - - assert!(dict_arr.is_valid(0)); - assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null - assert!(dict_arr.is_valid(2)); - assert!(dict_arr.is_valid(3)); - - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "APPLE"); - assert_eq!(dict_values.value(1), "BANANA"); - assert_eq!(dict_values.value(2), "CHERRY"); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 1); + assert_eq!(list_arr.value_length(0), 3); + let values = list_arr.values().as_primitive::(); + assert_eq!(values.len(), 3); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + assert_eq!(values.value(2), 3); } -} \ No newline at end of file +} From 84ffb62c6333479effe2f08ad92c5a14103a24f2 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 15:23:35 -0600 Subject: [PATCH 7/8] * Minor Cleanup Signed-off-by: Connor Sanders --- arrow-avro/src/reader/cursor.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index ba1d01f72d7e..9e38a78c63ec 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -71,7 +71,6 @@ impl<'a> AvroCursor<'a> { let val: u32 = varint .try_into() .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; - // Zig-zag decode Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } @@ -79,7 +78,6 @@ impl<'a> AvroCursor<'a> { #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; - // Zig-zag decode Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } From 8600680df4f75b5a937a5fc0ffae2f743705379b Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 17:26:36 -0600 Subject: [PATCH 8/8] * Added record decoder support for the following types: - Fixed - Interval Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 381 +++++++++++++++++++++++--------- 1 file changed, 278 insertions(+), 103 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 500fe27fd53b..87ae7e2426a5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -20,13 +20,13 @@ use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::reader::header::Header; use crate::schema::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; -use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, - TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, + SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use std::collections::HashMap; use std::io::Read; @@ -87,8 +87,6 @@ impl RecordDecoder { } /// Decoder for Avro data of various shapes. -/// -/// This is the “internal” representation used by [`RecordDecoder`]. #[derive(Debug)] enum Decoder { /// Avro `null` @@ -117,25 +115,19 @@ enum Decoder { Binary(OffsetBufferBuilder, Vec), /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), + /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` + Fixed(i32, Vec), + /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) + Interval(Vec), /// Avro `array` - /// * `FieldRef` is the arrow field for the list - /// * `OffsetBufferBuilder` holds offsets into the child array - /// * The boxed `Decoder` decodes T itself List(FieldRef, OffsetBufferBuilder, Box), /// Avro `record` - /// * `Fields` is the Arrow schema of the record - /// * The `Vec` is one decoder per child field Record(Fields, Vec), - /// Avro union that includes `null` => decodes as a single arrow field + a null bit mask + /// Avro union that includes `null` Nullable(Nullability, NullBufferBuilder, Box), /// Avro `enum` => Dictionary(int32 -> string) Enum(Vec, Vec), /// Avro `map` - /// * The `FieldRef` is the arrow field for the map - /// * `key_offsets`, `map_offsets`: offset builders - /// * `key_data` accumulates the raw UTF8 for keys - /// * `values_decoder_inner` decodes the map’s value type - /// * `current_entry_count` how many (key,value) pairs total seen so far Map( FieldRef, OffsetBufferBuilder, @@ -145,19 +137,17 @@ enum Decoder { usize, ), /// Avro decimal => Arrow decimal - /// (precision, scale, size, builder) Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable, i.e. wrapped in [`Decoder::Nullable`]. + /// Checks if the Decoder is nullable, i.e. wrapped in `Nullable`. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { - let not_implemented = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { Codec::Null => Decoder::Null(0), Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -182,8 +172,8 @@ impl Decoder { Codec::TimestampMicros(is_utc) => { Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return not_implemented("decoding Avro fixed-typed data"), - Codec::Interval => return not_implemented("decoding Avro interval"), + Codec::Fixed(n) => Decoder::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Interval => Decoder::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::List(item) => { let item_decoder = Box::new(Self::try_new(item)?); Decoder::List( @@ -192,17 +182,19 @@ impl Decoder { item_decoder, ) } - Codec::Struct(fields) => { - let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut decoders = Vec::with_capacity(fields.len()); - for avro_field in fields.iter() { + Codec::Struct(avro_fields) => { + let mut arrow_fields = Vec::with_capacity(avro_fields.len()); + let mut decoders = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { let d = Self::try_new(avro_field.data_type())?; arrow_fields.push(avro_field.field()); decoders.push(d); } Decoder::Record(arrow_fields.into(), decoders) } - Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Enum(symbols) => { + Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) + } Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( "entries", @@ -226,6 +218,8 @@ impl Decoder { Decoder::Decimal(*precision, *scale, *size, builder) } }; + + // Wrap in Nullable if needed match data_type.nullability() { Some(nb) => Ok(Decoder::Nullable( nb, @@ -237,8 +231,6 @@ impl Decoder { } /// Append a null to this decoder. - /// - /// This must keep the “row counts” in sync across child buffers, etc. fn append_null(&mut self) { match self { Decoder::Null(n) => { @@ -265,6 +257,19 @@ impl Decoder { Decoder::Binary(off, _) | Decoder::String(off, _) => { off.push_length(0); } + Decoder::Fixed(fsize, buf) => { + // For a null, push `fsize` zeroed bytes + let n = *fsize as usize; + buf.extend(std::iter::repeat(0u8).take(n)); + } + Decoder::Interval(intervals) => { + // null => store a 12-byte zero => months=0, days=0, nanos=0 + intervals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); + } Decoder::List(_, off, child) => { off.push_length(0); child.append_null(); @@ -277,58 +282,82 @@ impl Decoder { Decoder::Enum(_, indices) => { indices.push(0); } - Decoder::Map(_, key_off, map_off, _, _, entry_count) => { + Decoder::Map( + _, + key_off, + map_off, + _, + _, + entry_count, + ) => { key_off.push_length(0); map_off.push_length(*entry_count); } Decoder::Decimal(_, _, _, builder) => { let _ = builder.append_null(); } - Decoder::Nullable(_, _, _) => { /* The null mask is handled by the outer decoder */ } + Decoder::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } } } - /// Decode a single “row” of data from `buf`. + /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Decoder::Null(n) => { - *n += 1; + Decoder::Null(count) => { + *count += 1; } - Decoder::Boolean(vals) => { - vals.append(buf.get_bool()?); + Decoder::Boolean(values) => { + values.append(buf.get_bool()?); } - Decoder::Int32(vals) => { - vals.push(buf.get_int()?); + Decoder::Int32(values) => { + values.push(buf.get_int()?); } - Decoder::Date32(vals) => { - vals.push(buf.get_int()?); + Decoder::Date32(values) => { + values.push(buf.get_int()?); } - Decoder::Int64(vals) => { - vals.push(buf.get_long()?); + Decoder::Int64(values) => { + values.push(buf.get_long()?); } - Decoder::TimeMillis(vals) => { - vals.push(buf.get_int()?); + Decoder::TimeMillis(values) => { + values.push(buf.get_int()?); } - Decoder::TimeMicros(vals) => { - vals.push(buf.get_long()?); + Decoder::TimeMicros(values) => { + values.push(buf.get_long()?); } - Decoder::TimestampMillis(_, vals) => { - vals.push(buf.get_long()?); + Decoder::TimestampMillis(_, values) => { + values.push(buf.get_long()?); } - Decoder::TimestampMicros(_, vals) => { - vals.push(buf.get_long()?); + Decoder::TimestampMicros(_, values) => { + values.push(buf.get_long()?); } - Decoder::Float32(vals) => { - vals.push(buf.get_float()?); + Decoder::Float32(values) => { + values.push(buf.get_float()?); } - Decoder::Float64(vals) => { - vals.push(buf.get_double()?); + Decoder::Float64(values) => { + values.push(buf.get_double()?); } Decoder::Binary(off, data) | Decoder::String(off, data) => { let bytes = buf.get_bytes()?; off.push_length(bytes.len()); data.extend_from_slice(bytes); } + Decoder::Fixed(fsize, accum) => { + let raw = buf.get_fixed(*fsize as usize)?; + accum.extend_from_slice(raw); + } + Decoder::Interval(intervals) => { + let raw = buf.get_fixed(12)?; + let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); + let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); + let nanos = millis as i64 * 1_000_000; + let val = IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }; + intervals.push(val); + } Decoder::List(_, off, child) => { let total_items = read_array_blocks(buf, |b| child.decode(b))?; off.push_length(total_items); @@ -338,17 +367,15 @@ impl Decoder { c.decode(buf)?; } } - Decoder::Nullable(_, null_buf, child) => { + Decoder::Nullable(_, nulls, child) => { let branch_index = buf.get_int()?; match branch_index { 0 => { - // child - null_buf.append(true); + nulls.append(true); child.decode(buf)?; } 1 => { - // null - null_buf.append(false); + nulls.append(false); child.append_null(); } other => { @@ -388,6 +415,7 @@ impl Decoder { /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { + // For a nullable wrapper => flush the child with the built null buffer Decoder::Nullable(_, nb, child) => { let mask = nb.finish(); child.flush(mask) @@ -461,6 +489,32 @@ impl Decoder { let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } + // Avro fixed => FixedSizeBinaryArray + Decoder::Fixed(fsize, raw) => { + let size = *fsize; + let buf: Buffer = flush_values(raw).into(); + let total_len = buf.len() / (size as usize); + let array = FixedSizeBinaryArray::try_new(size, buf, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(array)) + } + // Avro interval => IntervalMonthDayNanoType + Decoder::Interval(vals) => { + let data_len = vals.len(); + let mut builder = PrimitiveBuilder::::with_capacity(data_len); + for v in vals.drain(..) { + builder.append_value(v); + } + let arr = builder.finish().with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + // "merge" the newly built array with the nulls + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = unsafe { arr_data.build_unchecked() }; + Ok(Arc::new(PrimitiveArray::::from(arr_data))) + } else { + Ok(Arc::new(arr)) + } + } // Avro array => ListArray Decoder::List(field, off, item_dec) => { let child_arr = item_dec.flush(None)?; @@ -532,10 +586,7 @@ impl Decoder { } } -/// Helper to decode an Avro array in blocks until a 0 block_count signals end. -/// -/// Each block may be negative, in which case we read an extra “block size” `long`, -/// but typically ignore it unless we want to skip. This function invokes `decode_item` once per item. +/// Decode an Avro array in blocks until a 0 block_count signals end. fn read_array_blocks( buf: &mut AvroCursor, mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, @@ -547,7 +598,7 @@ fn read_array_blocks( break; } else if block_count < 0 { let item_count = (-block_count) as usize; - let _block_size = buf.get_long()?; // read but ignore + let _block_size = buf.get_long()?; // “block size” is read but not used for _ in 0..item_count { decode_item(buf)?; } @@ -563,13 +614,10 @@ fn read_array_blocks( Ok(total_items) } -/// Helper to decode an Avro map in blocks until a 0 block_count signals end. -/// -/// For each entry in a block, we decode a key (bytes) + a value (`decode_value`). -/// Returns how many map entries were decoded. +/// Decode an Avro map in blocks until 0 block_count => end. fn read_map_blocks( buf: &mut AvroCursor, - mut decode_value: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { let block_count = buf.get_long()?; if block_count <= 0 { @@ -577,7 +625,7 @@ fn read_map_blocks( } else { let n = block_count as usize; for _ in 0..n { - decode_value(buf)?; + decode_entry(buf)?; } Ok(n) } @@ -592,13 +640,13 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Flush an [`OffsetBufferBuilder`], returning its completed offsets. +/// Flush an [`OffsetBufferBuilder`]. #[inline] fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() } -/// Remove and return the contents of `values`, replacing it with an empty buffer. +/// Take ownership of `values`. #[inline] fn flush_values(values: &mut Vec) -> Vec { std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) @@ -619,39 +667,25 @@ impl DecimalBuilder { size: Option, ) -> Result { match size { - Some(s) if s > 16 && s <= 32 => { - // decimal256 - Ok(Self::Decimal256( - Decimal256Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, - )) - } - Some(s) if s <= 16 => { - // decimal128 - Ok(Self::Decimal128( - Decimal128Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, - )) - } + Some(s) if s > 16 && s <= 32 => Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), + Some(s) if s <= 16 => Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), None => { - // infer from precision when fixed size is None + // infer from precision if precision <= DECIMAL128_MAX_PRECISION as usize { Ok(Self::Decimal128( - Decimal128Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, )) } else if precision <= DECIMAL256_MAX_PRECISION as usize { Ok(Self::Decimal256( - Decimal256Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, )) } else { Err(ArrowError::ParseError(format!( @@ -699,7 +733,7 @@ impl DecimalBuilder { Ok(()) } - /// Finish building this decimal array, returning an [`ArrayRef`]. + /// Finish building the decimal array, returning an [`ArrayRef`]. fn finish( self, nulls: Option, @@ -779,15 +813,17 @@ mod tests { use super::*; use arrow_array::{ cast::AsArray, Array, ArrayRef, Decimal128Array, Decimal256Array, DictionaryArray, - Int32Array, ListArray, MapArray, StringArray, StructArray, + FixedSizeBinaryArray, Int32Array, IntervalMonthDayNanoArray, ListArray, MapArray, + StringArray, StructArray, }; use arrow_buffer::Buffer; use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; use serde_json::json; + use std::iter; - // ------------------- - // Zig-Zag Encoding Helper Functions - // ------------------- + // --------------- + // Zig-Zag Helpers + // --------------- fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -816,6 +852,145 @@ mod tests { buf } + // ----------------- + // Test Fixed + // ----------------- + #[test] + fn test_fixed_decoding() { + // `fixed(4)` => Arrow FixedSizeBinary(4) + let dt = AvroDataType::from_codec(Codec::Fixed(4)); + let mut dec = Decoder::try_new(&dt).unwrap(); + // 2 rows, each row => 4 bytes + let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; + let row2 = [0x01, 0x23, 0x45, 0x67]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 2); + assert_eq!(fsb.value_length(), 4); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(1), row2); + } + + #[test] + fn test_fixed_with_nulls() { + // Avro union => [ fixed(2), null] + let dt = AvroDataType::from_codec(Codec::Fixed(2)); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // Decode 3 rows: row1 => branch=0 => [0x00], then 2 bytes + // row2 => branch=1 => null => [0x02] + // row3 => branch=0 => 2 bytes + let row1 = [0x11, 0x22]; + let row3 = [0x55, 0x66]; + let mut data = Vec::new(); + // row1 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // row2 => union=1 => null + data.extend_from_slice(&encode_avro_int(1)); + // row3 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + dec.decode(&mut cursor).unwrap(); // row3 + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 3); + assert!(fsb.is_valid(0)); + assert!(!fsb.is_valid(1)); + assert!(fsb.is_valid(2)); + assert_eq!(fsb.value_length(), 2); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(2), row3); + } + + // ----------------- + // Test Interval + // ----------------- + #[test] + fn test_interval_decoding() { + // Avro interval => 12 bytes => [ months i32, days i32, ms i32 ] + // decode 2 rows => row1 => months=1, days=2, ms=100 => row2 => months=-1, days=10, ms=9999 + let dt = AvroDataType::from_codec(Codec::Interval); + let mut dec = Decoder::try_new(&dt).unwrap(); + // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 + // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 + let row1 = [0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00]; + let row2 = [0xFF, 0xFF, 0xFF, 0xFF, + 0x0A, 0x00, 0x00, 0x00, + 0x0F, 0x27, 0x00, 0x00]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + // row0 => months=1, days=2, ms=100 => nanos=100_000_000 + // row1 => months=-1, days=10, ms=9999 => nanos=9999_000_000 + let val0 = intervals.value(0); + assert_eq!(val0.months, 1); + assert_eq!(val0.days, 2); + assert_eq!(val0.nanoseconds, 100_000_000); + let val1 = intervals.value(1); + assert_eq!(val1.months, -1); + assert_eq!(val1.days, 10); + assert_eq!(val1.nanoseconds, 9_999_000_000); + } + + #[test] + fn test_interval_decoding_with_nulls() { + // Avro union => [ interval, null] + let dt = AvroDataType::from_codec(Codec::Interval); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // We'll decode 2 rows: row1 => interval => months=2, days=3, ms=500 => row2 => null + // row1 => union=0 => child => 12 bytes + // row2 => union=1 => null => no data + let row1 = [0x02, 0x00, 0x00, 0x00, // months=2 + 0x03, 0x00, 0x00, 0x00, // days=3 + 0xF4, 0x01, 0x00, 0x00]; // ms=500 => nanos=500_000_000 + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); // union=0 => child + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); // union=1 => null + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + let arr = dec.flush(None).unwrap(); + let intervals = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(intervals.len(), 2); + assert!(intervals.is_valid(0)); + assert!(!intervals.is_valid(1)); + let val0 = intervals.value(0); + assert_eq!(val0.months, 2); + assert_eq!(val0.days, 3); + assert_eq!(val0.nanoseconds, 500_000_000); + } + // ------------------- // Tests for Enum // -------------------