Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/udf.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`ScalarUDF`]: Scalar User Defined Functions
19
20
use crate::expr::schema_name_from_exprs_comma_seperated_without_space;
21
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
22
use crate::sort_properties::{ExprProperties, SortProperties};
23
use crate::{
24
    ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature,
25
};
26
use arrow::datatypes::DataType;
27
use datafusion_common::{not_impl_err, ExprSchema, Result};
28
use datafusion_expr_common::interval_arithmetic::Interval;
29
use std::any::Any;
30
use std::cmp::Ordering;
31
use std::fmt::Debug;
32
use std::hash::{DefaultHasher, Hash, Hasher};
33
use std::sync::Arc;
34
35
/// Logical representation of a Scalar User Defined Function.
36
///
37
/// A scalar function produces a single row output for each row of input. This
38
/// struct contains the information DataFusion needs to plan and invoke
39
/// functions you supply such name, type signature, return type, and actual
40
/// implementation.
41
///
42
/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
43
///
44
/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
45
///    access (examples in  [`advanced_udf.rs`]).
46
///
47
/// See [`Self::call`] to invoke a `ScalarUDF` with arguments.
48
///
49
/// # API Note
50
///
51
/// This is a separate struct from `ScalarUDFImpl` to maintain backwards
52
/// compatibility with the older API.
53
///
54
/// [`create_udf`]: crate::expr_fn::create_udf
55
/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
56
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
57
#[derive(Debug, Clone)]
58
pub struct ScalarUDF {
59
    inner: Arc<dyn ScalarUDFImpl>,
60
}
61
62
impl PartialEq for ScalarUDF {
63
0
    fn eq(&self, other: &Self) -> bool {
64
0
        self.inner.equals(other.inner.as_ref())
65
0
    }
66
}
67
68
// Manual implementation based on `ScalarUDFImpl::equals`
69
impl PartialOrd for ScalarUDF {
70
0
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
71
0
        match self.name().partial_cmp(other.name()) {
72
0
            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
73
0
            cmp => cmp,
74
        }
75
0
    }
76
}
77
78
impl Eq for ScalarUDF {}
79
80
impl Hash for ScalarUDF {
81
0
    fn hash<H: Hasher>(&self, state: &mut H) {
82
0
        self.inner.hash_value().hash(state)
83
0
    }
84
}
85
86
impl ScalarUDF {
87
    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
88
    ///
89
    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
90
0
    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
91
0
    where
92
0
        F: ScalarUDFImpl + 'static,
93
0
    {
94
0
        Self {
95
0
            inner: Arc::new(fun),
96
0
        }
97
0
    }
98
99
    /// Return the underlying [`ScalarUDFImpl`] trait object for this function
100
0
    pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
101
0
        &self.inner
102
0
    }
103
104
    /// Adds additional names that can be used to invoke this function, in
105
    /// addition to `name`
106
    ///
107
    /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
108
0
    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
109
0
        Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
110
0
    }
111
112
    /// Returns a [`Expr`] logical expression to call this UDF with specified
113
    /// arguments.
114
    ///
115
    /// This utility allows easily calling UDFs
116
    ///
117
    /// # Example
118
    /// ```no_run
119
    /// use datafusion_expr::{col, lit, ScalarUDF};
120
    /// # fn my_udf() -> ScalarUDF { unimplemented!() }
121
    /// let my_func: ScalarUDF = my_udf();
122
    /// // Create an expr for `my_func(a, 12.3)`
123
    /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
124
    /// ```
125
0
    pub fn call(&self, args: Vec<Expr>) -> Expr {
126
0
        Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
127
0
            Arc::new(self.clone()),
128
0
            args,
129
0
        ))
130
0
    }
131
132
    /// Returns this function's name.
133
    ///
134
    /// See [`ScalarUDFImpl::name`] for more details.
135
0
    pub fn name(&self) -> &str {
136
0
        self.inner.name()
137
0
    }
138
139
    /// Returns this function's display_name.
140
    ///
141
    /// See [`ScalarUDFImpl::display_name`] for more details
