Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/expr_fn.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Functions for creating logical expressions
19
20
use crate::expr::{
21
    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22
    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23
};
24
use crate::function::{
25
    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26
    StateFieldsArgs,
27
};
28
use crate::{
29
    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
30
    AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF,
31
    Signature, Volatility,
32
};
33
use crate::{
34
    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
35
};
36
use arrow::compute::kernels::cast_utils::{
37
    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
38
};
39
use arrow::datatypes::{DataType, Field};
40
use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
41
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
42
use sqlparser::ast::NullTreatment;
43
use std::any::Any;
44
use std::fmt::Debug;
45
use std::ops::Not;
46
use std::sync::Arc;
47
48
/// Create a column expression based on a qualified or unqualified column name. Will
49
/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
50
///
51
/// For example:
52
///
53
/// ```rust
54
/// # use datafusion_expr::col;
55
/// let c1 = col("a");
56
/// let c2 = col("A");
57
/// assert_eq!(c1, c2);
58
///
59
/// // note how quoting with double quotes preserves the case
60
/// let c3 = col(r#""A""#);
61
/// assert_ne!(c1, c3);
62
/// ```
63
0
pub fn col(ident: impl Into<Column>) -> Expr {
64
0
    Expr::Column(ident.into())
65
0
}
66
67
/// Create an out reference column which hold a reference that has been resolved to a field
68
/// outside of the current plan.
69
0
pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
70
0
    Expr::OuterReferenceColumn(dt, ident.into())
71
0
}
72
73
/// Create an unqualified column expression from the provided name, without normalizing
74
/// the column.
75
///
76
/// For example:
77
///
78
/// ```rust
79
/// # use datafusion_expr::{col, ident};
80
/// let c1 = ident("A"); // not normalized staying as column 'A'
81
/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
82
/// assert_ne!(c1, c2);
83
///
84
/// let c3 = col(r#""A""#);
85
/// assert_eq!(c1, c3);
86
///
87
/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
88
/// let c5 = ident("t1.a"); // parses as column 't1.a'
89
/// assert_ne!(c4, c5);
90
/// ```
91
0
pub fn ident(name: impl Into<String>) -> Expr {
92
0
    Expr::Column(Column::from_name(name))
93
0
}
94
95
/// Create placeholder value that will be filled in (such as `$1`)
96
///
97
/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
98
///
99
/// # Example
100
///
101
/// ```rust
102
/// # use datafusion_expr::{placeholder};
103
/// let p = placeholder("$0"); // $0, refers to parameter 1
104
/// assert_eq!(p.to_string(), "$0")
105
/// ```
106
0
pub fn placeholder(id: impl Into<String>) -> Expr {
107
0
    Expr::Placeholder(Placeholder {
108
0
        id: id.into(),
109
0
        data_type: None,
110
0
    })
111
0
}
112
113
/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
114
///
115
/// # Example
116
///
117
/// ```rust
118
/// # use datafusion_expr::{wildcard};
119
/// let p = wildcard();
120
/// assert_eq!(p.to_string(), "*")
121
/// ```
122
0
pub fn wildcard() -> Expr {
123
0
    Expr::Wildcard {
124
0
        qualifier: None,
125
0
        options: WildcardOptions::default(),
126
0
    }
127
0
}
128
129
/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
130
0
pub fn wildcard_with_options(options: WildcardOptions) -> Expr {
131
0
    Expr::Wildcard {
132
0
        qualifier: None,
133
0
        options,
134
0
    }
135
0
}
136
137
/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
138
///
139
/// # Example
140
///
141
/// ```rust
142
/// # use datafusion_common::TableReference;
143
/// # use datafusion_expr::{qualified_wildcard};
144
/// let p = qualified_wildcard(TableReference::bare("t"));
145
/// assert_eq!(p.to_string(), "t.*")
146
/// ```
147
0
pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> Expr {
148
0
    Expr::Wildcard {
149
0
        qualifier: Some(qualifier.into()),
150
0
        options: WildcardOptions::default(),
151
0
    }
152
0
}
153
154
/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
155
0
pub fn qualified_wildcard_with_options(
156
0
    qualifier: impl Into<TableReference>,
157
0
    options: WildcardOptions,
158
0
) -> Expr {
159
0
    Expr::Wildcard {
160
0
        qualifier: Some(qualifier.into()),
161
0
        options,
162
0
    }
163
0
}
164
165
/// Return a new expression `left <op> right`
166
0
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
167
0
    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
