Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/type_coercion/functions.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use super::binary::{binary_numeric_coercion, comparison_coercion};
19
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20
use arrow::{
21
    compute::can_cast_types,
22
    datatypes::{DataType, TimeUnit},
23
};
24
use datafusion_common::{
25
    exec_err, internal_datafusion_err, internal_err, plan_err,
26
    utils::{coerced_fixed_size_list_to_list, list_ndims},
27
    Result,
28
};
29
use datafusion_expr_common::signature::{
30
    ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
31
};
32
use std::sync::Arc;
33
34
/// Performs type coercion for scalar function arguments.
35
///
36
/// Returns the data types to which each argument must be coerced to
37
/// match `signature`.
38
///
39
/// For more details on coercion in general, please see the
40
/// [`type_coercion`](crate::type_coercion) module.
41
0
pub fn data_types_with_scalar_udf(
42
0
    current_types: &[DataType],
43
0
    func: &ScalarUDF,
44
0
) -> Result<Vec<DataType>> {
45
0
    let signature = func.signature();
46
0
47
0
    if current_types.is_empty() {
48
0
        if signature.type_signature.supports_zero_argument() {
49
0
            return Ok(vec![]);
50
        } else {
51
0
            return plan_err!("{} does not support zero arguments.", func.name());
52
        }
53
0
    }
54
55
0
    let valid_types =
56
0
        get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?;
57
58
0
    if valid_types
59
0
        .iter()
60
0
        .any(|data_type| data_type == current_types)
61
    {
62
0
        return Ok(current_types.to_vec());
63
0
    }
64
0
65
0
    try_coerce_types(valid_types, current_types, &signature.type_signature)
66
0
}
67
68
/// Performs type coercion for aggregate function arguments.
69
///
70
/// Returns the data types to which each argument must be coerced to
71
/// match `signature`.
72
///
73
/// For more details on coercion in general, please see the
74
/// [`type_coercion`](crate::type_coercion) module.
75
0
pub fn data_types_with_aggregate_udf(
76
0
    current_types: &[DataType],
77
0
    func: &AggregateUDF,
78
0
) -> Result<Vec<DataType>> {
79
0
    let signature = func.signature();
80
0
81
0
    if current_types.is_empty() {
82
0
        if signature.type_signature.supports_zero_argument() {
83
0
            return Ok(vec![]);
84
        } else {
85
0
            return plan_err!("{} does not support zero arguments.", func.name());
86
        }
87
0
    }
88
89
0
    let valid_types = get_valid_types_with_aggregate_udf(
90
0
        &signature.type_signature,
91
0
        current_types,
92
0
        func,
93
0
    )?;
94
0
    if valid_types
95
0
        .iter()
96
0
        .any(|data_type| data_type == current_types)
97
    {
98
0
        return Ok(current_types.to_vec());
99
0
    }
100
0
101
0
    try_coerce_types(valid_types, current_types, &signature.type_signature)
102
0
}
103
104
/// Performs type coercion for window function arguments.
105
///
106
/// Returns the data types to which each argument must be coerced to
107
/// match `signature`.
108
///
109
/// For more details on coercion in general, please see the
110
/// [`type_coercion`](crate::type_coercion) module.
111
0
pub fn data_types_with_window_udf(
112
0
    current_types: &[DataType],
113
0
    func: &WindowUDF,
114
0
) -> Result<Vec<DataType>> {
115
0
    let signature = func.signature();
116
0
117
0
    if current_types.is_empty() {
118
0
        if signature.type_signature.supports_zero_argument() {
119
0
            return Ok(vec![]);
120
        } else {
121
0
            return plan_err!("{} does not support zero arguments.", func.name());
122
        }
123
0
    }
124
125
0
    let valid_types =
126
0
        get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?;
127
0
    if valid_types
128
0
        .iter()
129
0
        .any(|data_type| data_type == current_types)
130
    {
131
0
        return Ok(current_types.to_vec());
132
0
    }
133
0
134
0
    try_coerce_types(valid_types, current_types, &signature.type_signature)
135
0
}
136
137
/// Performs type coercion for function arguments.
138
///
139
/// Returns the data types to which each argument must be coerced to
140
/// match `signature`.
141
///
142
/// For more details on coercion in general, please see the
143
/// [`type_coercion`](crate::type_coercion) module.
144
0
pub fn data_types(
145
0
    current_types: &[DataType],
146
0
    signature: &Signature,
147
0
) -> Result<Vec<DataType>> {
148
0
    if current_types.is_empty() {
149
0
        if signature.type_signature.supports_zero_argument() {
150
0
            return Ok(vec![]);
151
        } else {
152
0
            return plan_err!(
153
0
                "signature {:?} does not support zero arguments.",
154
0
                &signature.type_signature
155
0
            );
156
        }
157
0
    }
158
159
0
    let valid_types = get_valid_types(&signature.type_signature, current_types)?;
160
0
    if valid_types
161
0
        .iter()
162
0
        .any(|data_type| data_type == current_types)
163
    {
164
0
        return Ok(current_types.to_vec());
165
0
    }
166
0
167
0
    try_coerce_types(valid_types, current_types, &signature.type_signature)
168
0
}
169
170
0
fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
171
0
    if let TypeSignature::OneOf(signatures) = type_signature {
172
0
        return signatures.iter().all(is_well_supported_signature);
173
0
    }