142
0
    pub fn display_name(&self, args: &[Expr]) -> Result<String> {
143
0
        self.inner.display_name(args)
144
0
    }
145
146
    /// Returns this function's schema_name.
147
    ///
148
    /// See [`ScalarUDFImpl::schema_name`] for more details
149
0
    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
150
0
        self.inner.schema_name(args)
151
0
    }
152
153
    /// Returns the aliases for this function.
154
    ///
155
    /// See [`ScalarUDF::with_aliases`] for more details
156
0
    pub fn aliases(&self) -> &[String] {
157
0
        self.inner.aliases()
158
0
    }
159
160
    /// Returns this function's [`Signature`] (what input types are accepted).
161
    ///
162
    /// See [`ScalarUDFImpl::signature`] for more details.
163
0
    pub fn signature(&self) -> &Signature {
164
0
        self.inner.signature()
165
0
    }
166
167
    /// The datatype this function returns given the input argument input types.
168
    /// This function is used when the input arguments are [`Expr`]s.
169
    ///
170
    ///
171
    /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
172
0
    pub fn return_type_from_exprs(
173
0
        &self,
174
0
        args: &[Expr],
175
0
        schema: &dyn ExprSchema,
176
0
        arg_types: &[DataType],
177
0
    ) -> Result<DataType> {
178
0
        // If the implementation provides a return_type_from_exprs, use it
179
0
        self.inner.return_type_from_exprs(args, schema, arg_types)
180
0
    }
181
182
    /// Do the function rewrite
183
    ///
184
    /// See [`ScalarUDFImpl::simplify`] for more details.
185
0
    pub fn simplify(
186
0
        &self,
187
0
        args: Vec<Expr>,
188
0
        info: &dyn SimplifyInfo,
189
0
    ) -> Result<ExprSimplifyResult> {
190
0
        self.inner.simplify(args, info)
191
0
    }
192
193
    /// Invoke the function on `args`, returning the appropriate result.
194
    ///
195
    /// See [`ScalarUDFImpl::invoke`] for more details.
196
0
    pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
197
0
        self.inner.invoke(args)
198
0
    }
199
200
0
    pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
201
0
        self.inner.is_nullable(args, schema)
202
0
    }
203
204
    /// Invoke the function without `args` but number of rows, returning the appropriate result.
205
    ///
206
    /// See [`ScalarUDFImpl::invoke_no_args`] for more details.
207
0
    pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
208
0
        self.inner.invoke_no_args(number_rows)
209
0
    }
210
211
    /// Returns a `ScalarFunctionImplementation` that can invoke the function
212
    /// during execution
213
    #[deprecated(since = "42.0.0", note = "Use `invoke` or `invoke_no_args` instead")]
214
0
    pub fn fun(&self) -> ScalarFunctionImplementation {
215
0
        let captured = Arc::clone(&self.inner);
216
0
        Arc::new(move |args| captured.invoke(args))
217
0
    }
218
219
    /// Get the circuits of inner implementation
220
0
    pub fn short_circuits(&self) -> bool {
221
0
        self.inner.short_circuits()
222
0
    }
223
224
    /// Computes the output interval for a [`ScalarUDF`], given the input
225
    /// intervals.
226
    ///
227
    /// # Parameters
228
    ///
229
    /// * `inputs` are the intervals for the inputs (children) of this function.
230
    ///
231
    /// # Example
232
    ///
233
    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
234
    /// then the output interval would be `[0, 3]`.
235
0
    pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
236
0
        self.inner.evaluate_bounds(inputs)
237
0
    }
238
239
    /// Updates bounds for child expressions, given a known interval for this
240
    /// function. This is used to propagate constraints down through an expression
241
    /// tree.
242
    ///
243
    /// # Parameters
244
    ///
245
    /// * `interval` is the currently known interval for this function.
246
    /// * `inputs` are the current intervals for the inputs (children) of this function.
247
    ///
248
    /// # Returns
249
    ///
250
    /// A `Vec` of new intervals for the children, in order.
251
    ///
252
    /// If constraint propagation reveals an infeasibility for any child, returns