168
0
}
169
170
/// Return a new expression with a logical AND
171
0
pub fn and(left: Expr, right: Expr) -> Expr {
172
0
    Expr::BinaryExpr(BinaryExpr::new(
173
0
        Box::new(left),
174
0
        Operator::And,
175
0
        Box::new(right),
176
0
    ))
177
0
}
178
179
/// Return a new expression with a logical OR
180
0
pub fn or(left: Expr, right: Expr) -> Expr {
181
0
    Expr::BinaryExpr(BinaryExpr::new(
182
0
        Box::new(left),
183
0
        Operator::Or,
184
0
        Box::new(right),
185
0
    ))
186
0
}
187
188
/// Return a new expression with a logical NOT
189
0
pub fn not(expr: Expr) -> Expr {
190
0
    expr.not()
191
0
}
192
193
/// Return a new expression with bitwise AND
194
0
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
195
0
    Expr::BinaryExpr(BinaryExpr::new(
196
0
        Box::new(left),
197
0
        Operator::BitwiseAnd,
198
0
        Box::new(right),
199
0
    ))
200
0
}
201
202
/// Return a new expression with bitwise OR
203
0
pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
204
0
    Expr::BinaryExpr(BinaryExpr::new(
205
0
        Box::new(left),
206
0
        Operator::BitwiseOr,
207
0
        Box::new(right),
208
0
    ))
209
0
}
210
211
/// Return a new expression with bitwise XOR
212
0
pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
213
0
    Expr::BinaryExpr(BinaryExpr::new(
214
0
        Box::new(left),
215
0
        Operator::BitwiseXor,
216
0
        Box::new(right),
217
0
    ))
218
0
}
219
220
/// Return a new expression with bitwise SHIFT RIGHT
221
0
pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
222
0
    Expr::BinaryExpr(BinaryExpr::new(
223
0
        Box::new(left),
224
0
        Operator::BitwiseShiftRight,
225
0
        Box::new(right),
226
0
    ))
227
0
}
228
229
/// Return a new expression with bitwise SHIFT LEFT
230
0
pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
231
0
    Expr::BinaryExpr(BinaryExpr::new(
232
0
        Box::new(left),
233
0
        Operator::BitwiseShiftLeft,
234
0
        Box::new(right),
235
0
    ))
236
0
}
237
238
/// Create an in_list expression
239
0
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
240
0
    Expr::InList(InList::new(Box::new(expr), list, negated))
241
0
}
242
243
/// Create an EXISTS subquery expression
244
0
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
245
0
    let outer_ref_columns = subquery.all_out_ref_exprs();
246
0
    Expr::Exists(Exists {
247
0
        subquery: Subquery {
248
0
            subquery,
249
0
            outer_ref_columns,
250
0
        },
251
0
        negated: false,
252
0
    })
253
0
}
254
255
/// Create a NOT EXISTS subquery expression
256
0
pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
257
0
    let outer_ref_columns = subquery.all_out_ref_exprs();
258
0
    Expr::Exists(Exists {
259
0
        subquery: Subquery {
260
0
            subquery,
261
0
            outer_ref_columns,
262
0
        },
263
0
        negated: true,
264
0
    })
265
0
}
266
267
/// Create an IN subquery expression
268
0
pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
269
0
    let outer_ref_columns = subquery.all_out_ref_exprs();
270
0
    Expr::InSubquery(InSubquery::new(
271
0
        Box::new(expr),
272
0
        Subquery {
273
0
            subquery,
274
0
            outer_ref_columns,
275
0
        },
276
0
        false,
277
0
    ))
278
0
}
279
280
/// Create a NOT IN subquery expression
281
0
pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
282
0
    let outer_ref_columns = subquery.all_out_ref_exprs();
283
0
    Expr::InSubquery(InSubquery::new(
284
0
        Box::new(expr),
285
0
        Subquery {
286
0
            subquery,
287
0
            outer_ref_columns,
288
0
        },
289
0
        true,
290
0
    ))
291
0
}
292
293
/// Create a scalar subquery expression
294
0
pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
295
0
    let outer_ref_columns = subquery.all_out_ref_exprs();