174
175
0
    matches!(
176
0
        type_signature,
177
        TypeSignature::UserDefined
178
            | TypeSignature::Numeric(_)
179
            | TypeSignature::Coercible(_)
180
            | TypeSignature::Any(_)
181
    )
182
0
}
183
184
0
fn try_coerce_types(
185
0
    valid_types: Vec<Vec<DataType>>,
186
0
    current_types: &[DataType],
187
0
    type_signature: &TypeSignature,
188
0
) -> Result<Vec<DataType>> {
189
0
    let mut valid_types = valid_types;
190
0
191
0
    // Well-supported signature that returns exact valid types.
192
0
    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
193
        // exact valid types
194
0
        assert_eq!(valid_types.len(), 1);
195
0
        let valid_types = valid_types.swap_remove(0);
196
0
        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
197
0
            return Ok(t);
198
0
        }
199
    } else {
200
        // Try and coerce the argument types to match the signature, returning the
201
        // coerced types from the first matching signature.
202
0
        for valid_types in valid_types {
203
0
            if let Some(types) = maybe_data_types(&valid_types, current_types) {
204
0
                return Ok(types);
205
0
            }
206
        }
207
    }
208
209
    // none possible -> Error
210
0
    plan_err!(
211
0
        "Coercion from {:?} to the signature {:?} failed.",
212
0
        current_types,
213
0
        type_signature
214
0
    )
215
0
}
216
217
0
fn get_valid_types_with_scalar_udf(
218
0
    signature: &TypeSignature,
219
0
    current_types: &[DataType],
220
0
    func: &ScalarUDF,
221
0
) -> Result<Vec<Vec<DataType>>> {
222
0
    let valid_types = match signature {
223
0
        TypeSignature::UserDefined => match func.coerce_types(current_types) {
224
0
            Ok(coerced_types) => vec![coerced_types],
225
0
            Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
226
        },
227
0
        TypeSignature::OneOf(signatures) => signatures
228
0
            .iter()
229
0
            .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok())
230
0
            .flatten()
231
0
            .collect::<Vec<_>>(),
232
0
        _ => get_valid_types(signature, current_types)?,
233
    };
234
235
0
    Ok(valid_types)
