/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_fn.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Functions for creating logical expressions |
19 | | |
20 | | use crate::expr::{ |
21 | | AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, |
22 | | Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, |
23 | | }; |
24 | | use crate::function::{ |
25 | | AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, |
26 | | StateFieldsArgs, |
27 | | }; |
28 | | use crate::{ |
29 | | conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, |
30 | | AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, |
31 | | Signature, Volatility, |
32 | | }; |
33 | | use crate::{ |
34 | | AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, |
35 | | }; |
36 | | use arrow::compute::kernels::cast_utils::{ |
37 | | parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, |
38 | | }; |
39 | | use arrow::datatypes::{DataType, Field}; |
40 | | use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference}; |
41 | | use datafusion_functions_window_common::field::WindowUDFFieldArgs; |
42 | | use sqlparser::ast::NullTreatment; |
43 | | use std::any::Any; |
44 | | use std::fmt::Debug; |
45 | | use std::ops::Not; |
46 | | use std::sync::Arc; |
47 | | |
48 | | /// Create a column expression based on a qualified or unqualified column name. Will |
49 | | /// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase). |
50 | | /// |
51 | | /// For example: |
52 | | /// |
53 | | /// ```rust |
54 | | /// # use datafusion_expr::col; |
55 | | /// let c1 = col("a"); |
56 | | /// let c2 = col("A"); |
57 | | /// assert_eq!(c1, c2); |
58 | | /// |
59 | | /// // note how quoting with double quotes preserves the case |
60 | | /// let c3 = col(r#""A""#); |
61 | | /// assert_ne!(c1, c3); |
62 | | /// ``` |
63 | 0 | pub fn col(ident: impl Into<Column>) -> Expr { |
64 | 0 | Expr::Column(ident.into()) |
65 | 0 | } |
66 | | |
67 | | /// Create an out reference column which hold a reference that has been resolved to a field |
68 | | /// outside of the current plan. |
69 | 0 | pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr { |
70 | 0 | Expr::OuterReferenceColumn(dt, ident.into()) |
71 | 0 | } |
72 | | |
73 | | /// Create an unqualified column expression from the provided name, without normalizing |
74 | | /// the column. |
75 | | /// |
76 | | /// For example: |
77 | | /// |
78 | | /// ```rust |
79 | | /// # use datafusion_expr::{col, ident}; |
80 | | /// let c1 = ident("A"); // not normalized staying as column 'A' |
81 | | /// let c2 = col("A"); // normalized via SQL rules becoming column 'a' |
82 | | /// assert_ne!(c1, c2); |
83 | | /// |
84 | | /// let c3 = col(r#""A""#); |
85 | | /// assert_eq!(c1, c3); |
86 | | /// |
87 | | /// let c4 = col("t1.a"); // parses as relation 't1' column 'a' |
88 | | /// let c5 = ident("t1.a"); // parses as column 't1.a' |
89 | | /// assert_ne!(c4, c5); |
90 | | /// ``` |
91 | 0 | pub fn ident(name: impl Into<String>) -> Expr { |
92 | 0 | Expr::Column(Column::from_name(name)) |
93 | 0 | } |
94 | | |
95 | | /// Create placeholder value that will be filled in (such as `$1`) |
96 | | /// |
97 | | /// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`] |
98 | | /// |
99 | | /// # Example |
100 | | /// |
101 | | /// ```rust |
102 | | /// # use datafusion_expr::{placeholder}; |
103 | | /// let p = placeholder("$0"); // $0, refers to parameter 1 |
104 | | /// assert_eq!(p.to_string(), "$0") |
105 | | /// ``` |
106 | 0 | pub fn placeholder(id: impl Into<String>) -> Expr { |
107 | 0 | Expr::Placeholder(Placeholder { |
108 | 0 | id: id.into(), |
109 | 0 | data_type: None, |
110 | 0 | }) |
111 | 0 | } |
112 | | |
113 | | /// Create an '*' [`Expr::Wildcard`] expression that matches all columns |
114 | | /// |
115 | | /// # Example |
116 | | /// |
117 | | /// ```rust |
118 | | /// # use datafusion_expr::{wildcard}; |
119 | | /// let p = wildcard(); |
120 | | /// assert_eq!(p.to_string(), "*") |
121 | | /// ``` |
122 | 0 | pub fn wildcard() -> Expr { |
123 | 0 | Expr::Wildcard { |
124 | 0 | qualifier: None, |
125 | 0 | options: WildcardOptions::default(), |
126 | 0 | } |
127 | 0 | } |
128 | | |
129 | | /// Create an '*' [`Expr::Wildcard`] expression with the wildcard options |
130 | 0 | pub fn wildcard_with_options(options: WildcardOptions) -> Expr { |
131 | 0 | Expr::Wildcard { |
132 | 0 | qualifier: None, |
133 | 0 | options, |
134 | 0 | } |
135 | 0 | } |
136 | | |
137 | | /// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table |
138 | | /// |
139 | | /// # Example |
140 | | /// |
141 | | /// ```rust |
142 | | /// # use datafusion_common::TableReference; |
143 | | /// # use datafusion_expr::{qualified_wildcard}; |
144 | | /// let p = qualified_wildcard(TableReference::bare("t")); |
145 | | /// assert_eq!(p.to_string(), "t.*") |
146 | | /// ``` |
147 | 0 | pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> Expr { |
148 | 0 | Expr::Wildcard { |
149 | 0 | qualifier: Some(qualifier.into()), |
150 | 0 | options: WildcardOptions::default(), |
151 | 0 | } |
152 | 0 | } |
153 | | |
154 | | /// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options |
155 | 0 | pub fn qualified_wildcard_with_options( |
156 | 0 | qualifier: impl Into<TableReference>, |
157 | 0 | options: WildcardOptions, |
158 | 0 | ) -> Expr { |
159 | 0 | Expr::Wildcard { |
160 | 0 | qualifier: Some(qualifier.into()), |
161 | 0 | options, |
162 | 0 | } |
163 | 0 | } |
164 | | |
165 | | /// Return a new expression `left <op> right` |
166 | 0 | pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { |
167 | 0 | Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) |
168 | 0 | } |
169 | | |
170 | | /// Return a new expression with a logical AND |
171 | 0 | pub fn and(left: Expr, right: Expr) -> Expr { |
172 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
173 | 0 | Box::new(left), |
174 | 0 | Operator::And, |
175 | 0 | Box::new(right), |
176 | 0 | )) |
177 | 0 | } |
178 | | |
179 | | /// Return a new expression with a logical OR |
180 | 0 | pub fn or(left: Expr, right: Expr) -> Expr { |
181 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
182 | 0 | Box::new(left), |
183 | 0 | Operator::Or, |
184 | 0 | Box::new(right), |
185 | 0 | )) |
186 | 0 | } |
187 | | |
188 | | /// Return a new expression with a logical NOT |
189 | 0 | pub fn not(expr: Expr) -> Expr { |
190 | 0 | expr.not() |
191 | 0 | } |
192 | | |
193 | | /// Return a new expression with bitwise AND |
194 | 0 | pub fn bitwise_and(left: Expr, right: Expr) -> Expr { |
195 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
196 | 0 | Box::new(left), |
197 | 0 | Operator::BitwiseAnd, |
198 | 0 | Box::new(right), |
199 | 0 | )) |
200 | 0 | } |
201 | | |
202 | | /// Return a new expression with bitwise OR |
203 | 0 | pub fn bitwise_or(left: Expr, right: Expr) -> Expr { |
204 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
205 | 0 | Box::new(left), |
206 | 0 | Operator::BitwiseOr, |
207 | 0 | Box::new(right), |
208 | 0 | )) |
209 | 0 | } |
210 | | |
211 | | /// Return a new expression with bitwise XOR |
212 | 0 | pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { |
213 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
214 | 0 | Box::new(left), |
215 | 0 | Operator::BitwiseXor, |
216 | 0 | Box::new(right), |
217 | 0 | )) |
218 | 0 | } |
219 | | |
220 | | /// Return a new expression with bitwise SHIFT RIGHT |
221 | 0 | pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { |
222 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
223 | 0 | Box::new(left), |
224 | 0 | Operator::BitwiseShiftRight, |
225 | 0 | Box::new(right), |
226 | 0 | )) |
227 | 0 | } |
228 | | |
229 | | /// Return a new expression with bitwise SHIFT LEFT |
230 | 0 | pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { |
231 | 0 | Expr::BinaryExpr(BinaryExpr::new( |
232 | 0 | Box::new(left), |
233 | 0 | Operator::BitwiseShiftLeft, |
234 | 0 | Box::new(right), |
235 | 0 | )) |
236 | 0 | } |
237 | | |
238 | | /// Create an in_list expression |
239 | 0 | pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr { |
240 | 0 | Expr::InList(InList::new(Box::new(expr), list, negated)) |
241 | 0 | } |
242 | | |
243 | | /// Create an EXISTS subquery expression |
244 | 0 | pub fn exists(subquery: Arc<LogicalPlan>) -> Expr { |
245 | 0 | let outer_ref_columns = subquery.all_out_ref_exprs(); |
246 | 0 | Expr::Exists(Exists { |
247 | 0 | subquery: Subquery { |
248 | 0 | subquery, |
249 | 0 | outer_ref_columns, |
250 | 0 | }, |
251 | 0 | negated: false, |
252 | 0 | }) |
253 | 0 | } |
254 | | |
255 | | /// Create a NOT EXISTS subquery expression |
256 | 0 | pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr { |
257 | 0 | let outer_ref_columns = subquery.all_out_ref_exprs(); |
258 | 0 | Expr::Exists(Exists { |
259 | 0 | subquery: Subquery { |
260 | 0 | subquery, |
261 | 0 | outer_ref_columns, |
262 | 0 | }, |
263 | 0 | negated: true, |
264 | 0 | }) |
265 | 0 | } |
266 | | |
267 | | /// Create an IN subquery expression |
268 | 0 | pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr { |
269 | 0 | let outer_ref_columns = subquery.all_out_ref_exprs(); |
270 | 0 | Expr::InSubquery(InSubquery::new( |
271 | 0 | Box::new(expr), |
272 | 0 | Subquery { |
273 | 0 | subquery, |
274 | 0 | outer_ref_columns, |
275 | 0 | }, |
276 | 0 | false, |
277 | 0 | )) |
278 | 0 | } |
279 | | |
280 | | /// Create a NOT IN subquery expression |
281 | 0 | pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr { |
282 | 0 | let outer_ref_columns = subquery.all_out_ref_exprs(); |
283 | 0 | Expr::InSubquery(InSubquery::new( |
284 | 0 | Box::new(expr), |
285 | 0 | Subquery { |
286 | 0 | subquery, |
287 | 0 | outer_ref_columns, |
288 | 0 | }, |
289 | 0 | true, |
290 | 0 | )) |
291 | 0 | } |
292 | | |
293 | | /// Create a scalar subquery expression |
294 | 0 | pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr { |
295 | 0 | let outer_ref_columns = subquery.all_out_ref_exprs(); |
296 | 0 | Expr::ScalarSubquery(Subquery { |
297 | 0 | subquery, |
298 | 0 | outer_ref_columns, |
299 | 0 | }) |
300 | 0 | } |
301 | | |
302 | | /// Create a grouping set |
303 | 0 | pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr { |
304 | 0 | Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) |
305 | 0 | } |
306 | | |
307 | | /// Create a grouping set for all combination of `exprs` |
308 | 0 | pub fn cube(exprs: Vec<Expr>) -> Expr { |
309 | 0 | Expr::GroupingSet(GroupingSet::Cube(exprs)) |
310 | 0 | } |
311 | | |
312 | | /// Create a grouping set for rollup |
313 | 0 | pub fn rollup(exprs: Vec<Expr>) -> Expr { |
314 | 0 | Expr::GroupingSet(GroupingSet::Rollup(exprs)) |
315 | 0 | } |
316 | | |
317 | | /// Create a cast expression |
318 | 0 | pub fn cast(expr: Expr, data_type: DataType) -> Expr { |
319 | 0 | Expr::Cast(Cast::new(Box::new(expr), data_type)) |
320 | 0 | } |
321 | | |
322 | | /// Create a try cast expression |
323 | 0 | pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { |
324 | 0 | Expr::TryCast(TryCast::new(Box::new(expr), data_type)) |
325 | 0 | } |
326 | | |
327 | | /// Create is null expression |
328 | 0 | pub fn is_null(expr: Expr) -> Expr { |
329 | 0 | Expr::IsNull(Box::new(expr)) |
330 | 0 | } |
331 | | |
332 | | /// Create is true expression |
333 | 0 | pub fn is_true(expr: Expr) -> Expr { |
334 | 0 | Expr::IsTrue(Box::new(expr)) |
335 | 0 | } |
336 | | |
337 | | /// Create is not true expression |
338 | 0 | pub fn is_not_true(expr: Expr) -> Expr { |
339 | 0 | Expr::IsNotTrue(Box::new(expr)) |
340 | 0 | } |
341 | | |
342 | | /// Create is false expression |
343 | 0 | pub fn is_false(expr: Expr) -> Expr { |
344 | 0 | Expr::IsFalse(Box::new(expr)) |
345 | 0 | } |
346 | | |
347 | | /// Create is not false expression |
348 | 0 | pub fn is_not_false(expr: Expr) -> Expr { |
349 | 0 | Expr::IsNotFalse(Box::new(expr)) |
350 | 0 | } |
351 | | |
352 | | /// Create is unknown expression |
353 | 0 | pub fn is_unknown(expr: Expr) -> Expr { |
354 | 0 | Expr::IsUnknown(Box::new(expr)) |
355 | 0 | } |
356 | | |
357 | | /// Create is not unknown expression |
358 | 0 | pub fn is_not_unknown(expr: Expr) -> Expr { |
359 | 0 | Expr::IsNotUnknown(Box::new(expr)) |
360 | 0 | } |
361 | | |
362 | | /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. |
363 | 0 | pub fn case(expr: Expr) -> CaseBuilder { |
364 | 0 | CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None) |
365 | 0 | } |
366 | | |
367 | | /// Create a CASE WHEN statement with boolean WHEN expressions and no base expression. |
368 | 0 | pub fn when(when: Expr, then: Expr) -> CaseBuilder { |
369 | 0 | CaseBuilder::new(None, vec![when], vec![then], None) |
370 | 0 | } |
371 | | |
372 | | /// Create a Unnest expression |
373 | 0 | pub fn unnest(expr: Expr) -> Expr { |
374 | 0 | Expr::Unnest(Unnest { |
375 | 0 | expr: Box::new(expr), |
376 | 0 | }) |
377 | 0 | } |
378 | | |
379 | | /// Convenience method to create a new user defined scalar function (UDF) with a |
380 | | /// specific signature and specific return type. |
381 | | /// |
382 | | /// Note this function does not expose all available features of [`ScalarUDF`], |
383 | | /// such as |
384 | | /// |
385 | | /// * computing return types based on input types |
386 | | /// * multiple [`Signature`]s |
387 | | /// * aliases |
388 | | /// |
389 | | /// See [`ScalarUDF`] for details and examples on how to use the full |
390 | | /// functionality. |
391 | 0 | pub fn create_udf( |
392 | 0 | name: &str, |
393 | 0 | input_types: Vec<DataType>, |
394 | 0 | return_type: DataType, |
395 | 0 | volatility: Volatility, |
396 | 0 | fun: ScalarFunctionImplementation, |
397 | 0 | ) -> ScalarUDF { |
398 | 0 | ScalarUDF::from(SimpleScalarUDF::new( |
399 | 0 | name, |
400 | 0 | input_types, |
401 | 0 | return_type, |
402 | 0 | volatility, |
403 | 0 | fun, |
404 | 0 | )) |
405 | 0 | } |
406 | | |
407 | | /// Implements [`ScalarUDFImpl`] for functions that have a single signature and |
408 | | /// return type. |
409 | | pub struct SimpleScalarUDF { |
410 | | name: String, |
411 | | signature: Signature, |
412 | | return_type: DataType, |
413 | | fun: ScalarFunctionImplementation, |
414 | | } |
415 | | |
416 | | impl Debug for SimpleScalarUDF { |
417 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
418 | 0 | f.debug_struct("ScalarUDF") |
419 | 0 | .field("name", &self.name) |
420 | 0 | .field("signature", &self.signature) |
421 | 0 | .field("fun", &"<FUNC>") |
422 | 0 | .finish() |
423 | 0 | } |
424 | | } |
425 | | |
426 | | impl SimpleScalarUDF { |
427 | | /// Create a new `SimpleScalarUDF` from a name, input types, return type and |
428 | | /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility |
429 | 0 | pub fn new( |
430 | 0 | name: impl Into<String>, |
431 | 0 | input_types: Vec<DataType>, |
432 | 0 | return_type: DataType, |
433 | 0 | volatility: Volatility, |
434 | 0 | fun: ScalarFunctionImplementation, |
435 | 0 | ) -> Self { |
436 | 0 | let name = name.into(); |
437 | 0 | let signature = Signature::exact(input_types, volatility); |
438 | 0 | Self { |
439 | 0 | name, |
440 | 0 | signature, |
441 | 0 | return_type, |
442 | 0 | fun, |
443 | 0 | } |
444 | 0 | } |
445 | | } |
446 | | |
447 | | impl ScalarUDFImpl for SimpleScalarUDF { |
448 | 0 | fn as_any(&self) -> &dyn Any { |
449 | 0 | self |
450 | 0 | } |
451 | | |
452 | 0 | fn name(&self) -> &str { |
453 | 0 | &self.name |
454 | 0 | } |
455 | | |
456 | 0 | fn signature(&self) -> &Signature { |
457 | 0 | &self.signature |
458 | 0 | } |
459 | | |
460 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
461 | 0 | Ok(self.return_type.clone()) |
462 | 0 | } |
463 | | |
464 | 0 | fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { |
465 | 0 | (self.fun)(args) |
466 | 0 | } |
467 | | } |
468 | | |
469 | | /// Creates a new UDAF with a specific signature, state type and return type. |
470 | | /// The signature and state type must match the `Accumulator's implementation`. |
471 | 0 | pub fn create_udaf( |
472 | 0 | name: &str, |
473 | 0 | input_type: Vec<DataType>, |
474 | 0 | return_type: Arc<DataType>, |
475 | 0 | volatility: Volatility, |
476 | 0 | accumulator: AccumulatorFactoryFunction, |
477 | 0 | state_type: Arc<Vec<DataType>>, |
478 | 0 | ) -> AggregateUDF { |
479 | 0 | let return_type = Arc::unwrap_or_clone(return_type); |
480 | 0 | let state_type = Arc::unwrap_or_clone(state_type); |
481 | 0 | let state_fields = state_type |
482 | 0 | .into_iter() |
483 | 0 | .enumerate() |
484 | 0 | .map(|(i, t)| Field::new(format!("{i}"), t, true)) |
485 | 0 | .collect::<Vec<_>>(); |
486 | 0 | AggregateUDF::from(SimpleAggregateUDF::new( |
487 | 0 | name, |
488 | 0 | input_type, |
489 | 0 | return_type, |
490 | 0 | volatility, |
491 | 0 | accumulator, |
492 | 0 | state_fields, |
493 | 0 | )) |
494 | 0 | } |
495 | | |
496 | | /// Implements [`AggregateUDFImpl`] for functions that have a single signature and |
497 | | /// return type. |
498 | | pub struct SimpleAggregateUDF { |
499 | | name: String, |
500 | | signature: Signature, |
501 | | return_type: DataType, |
502 | | accumulator: AccumulatorFactoryFunction, |
503 | | state_fields: Vec<Field>, |
504 | | } |
505 | | |
506 | | impl Debug for SimpleAggregateUDF { |
507 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
508 | 0 | f.debug_struct("AggregateUDF") |
509 | 0 | .field("name", &self.name) |
510 | 0 | .field("signature", &self.signature) |
511 | 0 | .field("fun", &"<FUNC>") |
512 | 0 | .finish() |
513 | 0 | } |
514 | | } |
515 | | |
516 | | impl SimpleAggregateUDF { |
517 | | /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and |
518 | | /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility |
519 | 0 | pub fn new( |
520 | 0 | name: impl Into<String>, |
521 | 0 | input_type: Vec<DataType>, |
522 | 0 | return_type: DataType, |
523 | 0 | volatility: Volatility, |
524 | 0 | accumulator: AccumulatorFactoryFunction, |
525 | 0 | state_fields: Vec<Field>, |
526 | 0 | ) -> Self { |
527 | 0 | let name = name.into(); |
528 | 0 | let signature = Signature::exact(input_type, volatility); |
529 | 0 | Self { |
530 | 0 | name, |
531 | 0 | signature, |
532 | 0 | return_type, |
533 | 0 | accumulator, |
534 | 0 | state_fields, |
535 | 0 | } |
536 | 0 | } |
537 | | |
538 | 0 | pub fn new_with_signature( |
539 | 0 | name: impl Into<String>, |
540 | 0 | signature: Signature, |
541 | 0 | return_type: DataType, |
542 | 0 | accumulator: AccumulatorFactoryFunction, |
543 | 0 | state_fields: Vec<Field>, |
544 | 0 | ) -> Self { |
545 | 0 | let name = name.into(); |
546 | 0 | Self { |
547 | 0 | name, |
548 | 0 | signature, |
549 | 0 | return_type, |
550 | 0 | accumulator, |
551 | 0 | state_fields, |
552 | 0 | } |
553 | 0 | } |
554 | | } |
555 | | |
556 | | impl AggregateUDFImpl for SimpleAggregateUDF { |
557 | 0 | fn as_any(&self) -> &dyn Any { |
558 | 0 | self |
559 | 0 | } |
560 | | |
561 | 0 | fn name(&self) -> &str { |
562 | 0 | &self.name |
563 | 0 | } |
564 | | |
565 | 0 | fn signature(&self) -> &Signature { |
566 | 0 | &self.signature |
567 | 0 | } |
568 | | |
569 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
570 | 0 | Ok(self.return_type.clone()) |
571 | 0 | } |
572 | | |
573 | 0 | fn accumulator( |
574 | 0 | &self, |
575 | 0 | acc_args: AccumulatorArgs, |
576 | 0 | ) -> Result<Box<dyn crate::Accumulator>> { |
577 | 0 | (self.accumulator)(acc_args) |
578 | 0 | } |
579 | | |
580 | 0 | fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> { |
581 | 0 | Ok(self.state_fields.clone()) |
582 | 0 | } |
583 | | } |
584 | | |
585 | | /// Creates a new UDWF with a specific signature, state type and return type. |
586 | | /// |
587 | | /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. |
588 | | /// |
589 | | /// [`PartitionEvaluator`]: crate::PartitionEvaluator |
590 | 0 | pub fn create_udwf( |
591 | 0 | name: &str, |
592 | 0 | input_type: DataType, |
593 | 0 | return_type: Arc<DataType>, |
594 | 0 | volatility: Volatility, |
595 | 0 | partition_evaluator_factory: PartitionEvaluatorFactory, |
596 | 0 | ) -> WindowUDF { |
597 | 0 | let return_type = Arc::unwrap_or_clone(return_type); |
598 | 0 | WindowUDF::from(SimpleWindowUDF::new( |
599 | 0 | name, |
600 | 0 | input_type, |
601 | 0 | return_type, |
602 | 0 | volatility, |
603 | 0 | partition_evaluator_factory, |
604 | 0 | )) |
605 | 0 | } |
606 | | |
607 | | /// Implements [`WindowUDFImpl`] for functions that have a single signature and |
608 | | /// return type. |
609 | | pub struct SimpleWindowUDF { |
610 | | name: String, |
611 | | signature: Signature, |
612 | | return_type: DataType, |
613 | | partition_evaluator_factory: PartitionEvaluatorFactory, |
614 | | } |
615 | | |
616 | | impl Debug for SimpleWindowUDF { |
617 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
618 | 0 | f.debug_struct("WindowUDF") |
619 | 0 | .field("name", &self.name) |
620 | 0 | .field("signature", &self.signature) |
621 | 0 | .field("return_type", &"<func>") |
622 | 0 | .field("partition_evaluator_factory", &"<FUNC>") |
623 | 0 | .finish() |
624 | 0 | } |
625 | | } |
626 | | |
627 | | impl SimpleWindowUDF { |
628 | | /// Create a new `SimpleWindowUDF` from a name, input types, return type and |
629 | | /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility |
630 | 0 | pub fn new( |
631 | 0 | name: impl Into<String>, |
632 | 0 | input_type: DataType, |
633 | 0 | return_type: DataType, |
634 | 0 | volatility: Volatility, |
635 | 0 | partition_evaluator_factory: PartitionEvaluatorFactory, |
636 | 0 | ) -> Self { |
637 | 0 | let name = name.into(); |
638 | 0 | let signature = Signature::exact([input_type].to_vec(), volatility); |
639 | 0 | Self { |
640 | 0 | name, |
641 | 0 | signature, |
642 | 0 | return_type, |
643 | 0 | partition_evaluator_factory, |
644 | 0 | } |
645 | 0 | } |
646 | | } |
647 | | |
648 | | impl WindowUDFImpl for SimpleWindowUDF { |
649 | 0 | fn as_any(&self) -> &dyn Any { |
650 | 0 | self |
651 | 0 | } |
652 | | |
653 | 0 | fn name(&self) -> &str { |
654 | 0 | &self.name |
655 | 0 | } |
656 | | |
657 | 0 | fn signature(&self) -> &Signature { |
658 | 0 | &self.signature |
659 | 0 | } |
660 | | |
661 | 0 | fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> { |
662 | 0 | (self.partition_evaluator_factory)() |
663 | 0 | } |
664 | | |
665 | 0 | fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> { |
666 | 0 | Ok(Field::new( |
667 | 0 | field_args.name(), |
668 | 0 | self.return_type.clone(), |
669 | 0 | true, |
670 | 0 | )) |
671 | 0 | } |
672 | | } |
673 | | |
674 | 0 | pub fn interval_year_month_lit(value: &str) -> Expr { |
675 | 0 | let interval = parse_interval_year_month(value).ok(); |
676 | 0 | Expr::Literal(ScalarValue::IntervalYearMonth(interval)) |
677 | 0 | } |
678 | | |
679 | 0 | pub fn interval_datetime_lit(value: &str) -> Expr { |
680 | 0 | let interval = parse_interval_day_time(value).ok(); |
681 | 0 | Expr::Literal(ScalarValue::IntervalDayTime(interval)) |
682 | 0 | } |
683 | | |
684 | 0 | pub fn interval_month_day_nano_lit(value: &str) -> Expr { |
685 | 0 | let interval = parse_interval_month_day_nano(value).ok(); |
686 | 0 | Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) |
687 | 0 | } |
688 | | |
689 | | /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] |
690 | | /// |
691 | | /// Adds methods to [`Expr`] that make it easy to set optional options |
692 | | /// such as `ORDER BY`, `FILTER` and `DISTINCT` |
693 | | /// |
694 | | /// # Example |
695 | | /// ```no_run |
696 | | /// # use datafusion_common::Result; |
697 | | /// # use datafusion_expr::test::function_stub::count; |
698 | | /// # use sqlparser::ast::NullTreatment; |
699 | | /// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; |
700 | | /// # use datafusion_expr::window_function::percent_rank; |
701 | | /// # // first_value is an aggregate function in another crate |
702 | | /// # fn first_value(_arg: Expr) -> Expr { |
703 | | /// unimplemented!() } |
704 | | /// # fn main() -> Result<()> { |
705 | | /// // Create an aggregate count, filtering on column y > 5 |
706 | | /// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; |
707 | | /// |
708 | | /// // Find the first value in an aggregate sorted by column y |
709 | | /// // equivalent to: |
710 | | /// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)` |
711 | | /// let sort_expr = col("y").sort(true, true); |
712 | | /// let agg = first_value(col("x")) |
713 | | /// .order_by(vec![sort_expr]) |
714 | | /// .null_treatment(NullTreatment::IgnoreNulls) |
715 | | /// .build()?; |
716 | | /// |
717 | | /// // Create a window expression for percent rank partitioned on column a |
718 | | /// // equivalent to: |
719 | | /// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` |
720 | | /// let window = percent_rank() |
721 | | /// .partition_by(vec![col("a")]) |
722 | | /// .order_by(vec![col("b").sort(true, true)]) |
723 | | /// .null_treatment(NullTreatment::IgnoreNulls) |
724 | | /// .build()?; |
725 | | /// # Ok(()) |
726 | | /// # } |
727 | | /// ``` |
728 | | pub trait ExprFunctionExt { |
729 | | /// Add `ORDER BY <order_by>` |
730 | | fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder; |
731 | | /// Add `FILTER <filter>` |
732 | | fn filter(self, filter: Expr) -> ExprFuncBuilder; |
733 | | /// Add `DISTINCT` |
734 | | fn distinct(self) -> ExprFuncBuilder; |
735 | | /// Add `RESPECT NULLS` or `IGNORE NULLS` |
736 | | fn null_treatment( |
737 | | self, |
738 | | null_treatment: impl Into<Option<NullTreatment>>, |
739 | | ) -> ExprFuncBuilder; |
740 | | /// Add `PARTITION BY` |
741 | | fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder; |
742 | | /// Add appropriate window frame conditions |
743 | | fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; |
744 | | } |
745 | | |
746 | | #[derive(Debug, Clone)] |
747 | | pub enum ExprFuncKind { |
748 | | Aggregate(AggregateFunction), |
749 | | Window(WindowFunction), |
750 | | } |
751 | | |
752 | | /// Implementation of [`ExprFunctionExt`]. |
753 | | /// |
754 | | /// See [`ExprFunctionExt`] for usage and examples |
755 | | #[derive(Debug, Clone)] |
756 | | pub struct ExprFuncBuilder { |
757 | | fun: Option<ExprFuncKind>, |
758 | | order_by: Option<Vec<Sort>>, |
759 | | filter: Option<Expr>, |
760 | | distinct: bool, |
761 | | null_treatment: Option<NullTreatment>, |
762 | | partition_by: Option<Vec<Expr>>, |
763 | | window_frame: Option<WindowFrame>, |
764 | | } |
765 | | |
766 | | impl ExprFuncBuilder { |
767 | | /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] |
768 | 0 | fn new(fun: Option<ExprFuncKind>) -> Self { |
769 | 0 | Self { |
770 | 0 | fun, |
771 | 0 | order_by: None, |
772 | 0 | filter: None, |
773 | 0 | distinct: false, |
774 | 0 | null_treatment: None, |
775 | 0 | partition_by: None, |
776 | 0 | window_frame: None, |
777 | 0 | } |
778 | 0 | } |
779 | | |
780 | | /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] |
781 | | /// |
782 | | /// # Errors: |
783 | | /// |
784 | | /// Returns an error if this builder [`ExprFunctionExt`] was used with an |
785 | | /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] |
786 | 0 | pub fn build(self) -> Result<Expr> { |
787 | 0 | let Self { |
788 | 0 | fun, |
789 | 0 | order_by, |
790 | 0 | filter, |
791 | 0 | distinct, |
792 | 0 | null_treatment, |
793 | 0 | partition_by, |
794 | 0 | window_frame, |
795 | 0 | } = self; |
796 | | |
797 | 0 | let Some(fun) = fun else { |
798 | 0 | return plan_err!( |
799 | 0 | "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" |
800 | 0 | ); |
801 | | }; |
802 | | |
803 | 0 | let fun_expr = match fun { |
804 | 0 | ExprFuncKind::Aggregate(mut udaf) => { |
805 | 0 | udaf.order_by = order_by; |
806 | 0 | udaf.filter = filter.map(Box::new); |
807 | 0 | udaf.distinct = distinct; |
808 | 0 | udaf.null_treatment = null_treatment; |
809 | 0 | Expr::AggregateFunction(udaf) |
810 | | } |
811 | 0 | ExprFuncKind::Window(mut udwf) => { |
812 | 0 | let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); |
813 | 0 | udwf.order_by = order_by.unwrap_or_default(); |
814 | 0 | udwf.partition_by = partition_by.unwrap_or_default(); |
815 | 0 | udwf.window_frame = |
816 | 0 | window_frame.unwrap_or(WindowFrame::new(has_order_by)); |
817 | 0 | udwf.null_treatment = null_treatment; |
818 | 0 | Expr::WindowFunction(udwf) |
819 | | } |
820 | | }; |
821 | | |
822 | 0 | Ok(fun_expr) |
823 | 0 | } |
824 | | } |
825 | | |
826 | | impl ExprFunctionExt for ExprFuncBuilder { |
827 | | /// Add `ORDER BY <order_by>` |
828 | 0 | fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder { |
829 | 0 | self.order_by = Some(order_by); |
830 | 0 | self |
831 | 0 | } |
832 | | |
833 | | /// Add `FILTER <filter>` |
834 | 0 | fn filter(mut self, filter: Expr) -> ExprFuncBuilder { |
835 | 0 | self.filter = Some(filter); |
836 | 0 | self |
837 | 0 | } |
838 | | |
839 | | /// Add `DISTINCT` |
840 | 0 | fn distinct(mut self) -> ExprFuncBuilder { |
841 | 0 | self.distinct = true; |
842 | 0 | self |
843 | 0 | } |
844 | | |
845 | | /// Add `RESPECT NULLS` or `IGNORE NULLS` |
846 | 0 | fn null_treatment( |
847 | 0 | mut self, |
848 | 0 | null_treatment: impl Into<Option<NullTreatment>>, |
849 | 0 | ) -> ExprFuncBuilder { |
850 | 0 | self.null_treatment = null_treatment.into(); |
851 | 0 | self |
852 | 0 | } |
853 | | |
854 | 0 | fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder { |
855 | 0 | self.partition_by = Some(partition_by); |
856 | 0 | self |
857 | 0 | } |
858 | | |
859 | 0 | fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { |
860 | 0 | self.window_frame = Some(window_frame); |
861 | 0 | self |
862 | 0 | } |
863 | | } |
864 | | |
865 | | impl ExprFunctionExt for Expr { |
866 | 0 | fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder { |
867 | 0 | let mut builder = match self { |
868 | 0 | Expr::AggregateFunction(udaf) => { |
869 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) |
870 | | } |
871 | 0 | Expr::WindowFunction(udwf) => { |
872 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) |
873 | | } |
874 | 0 | _ => ExprFuncBuilder::new(None), |
875 | | }; |
876 | 0 | if builder.fun.is_some() { |
877 | 0 | builder.order_by = Some(order_by); |
878 | 0 | } |
879 | 0 | builder |
880 | 0 | } |
881 | 0 | fn filter(self, filter: Expr) -> ExprFuncBuilder { |
882 | 0 | match self { |
883 | 0 | Expr::AggregateFunction(udaf) => { |
884 | 0 | let mut builder = |
885 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); |
886 | 0 | builder.filter = Some(filter); |
887 | 0 | builder |
888 | | } |
889 | 0 | _ => ExprFuncBuilder::new(None), |
890 | | } |
891 | 0 | } |
892 | 0 | fn distinct(self) -> ExprFuncBuilder { |
893 | 0 | match self { |
894 | 0 | Expr::AggregateFunction(udaf) => { |
895 | 0 | let mut builder = |
896 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); |
897 | 0 | builder.distinct = true; |
898 | 0 | builder |
899 | | } |
900 | 0 | _ => ExprFuncBuilder::new(None), |
901 | | } |
902 | 0 | } |
903 | 0 | fn null_treatment( |
904 | 0 | self, |
905 | 0 | null_treatment: impl Into<Option<NullTreatment>>, |
906 | 0 | ) -> ExprFuncBuilder { |
907 | 0 | let mut builder = match self { |
908 | 0 | Expr::AggregateFunction(udaf) => { |
909 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) |
910 | | } |
911 | 0 | Expr::WindowFunction(udwf) => { |
912 | 0 | ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) |
913 | | } |
914 | 0 | _ => ExprFuncBuilder::new(None), |
915 | | }; |
916 | 0 | if builder.fun.is_some() { |
917 | 0 | builder.null_treatment = null_treatment.into(); |
918 | 0 | } |
919 | 0 | builder |
920 | 0 | } |
921 | | |
922 | 0 | fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder { |
923 | 0 | match self { |
924 | 0 | Expr::WindowFunction(udwf) => { |
925 | 0 | let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); |
926 | 0 | builder.partition_by = Some(partition_by); |
927 | 0 | builder |
928 | | } |
929 | 0 | _ => ExprFuncBuilder::new(None), |
930 | | } |
931 | 0 | } |
932 | | |
933 | 0 | fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { |
934 | 0 | match self { |
935 | 0 | Expr::WindowFunction(udwf) => { |
936 | 0 | let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); |
937 | 0 | builder.window_frame = Some(window_frame); |
938 | 0 | builder |
939 | | } |
940 | 0 | _ => ExprFuncBuilder::new(None), |
941 | | } |
942 | 0 | } |
943 | | } |
944 | | |
945 | | #[cfg(test)] |
946 | | mod test { |
947 | | use super::*; |
948 | | |
949 | | #[test] |
950 | | fn filter_is_null_and_is_not_null() { |
951 | | let col_null = col("col1"); |
952 | | let col_not_null = ident("col2"); |
953 | | assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL"); |
954 | | assert_eq!( |
955 | | format!("{}", col_not_null.is_not_null()), |
956 | | "col2 IS NOT NULL" |
957 | | ); |
958 | | } |
959 | | } |