296
0
    Expr::ScalarSubquery(Subquery {
297
0
        subquery,
298
0
        outer_ref_columns,
299
0
    })
300
0
}
301
302
/// Create a grouping set
303
0
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
304
0
    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
305
0
}
306
307
/// Create a grouping set for all combination of `exprs`
308
0
pub fn cube(exprs: Vec<Expr>) -> Expr {
309
0
    Expr::GroupingSet(GroupingSet::Cube(exprs))
310
0
}
311
312
/// Create a grouping set for rollup
313
0
pub fn rollup(exprs: Vec<Expr>) -> Expr {
314
0
    Expr::GroupingSet(GroupingSet::Rollup(exprs))
315
0
}
316
317
/// Create a cast expression
318
0
pub fn cast(expr: Expr, data_type: DataType) -> Expr {
319
0
    Expr::Cast(Cast::new(Box::new(expr), data_type))
320
0
}
321
322
/// Create a try cast expression
323
0
pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
324
0
    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
325
0
}
326
327
/// Create is null expression
328
0
pub fn is_null(expr: Expr) -> Expr {
329
0
    Expr::IsNull(Box::new(expr))
330
0
}
331
332
/// Create is true expression
333
0
pub fn is_true(expr: Expr) -> Expr {
334
0
    Expr::IsTrue(Box::new(expr))
335
0
}
336
337
/// Create is not true expression
338
0
pub fn is_not_true(expr: Expr) -> Expr {
339
0
    Expr::IsNotTrue(Box::new(expr))
340
0
}
341
342
/// Create is false expression
343
0
pub fn is_false(expr: Expr) -> Expr {
344
0
    Expr::IsFalse(Box::new(expr))
345
0
}
346
347
/// Create is not false expression
348
0
pub fn is_not_false(expr: Expr) -> Expr {
349
0
    Expr::IsNotFalse(Box::new(expr))
350
0
}
351
352
/// Create is unknown expression
353
0
pub fn is_unknown(expr: Expr) -> Expr {
354
0
    Expr::IsUnknown(Box::new(expr))
355
0
}
356
357
/// Create is not unknown expression
358
0
pub fn is_not_unknown(expr: Expr) -> Expr {
359
0
    Expr::IsNotUnknown(Box::new(expr))
360
0
}
361
362
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
363
0
pub fn case(expr: Expr) -> CaseBuilder {
364
0
    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
365
0
}
366
367
/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
368
0
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
369
0
    CaseBuilder::new(None, vec![when], vec![then], None)
370
0
}
371
372
/// Create a Unnest expression
373
0
pub fn unnest(expr: Expr) -> Expr {
374
0
    Expr::Unnest(Unnest {
375
0
        expr: Box::new(expr),
376
0
    })
377
0
}
378
379
/// Convenience method to create a new user defined scalar function (UDF) with a
380
/// specific signature and specific return type.
381
///
382
/// Note this function does not expose all available features of [`ScalarUDF`],
383
/// such as
384
///
385
/// * computing return types based on input types
386
/// * multiple [`Signature`]s
387
/// * aliases
388
///
389
/// See [`ScalarUDF`] for details and examples on how to use the full
390
/// functionality.
391
0
pub fn create_udf(
392
0
    name: &str,
393
0
    input_types: Vec<DataType>,
394
0
    return_type: DataType,
395
0
    volatility: Volatility,
396
0
    fun: ScalarFunctionImplementation,
397
0
) -> ScalarUDF {
398
0
    ScalarUDF::from(SimpleScalarUDF::new(
399
0
        name,
400
0
        input_types,
401
0
        return_type,
402
0
        volatility,
403
0
        fun,
404
0
    ))
405
0
}
406
407
/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
408
/// return type.
409
pub struct SimpleScalarUDF {
410
    name: String,
411
    signature: Signature,
412
    return_type: DataType,
413
    fun: ScalarFunctionImplementation,
414
}
415
416
impl Debug for SimpleScalarUDF {
417
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
418
0
        f.debug_struct("ScalarUDF")
419
0
            .field("name", &self.name)
420
0
            .field("signature", &self.signature)
421
0
            .field("fun", &"<FUNC>")
422
0
            .finish()
423
0
    }