236
0
}
237
238
0
fn get_valid_types_with_aggregate_udf(
239
0
    signature: &TypeSignature,
240
0
    current_types: &[DataType],
241
0
    func: &AggregateUDF,
242
0
) -> Result<Vec<Vec<DataType>>> {
243
0
    let valid_types = match signature {
244
0
        TypeSignature::UserDefined => match func.coerce_types(current_types) {
245
0
            Ok(coerced_types) => vec![coerced_types],
246
0
            Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
247
        },
248
0
        TypeSignature::OneOf(signatures) => signatures
249
0
            .iter()
250
0
            .filter_map(|t| {
251
0
                get_valid_types_with_aggregate_udf(t, current_types, func).ok()
252
0
            })
253
0
            .flatten()
254
0
            .collect::<Vec<_>>(),
255
0
        _ => get_valid_types(signature, current_types)?,
256
    };
257
258
0
    Ok(valid_types)
259
0
}
260
261
0
fn get_valid_types_with_window_udf(
262
0
    signature: &TypeSignature,
263
0
    current_types: &[DataType],
264
0
    func: &WindowUDF,
265
0
) -> Result<Vec<Vec<DataType>>> {
266
0
    let valid_types = match signature {
267
0
        TypeSignature::UserDefined => match func.coerce_types(current_types) {
268
0
            Ok(coerced_types) => vec![coerced_types],
269
0
            Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
270
        },
271
0
        TypeSignature::OneOf(signatures) => signatures
272
0
            .iter()
273
0
            .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
274
0
            .flatten()
275
0
            .collect::<Vec<_>>(),
276
0
        _ => get_valid_types(signature, current_types)?,
277
    };
278
279
0
    Ok(valid_types)
280
0
}
281
282
/// Returns a Vec of all possible valid argument types for the given signature.
283
0
fn get_valid_types(
284
0
    signature: &TypeSignature,
285
0
    current_types: &[DataType],
286
0
) -> Result<Vec<Vec<DataType>>> {
287
0
    fn array_element_and_optional_index(
288
0
        current_types: &[DataType],
289
0
    ) -> Result<Vec<Vec<DataType>>> {
290
0
        // make sure there's 2 or 3 arguments
291
0
        if !(current_types.len() == 2 || current_types.len() == 3) {
292
0
            return Ok(vec![vec![]]);
293
0
        }
294
0
295
0
        let first_two_types = &current_types[0..2];
296
0
        let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;
297
298
        // Early return if there are only 2 arguments
299
0
        if current_types.len() == 2 {
300
0
            return Ok(valid_types);
301
0
        }
302
0
303
0
        let valid_types_with_index = valid_types
304
0
            .iter()
305
0
            .map(|t| {
306
0
                let mut t = t.clone();
307
0
                t.push(DataType::Int64);
308
0
                t
309
0
            })
310
0
            .collect::<Vec<_>>();
311
0
312
0
        valid_types.extend(valid_types_with_index);
313
0
314
0
        Ok(valid_types)
315
0
    }
316
317
0
    fn array_append_or_prepend_valid_types(
318
0
        current_types: &[DataType],
319
0
        is_append: bool,
320
0
    ) -> Result<Vec<Vec<DataType>>> {
321
0
        if current_types.len() != 2 {
322
0
            return Ok(vec![vec![]]);
323
0
        }
324
325
0
        let (array_type, elem_type) = if is_append {
326
0
            (&current_types[0], &current_types[1])
327
        } else {
328
0
            (&current_types[1], &current_types[0])
329
        };
330
331
        // We follow Postgres on `array_append(Null, T)`, which is not valid.
332
0
        if array_type.eq(&DataType::Null) {
333
0
            return Ok(vec![vec![]]);
334
0
        }
335
0
336
0
        // We need to find the coerced base type, mainly for cases like:
337
0
        // `array_append(List(null), i64)` -> `List(i64)`
338
0
        let array_base_type = datafusion_common::utils::base_type(array_type);
339
0
        let elem_base_type = datafusion_common::utils::base_type(elem_type);
340
0
        let new_base_type = comparison_coercion(&array_base_type, &elem_base_type);
341
342
0
        let new_base_type = new_base_type.ok_or_else(|| {
343
0
            internal_datafusion_err!(
344
0
                "Coercion from {array_base_type:?} to {elem_base_type:?} not supported."
345
0
            )
346
0
        })?;
347
348
0
        let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
349
0
            array_type,
350
0
            &new_base_type,
351
0
        );
352
0
353
0
        match new_array_type {
354
0
            DataType::List(ref field)
355
0
            | DataType::LargeList(ref field)
356
0
            | DataType::FixedSizeList(ref field, _) => {
357
0
                let new_elem_type = field.data_type();
358
0
                if is_append {
359
0
                    Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]])
360
                } else {
361
0
                    Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]])
362
                }
363
            }
364
0
            _ => Ok(vec![vec![]]),
365
        }
366
0
    }
