diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 809bbcaebbd9..859ef79fc1d2 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -639,7 +639,7 @@ fn build_is_null_column_expr( Expr::Column(ref col) => { let field = schema.field_with_name(&col.name).ok()?; - let null_count_field = &Field::new(field.name(), DataType::UInt64, false); + let null_count_field = &Field::new(field.name(), DataType::UInt64, true); required_columns .null_count_column_expr(col, expr, null_count_field) .map(|null_count_column_expr| { @@ -809,10 +809,19 @@ mod tests { use std::collections::HashMap; #[derive(Debug)] - /// Test for container stats + + /// Mock statistic provider for tests + /// + /// Each row represents the statistics for a "container" (which + /// might represent an entire parquet file, or directory of files, + /// or some other collection of data for which we had statistics) + /// + /// Note All `ArrayRefs` must be the same size. struct ContainerStats { min: ArrayRef, max: ArrayRef, + /// Optional values + null_counts: Option, } impl ContainerStats { @@ -835,6 +844,7 @@ mod tests { .with_precision_and_scale(precision, scale) .unwrap(), ), + null_counts: None, } } @@ -845,6 +855,7 @@ mod tests { Self { min: Arc::new(min.into_iter().collect::()), max: Arc::new(max.into_iter().collect::()), + null_counts: None, } } @@ -855,6 +866,7 @@ mod tests { Self { min: Arc::new(min.into_iter().collect::()), max: Arc::new(max.into_iter().collect::()), + null_counts: None, } } @@ -865,6 +877,7 @@ mod tests { Self { min: Arc::new(min.into_iter().collect::()), max: Arc::new(max.into_iter().collect::()), + null_counts: None, } } @@ -875,6 +888,7 @@ mod tests { Self { min: Arc::new(min.into_iter().collect::()), max: Arc::new(max.into_iter().collect::()), + null_counts: None, } } @@ -886,10 +900,29 @@ mod tests { Some(self.max.clone()) } + fn null_counts(&self) -> Option { + self.null_counts.clone() + } + fn len(&self) -> usize { assert_eq!(self.min.len(), self.max.len()); self.min.len() } + + /// Add null counts. There must be the same number of null counts as + /// there are containers + fn with_null_counts( + mut self, + counts: impl IntoIterator>, + ) -> Self { + // take stats out and update them + let null_counts: ArrayRef = + Arc::new(counts.into_iter().collect::()); + + assert_eq!(null_counts.len(), self.len()); + self.null_counts = Some(null_counts); + self + } } #[derive(Debug, Default)] @@ -908,8 +941,30 @@ mod tests { name: impl Into, container_stats: ContainerStats, ) -> Self { - self.stats - .insert(Column::from_name(name.into()), container_stats); + let col = Column::from_name(name.into()); + self.stats.insert(col, container_stats); + self + } + + /// Add null counts for the specified columm. + /// There must be the same number of null counts as + /// there are containers + fn with_null_counts( + mut self, + name: impl Into, + counts: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .expect("Can not find stats for column") + .with_null_counts(counts); + + // put stats back in + self.stats.insert(col, container_stats); self } } @@ -937,8 +992,11 @@ mod tests { .unwrap_or(0) } - fn null_counts(&self, _column: &Column) -> Option { - None + fn null_counts(&self, column: &Column) -> Option { + self.stats + .get(column) + .map(|container_stats| container_stats.null_counts()) + .unwrap_or(None) } } @@ -1761,4 +1819,39 @@ mod tests { let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); } + + #[test] + fn prune_int32_is_null() { + let (schema, statistics) = int32_setup(); + + // Expression "i IS NULL" when there are no null statistics, + // should all be kept + let expected_ret = vec![true, true, true, true, true]; + + // i IS NULL, no null statistics + let expr = col("i").is_null(); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + + // provide null counts for each column + let statistics = statistics.with_null_counts( + "i", + vec![ + Some(0), // no nulls (don't keep) + Some(1), // 1 null + None, // unknown nulls + None, // unknown nulls (min/max are both null too, like no stats at all) + Some(0), // 0 nulls (max=null too which means no known max) (don't keep) + ], + ); + + let expected_ret = vec![false, true, true, true, false]; + + // i IS NULL, with actual null statistcs + let expr = col("i").is_null(); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected_ret); + } }