From 8d2b240b7c21ac20475b42b37fbe3f7f5b8b7956 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:01:09 +0100 Subject: [PATCH] Allow using dictionary arrays as filters (#12382) * Allow using dictionaries as filters * revert, nested * fmt --- datafusion/core/tests/dataframe/mod.rs | 107 ++++++++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 14 ++- 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 19ce9294cfad..171ef9561e55 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,7 +29,10 @@ use arrow::{ }, record_batch::RecordBatch, }; -use arrow_array::{Array, Float32Array, Float64Array, UnionArray}; +use arrow_array::{ + Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int8Array, + UnionArray, +}; use arrow_buffer::ScalarBuffer; use arrow_schema::{ArrowError, UnionFields, UnionMode}; use datafusion_functions_aggregate::count::count_udaf; @@ -2363,3 +2366,105 @@ async fn dense_union_is_null() { ]; assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); } + +#[tokio::test] +async fn boolean_dictionary_as_filter() { + let values = vec![Some(true), Some(false), None, Some(true)]; + let keys = vec![0, 0, 1, 2, 1, 3, 1]; + let values_array = BooleanArray::from(values); + let keys_array = Int8Array::from(keys); + let array = + DictionaryArray::new(keys_array, Arc::new(values_array) as Arc); + let array = Arc::new(array); + + let field = Field::new( + "my_dict", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Boolean)), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("dict_batch", batch).unwrap(); + + let df = ctx.table("dict_batch").await.unwrap(); + + // view_all + let expected = [ + "+---------+", + "| my_dict |", + "+---------+", + "| true |", + "| true |", + "| false |", + "| |", + "| false |", + "| true |", + "| false |", + "+---------+", + ]; + assert_batches_eq!(expected, &df.clone().collect().await.unwrap()); + + let result_df = df.clone().filter(col("my_dict")).unwrap(); + let expected = [ + "+---------+", + "| my_dict |", + "+---------+", + "| true |", + "| true |", + "| true |", + "+---------+", + ]; + assert_batches_eq!(expected, &result_df.collect().await.unwrap()); + + // test nested dictionary + let keys = vec![0, 2]; // 0 -> true, 2 -> false + let keys_array = Int8Array::from(keys); + let nested_array = DictionaryArray::new(keys_array, array); + + let field = Field::new( + "my_nested_dict", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Boolean), + )), + ), + true, + ); + + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_array)]).unwrap(); + + ctx.register_batch("nested_dict_batch", batch).unwrap(); + + let df = ctx.table("nested_dict_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------------+", + "| my_nested_dict |", + "+----------------+", + "| true |", + "| false |", + "+----------------+", + ]; + + assert_batches_eq!(expected, &df.clone().collect().await.unwrap()); + + let result_df = df.clone().filter(col("my_nested_dict")).unwrap(); + let expected = [ + "+----------------+", + "| my_nested_dict |", + "+----------------+", + "| true |", + "+----------------+", + ]; + + assert_batches_eq!(expected, &result_df.collect().await.unwrap()); +} diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 975bfc9feebf..1c94c7f3afd3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2207,6 +2207,17 @@ impl Filter { Self::try_new_internal(predicate, input, true) } + fn is_allowed_filter_type(data_type: &DataType) -> bool { + match data_type { + // Interpret NULL as a missing boolean value. + DataType::Boolean | DataType::Null => true, + DataType::Dictionary(_, value_type) => { + Filter::is_allowed_filter_type(value_type.as_ref()) + } + _ => false, + } + } + fn try_new_internal( predicate: Expr, input: Arc, @@ -2217,8 +2228,7 @@ impl Filter { // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { - // Interpret NULL as a missing boolean value. - if predicate_type != DataType::Boolean && predicate_type != DataType::Null { + if !Filter::is_allowed_filter_type(&predicate_type) { return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" );