367
0
    fn array(array_type: &DataType) -> Option<DataType> {
368
0
        match array_type {
369
            DataType::List(_)
370
            | DataType::LargeList(_)
371
            | DataType::FixedSizeList(_, _) => {
372
0
                let array_type = coerced_fixed_size_list_to_list(array_type);
373
0
                Some(array_type)
374
            }
375
0
            _ => None,
376
        }
377
0
    }
378
379
0
    let valid_types = match signature {
380
0
        TypeSignature::Variadic(valid_types) => valid_types
381
0
            .iter()
382
0
            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
383
0
            .collect(),
384
0
        TypeSignature::Numeric(number) => {
385
0
            if *number < 1 {
386
0
                return plan_err!(
387
0
                    "The signature expected at least one argument but received {}",
388
0
                    current_types.len()
389
0
                );
390
0
            }
391
0
            if *number != current_types.len() {
392
0
                return plan_err!(
393
0
                    "The signature expected {} arguments but received {}",
394
0
                    number,
395
0
                    current_types.len()
396
0
                );
397
0
            }
398
0
399
0
            let mut valid_type = current_types.first().unwrap().clone();
400
0
            for t in current_types.iter().skip(1) {
401
0
                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
402
0
                    valid_type = coerced_type;
403
0
                } else {
404
0
                    return plan_err!(
405
0
                        "{} and {} are not coercible to a common numeric type",
406
0
                        valid_type,
407
0
                        t
408
0
                    );
409
                }
410
            }
411
412
0
            vec![vec![valid_type; *number]]
413
        }
414
0
        TypeSignature::Coercible(target_types) => {
415
0
            if target_types.is_empty() {
416
0
                return plan_err!(
417
0
                    "The signature expected at least one argument but received {}",
418
0
                    current_types.len()
419
0
                );
420
0
            }
421
0
            if target_types.len() != current_types.len() {
422
0
                return plan_err!(
423
0
                    "The signature expected {} arguments but received {}",
424
0
                    target_types.len(),
425
0
                    current_types.len()
426
0
                );
427
0
            }
428
429
0
            for (data_type, target_type) in current_types.iter().zip(target_types.iter())
430
            {
431
0
                if !can_cast_types(data_type, target_type) {
432
0
                    return plan_err!("{data_type} is not coercible to {target_type}");
433
0
                }
434
            }
435
436
0
            vec![target_types.to_owned()]
437
        }
438
0
        TypeSignature::Uniform(number, valid_types) => valid_types
439
0
            .iter()
440
0
            .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
441
0
            .collect(),
442
        TypeSignature::UserDefined => {
443
0
            return internal_err!(
444
0
            "User-defined signature should be handled by function-specific coerce_types."
445
0
        )
446
        }
447
        TypeSignature::VariadicAny => {
448
0
            vec![current_types.to_vec()]
449
        }
450
0
        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
451
0
        TypeSignature::ArraySignature(ref function_signature) => match function_signature
452
        {
453
            ArrayFunctionSignature::ArrayAndElement => {
454
0
                array_append_or_prepend_valid_types(current_types, true)?
455
            }
456
            ArrayFunctionSignature::ElementAndArray => {
457
0
                array_append_or_prepend_valid_types(current_types, false)?
458
            }
459
            ArrayFunctionSignature::ArrayAndIndex => {
460
0
                if current_types.len() != 2 {
461
0
                    return Ok(vec![vec![]]);
462
0
                }
463
0
                array(&current_types[0]).map_or_else(
464
0
                    || vec![vec![]],
465
0
                    |array_type| vec![vec![array_type, DataType::Int64]],
466
0
                )
467
            }
468
            ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
469
0
                array_element_and_optional_index(current_types)?
470
            }
471
            ArrayFunctionSignature::Array => {
472
0
                if current_types.len() != 1 {
473
0
                    return Ok(vec![vec![]]);
474
0
                }
475
0
476
0
                array(&current_types[0])
477
0
                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
478
            }
479
            ArrayFunctionSignature::MapArray => {
480
0
                if current_types.len() != 1 {
481
0
                    return Ok(vec![vec![]]);
482
0
                }
483
0
484
0
                match &current_types[0] {
485
0
                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
486
0
                    _ => vec![vec![]],
487
                }
488
            }