424
}
425
426
impl SimpleScalarUDF {
427
    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
428
    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
429
0
    pub fn new(
430
0
        name: impl Into<String>,
431
0
        input_types: Vec<DataType>,
432
0
        return_type: DataType,
433
0
        volatility: Volatility,
434
0
        fun: ScalarFunctionImplementation,
435
0
    ) -> Self {
436
0
        let name = name.into();
437
0
        let signature = Signature::exact(input_types, volatility);
438
0
        Self {
439
0
            name,
440
0
            signature,
441
0
            return_type,
442
0
            fun,
443
0
        }
444
0
    }
445
}
446
447
impl ScalarUDFImpl for SimpleScalarUDF {
448
0
    fn as_any(&self) -> &dyn Any {
449
0
        self
450
0
    }
451
452
0
    fn name(&self) -> &str {
453
0
        &self.name
454
0
    }
455
456
0
    fn signature(&self) -> &Signature {
457
0
        &self.signature
458
0
    }
459
460
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
461
0
        Ok(self.return_type.clone())
462
0
    }
463
464
0
    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
465
0
        (self.fun)(args)
466
0
    }
467
}
468
469
/// Creates a new UDAF with a specific signature, state type and return type.
470
/// The signature and state type must match the `Accumulator's implementation`.
471
0
pub fn create_udaf(
472
0
    name: &str,
473
0
    input_type: Vec<DataType>,
474
0
    return_type: Arc<DataType>,
475
0
    volatility: Volatility,
476
0
    accumulator: AccumulatorFactoryFunction,
477
0
    state_type: Arc<Vec<DataType>>,
478
0
) -> AggregateUDF {
479
0
    let return_type = Arc::unwrap_or_clone(return_type);
480
0
    let state_type = Arc::unwrap_or_clone(state_type);
481
0
    let state_fields = state_type
482
0
        .into_iter()
483
0
        .enumerate()
484
0
        .map(|(i, t)| Field::new(format!("{i}"), t, true))
485
0
        .collect::<Vec<_>>();
486
0
    AggregateUDF::from(SimpleAggregateUDF::new(
487
0
        name,
488
0
        input_type,
489
0
        return_type,
490
0
        volatility,
491
0
        accumulator,
492
0
        state_fields,
493
0
    ))
494
0
}
495
496
/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
497
/// return type.
498
pub struct SimpleAggregateUDF {
499
    name: String,
500
    signature: Signature,
501
    return_type: DataType,
502
    accumulator: AccumulatorFactoryFunction,
503
    state_fields: Vec<Field>,
504
}
505
506
impl Debug for SimpleAggregateUDF {
507
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
508
0
        f.debug_struct("AggregateUDF")
509
0
            .field("name", &self.name)
510
0
            .field("signature", &self.signature)
511
0
            .field("fun", &"<FUNC>")
512
0
            .finish()
513
0
    }
514
}
515
516
impl SimpleAggregateUDF {
517
    /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and
518
    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
519
0
    pub fn new(
520
0
        name: impl Into<String>,
521
0
        input_type: Vec<DataType>,
522
0
        return_type: DataType,
523
0
        volatility: Volatility,
524
0
        accumulator: AccumulatorFactoryFunction,
525
0
        state_fields: Vec<Field>,
526
0
    ) -> Self {
527
0
        let name = name.into();
528
0
        let signature = Signature::exact(input_type, volatility);
529
0
        Self {
530
0
            name,
531
0
            signature,
532
0
            return_type,
533
0
            accumulator,
534
0
            state_fields,
535
0
        }
536
0
    }
537
538
0
    pub fn new_with_signature(
539
0
        name: impl Into<String>,
540
0
        signature: Signature,
541
0
        return_type: DataType,
542
0
        accumulator: AccumulatorFactoryFunction,
543
0
        state_fields: Vec<Field>,
544
0
    ) -> Self {
545
0
        let name = name.into();
546
0
        Self {
547
0
            name,
548
0
            signature,
549
0
            return_type,
550
0
            accumulator,
551
0
            state_fields,
552
0
        }
553
0
    }
554
}
555
556
impl AggregateUDFImpl for SimpleAggregateUDF {
557
0
    fn as_any(&self) -> &dyn Any {
558
0
        self
559
0
    }
560
561
0
    fn name(&self) -> &str {
562
0
        &self.name
563
0
    }
564
565
0
    fn signature(&self) -> &Signature {
566
0
        &self.signature
567
0
    }
568
569
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
570
0
        Ok(self.return_type.clone())
571
0
    }
