From 6a51971d630f8c303cd1460669958c86a2d97286 Mon Sep 17 00:00:00 2001 From: Adrian Ehrsam <5270024+aersam@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:24:53 +0100 Subject: [PATCH] fix: schema evolution not coercing with large arrow types (#2305) # Description This fixes schema merge for large arrow types in combination with dictionary types. Basically we just allow merging any arrow data type that is the same delta type # Related Issue(s) Fixes #2298 # Documentation --- crates/core/src/kernel/arrow/mod.rs | 1 + crates/core/src/operations/cast.rs | 192 +++++++++++++++++++++---- crates/core/src/operations/write.rs | 6 +- crates/core/src/writer/record_batch.rs | 8 +- python/tests/test_writer.py | 2 +- 5 files changed, 171 insertions(+), 38 deletions(-) diff --git a/crates/core/src/kernel/arrow/mod.rs b/crates/core/src/kernel/arrow/mod.rs index d27bc6463b..45d6432e1d 100644 --- a/crates/core/src/kernel/arrow/mod.rs +++ b/crates/core/src/kernel/arrow/mod.rs @@ -261,6 +261,7 @@ impl TryFrom<&ArrowDataType> for DataType { panic!("DataType::Map should contain a struct field child"); } } + ArrowDataType::Dictionary(_, value_type) => Ok(value_type.as_ref().try_into()?), s => Err(ArrowError::SchemaError(format!( "Invalid data type for Delta Lake: {s}" ))), diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index 33155dedd8..f92b3c646e 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -1,53 +1,66 @@ //! Provide common cast functionality for callers //! -use arrow::datatypes::DataType::Dictionary; +use crate::kernel::{ + ArrayType, DataType as DeltaDataType, MapType, MetadataValue, StructField, StructType, +}; use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema, - SchemaRef as ArrowSchemaRef, -}; +use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; +use std::collections::HashMap; use std::sync::Arc; use crate::DeltaResult; -pub(crate) fn merge_field(left: &ArrowField, right: &ArrowField) -> Result { - if let Dictionary(_, value_type) = right.data_type() { - if value_type.equals_datatype(left.data_type()) { - return Ok(left.clone()); - } - } - if let Dictionary(_, value_type) = left.data_type() { - if value_type.equals_datatype(right.data_type()) { - return Ok(right.clone()); +fn try_merge_metadata( + left: &mut HashMap, + right: &HashMap, +) -> Result<(), ArrowError> { + for (k, v) in right { + if let Some(vl) = left.get(k) { + if vl != v { + return Err(ArrowError::SchemaError(format!( + "Cannot merge metadata with different values for key {}", + k + ))); + } + } else { + left.insert(k.clone(), v.clone()); } } - let mut new_field = left.clone(); - new_field.try_merge(right)?; - Ok(new_field) + Ok(()) } -pub(crate) fn merge_schema( - left: ArrowSchema, - right: ArrowSchema, -) -> Result { +pub(crate) fn merge_struct( + left: &StructType, + right: &StructType, +) -> Result { let mut errors = Vec::with_capacity(left.fields().len()); - let merged_fields: Result, ArrowError> = left + let merged_fields: Result, ArrowError> = left .fields() .iter() .map(|field| { let right_field = right.field_with_name(field.name()); if let Ok(right_field) = right_field { - let field_or_not = merge_field(field.as_ref(), right_field); - match field_or_not { + let type_or_not = merge_type(field.data_type(), right_field.data_type()); + match type_or_not { Err(e) => { errors.push(e.to_string()); Err(e) } - Ok(f) => Ok(f), + Ok(f) => { + let mut new_field = StructField::new( + field.name(), + f, + field.is_nullable() || right_field.is_nullable(), + ); + + new_field.metadata = field.metadata.clone(); + try_merge_metadata(&mut new_field.metadata, &right_field.metadata)?; + Ok(new_field) + } } } else { - Ok(field.as_ref().clone()) + Ok(field.clone()) } }) .collect(); @@ -55,11 +68,11 @@ pub(crate) fn merge_schema( Ok(mut fields) => { for field in right.fields() { if !left.field_with_name(field.name()).is_ok() { - fields.push(field.as_ref().clone()); + fields.push(field.clone()); } } - Ok(ArrowSchema::new(fields)) + Ok(StructType::new(fields)) } Err(e) => { errors.push(e.to_string()); @@ -68,6 +81,51 @@ pub(crate) fn merge_schema( } } +pub(crate) fn merge_type( + left: &DeltaDataType, + right: &DeltaDataType, +) -> Result { + if left == right { + return Ok(left.clone()); + } + match (left, right) { + (DeltaDataType::Array(a), DeltaDataType::Array(b)) => { + let merged = merge_type(&a.element_type, &b.element_type)?; + Ok(DeltaDataType::Array(Box::new(ArrayType::new( + merged, + a.contains_null() || b.contains_null(), + )))) + } + (DeltaDataType::Map(a), DeltaDataType::Map(b)) => { + let merged_key = merge_type(&a.key_type, &b.key_type)?; + let merged_value = merge_type(&a.value_type, &b.value_type)?; + Ok(DeltaDataType::Map(Box::new(MapType::new( + merged_key, + merged_value, + a.value_contains_null() || b.value_contains_null(), + )))) + } + (DeltaDataType::Struct(a), DeltaDataType::Struct(b)) => { + let merged = merge_struct(a, b)?; + Ok(DeltaDataType::Struct(Box::new(merged))) + } + (a, b) => Err(ArrowError::SchemaError(format!( + "Cannot merge types {} and {}", + a, b + ))), + } +} + +pub(crate) fn merge_schema( + left: ArrowSchemaRef, + right: ArrowSchemaRef, +) -> Result { + let left_delta: StructType = left.try_into()?; + let right_delta: StructType = right.try_into()?; + let merged: StructType = merge_struct(&left_delta, &right_delta)?; + Ok(Arc::new((&merged).try_into()?)) +} + fn cast_struct( struct_array: &StructArray, fields: &Fields, @@ -142,13 +200,91 @@ pub fn cast_record_batch( #[cfg(test)] mod tests { + use crate::kernel::{ + ArrayType as DeltaArrayType, DataType as DeltaDataType, StructField as DeltaStructField, + StructType as DeltaStructType, + }; + use crate::operations::cast::MetadataValue; use crate::operations::cast::{cast_record_batch, is_cast_required}; use arrow::array::ArrayData; use arrow_array::{Array, ArrayRef, ListArray, RecordBatch}; use arrow_buffer::Buffer; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; + use std::collections::HashMap; use std::sync::Arc; + #[test] + fn test_merge_schema_with_dict() { + let left_schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + let right_schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::LargeUtf8, + true, + )])); + + let result = super::merge_schema(left_schema, right_schema).unwrap(); + assert_eq!(result.fields().len(), 1); + let delta_type: DeltaDataType = result.fields()[0].data_type().try_into().unwrap(); + assert_eq!(delta_type, DeltaDataType::STRING); + assert_eq!(result.fields()[0].is_nullable(), true); + } + + #[test] + fn test_merge_schema_with_meta() { + let mut left_meta = HashMap::new(); + left_meta.insert("a".to_string(), "a1".to_string()); + let left_schema = DeltaStructType::new(vec![DeltaStructField::new( + "f", + DeltaDataType::STRING, + false, + ) + .with_metadata(left_meta)]); + let mut right_meta = HashMap::new(); + right_meta.insert("b".to_string(), "b2".to_string()); + let right_schema = DeltaStructType::new(vec![DeltaStructField::new( + "f", + DeltaDataType::STRING, + true, + ) + .with_metadata(right_meta)]); + + let result = super::merge_struct(&left_schema, &right_schema).unwrap(); + assert_eq!(result.fields().len(), 1); + let delta_type = result.fields()[0].data_type(); + assert_eq!(delta_type, &DeltaDataType::STRING); + let mut expected_meta = HashMap::new(); + expected_meta.insert("a".to_string(), MetadataValue::String("a1".to_string())); + expected_meta.insert("b".to_string(), MetadataValue::String("b2".to_string())); + assert_eq!(result.fields()[0].metadata(), &expected_meta); + } + + #[test] + fn test_merge_schema_with_nested() { + let left_schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))), + false, + )])); + let right_schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, false))), + true, + )])); + + let result = super::merge_schema(left_schema, right_schema).unwrap(); + assert_eq!(result.fields().len(), 1); + let delta_type: DeltaDataType = result.fields()[0].data_type().try_into().unwrap(); + assert_eq!( + delta_type, + DeltaDataType::Array(Box::new(DeltaArrayType::new(DeltaDataType::STRING, false))) + ); + assert_eq!(result.fields()[0].is_nullable(), true); + } + #[test] fn test_cast_record_batch_with_list_non_default_item() { let array = Arc::new(make_list_array()) as ArrayRef; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index d2751a6b1f..ec0a6c80d1 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -623,10 +623,8 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { new_schema = None // we overwrite anyway, so no need to cast } else if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = Some(Arc::new(merge_schema( - table_schema.as_ref().clone(), - schema.as_ref().clone(), - )?)); + new_schema = + Some(merge_schema(table_schema.clone(), schema.clone())?); } else { return Err(schema_err.into()); } diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index 5c8fb57509..8b43f35242 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -306,11 +306,9 @@ impl PartitionWriter { WriteMode::MergeSchema => { debug!("The writer and record batch schemas do not match, merging"); - let merged = merge_schema( - self.arrow_schema.as_ref().clone(), - record_batch.schema().as_ref().clone(), - )?; - self.arrow_schema = Arc::new(merged); + let merged = + merge_schema(self.arrow_schema.clone(), record_batch.schema().clone())?; + self.arrow_schema = merged; let mut cols = vec![]; for field in self.arrow_schema.fields() { diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index dfd124a73d..96903f0824 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -259,7 +259,7 @@ def test_update_schema_rust_writer_append(existing_table: DeltaTable): ) with pytest.raises( SchemaMismatchError, - match="Schema error: Fail to merge schema field 'utf8' because the from data_type = Int64 does not equal Utf8", + match="Schema error: Cannot merge types string and long", ): write_deltalake( existing_table,