Skip to content

Commit

Permalink
Remove Box from Sort
Browse files Browse the repository at this point in the history
`expr::Sort` had `Box<Expr>` because Sort was also an expression (via
`expr::Expr::Sort`). This has been removed, obsoleting need to use a
`Box`.
  • Loading branch information
findepi committed Aug 29, 2024
1 parent 0f16849 commit 03d6167
Show file tree
Hide file tree
Showing 21 changed files with 46 additions and 59 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn create_ordering(
// Construct PhysicalSortExpr objects from Expr objects:
let mut sort_exprs = vec![];
for sort in exprs {
match sort.expr.as_ref() {
match &sort.expr {
Expr::Column(col) => match expressions::col(&col.name, schema) {
Ok(expr) => {
sort_exprs.push(PhysicalSortExpr {
Expand Down
20 changes: 10 additions & 10 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
))
.order_by(vec![Sort::new(Box::new(col("a")), false, true)])
.order_by(vec![Sort::new(col("a"), false, true)])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
Expand Down Expand Up @@ -352,7 +352,7 @@ async fn sort_on_unprojected_columns() -> Result<()> {
.unwrap()
.select(vec![col("a")])
.unwrap()
.sort(vec![Sort::new(Box::new(col("b")), false, true)])
.sort(vec![Sort::new(col("b"), false, true)])
.unwrap();
let results = df.collect().await.unwrap();

Expand Down Expand Up @@ -396,7 +396,7 @@ async fn sort_on_distinct_columns() -> Result<()> {
.unwrap()
.distinct()
.unwrap()
.sort(vec![Sort::new(Box::new(col("a")), false, true)])
.sort(vec![Sort::new(col("a"), false, true)])
.unwrap();
let results = df.collect().await.unwrap();

Expand Down Expand Up @@ -435,7 +435,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> {
.await?
.select(vec![col("a")])?
.distinct()?
.sort(vec![Sort::new(Box::new(col("b")), false, true)])
.sort(vec![Sort::new(col("b"), false, true)])
.unwrap_err();
assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list");
Ok(())
Expand Down Expand Up @@ -599,8 +599,8 @@ async fn test_grouping_sets() -> Result<()> {
.await?
.aggregate(vec![grouping_set_expr], vec![count(col("a"))])?
.sort(vec![
Sort::new(Box::new(col("a")), false, true),
Sort::new(Box::new(col("b")), false, true),
Sort::new(col("a"), false, true),
Sort::new(col("b"), false, true),
])?;

let results = df.collect().await?;
Expand Down Expand Up @@ -640,8 +640,8 @@ async fn test_grouping_sets_count() -> Result<()> {
.await?
.aggregate(vec![grouping_set_expr], vec![count(lit(1))])?
.sort(vec![
Sort::new(Box::new(col("c1")), false, true),
Sort::new(Box::new(col("c2")), false, true),
Sort::new(col("c1"), false, true),
Sort::new(col("c2"), false, true),
])?;

let results = df.collect().await?;
Expand Down Expand Up @@ -687,8 +687,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
],
)?
.sort(vec![
Sort::new(Box::new(col("c1")), false, true),
Sort::new(Box::new(col("c2")), false, true),
Sort::new(col("c1"), false, true),
Sort::new(col("c2"), false, true),
])?;

let results = df.collect().await?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
}

fn expressions(&self) -> Vec<Expr> {
vec![self.expr.expr.as_ref().clone()]
vec![self.expr.expr.clone()]
}

/// For example: `TopK: k=10`
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ impl TryCast {
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Sort {
/// The expression to sort on
pub expr: Box<Expr>,
pub expr: Expr,
/// The direction of the sort
pub asc: bool,
/// Whether to put Nulls before all other data values
Expand All @@ -611,7 +611,7 @@ pub struct Sort {

impl Sort {
/// Create a new Sort expression
pub fn new(expr: Box<Expr>, asc: bool, nulls_first: bool) -> Self {
pub fn new(expr: Expr, asc: bool, nulls_first: bool) -> Self {
Self {
expr,
asc,
Expand Down Expand Up @@ -1368,7 +1368,7 @@ impl Expr {
/// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST
/// ```
pub fn sort(self, asc: bool, nulls_first: bool) -> Sort {
Sort::new(Box::new(self), asc, nulls_first)
Sort::new(self, asc, nulls_first)
}

/// Return `IsTrue(Box(self))`
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ pub fn normalize_sorts(
.into_iter()
.map(|e| {
let sort = e.into();
normalize_col(*sort.expr, plan)
.map(|expr| Sort::new(Box::new(expr), sort.asc, sort.nulls_first))
normalize_col(sort.expr, plan)
.map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
})
.collect()
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub fn rewrite_sort_cols_by_aggs(
.map(|e| {
let sort = e.into();
Ok(Sort::new(
Box::new(rewrite_sort_col_by_aggs(*sort.expr, plan)?),
rewrite_sort_col_by_aggs(sort.expr, plan)?,
sort.asc,
sort.nulls_first,
))
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1720,8 +1720,8 @@ mod tests {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?
.sort(vec![
expr::Sort::new(Box::new(col("state")), true, true),
expr::Sort::new(Box::new(col("salary")), false, false),
expr::Sort::new(col("state"), true, true),
expr::Sort::new(col("salary"), false, false),
])?
.build()?;

Expand Down Expand Up @@ -2147,8 +2147,8 @@ mod tests {
let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?
.sort(vec![
expr::Sort::new(Box::new(col("state")), true, true),
expr::Sort::new(Box::new(col("salary")), false, false),
expr::Sort::new(col("state"), true, true),
expr::Sort::new(col("salary"), false, false),
])?
.build()?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2616,7 +2616,7 @@ impl DistinctOn {
// Check that the left-most sort expressions are the same as the `ON` expressions.
let mut matched = true;
for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) {
if on != &*sort.expr {
if on != &sort.expr {
matched = false;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ impl LogicalPlan {
})) => on_expr
.iter()
.chain(select_expr.iter())
.chain(sort_expr.iter().flatten().map(|sort| &*sort.expr))
.chain(sort_expr.iter().flatten().map(|sort| &sort.expr))
.apply_until_stop(f),
// plans without expressions
LogicalPlan::EmptyRelation(_)
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl TreeNode for Expr {
expr_vec.push(f.as_ref());
}
if let Some(order_by) = order_by {
expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref()));
expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
}
expr_vec
}
Expand All @@ -109,7 +109,7 @@ impl TreeNode for Expr {
}) => {
let mut expr_vec = args.iter().collect::<Vec<_>>();
expr_vec.extend(partition_by);
expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref()));
expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
expr_vec
}
Expr::InList(InList { expr, list, .. }) => {
Expand Down Expand Up @@ -395,7 +395,7 @@ pub fn transform_sort_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
) -> Result<Transformed<Vec<Sort>>> {
Ok(sorts
.iter()
.map(|sort| (*sort.expr).clone())
.map(|sort| sort.expr.clone())
.map_until_stop_and_collect(&mut f)?
.update_data(|transformed_exprs| {
replace_sort_expressions(sorts, transformed_exprs)
Expand All @@ -413,7 +413,7 @@ pub fn replace_sort_expressions(sorts: Vec<Sort>, new_expr: Vec<Expr>) -> Vec<So

pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort {
Sort {
expr: Box::new(new_expr),
expr: new_expr,
..sort
}
}
16 changes: 8 additions & 8 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,9 +1401,9 @@ mod tests {

#[test]
fn test_group_window_expr_by_sort_keys() -> Result<()> {
let age_asc = expr::Sort::new(Box::new(col("age")), true, true);
let name_desc = expr::Sort::new(Box::new(col("name")), false, true);
let created_at_desc = expr::Sort::new(Box::new(col("created_at")), false, true);
let age_asc = expr::Sort::new(col("age"), true, true);
let name_desc = expr::Sort::new(col("name"), false, true);
let created_at_desc = expr::Sort::new(col("created_at"), false, true);
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![col("name")],
Expand Down Expand Up @@ -1463,12 +1463,12 @@ mod tests {
for nulls_first_ in nulls_first_or_last {
let order_by = &[
Sort {
expr: Box::new(col("age")),
expr: col("age"),
asc: asc_,
nulls_first: nulls_first_,
},
Sort {
expr: Box::new(col("name")),
expr: col("name"),
asc: asc_,
nulls_first: nulls_first_,
},
Expand All @@ -1477,23 +1477,23 @@ mod tests {
let expected = vec![
(
Sort {
expr: Box::new(col("age")),
expr: col("age"),
asc: asc_,
nulls_first: nulls_first_,
},
true,
),
(
Sort {
expr: Box::new(col("name")),
expr: col("name"),
asc: asc_,
nulls_first: nulls_first_,
},
true,
),
(
Sort {
expr: Box::new(col("created_at")),
expr: col("created_at"),
asc: true,
nulls_first: false,
},
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ mod tests {
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
))
.order_by(vec![Sort::new(Box::new(col("a")), false, true)])
.order_by(vec![Sort::new(col("a"), false, true)])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
Expand Down
3 changes: 1 addition & 2 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ impl CommonSubexprEliminate {
) -> Result<Transformed<LogicalPlan>> {
let Sort { expr, input, fetch } = sort;
let input = Arc::unwrap_or_clone(input);
let sort_expressions =
expr.iter().map(|sort| sort.expr.as_ref().clone()).collect();
let sort_expressions = expr.iter().map(|sort| sort.expr.clone()).collect();
let new_sort = self
.try_unary_plan(sort_expressions, input, config)?
.update_data(|(new_expr, new_input)| {
Expand Down
7 changes: 1 addition & 6 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,7 @@ pub fn parse_sort(
codec: &dyn LogicalExtensionCodec,
) -> Result<Sort, Error> {
Ok(Sort::new(
Box::new(parse_required_expr(
sort.expr.as_ref(),
registry,
"expr",
codec,
)?),
parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?,
sort.asc,
sort.nulls_first,
))
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ where
nulls_first,
} = sort;
Ok(protobuf::SortExprNode {
expr: Some(serialize_expr(expr.as_ref(), codec)?),
expr: Some(serialize_expr(expr, codec)?),
asc: *asc,
nulls_first: *nulls_first,
})
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let func_deps = schema.functional_dependencies();
// Find whether ties are possible in the given ordering
let is_ordering_strict = order_by.iter().find_map(|orderby_expr| {
if let Expr::Column(col) = orderby_expr.expr.as_ref() {
if let Expr::Column(col) = &orderby_expr.expr {
let idx = schema.index_of_column(col).ok()?;
return if func_deps.iter().any(|dep| {
dep.source_indices == vec![idx] && dep.mode == Dependency::Single
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
};
let asc = asc.unwrap_or(true);
expr_vec.push(Sort::new(
Box::new(expr),
expr,
asc,
// when asc is true, by default nulls last to be consistent with postgres
// postgres rule: https://www.postgresql.org/docs/current/queries-order.html
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,7 @@ mod tests {
fun: WindowFunctionDefinition::AggregateUDF(count_udaf()),
args: vec![wildcard()],
partition_by: vec![],
order_by: vec![Sort::new(Box::new(col("a")), false, true)],
order_by: vec![Sort::new(col("a"), false, true)],
window_frame: WindowFrame::new_bounds(
datafusion_expr::WindowFrameUnits::Range,
datafusion_expr::WindowFrameBound::Preceding(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/unparser/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(

let mut collects = p.expr.clone();
for sort in &sort.expr {
collects.push(sort.expr.as_ref().clone());
collects.push(sort.expr.clone());
}

// Compare outer collects Expr::to_string with inner collected transformed values
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ pub async fn from_substrait_sorts(
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
sorts.push(Sort {
expr: Box::new(expr),
expr,
asc,
nulls_first,
});
Expand Down
9 changes: 1 addition & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
// under the License.

use itertools::Itertools;
use std::ops::Deref;
use std::sync::Arc;

use arrow_buffer::ToByteSlice;
Expand Down Expand Up @@ -819,13 +818,7 @@ fn to_substrait_sort_field(
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
expr: Some(to_substrait_rex(
ctx,
sort.expr.deref(),
schema,
0,
extensions,
)?),
expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?),
sort_kind: Some(SortKind::Direction(sort_kind.into())),
})
}
Expand Down

0 comments on commit 03d6167

Please sign in to comment.