572
573
0
    fn accumulator(
574
0
        &self,
575
0
        acc_args: AccumulatorArgs,
576
0
    ) -> Result<Box<dyn crate::Accumulator>> {
577
0
        (self.accumulator)(acc_args)
578
0
    }
579
580
0
    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
581
0
        Ok(self.state_fields.clone())
582
0
    }
583
}
584
585
/// Creates a new UDWF with a specific signature, state type and return type.
586
///
587
/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
588
///
589
/// [`PartitionEvaluator`]: crate::PartitionEvaluator
590
0
pub fn create_udwf(
591
0
    name: &str,
592
0
    input_type: DataType,
593
0
    return_type: Arc<DataType>,
594
0
    volatility: Volatility,
595
0
    partition_evaluator_factory: PartitionEvaluatorFactory,
596
0
) -> WindowUDF {
597
0
    let return_type = Arc::unwrap_or_clone(return_type);
598
0
    WindowUDF::from(SimpleWindowUDF::new(
599
0
        name,
600
0
        input_type,
601
0
        return_type,
602
0
        volatility,
603
0
        partition_evaluator_factory,
604
0
    ))
605
0
}
606
607
/// Implements [`WindowUDFImpl`] for functions that have a single signature and
608
/// return type.
609
pub struct SimpleWindowUDF {
610
    name: String,
611
    signature: Signature,
612
    return_type: DataType,
613
    partition_evaluator_factory: PartitionEvaluatorFactory,
614
}
615
616
impl Debug for SimpleWindowUDF {
617
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
618
0
        f.debug_struct("WindowUDF")
619
0
            .field("name", &self.name)
620
0
            .field("signature", &self.signature)
621
0
            .field("return_type", &"<func>")
622
0
            .field("partition_evaluator_factory", &"<FUNC>")
623
0
            .finish()
624
0
    }
625
}
626
627
impl SimpleWindowUDF {
628
    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
629
    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
630
0
    pub fn new(
631
0
        name: impl Into<String>,
632
0
        input_type: DataType,
633
0
        return_type: DataType,
634
0
        volatility: Volatility,
635
0
        partition_evaluator_factory: PartitionEvaluatorFactory,
636
0
    ) -> Self {
637
0
        let name = name.into();
638
0
        let signature = Signature::exact([input_type].to_vec(), volatility);
639
0
        Self {
640
0
            name,
641
0
            signature,
642
0
            return_type,
643
0
            partition_evaluator_factory,
644
0
        }
645
0
    }
646
}
647
648
impl WindowUDFImpl for SimpleWindowUDF {
649
0
    fn as_any(&self) -> &dyn Any {
650
0
        self
651
0
    }
652
653
0
    fn name(&self) -> &str {
654
0
        &self.name
655
0
    }
656
657
0
    fn signature(&self) -> &Signature {
658
0
        &self.signature
659
0
    }
660
661
0
    fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
662
0
        (self.partition_evaluator_factory)()
663
0
    }
664
665
0
    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
666
0
        Ok(Field::new(
667
0
            field_args.name(),
668
0
            self.return_type.clone(),
669
0
            true,
670
0
        ))
671
0
    }
672
}
673
674
0
pub fn interval_year_month_lit(value: &str) -> Expr {
675
0
    let interval = parse_interval_year_month(value).ok();
676
0
    Expr::Literal(ScalarValue::IntervalYearMonth(interval))
677
0
}
678
679
0
pub fn interval_datetime_lit(value: &str) -> Expr {
680
0
    let interval = parse_interval_day_time(value).ok();
681
0
    Expr::Literal(ScalarValue::IntervalDayTime(interval))
682
0
}
683
684
0
pub fn interval_month_day_nano_lit(value: &str) -> Expr {
685
0
    let interval = parse_interval_month_day_nano(value).ok();
686
0
    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval))