253
    /// [`None`]. If none of the children intervals change as a result of
254
    /// propagation, may return an empty vector instead of cloning `children`.
255
    /// This is the default (and conservative) return value.
256
    ///
257
    /// # Example
258
    ///
259
    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
260
    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
261
0
    pub fn propagate_constraints(
262
0
        &self,
263
0
        interval: &Interval,
264
0
        inputs: &[&Interval],
265
0
    ) -> Result<Option<Vec<Interval>>> {
266
0
        self.inner.propagate_constraints(interval, inputs)
267
0
    }
268
269
    /// Calculates the [`SortProperties`] of this function based on its
270
    /// children's properties.
271
0
    pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
272
0
        self.inner.output_ordering(inputs)
273
0
    }
274
275
    /// See [`ScalarUDFImpl::coerce_types`] for more details.
276
0
    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
277
0
        self.inner.coerce_types(arg_types)
278
0
    }
279
280
    /// Returns the documentation for this Scalar UDF.
281
    ///
282
    /// Documentation can be accessed programmatically as well as
283
    /// generating publicly facing documentation.
284
0
    pub fn documentation(&self) -> Option<&Documentation> {
285
0
        self.inner.documentation()
286
0
    }
287
}
288
289
impl<F> From<F> for ScalarUDF
290
where
291
    F: ScalarUDFImpl + Send + Sync + 'static,