489
        },
490
0
        TypeSignature::Any(number) => {
491
0
            if current_types.len() != *number {
492
0
                return plan_err!(
493
0
                    "The function expected {} arguments but received {}",
494
0
                    number,
495
0
                    current_types.len()
496
0
                );
497
0
            }
498
0
            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
499
        }
500
0
        TypeSignature::OneOf(types) => types
501
0
            .iter()
502
0
            .filter_map(|t| get_valid_types(t, current_types).ok())
503
0
            .flatten()
504
0
            .collect::<Vec<_>>(),
505
    };
506
507
0
    Ok(valid_types)
508
0
}
509
510
/// Try to coerce the current argument types to match the given `valid_types`.
511
///
512
/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
513
/// but was called with `(int32, int64)`, this function could match the
514
/// valid_types by coercing the first argument to `int64`, and would return
515
/// `Some([int64, int64])`.
516
0
fn maybe_data_types(
517
0
    valid_types: &[DataType],
518
0
    current_types: &[DataType],
519
0
) -> Option<Vec<DataType>> {
520
0
    if valid_types.len() != current_types.len() {
521
0
        return None;
522
0
    }
523
0
524
0
    let mut new_type = Vec::with_capacity(valid_types.len());
525
0
    for (i, valid_type) in valid_types.iter().enumerate() {
526
0
        let current_type = &current_types[i];
527
0
528
0
        if current_type == valid_type {
529
0
            new_type.push(current_type.clone())
530
        } else {
531
            // attempt to coerce.
532
            // TODO: Replace with `can_cast_types` after failing cases are resolved
533
            // (they need new signature that returns exactly valid types instead of list of possible valid types).
534
0
            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
535
0
                new_type.push(coerced_type)
536
            } else {
537
                // not possible
538
0
                return None;
539
            }
540
        }
541
    }
542
0
    Some(new_type)
543
0
}
544
545
/// Check if the current argument types can be coerced to match the given `valid_types`
546
/// unlike `maybe_data_types`, this function does not coerce the types.
547
/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
548
0
fn maybe_data_types_without_coercion(
549
0
    valid_types: &[DataType],
550
0
    current_types: &[DataType],
551
0
) -> Option<Vec<DataType>> {
552
0
    if valid_types.len() != current_types.len() {
553
0
        return None;
554
0
    }
555
0
556
0
    let mut new_type = Vec::with_capacity(valid_types.len());
557
0
    for (i, valid_type) in valid_types.iter().enumerate() {
558
0
        let current_type = &current_types[i];
559
0
560
0
        if current_type == valid_type {
561
0
            new_type.push(current_type.clone())
562
0
        } else if can_cast_types(current_type, valid_type) {
563
            // validate the valid type is castable from the current type
564
0
            new_type.push(valid_type.clone())
565
        } else {
566
0
            return None;
567
        }
568
    }
569
0
    Some(new_type)
570
0
}
571
572
/// Return true if a value of type `type_from` can be coerced
573
/// (losslessly converted) into a value of `type_to`
574
///
575
/// See the module level documentation for more detail on coercion.
576
0
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
577
0
    if type_into == type_from {
578
0
        return true;
579
0
    }
580
0
    if let Some(coerced) = coerced_from(type_into, type_from) {
581
0
        return coerced == *type_into;
582
0
    }
583
0
    false
584
0
}
585
586
/// Find the coerced type for the given `type_into` and `type_from`.
587
/// Returns `None` if coercion is not possible.
588
///
589
/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
590
///
591
/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion.
592
0
fn coerced_from<'a>(
593
0
    type_into: &'a DataType,
594
0
    type_from: &'a DataType,