687
0
}
688
689
/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
690
///
691
/// Adds methods to [`Expr`] that make it easy to set optional options
692
/// such as `ORDER BY`, `FILTER` and `DISTINCT`
693
///
694
/// # Example
695
/// ```no_run
696
/// # use datafusion_common::Result;
697
/// # use datafusion_expr::test::function_stub::count;
698
/// # use sqlparser::ast::NullTreatment;
699
/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
700
/// # use datafusion_expr::window_function::percent_rank;
701
/// # // first_value is an aggregate function in another crate
702
/// # fn first_value(_arg: Expr) -> Expr {
703
/// unimplemented!() }
704
/// # fn main() -> Result<()> {
705
/// // Create an aggregate count, filtering on column y > 5
706
/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
707
///
708
/// // Find the first value in an aggregate sorted by column y
709
/// // equivalent to:
710
/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
711
/// let sort_expr = col("y").sort(true, true);
712
/// let agg = first_value(col("x"))
713
///     .order_by(vec![sort_expr])
714
///     .null_treatment(NullTreatment::IgnoreNulls)
715
///     .build()?;
716
///
717
/// // Create a window expression for percent rank partitioned on column a
718
/// // equivalent to:
719
/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
720
/// let window = percent_rank()
721
///     .partition_by(vec![col("a")])
722
///     .order_by(vec![col("b").sort(true, true)])
723
///     .null_treatment(NullTreatment::IgnoreNulls)
724
///     .build()?;
725
/// #     Ok(())
726
/// # }
727
/// ```
728
pub trait ExprFunctionExt {
729
    /// Add `ORDER BY <order_by>`
730
    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
731
    /// Add `FILTER <filter>`
732
    fn filter(self, filter: Expr) -> ExprFuncBuilder;
733
    /// Add `DISTINCT`
734
    fn distinct(self) -> ExprFuncBuilder;
735
    /// Add `RESPECT NULLS` or `IGNORE NULLS`
736
    fn null_treatment(
737
        self,
738
        null_treatment: impl Into<Option<NullTreatment>>,
739
    ) -> ExprFuncBuilder;
740
    /// Add `PARTITION BY`
741
    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
742
    /// Add appropriate window frame conditions
743
    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
744
}
745
746
#[derive(Debug, Clone)]
747
pub enum ExprFuncKind {
748
    Aggregate(AggregateFunction),
749
    Window(WindowFunction),
750
}
751
752
/// Implementation of [`ExprFunctionExt`].
753
///
754
/// See [`ExprFunctionExt`] for usage and examples
755
#[derive(Debug, Clone)]
756
pub struct ExprFuncBuilder {
757
    fun: Option<ExprFuncKind>,
758
    order_by: Option<Vec<Sort>>,
759
    filter: Option<Expr>,
760
    distinct: bool,
761
    null_treatment: Option<NullTreatment>,
762
    partition_by: Option<Vec<Expr>>,
763
    window_frame: Option<WindowFrame>,
764
}
765
766
impl ExprFuncBuilder {
767
    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
768
0
    fn new(fun: Option<ExprFuncKind>) -> Self {
769
0
        Self {
770
0
            fun,
771
0
            order_by: None,
772
0
            filter: None,
773
0
            distinct: false,
774
0
            null_treatment: None,
775
0
            partition_by: None,
776
0
            window_frame: None,
777
0
        }
778
0
    }
779
780
    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
781
    ///
782
    /// # Errors:
783
    ///
784
    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
785
    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
786
0
    pub fn build(self) -> Result<Expr> {
787
0
        let Self {
788
0
            fun,
789
0
            order_by,
790
0
            filter,
791
0
            distinct,
792
0
            null_treatment,
793
0
            partition_by,
794
0
            window_frame,
795
0
        } = self;
796
797
0
        let Some(fun) = fun else {
798
0
            return plan_err!(
799
0
                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
800
0
            );
801
        };
802
803
0
        let fun_expr = match fun {
804
0
            ExprFuncKind::Aggregate(mut udaf) => {
805
0
                udaf.order_by = order_by;
806
0
                udaf.filter = filter.map(Box::new);
807
0
                udaf.distinct = distinct;
808
0
                udaf.null_treatment = null_treatment;
809
0
                Expr::AggregateFunction(udaf)
810
            }
811
0
            ExprFuncKind::Window(mut udwf) => {
812
0
                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
813
0
                udwf.order_by = order_by.unwrap_or_default();
814
0
                udwf.partition_by = partition_by.unwrap_or_default();
815
0
                udwf.window_frame =
816
0
                    window_frame.unwrap_or(WindowFrame::new(has_order_by));
817
0
                udwf.null_treatment = null_treatment;
818
0
                Expr::WindowFunction(udwf)
819
            }
820
        };
821
822
0
        Ok(fun_expr)
823
0
    }
824
}
825
826
impl ExprFunctionExt for ExprFuncBuilder {
827
    /// Add `ORDER BY <order_by>`
828
0
    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
829
0
        self.order_by = Some(order_by);
830
0
        self
831
0
    }