292
{
293
0
    fn from(fun: F) -> Self {
294
0
        Self::new_from_impl(fun)
295
0
    }
296
}
297
298
/// Trait for implementing [`ScalarUDF`].
299
///
300
/// This trait exposes the full API for implementing user defined functions and
301
/// can be used to implement any function.
302
///
303
/// See [`advanced_udf.rs`] for a full example with complete implementation and
304
/// [`ScalarUDF`] for other available options.
305
///
306
///
307
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
308
/// # Basic Example
309
/// ```
310
/// # use std::any::Any;
311
/// # use std::sync::OnceLock;
312
/// # use arrow::datatypes::DataType;
313
/// # use datafusion_common::{DataFusionError, plan_err, Result};
314
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
315
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
316
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
317
///
318
/// #[derive(Debug)]
319
/// struct AddOne {
320
///   signature: Signature,
321
/// }
322
///
323
/// impl AddOne {
324
///   fn new() -> Self {
325
///     Self {
326
///       signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
327
///      }
328
///   }
329
/// }
330
///  
331
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
332
///
333
/// fn get_doc() -> &'static Documentation {
334
///     DOCUMENTATION.get_or_init(|| {
335
///         Documentation::builder()
336
///             .with_doc_section(DOC_SECTION_MATH)
337
///             .with_description("Add one to an int32")
338
///             .with_syntax_example("add_one(2)")
339
///             .with_argument("arg1", "The int32 number to add one to")
340
///             .build()
341
///             .unwrap()
342
///     })
343
/// }
344
///
345
/// /// Implement the ScalarUDFImpl trait for AddOne
346
/// impl ScalarUDFImpl for AddOne {
347
///    fn as_any(&self) -> &dyn Any { self }
348
///    fn name(&self) -> &str { "add_one" }
349
///    fn signature(&self) -> &Signature { &self.signature }
350
///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
351
///      if !matches!(args.get(0), Some(&DataType::Int32)) {
352
///        return plan_err!("add_one only accepts Int32 arguments");
353
///      }
354
///      Ok(DataType::Int32)
355
///    }
356
///    // The actual implementation would add one to the argument
357
///    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
358
///    fn documentation(&self) -> Option<&Documentation> {
359
///         Some(get_doc())
360
///     }
361
/// }
362
///
363
/// // Create a new ScalarUDF from the implementation
364
/// let add_one = ScalarUDF::from(AddOne::new());
365
///
366
/// // Call the function `add_one(col)`
367
/// let expr = add_one.call(vec![col("a")]);
368
/// ```
369
pub trait ScalarUDFImpl: Debug + Send + Sync {
370
    // Note: When adding any methods (with default implementations), remember to add them also
371
    // into the AliasedScalarUDFImpl below!
372
373
    /// Returns this object as an [`Any`] trait object
374
    fn as_any(&self) -> &dyn Any;
375
376
    /// Returns this function's name
377
    fn name(&self) -> &str;
378
379
    /// Returns the user-defined display name of the UDF given the arguments
380
0
    fn display_name(&self, args: &[Expr]) -> Result<String> {
381
0
        let names: Vec<String> = args.iter().map(ToString::to_string).collect();
382
0
        // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
383
0
        Ok(format!("{}({})", self.name(), names.join(",")))
384
0
    }
385
386
    /// Returns the name of the column this expression would create
387
    ///
388
    /// See [`Expr::schema_name`] for details
389
0
    fn schema_name(&self, args: &[Expr]) -> Result<String> {
390
0
        Ok(format!(
391
0
            "{}({})",
392
0
            self.name(),
393
0
            schema_name_from_exprs_comma_seperated_without_space(args)?
394
        ))
395
0
    }
396
397
    /// Returns the function's [`Signature`] for information about what input
398
    /// types are accepted and the function's Volatility.
399
    fn signature(&self) -> &Signature;
400
401
    /// What [`DataType`] will be returned by this function, given the types of
402
    /// the arguments.
403
    ///
404
    /// # Notes
405
    ///
406
    /// If you provide an implementation for [`Self::return_type_from_exprs`],
407
    /// DataFusion will not call `return_type` (this function). In this case it
408
    /// is recommended to return [`DataFusionError::Internal`].
409
    ///
410
    /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
411
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
412
413
    /// What [`DataType`] will be returned by this function, given the
414
    /// arguments?
415
    ///
416
    /// Note most UDFs should implement [`Self::return_type`] and not this
417
    /// function. The output type for most functions only depends on the types
418
    /// of their inputs (e.g. `sqrt(f32)` is always `f32`).
419
    ///
420
    /// By default, this function calls [`Self::return_type`] with the
421
    /// types of each argument.
422
    ///
423
    /// This method can be overridden for functions that return different
424
    /// *types* based on the *values* of their arguments.
425
    ///
426
    /// For example, the following two function calls get the same argument
427
    /// types (something and a `Utf8` string) but return different types based
428
    /// on the value of the second argument:
429
    ///
430
    /// * `arrow_cast(x, 'Int16')` --> `Int16`
431
    /// * `arrow_cast(x, 'Float32')` --> `Float32`
432
    ///
433
    /// # Notes:
434
    ///
435
    /// This function must consistently return the same type for the same
436
    /// logical input even if the input is simplified (e.g. it must return the same
437
    /// value for `('foo' | 'bar')` as it does for ('foobar').
438
0
    fn return_type_from_exprs(
439
0
        &self,
440
0
        _args: &[Expr],
441
0
        _schema: &dyn ExprSchema,
442
0
        arg_types: &[DataType],
443
0
    ) -> Result<DataType> {
444
0
        self.return_type(arg_types)
445
0
    }
446
447
0
    fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
448
0
        true
449
0
    }
450
451
    /// Invoke the function on `args`, returning the appropriate result
452
    ///
453
    /// The function will be invoked passed with the slice of [`ColumnarValue`]
454
    /// (either scalar or array).
455
    ///
456
    /// If the function does not take any arguments, please use [invoke_no_args]
457
    /// instead and return [not_impl_err] for this function.
458
    ///
459
    ///
460
    /// # Performance
461
    ///
462
    /// For the best performance, the implementations of `invoke` should handle
463
    /// the common case when one or more of their arguments are constant values
464
    /// (aka  [`ColumnarValue::Scalar`]).
465
    ///
466
    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
467
    /// to arrays, which will likely be simpler code, but be slower.
468
    ///
469
    /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
470
    fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;
471
472
    /// Invoke the function without `args`, instead the number of rows are provided,
473
    /// returning the appropriate result.
474
0
    fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
475
0
        not_impl_err!(
476
0
            "Function {} does not implement invoke_no_args but called",
477
0
            self.name()
478
0
        )