595
0
) -> Option<DataType> {
596
    use self::DataType::*;
597
598
    // match Dictionary first
599
0
    match (type_into, type_from) {
600
        // coerced dictionary first
601
0
        (_, Dictionary(_, value_type))
602
0
            if coerced_from(type_into, value_type).is_some() =>
603
0
        {
604
0
            Some(type_into.clone())
605
        }
606
0
        (Dictionary(_, value_type), _)
607
0
            if coerced_from(value_type, type_from).is_some() =>
608
0
        {
609
0
            Some(type_into.clone())
610
        }
611
        // coerced into type_into
612
0
        (Int8, Null | Int8) => Some(type_into.clone()),
613
0
        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
614
0
        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
615
        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
616
0
            Some(type_into.clone())
617
        }
618
0
        (UInt8, Null | UInt8) => Some(type_into.clone()),
619
0
        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
620
0
        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
621
0
        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
622
        (
623
            Float32,
624
            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
625
            | Float32,
626
0
        ) => Some(type_into.clone()),
627
        (
628
            Float64,
629
            Null
630
            | Int8
631
            | Int16
632
            | Int32
633
            | Int64
634
            | UInt8
635
            | UInt16
636
            | UInt32
637
            | UInt64
638
            | Float32
639
            | Float64
640
            | Decimal128(_, _),
641
0
        ) => Some(type_into.clone()),
642
        (
643
            Timestamp(TimeUnit::Nanosecond, None),
644
            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
645
0
        ) => Some(type_into.clone()),
646
0
        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
647
        // We can go into a Utf8View from a Utf8 or LargeUtf8
648
0
        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
649
        // Any type can be coerced into strings
650
0
        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
651
0
        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
652
653
0
        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
654
655
        // Only accept list and largelist with the same number of dimensions unless the type is Null.
656
        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
657
        (List(_) | LargeList(_), _)
658
0
            if datafusion_common::utils::base_type(type_from).eq(&Null)
659
0
                || list_ndims(type_from) == list_ndims(type_into) =>
660
        {
661
0
            Some(type_into.clone())
662
        }
663
        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
664
        (
665
0
            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
666
0
            FixedSizeList(f_from, size_from),
667
0
        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
668
0
            Some(data_type) if &data_type != f_into.data_type() => {
669
0
                let new_field =
670
0
                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
671
0
                Some(FixedSizeList(new_field, *size_from))
672
            }
673
0
            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
674
0
            _ => None,
675
        },
676
0
        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
677
0
            match type_from {
678
0
                Timestamp(_, Some(from_tz)) => {
679
0
                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
680
                }
681
                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
682
                    // In the absence of any other information assume the time zone is "+00" (UTC).
683
0
                    Some(Timestamp(*unit, Some("+00".into())))
684
                }
685
0
                _ => None,
686
            }
687
        }
688
        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
689
0
            Some(type_into.clone())
690
        }
691
0
        _ => None,
692
    }