832
833
    /// Add `FILTER <filter>`
834
0
    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
835
0
        self.filter = Some(filter);
836
0
        self
837
0
    }
838
839
    /// Add `DISTINCT`
840
0
    fn distinct(mut self) -> ExprFuncBuilder {
841
0
        self.distinct = true;
842
0
        self
843
0
    }
844
845
    /// Add `RESPECT NULLS` or `IGNORE NULLS`
846
0
    fn null_treatment(
847
0
        mut self,
848
0
        null_treatment: impl Into<Option<NullTreatment>>,
849
0
    ) -> ExprFuncBuilder {
850
0
        self.null_treatment = null_treatment.into();
851
0
        self
852
0
    }
853
854
0
    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
855
0
        self.partition_by = Some(partition_by);
856
0
        self
857
0
    }
858
859
0
    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
860
0
        self.window_frame = Some(window_frame);
861
0
        self
862
0
    }
863
}
864
865
impl ExprFunctionExt for Expr {
866
0
    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
867
0
        let mut builder = match self {
868
0
            Expr::AggregateFunction(udaf) => {
869
0
                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
870
            }
871
0
            Expr::WindowFunction(udwf) => {
872
0
                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
873
            }
874
0
            _ => ExprFuncBuilder::new(None),
875
        };
876
0
        if builder.fun.is_some() {
877
0
            builder.order_by = Some(order_by);
878
0
        }
879
0
        builder
880
0
    }
881
0
    fn filter(self, filter: Expr) -> ExprFuncBuilder {
882
0
        match self {
883
0
            Expr::AggregateFunction(udaf) => {
884
0
                let mut builder =
885
0
                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
886
0
                builder.filter = Some(filter);
887
0
                builder
888
            }
889
0
            _ => ExprFuncBuilder::new(None),
890
        }
891
0
    }
892
0
    fn distinct(self) -> ExprFuncBuilder {
893
0
        match self {
894
0
            Expr::AggregateFunction(udaf) => {
895
0
                let mut builder =
896
0
                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
897
0
                builder.distinct = true;
898
0
                builder
899
            }
900
0
            _ => ExprFuncBuilder::new(None),
901
        }
902
0
    }
903
0
    fn null_treatment(
904
0
        self,
905
0
        null_treatment: impl Into<Option<NullTreatment>>,
906
0
    ) -> ExprFuncBuilder {
907
0
        let mut builder = match self {
908
0
            Expr::AggregateFunction(udaf) => {
909
0
                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
910
            }
911
0
            Expr::WindowFunction(udwf) => {
912
0
                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
913
            }
914
0
            _ => ExprFuncBuilder::new(None),
915
        };
916
0
        if builder.fun.is_some() {
917
0
            builder.null_treatment = null_treatment.into();
918
0
        }
919
0
        builder
920
0
    }
921
922
0
    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
923
0
        match self {
924
0
            Expr::WindowFunction(udwf) => {
925
0
                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
926
0
                builder.partition_by = Some(partition_by);
927
0
                builder
928
            }
929
0
            _ => ExprFuncBuilder::new(None),
930
        }
931
0
    }
932
933
0
    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
934
0
        match self {
935
0
            Expr::WindowFunction(udwf) => {
936
0
                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
937
0
                builder.window_frame = Some(window_frame);
938
0
                builder
939
            }
940
0
            _ => ExprFuncBuilder::new(None),
941
        }
942
0
    }
943
}
944
945
#[cfg(test)]
946
mod test {
947
    use super::*;
948
949
    #[test]
950
    fn filter_is_null_and_is_not_null() {
951
        let col_null = col("col1");
952
        let col_not_null = ident("col2");
953
        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
954
        assert_eq!(
955
            format!("{}", col_not_null.is_not_null()),
956
            "col2 IS NOT NULL"
957
        );
958
    }
959
}