479
0
    }
480
481
    /// Returns any aliases (alternate names) for this function.
482
    ///
483
    /// Aliases can be used to invoke the same function using different names.
484
    /// For example in some databases `now()` and `current_timestamp()` are
485
    /// aliases for the same function. This behavior can be obtained by
486
    /// returning `current_timestamp` as an alias for the `now` function.
487
    ///
488
    /// Note: `aliases` should only include names other than [`Self::name`].
489
    /// Defaults to `[]` (no aliases)
490
0
    fn aliases(&self) -> &[String] {
491
0
        &[]
492
0
    }
493
494
    /// Optionally apply per-UDF simplification / rewrite rules.
495
    ///
496
    /// This can be used to apply function specific simplification rules during
497
    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
498
    /// implementation does nothing.
499
    ///
500
    /// Note that DataFusion handles simplifying arguments and  "constant
501
    /// folding" (replacing a function call with constant arguments such as
502
    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
503
    /// optimizations manually for specific UDFs.
504
    ///
505
    /// # Arguments
506
    /// * `args`: The arguments of the function
507
    /// * `info`: The necessary information for simplification
508
    ///
509
    /// # Returns
510
    /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
511
    /// if the function cannot be simplified, the arguments *MUST* be returned
512
    /// unmodified
513
0
    fn simplify(
514
0
        &self,
515
0
        args: Vec<Expr>,
516
0
        _info: &dyn SimplifyInfo,
517
0
    ) -> Result<ExprSimplifyResult> {
518
0
        Ok(ExprSimplifyResult::Original(args))
519
0
    }
520
521
    /// Returns true if some of this `exprs` subexpressions may not be evaluated
522
    /// and thus any side effects (like divide by zero) may not be encountered
523
    /// Setting this to true prevents certain optimizations such as common subexpression elimination
524
0
    fn short_circuits(&self) -> bool {
525
0
        false
526
0
    }
527
528
    /// Computes the output interval for a [`ScalarUDFImpl`], given the input
529
    /// intervals.
530
    ///
531
    /// # Parameters
532
    ///
533
    /// * `children` are the intervals for the children (inputs) of this function.
534
    ///
535
    /// # Example
536
    ///
537
    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
538
    /// then the output interval would be `[0, 3]`.
539
0
    fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
540
0
        // We cannot assume the input datatype is the same of output type.
541
0
        Interval::make_unbounded(&DataType::Null)
542
0
    }
543
544
    /// Updates bounds for child expressions, given a known interval for this
545
    /// function. This is used to propagate constraints down through an expression
546
    /// tree.
547
    ///
548
    /// # Parameters
549
    ///
550
    /// * `interval` is the currently known interval for this function.
551
    /// * `inputs` are the current intervals for the inputs (children) of this function.
552
    ///
553
    /// # Returns
554
    ///
555
    /// A `Vec` of new intervals for the children, in order.
556
    ///
557
    /// If constraint propagation reveals an infeasibility for any child, returns
558
    /// [`None`]. If none of the children intervals change as a result of
559
    /// propagation, may return an empty vector instead of cloning `children`.
560
    /// This is the default (and conservative) return value.
561
    ///
562
    /// # Example
563
    ///
564
    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
565
    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
566
0
    fn propagate_constraints(
567
0
        &self,
568
0
        _interval: &Interval,
569
0
        _inputs: &[&Interval],
570
0
    ) -> Result<Option<Vec<Interval>>> {
571
0
        Ok(Some(vec![]))
572
0
    }
573
574
    /// Calculates the [`SortProperties`] of this function based on its
575
    /// children's properties.
576
0
    fn output_ordering(&self, _inputs: &[ExprProperties]) -> Result<SortProperties> {
577
0
        Ok(SortProperties::Unordered)
578
0
    }
579
580
    /// Coerce arguments of a function call to types that the function can evaluate.
581
    ///
582
    /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
583
    /// UDFs should return one of the other variants of `TypeSignature` which handle common
584
    /// cases
585
    ///
586
    /// See the [type coercion module](crate::type_coercion)