693
0
}
694
695
#[cfg(test)]
696
mod tests {
697
698
    use crate::Volatility;
699
700
    use super::*;
701
    use arrow::datatypes::Field;
702
703
    #[test]
704
    fn test_string_conversion() {
705
        let cases = vec![
706
            (DataType::Utf8View, DataType::Utf8, true),
707
            (DataType::Utf8View, DataType::LargeUtf8, true),
708
        ];
709
710
        for case in cases {
711
            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
712
        }
713
    }
714
715
    #[test]
716
    fn test_maybe_data_types() {
717
        // this vec contains: arg1, arg2, expected result
718
        let cases = vec![
719
            // 2 entries, same values
720
            (
721
                vec![DataType::UInt8, DataType::UInt16],
722
                vec![DataType::UInt8, DataType::UInt16],
723
                Some(vec![DataType::UInt8, DataType::UInt16]),
724
            ),
725
            // 2 entries, can coerce values
726
            (
727
                vec![DataType::UInt16, DataType::UInt16],
728
                vec![DataType::UInt8, DataType::UInt16],
729
                Some(vec![DataType::UInt16, DataType::UInt16]),
730
            ),
731
            // 0 entries, all good
732
            (vec![], vec![], Some(vec![])),
733
            // 2 entries, can't coerce
734
            (
735
                vec![DataType::Boolean, DataType::UInt16],
736
                vec![DataType::UInt8, DataType::UInt16],
737
                None,
738
            ),
739
            // u32 -> u16 is possible
740
            (
741
                vec![DataType::Boolean, DataType::UInt32],
742
                vec![DataType::Boolean, DataType::UInt16],
743
                Some(vec![DataType::Boolean, DataType::UInt32]),
744
            ),
745
            // UTF8 -> Timestamp
746
            (
747
                vec![
748
                    DataType::Timestamp(TimeUnit::Nanosecond, None),
749
                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
750
                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
751
                ],
752
                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
753
                Some(vec![
754
                    DataType::Timestamp(TimeUnit::Nanosecond, None),
755
                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
756
                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
757
                ]),
758
            ),
759
        ];
760
761
        for case in cases {
762
            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
763
        }
764
    }
765
766
    #[test]
767
    fn test_get_valid_types_one_of() -> Result<()> {
768
        let signature =
769
            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
770
771
        let invalid_types = get_valid_types(
772
            &signature,
773
            &[DataType::Int32, DataType::Int32, DataType::Int32],
774
        )?;
775
        assert_eq!(invalid_types.len(), 0);
776
777
        let args = vec![DataType::Int32, DataType::Int32];
778
        let valid_types = get_valid_types(&signature, &args)?;
779
        assert_eq!(valid_types.len(), 1);
780
        assert_eq!(valid_types[0], args);
781
782
        let args = vec![DataType::Int32];
783
        let valid_types = get_valid_types(&signature, &args)?;
784
        assert_eq!(valid_types.len(), 1);
785
        assert_eq!(valid_types[0], args);
786
787
        Ok(())
788
    }
789
790
    #[test]
791
    fn test_fixed_list_wildcard_coerce() -> Result<()> {
792
        let inner = Arc::new(Field::new("item", DataType::Int32, false));
793
        let current_types = vec![
794
            DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size
795
        ];
796
797
        let signature = Signature::exact(
798
            vec![DataType::FixedSizeList(
799
                Arc::clone(&inner),
800
                FIXED_SIZE_LIST_WILDCARD,
801
            )],
802
            Volatility::Stable,
803
        );
804
805
        let coerced_data_types = data_types(&current_types, &signature).unwrap();
806
        assert_eq!(coerced_data_types, current_types);
807
808
        // make sure it can't coerce to a different size
809
        let signature = Signature::exact(
810
            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
811
            Volatility::Stable,
812
        );
813
        let coerced_data_types = data_types(&current_types, &signature);
814
        assert!(coerced_data_types.is_err());
815
816
        // make sure it works with the same type.
817
        let signature = Signature::exact(
818
            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
819
            Volatility::Stable,
820
        );
821
        let coerced_data_types = data_types(&current_types, &signature).unwrap();
822
        assert_eq!(coerced_data_types, current_types);
823
824
        Ok(())
825
    }
826
827
    #[test]
828
    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
829
        let type_into = DataType::FixedSizeList(
830
            Arc::new(Field::new(
831
                "item",
832
                DataType::FixedSizeList(
833
                    Arc::new(Field::new("item", DataType::Int32, false)),
834
                    FIXED_SIZE_LIST_WILDCARD,
835
                ),
836
                false,
837
            )),
838
            FIXED_SIZE_LIST_WILDCARD,
839
        );
840
841
        let type_from = DataType::FixedSizeList(
842
            Arc::new(Field::new(
843
                "item",
844
                DataType::FixedSizeList(
845
                    Arc::new(Field::new("item", DataType::Int8, false)),
846
                    4,
847
                ),
848
                false,
849
            )),
850
            3,
851
        );
852
853
        assert_eq!(
854
            coerced_from(&type_into, &type_from),
855
            Some(DataType::FixedSizeList(
856
                Arc::new(Field::new(
857
                    "item",
858
                    DataType::FixedSizeList(
859
                        Arc::new(Field::new("item", DataType::Int32, false)),
860
                        4,
861
                    ),
862
                    false,
863
                )),
864
                3,
865
            ))
866
        );
867
868
        Ok(())
869
    }
870
871
    #[test]
872
    fn test_coerced_from_dictionary() {
873
        let type_into =
874
            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
875
        let type_from = DataType::Int64;
876
        assert_eq!(coerced_from(&type_into, &type_from), None);
877
878
        let type_from =
879
            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
880
        let type_into = DataType::Int64;
881
        assert_eq!(
882
            coerced_from(&type_into, &type_from),
883
            Some(type_into.clone())
884
        );
885
    }
886
}