587
    /// documentation for more details on type coercion
588
    ///
589
    /// For example, if your function requires a floating point arguments, but the user calls
590
    /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
591
    /// to ensure the argument is converted to `1::double`
592
    ///
593
    /// # Parameters
594
    /// * `arg_types`: The argument types of the arguments  this function with
595
    ///
596
    /// # Return value
597
    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
598
    /// arguments to these specific types.
599
0
    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
600
0
        not_impl_err!("Function {} does not implement coerce_types", self.name())
601
0
    }
602
603
    /// Return true if this scalar UDF is equal to the other.
604
    ///
605
    /// Allows customizing the equality of scalar UDFs.
606
    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
607
    ///
608
    /// - reflexive: `a.equals(a)`;
609
    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
610
    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
611
    ///
612
    /// By default, compares [`Self::name`] and [`Self::signature`].
613
0
    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
614
0
        self.name() == other.name() && self.signature() == other.signature()
615
0
    }
616
617
    /// Returns a hash value for this scalar UDF.
618
    ///
619
    /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`],
620
    /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
621
    ///
622
    /// By default, hashes [`Self::name`] and [`Self::signature`].
623
0
    fn hash_value(&self) -> u64 {
624
0
        let hasher = &mut DefaultHasher::new();
625
0
        self.name().hash(hasher);
626
0
        self.signature().hash(hasher);
627
0
        hasher.finish()
628
0
    }
629
630
    /// Returns the documentation for this Scalar UDF.
631
    ///
632
    /// Documentation can be accessed programmatically as well as
633
    /// generating publicly facing documentation.
634
0
    fn documentation(&self) -> Option<&Documentation> {
635
0
        None
636
0
    }
637
}
638
639
/// ScalarUDF that adds an alias to the underlying function. It is better to
640
/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
641
#[derive(Debug)]
642
struct AliasedScalarUDFImpl {
643
    inner: Arc<dyn ScalarUDFImpl>,
644
    aliases: Vec<String>,
645
}
646
647
impl AliasedScalarUDFImpl {
648
0
    pub fn new(
649
0
        inner: Arc<dyn ScalarUDFImpl>,
650
0
        new_aliases: impl IntoIterator<Item = &'static str>,
651
0
    ) -> Self {
652
0
        let mut aliases = inner.aliases().to_vec();
653
0
        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
654
0
        Self { inner, aliases }
655
0
    }
656
}
657
658
impl ScalarUDFImpl for AliasedScalarUDFImpl {
659
0
    fn as_any(&self) -> &dyn Any {
660
0
        self
661
0
    }
662
663
0
    fn name(&self) -> &str {
664
0
        self.inner.name()
665
0
    }
666
667
0
    fn display_name(&self, args: &[Expr]) -> Result<String> {
668
0
        self.inner.display_name(args)
669
0
    }
670
671
0
    fn schema_name(&self, args: &[Expr]) -> Result<String> {
672
0
        self.inner.schema_name(args)
673
0
    }
674
675
0
    fn signature(&self) -> &Signature {
676
0
        self.inner.signature()
677
0
    }
678
679
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
680
0
        self.inner.return_type(arg_types)
681
0
    }
682
683
0
    fn aliases(&self) -> &[String] {
684
0
        &self.aliases
685
0
    }
686
687
0
    fn return_type_from_exprs(
688
0
        &self,
689
0
        args: &[Expr],
690
0
        schema: &dyn ExprSchema,
691
0
        arg_types: &[DataType],
692
0
    ) -> Result<DataType> {
693
0
        self.inner.return_type_from_exprs(args, schema, arg_types)
694
0
    }
695
696
0
    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
697
0
        self.inner.invoke(args)
698
0
    }
699
700
0
    fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
701
0
        self.inner.invoke_no_args(number_rows)
702
0
    }
703
704
0
    fn simplify(
705
0
        &self,
706
0
        args: Vec<Expr>,
707
0
        info: &dyn SimplifyInfo,
708
0
    ) -> Result<ExprSimplifyResult> {
709
0
        self.inner.simplify(args, info)
710
0
    }
711
712
0
    fn short_circuits(&self) -> bool {
713
0
        self.inner.short_circuits()
714
0
    }
715
716
0
    fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
717
0
        self.inner.evaluate_bounds(input)
718
0
    }
719
720
0
    fn propagate_constraints(
721
0
        &self,
722
0
        interval: &Interval,
723
0
        inputs: &[&Interval],
724
0
    ) -> Result<Option<Vec<Interval>>> {
725
0
        self.inner.propagate_constraints(interval, inputs)
726
0
    }
727
728
0
    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
729
0
        self.inner.output_ordering(inputs)
730
0
    }
731
732
0
    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
733
0
        self.inner.coerce_types(arg_types)
734
0
    }
735
736
0
    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
737
0
        if let Some(other) = other.as_any().downcast_ref::<AliasedScalarUDFImpl>() {
738
0
            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
739
        } else {
740
0
            false
741
        }
742
0
    }
743
744
0
    fn hash_value(&self) -> u64 {
745
0
        let hasher = &mut DefaultHasher::new();
746
0
        self.inner.hash_value().hash(hasher);
747
0
        self.aliases.hash(hasher);
748
0
        hasher.finish()
749
0
    }
750
751
0
    fn documentation(&self) -> Option<&Documentation> {
752
0
        self.inner.documentation()
753
0
    }
754
}
755
756
// Scalar UDF doc sections for use in public documentation
757
pub mod scalar_doc_sections {
758
    use crate::DocSection;
759
760
    pub fn doc_sections() -> Vec<DocSection> {
761
        vec![
762
            DOC_SECTION_MATH,
763
            DOC_SECTION_CONDITIONAL,
764
            DOC_SECTION_STRING,
765
            DOC_SECTION_BINARY_STRING,
766
            DOC_SECTION_REGEX,
767
            DOC_SECTION_DATETIME,
768
            DOC_SECTION_ARRAY,
769
            DOC_SECTION_STRUCT,
770
            DOC_SECTION_MAP,
771
            DOC_SECTION_HASHING,
772
            DOC_SECTION_OTHER,
773
        ]
774
    }
775
776
    pub const DOC_SECTION_MATH: DocSection = DocSection {
777
        include: true,
778
        label: "Math Functions",
779
        description: None,
780
    };
781
782
    pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection {
783
        include: true,
784
        label: "Conditional Functions",
785
        description: None,
786
    };
787
788
    pub const DOC_SECTION_STRING: DocSection = DocSection {
789
        include: true,
790
        label: "String Functions",
791
        description: None,
792
    };
793
794
    pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection {
795
        include: true,
796
        label: "Binary String Functions",
797
        description: None,
798
    };
799
800
    pub const DOC_SECTION_REGEX: DocSection = DocSection {
801
        include: true,
802
        label: "Regular Expression Functions",
803
        description: Some(
804
            r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions)
805
regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax)
806
(minus support for several features including look-around and backreferences).
807
The following regular expression functions are supported:"#,
808
        ),
809
    };
810
811
    pub const DOC_SECTION_DATETIME: DocSection = DocSection {
812
        include: true,
813
        label: "Time and Date Functions",
814
        description: None,
815
    };
816
817
    pub const DOC_SECTION_ARRAY: DocSection = DocSection {
818
        include: true,
819
        label: "Array Functions",
820
        description: None,
821
    };
822
823
    pub const DOC_SECTION_STRUCT: DocSection = DocSection {
824
        include: true,
825
        label: "Struct Functions",
826
        description: None,
827
    };
828
829
    pub const DOC_SECTION_MAP: DocSection = DocSection {
830
        include: true,
831
        label: "Map Functions",
832
        description: None,
833
    };
834
835
    pub const DOC_SECTION_HASHING: DocSection = DocSection {
836
        include: true,
837
        label: "Hashing Functions",
838
        description: None,
839
    };
840
841
    pub const DOC_SECTION_OTHER: DocSection = DocSection {
842
        include: true,
843
        label: "Other Functions",
844
        description: None,
845
    };
846
}