Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr-common/src/type_coercion/binary.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Coercion rules for matching argument types for binary operators
19
20
use std::collections::HashSet;
21
use std::sync::Arc;
22
23
use crate::operator::Operator;
24
25
use arrow::array::{new_empty_array, Array};
26
use arrow::compute::can_cast_types;
27
use arrow::datatypes::{
28
    DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
29
    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
30
};
31
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};
32
33
/// The type signature of an instantiation of binary operator expression such as
34
/// `lhs + rhs`
35
///
36
/// Note this is different than [`crate::signature::Signature`] which
37
/// describes the type signature of a function.
38
struct Signature {
39
    /// The type to coerce the left argument to
40
    lhs: DataType,
41
    /// The type to coerce the right argument to
42
    rhs: DataType,
43
    /// The return type of the expression
44
    ret: DataType,
45
}
46
47
impl Signature {
48
    /// A signature where the inputs are the same type as the output
49
23.2k
    fn uniform(t: DataType) -> Self {
50
23.2k
        Self {
51
23.2k
            lhs: t.clone(),
52
23.2k
            rhs: t.clone(),
53
23.2k
            ret: t,
54
23.2k
        }
55
23.2k
    }
56
57
    /// A signature where the inputs are the same type with a boolean output
58
48.7k
    fn comparison(t: DataType) -> Self {
59
48.7k
        Self {
60
48.7k
            lhs: t.clone(),
61
48.7k
            rhs: t,
62
48.7k
            ret: DataType::Boolean,
63
48.7k
        }
64
48.7k
    }
65
}
66
67
/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs`
68
214k
fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result<Signature> {
69
    use arrow::datatypes::DataType::*;
70
    use Operator::*;
71
214k
    match op {
72
        Eq |
73
        NotEq |
74
        Lt |
75
        LtEq |
76
        Gt |
77
        GtEq |
78
        IsDistinctFrom |
79
        IsNotDistinctFrom => {
80
48.7k
            comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
81
0
                plan_datafusion_err!(
82
0
                    "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}"
83
0
                )
84
48.7k
            })
85
        }
86
23.2k
        And | Or => if 
matches!0
((lhs, rhs), (Boolean | Null, Boolean | Null)) {
87
            // Logical binary boolean operators can only be evaluated for
88
            // boolean or null arguments.                   
89
23.2k
            Ok(Signature::uniform(DataType::Boolean))
90
        } else {
91
0
            plan_err!(
92
0
                "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}"
93
0
            )
94
        }
95
        RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
96
0
            regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
97
0
                plan_datafusion_err!(
98
0
                    "Cannot infer common argument type for regex operation {lhs} {op} {rhs}"
99
0
                )
100
0
            })
101
        }
102
        LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
103
0
            regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
104
0
                plan_datafusion_err!(
105
0
                    "Cannot infer common argument type for regex operation {lhs} {op} {rhs}"
106
0
                )
107
0
            })
108
        }
109
        BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
110
0
            bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
111
0
                plan_datafusion_err!(
112
0
                    "Cannot infer common type for bitwise operation {lhs} {op} {rhs}"
113
0
                )
114
0
            })
115
        }
116
        StringConcat => {
117
0
            string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
118
0
                plan_datafusion_err!(
119
0
                    "Cannot infer common string type for string concat operation {lhs} {op} {rhs}"
120
0
                )
121
0
            })
122
        }
123
        AtArrow | ArrowAt => {
124
            // ArrowAt and AtArrow check for whether one array is contained in another.
125
            // The result type is boolean. Signature::comparison defines this signature.
126
            // Operation has nothing to do with comparison
127
0
            array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
128
0
                plan_datafusion_err!(
129
0
                    "Cannot infer common array type for arrow operation {lhs} {op} {rhs}"
130
0
                )
131
0
            })
132
        }
133
        Plus | Minus | Multiply | Divide | Modulo =>  {
134
142k
            let get_result = |lhs, rhs| {
135
                use arrow::compute::kernels::numeric::*;
136
142k
                let l = new_empty_array(lhs);
137
142k
                let r = new_empty_array(rhs);
138
139
142k
                let result = match op {
140
43.5k
                    Plus => add_wrapping(&l, &r),
141
74.6k
                    Minus => sub_wrapping(&l, &r),
142
0
                    Multiply => mul_wrapping(&l, &r),
143
0
                    Divide => div(&l, &r),
144
24.2k
                    Modulo => rem(&l, &r),
145
0
                    _ => unreachable!(),
146
                };
147
142k
                result.map(|x| x.data_type().clone())
148
142k
            };
149
150
142k
            if let Ok(ret) = get_result(lhs, rhs) {
151
                // Temporal arithmetic, e.g. Date32 + Interval
152
142k
                Ok(Signature{
153
142k
                    lhs: lhs.clone(),
154
142k
                    rhs: rhs.clone(),
155
142k
                    ret,
156
142k
                })
157
0
            } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
158
                // Temporal arithmetic by first coercing to a common time representation
159
                // e.g. Date32 - Timestamp
160
0
                let ret = get_result(&coerced, &coerced).map_err(|e| {
161
0
                    plan_datafusion_err!(
162
0
                        "Cannot get result type for temporal operation {coerced} {op} {coerced}: {e}"
163
0
                    )
164
0
                })?;
165
0
                Ok(Signature{
166
0
                    lhs: coerced.clone(),
167
0
                    rhs: coerced,
168
0
                    ret,
169
0
                })
170
0
            } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
171
                // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
172
0
                let ret = get_result(&lhs, &rhs).map_err(|e| {
173
0
                    plan_datafusion_err!(
174
0
                        "Cannot get result type for decimal operation {lhs} {op} {rhs}: {e}"
175
0
                    )
176
0
                })?;
177
0
                Ok(Signature{
178
0
                    lhs,
179
0
                    rhs,
180
0
                    ret,
181
0
                })
182
0
            } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
183
                // Numeric arithmetic, e.g. Int32 + Int32
184
0
                Ok(Signature::uniform(numeric))
185
            } else {
186
0
                plan_err!(
187
0
                    "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types"
188
0
                )
189
            }
190
        }
191
    }
192
214k
}
193
194
/// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types
195
214k
pub fn get_result_type(
196
214k
    lhs: &DataType,
197
214k
    op: &Operator,
198
214k
    rhs: &DataType,
199
214k
) -> Result<DataType> {
200
214k
    signature(lhs, op, rhs).map(|sig| sig.ret)
201
214k
}
202
203
/// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types
204
0
pub fn get_input_types(
205
0
    lhs: &DataType,
206
0
    op: &Operator,
207
0
    rhs: &DataType,
208
0
) -> Result<(DataType, DataType)> {
209
0
    signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs))
210
0
}
211
212
/// Coercion rules for mathematics operators between decimal and non-decimal types.
213
0
fn math_decimal_coercion(
214
0
    lhs_type: &DataType,
215
0
    rhs_type: &DataType,
216
0
) -> Option<(DataType, DataType)> {
217
    use arrow::datatypes::DataType::*;
218
219
0
    match (lhs_type, rhs_type) {
220
0
        (Dictionary(_, value_type), _) => {
221
0
            let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
222
0
            Some((value_type, rhs_type))
223
        }
224
0
        (_, Dictionary(_, value_type)) => {
225
0
            let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
226
0
            Some((lhs_type, value_type))
227
        }
228
0
        (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => {
229
0
            Some((dec_type.clone(), dec_type.clone()))
230
        }
231
        (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => {
232
0
            Some((lhs_type.clone(), rhs_type.clone()))
233
        }
234
        // Unlike with comparison we don't coerce to a decimal in the case of floating point
235
        // numbers, instead falling back to floating point arithmetic instead
236
        (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => {
237
0
            Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?))
238
        }
239
        (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => {
240
0
            Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone()))
241
        }
242
        (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some((
243
0
            lhs_type.clone(),
244
0
            coerce_numeric_type_to_decimal256(rhs_type)?,
245
        )),
246
        (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some((
247
0
            coerce_numeric_type_to_decimal256(lhs_type)?,
248
0
            rhs_type.clone(),
249
        )),
250
0
        _ => None,
251
    }
252
0
}
253
254
/// Returns the output type of applying bitwise operations such as
255
/// `&`, `|`, or `xor`to arguments of `lhs_type` and `rhs_type`.
256
0
fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
257
    use arrow::datatypes::DataType::*;
258
259
0
    if !both_numeric_or_null_and_numeric(left_type, right_type) {
260
0
        return None;
261
0
    }
262
0
263
0
    if left_type == right_type {
264
0
        return Some(left_type.clone());
265
0
    }
266
0
267
0
    match (left_type, right_type) {
268
0
        (UInt64, _) | (_, UInt64) => Some(UInt64),
269
        (Int64, _)
270
        | (_, Int64)
271
        | (UInt32, Int8)
272
        | (Int8, UInt32)
273
        | (UInt32, Int16)
274
        | (Int16, UInt32)
275
        | (UInt32, Int32)
276
0
        | (Int32, UInt32) => Some(Int64),
277
        (Int32, _)
278
        | (_, Int32)
279
        | (UInt16, Int16)
280
        | (Int16, UInt16)
281
        | (UInt16, Int8)
282
0
        | (Int8, UInt16) => Some(Int32),
283
0
        (UInt32, _) | (_, UInt32) => Some(UInt32),
284
0
        (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
285
0
        (UInt16, _) | (_, UInt16) => Some(UInt16),
286
0
        (Int8, _) | (_, Int8) => Some(Int8),
287
0
        (UInt8, _) | (_, UInt8) => Some(UInt8),
288
0
        _ => None,
289
    }
290
0
}
291
292
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
293
enum TypeCategory {
294
    Array,
295
    Boolean,
296
    Numeric,
297
    // String, well-defined type, but are considered as unknown type.
298
    DateTime,
299
    Composite,
300
    Unknown,
301
    NotSupported,
302
}
303
304
impl From<&DataType> for TypeCategory {
305
0
    fn from(data_type: &DataType) -> Self {
306
0
        match data_type {
307
            // Dict is a special type in arrow, we check the value type
308
0
            DataType::Dictionary(_, v) => {
309
0
                let v = v.as_ref();
310
0
                TypeCategory::from(v)
311
            }
312
            _ => {
313
0
                if data_type.is_numeric() {
314
0
                    return TypeCategory::Numeric;
315
0
                }
316
317
0
                if matches!(data_type, DataType::Boolean) {
318
0
                    return TypeCategory::Boolean;
319
0
                }
320
321
0
                if matches!(
322
0
                    data_type,
323
                    DataType::List(_)
324
                        | DataType::FixedSizeList(_, _)
325
                        | DataType::LargeList(_)
326
                ) {
327
0
                    return TypeCategory::Array;
328
0
                }
329
330
                // String literal is possible to cast to many other types like numeric or datetime,
331
                // therefore, it is categorized as a unknown type
332
0
                if matches!(
333
0
                    data_type,
334
                    DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
335
                ) {
336
0
                    return TypeCategory::Unknown;
337
0
                }
338
339
0
                if matches!(
340
0
                    data_type,
341
                    DataType::Date32
342
                        | DataType::Date64
343
                        | DataType::Time32(_)
344
                        | DataType::Time64(_)
345
                        | DataType::Timestamp(_, _)
346
                        | DataType::Interval(_)
347
                        | DataType::Duration(_)
348
                ) {
349
0
                    return TypeCategory::DateTime;
350
0
                }
351
352
0
                if matches!(
353
0
                    data_type,
354
                    DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _)
355
                ) {
356
0
                    return TypeCategory::Composite;
357
0
                }
358
0
359
0
                TypeCategory::NotSupported
360
            }
361
        }
362
0
    }
363
}
364
365
/// Coerce dissimilar data types to a single data type.
366
/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are
367
/// examples that has the similar resolution rules.
368
/// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information.
369
/// The rules in the document provide a clue, but adhering strictly to them doesn't precisely
370
/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules
371
/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted
372
/// decimal precision and scale when coercing decimal types.
373
0
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
374
0
    if data_types.is_empty() {
375
0
        return None;
376
0
    }
377
0
378
0
    // if all the data_types is the same return first one
379
0
    if data_types.iter().all(|t| t == &data_types[0]) {
380
0
        return Some(data_types[0].clone());
381
0
    }
382
0
383
0
    // if all the data_types are null, return string
384
0
    if data_types.iter().all(|t| t == &DataType::Null) {
385
0
        return Some(DataType::Utf8);
386
0
    }
387
0
388
0
    // Ignore Nulls, if any data_type category is not the same, return None
389
0
    let data_types_category: Vec<TypeCategory> = data_types
390
0
        .iter()
391
0
        .filter(|&t| t != &DataType::Null)
392
0
        .map(|t| t.into())
393
0
        .collect();
394
0
395
0
    if data_types_category
396
0
        .iter()
397
0
        .any(|t| t == &TypeCategory::NotSupported)
398
    {
399
0
        return None;
400
0
    }
401
0
402
0
    // check if there is only one category excluding Unknown
403
0
    let categories: HashSet<TypeCategory> = HashSet::from_iter(
404
0
        data_types_category
405
0
            .iter()
406
0
            .filter(|&c| c != &TypeCategory::Unknown)
407
0
            .cloned(),
408
0
    );
409
0
    if categories.len() > 1 {
410
0
        return None;
411
0
    }
412
0
413
0
    // Ignore Nulls
414
0
    let mut candidate_type: Option<DataType> = None;
415
0
    for data_type in data_types.iter() {
416
0
        if data_type == &DataType::Null {
417
0
            continue;
418
0
        }
419
0
        if let Some(ref candidate_t) = candidate_type {
420
            // Find candidate type that all the data types can be coerced to
421
            // Follows the behavior of Postgres and DuckDB
422
            // Coerced type may be different from the candidate and current data type
423
            // For example,
424
            //  i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2)
425
            //  numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2)
426
0
            if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) {
427
0
                candidate_type = Some(t);
428
0
            } else {
429
0
                return None;
430
            }
431
0
        } else {
432
0
            candidate_type = Some(data_type.clone());
433
0
        }
434
    }
435
436
0
    candidate_type
437
0
}
438
439
/// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution]
440
/// See [type_union_resolution] for more information.
441
0
fn type_union_resolution_coercion(
442
0
    lhs_type: &DataType,
443
0
    rhs_type: &DataType,
444
0
) -> Option<DataType> {
445
0
    if lhs_type == rhs_type {
446
0
        return Some(lhs_type.clone());
447
0
    }
448
0
449
0
    match (lhs_type, rhs_type) {
450
        (
451
0
            DataType::Dictionary(lhs_index_type, lhs_value_type),
452
0
            DataType::Dictionary(rhs_index_type, rhs_value_type),
453
0
        ) => {
454
0
            let new_index_type =
455
0
                type_union_resolution_coercion(lhs_index_type, rhs_index_type);
456
0
            let new_value_type =
457
0
                type_union_resolution_coercion(lhs_value_type, rhs_value_type);
458
0
            if let (Some(new_index_type), Some(new_value_type)) =
459
0
                (new_index_type, new_value_type)
460
            {
461
0
                Some(DataType::Dictionary(
462
0
                    Box::new(new_index_type),
463
0
                    Box::new(new_value_type),
464
0
                ))
465
            } else {
466
0
                None
467
            }
468
        }
469
0
        (DataType::Dictionary(index_type, value_type), other_type)
470
0
        | (other_type, DataType::Dictionary(index_type, value_type)) => {
471
0
            let new_value_type = type_union_resolution_coercion(value_type, other_type);
472
0
            new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t)))
473
        }
474
0
        (DataType::List(lhs), DataType::List(rhs)) => {
475
0
            let new_item_type =
476
0
                type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
477
0
            new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
478
        }
479
        _ => {
480
            // numeric coercion is the same as comparison coercion, both find the narrowest type
481
            // that can accommodate both types
482
0
            binary_numeric_coercion(lhs_type, rhs_type)
483
0
                .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
484
0
                .or_else(|| string_coercion(lhs_type, rhs_type))
485
0
                .or_else(|| numeric_string_coercion(lhs_type, rhs_type))
486
        }
487
    }
488
0
}
489
490
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a
491
/// comparison operation
492
///
493
/// Example comparison operations are `lhs = rhs` and `lhs > rhs`
494
///
495
/// Binary comparison kernels require the two arguments to be the (exact) same
496
/// data type. However, users can write queries where the two arguments are
497
/// different data types. In such cases, the data types are automatically cast
498
/// (coerced) to a single data type to pass to the kernels.
499
48.7k
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
500
48.7k
    if lhs_type == rhs_type {
501
        // same type => equality is possible
502
48.7k
        return Some(lhs_type.clone());
503
1
    }
504
1
    binary_numeric_coercion(lhs_type, rhs_type)
505
1
        .or_else(|| 
dictionary_comparison_coercion(lhs_type, rhs_type, true)0
)
506
1
        .or_else(|| 
temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)0
)
507
1
        .or_else(|| 
string_coercion(lhs_type, rhs_type)0
)
508
1
        .or_else(|| 
list_coercion(lhs_type, rhs_type)0
)
509
1
        .or_else(|| 
null_coercion(lhs_type, rhs_type)0
)
510
1
        .or_else(|| 
string_numeric_coercion(lhs_type, rhs_type)0
)
511
1
        .or_else(|| 
string_temporal_coercion(lhs_type, rhs_type)0
)
512
1
        .or_else(|| 
binary_coercion(lhs_type, rhs_type)0
)
513
1
        .or_else(|| 
struct_coercion(lhs_type, rhs_type)0
)
514
48.7k
}
515
516
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
517
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
518
0
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
519
    use arrow::datatypes::DataType::*;
520
0
    match (lhs_type, rhs_type) {
521
0
        (Utf8, _) if rhs_type.is_numeric() => Some(Utf8),
522
0
        (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8),
523
0
        (_, Utf8) if lhs_type.is_numeric() => Some(Utf8),
524
0
        (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8),
525
0
        _ => None,
526
    }
527
0
}
528
529
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
530
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
531
///
532
/// Note this cannot be performed in case of arithmetic as there is insufficient information
533
/// to correctly determine the type of argument. Consider
534
///
535
/// ```sql
536
/// timestamp > now() - '1 month'
537
/// interval > now() - '1970-01-2021'
538
/// ```
539
///
540
/// In the absence of a full type inference system, we can't determine the correct type
541
/// to parse the string argument
542
0
fn string_temporal_coercion(
543
0
    lhs_type: &DataType,
544
0
    rhs_type: &DataType,
545
0
) -> Option<DataType> {
546
    use arrow::datatypes::DataType::*;
547
548
0
    fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> {
549
0
        match (l, r) {
550
            // Coerce Utf8View/Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp
551
0
            (Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => {
552
0
                match temporal {
553
0
                    Date32 | Date64 => Some(temporal.clone()),
554
                    Time32(_) | Time64(_) => {
555
0
                        if is_time_with_valid_unit(temporal.to_owned()) {
556
0
                            Some(temporal.to_owned())
557
                        } else {
558
0
                            None
559
                        }
560
                    }
561
0
                    Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
562
0
                    _ => None,
563
                }
564
            }
565
0
            _ => None,
566
        }
567
0
    }
568
569
0
    match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type))
570
0
}
571
572
/// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric
573
1
pub fn binary_numeric_coercion(
574
1
    lhs_type: &DataType,
575
1
    rhs_type: &DataType,
576
1
) -> Option<DataType> {
577
    use arrow::datatypes::DataType::*;
578
1
    if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
579
0
        return None;
580
1
    };
581
1
582
1
    // same type => all good
583
1
    if lhs_type == rhs_type {
584
0
        return Some(lhs_type.clone());
585
1
    }
586
587
1
    if let Some(t) = decimal_coercion(lhs_type, rhs_type) {
588
1
        return Some(t);
589
0
    }
590
0
591
0
    // these are ordered from most informative to least informative so
592
0
    // that the coercion does not lose information via truncation
593
0
    match (lhs_type, rhs_type) {
594
0
        (Float64, _) | (_, Float64) => Some(Float64),
595
0
        (_, Float32) | (Float32, _) => Some(Float32),
596
        // The following match arms encode the following logic: Given the two
597
        // integral types, we choose the narrowest possible integral type that
598
        // accommodates all values of both types. Note that some information
599
        // loss is inevitable when we have a signed type and a `UInt64`, in
600
        // which case we use `Int64`;i.e. the widest signed integral type.
601
602
        // TODO: For i64 and u64, we can use decimal or float64
603
        // Postgres has no unsigned type :(
604
        // DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes))
605
        // for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer)
606
        (Int64, _)
607
        | (_, Int64)
608
        | (UInt64, Int8)
609
        | (Int8, UInt64)
610
        | (UInt64, Int16)
611
        | (Int16, UInt64)
612
        | (UInt64, Int32)
613
        | (Int32, UInt64)
614
        | (UInt32, Int8)
615
        | (Int8, UInt32)
616
        | (UInt32, Int16)
617
        | (Int16, UInt32)
618
        | (UInt32, Int32)
619
0
        | (Int32, UInt32) => Some(Int64),
620
0
        (UInt64, _) | (_, UInt64) => Some(UInt64),
621
        (Int32, _)
622
        | (_, Int32)
623
        | (UInt16, Int16)
624
        | (Int16, UInt16)
625
        | (UInt16, Int8)
626
0
        | (Int8, UInt16) => Some(Int32),
627
0
        (UInt32, _) | (_, UInt32) => Some(UInt32),
628
0
        (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
629
0
        (UInt16, _) | (_, UInt16) => Some(UInt16),
630
0
        (Int8, _) | (_, Int8) => Some(Int8),
631
0
        (UInt8, _) | (_, UInt8) => Some(UInt8),
632
0
        _ => None,
633
    }
634
1
}
635
636
/// Decimal coercion rules.
637
1
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
638
    use arrow::datatypes::DataType::*;
639
640
1
    match (lhs_type, rhs_type) {
641
        // Prefer decimal data type over floating point for comparison operation
642
        (Decimal128(_, _), Decimal128(_, _)) => {
643
1
            get_wider_decimal_type(lhs_type, rhs_type)
644
        }
645
0
        (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
646
0
        (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
647
        (Decimal256(_, _), Decimal256(_, _)) => {
648
0
            get_wider_decimal_type(lhs_type, rhs_type)
649
        }
650
0
        (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
651
0
        (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
652
0
        (_, _) => None,
653
    }
654
1
}
655
656
/// Coerce `lhs_type` and `rhs_type` to a common type.
657
0
fn get_common_decimal_type(
658
0
    decimal_type: &DataType,
659
0
    other_type: &DataType,
660
0
) -> Option<DataType> {
661
    use arrow::datatypes::DataType::*;
662
0
    match decimal_type {
663
        Decimal128(_, _) => {
664
0
            let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
665
0
            get_wider_decimal_type(decimal_type, &other_decimal_type)
666
        }
667
        Decimal256(_, _) => {
668
0
            let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?;
669
0
            get_wider_decimal_type(decimal_type, &other_decimal_type)
670
        }
671
0
        _ => None,
672
    }
673
0
}
674
675
/// Returns a `DataType::Decimal128` that can store any value from either
676
/// `lhs_decimal_type` and `rhs_decimal_type`
677
///
678
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
679
1
fn get_wider_decimal_type(
680
1
    lhs_decimal_type: &DataType,
681
1
    rhs_type: &DataType,
682
1
) -> Option<DataType> {
683
1
    match (lhs_decimal_type, rhs_type) {
684
1
        (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => {
685
1
            // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
686
1
            let s = *s1.max(s2);
687
1
            let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
688
1
            Some(create_decimal_type((range + s) as u8, s))
689
        }
690
0
        (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
691
0
            // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
692
0
            let s = *s1.max(s2);
693
0
            let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
694
0
            Some(create_decimal256_type((range + s) as u8, s))
695
        }
696
0
        (_, _) => None,
697
    }
698
1
}
699
700
/// Returns the wider type among arguments `lhs` and `rhs`.
701
/// The wider type is the type that can safely represent values from both types
702
/// without information loss. Returns an Error if types are incompatible.
703
0
pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> {
704
    use arrow::datatypes::DataType::*;
705
0
    Ok(match (lhs, rhs) {
706
0
        (lhs, rhs) if lhs == rhs => lhs.clone(),
707
        // Right UInt is larger than left UInt.
708
        (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) |
709
        // Right Int is larger than left Int.
710
        (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) |
711
        // Right Float is larger than left Float.
712
        (Float16, Float32 | Float64) | (Float32, Float64) |
713
        // Right String is larger than left String.
714
        (Utf8, LargeUtf8) |
715
        // Any right type is wider than a left hand side Null.
716
0
        (Null, _) => rhs.clone(),
717
        // Left UInt is larger than right UInt.
718
        (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) |
719
        // Left Int is larger than right Int.
720
        (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) |
721
        // Left Float is larger than right Float.
722
        (Float32 | Float64, Float16) | (Float64, Float32) |
723
        // Left String is larger than right String.
724
        (LargeUtf8, Utf8) |
725
        // Any left type is wider than a right hand side Null.
726
0
        (_, Null) => lhs.clone(),
727
0
        (List(lhs_field), List(rhs_field)) => {
728
0
            let field_type =
729
0
                get_wider_type(lhs_field.data_type(), rhs_field.data_type())?;
730
0
            if lhs_field.name() != rhs_field.name() {
731
0
                return Err(exec_datafusion_err!(
732
0
                    "There is no wider type that can represent both {lhs} and {rhs}."
733
0
                ));
734
0
            }
735
0
            assert_eq!(lhs_field.name(), rhs_field.name());
736
0
            let field_name = lhs_field.name();
737
0
            let nullable = lhs_field.is_nullable() | rhs_field.is_nullable();
738
0
            List(Arc::new(Field::new(field_name, field_type, nullable)))
739
        }
740
        (_, _) => {
741
0
            return Err(exec_datafusion_err!(
742
0
                "There is no wider type that can represent both {lhs} and {rhs}."
743
0
            ));
744
        }
745
    })
746
0
}
747
748
/// Convert the numeric data type to the decimal data type.
749
/// Now, we just support the signed integer type and floating-point type.
750
0
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
751
    use arrow::datatypes::DataType::*;
752
    // This conversion rule is from spark
753
    // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
754
0
    match numeric_type {
755
0
        Int8 => Some(Decimal128(3, 0)),
756
0
        Int16 => Some(Decimal128(5, 0)),
757
0
        Int32 => Some(Decimal128(10, 0)),
758
0
        Int64 => Some(Decimal128(20, 0)),
759
        // TODO if we convert the floating-point data to the decimal type, it maybe overflow.
760
0
        Float32 => Some(Decimal128(14, 7)),
761
0
        Float64 => Some(Decimal128(30, 15)),
762
0
        _ => None,
763
    }
764
0
}
765
766
/// Convert the numeric data type to the decimal data type.
767
/// Now, we just support the signed integer type and floating-point type.
768
0
fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
769
    use arrow::datatypes::DataType::*;
770
    // This conversion rule is from spark
771
    // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
772
0
    match numeric_type {
773
0
        Int8 => Some(Decimal256(3, 0)),
774
0
        Int16 => Some(Decimal256(5, 0)),
775
0
        Int32 => Some(Decimal256(10, 0)),
776
0
        Int64 => Some(Decimal256(20, 0)),
777
        // TODO if we convert the floating-point data to the decimal type, it maybe overflow.
778
0
        Float32 => Some(Decimal256(14, 7)),
779
0
        Float64 => Some(Decimal256(30, 15)),
780
0
        _ => None,
781
    }
782
0
}
783
784
0
fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
785
    use arrow::datatypes::DataType::*;
786
0
    match (lhs_type, rhs_type) {
787
0
        (Struct(lhs_fields), Struct(rhs_fields)) => {
788
0
            if lhs_fields.len() != rhs_fields.len() {
789
0
                return None;
790
0
            }
791
792
0
            let types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter())
793
0
                .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type()))
794
0
                .collect::<Option<Vec<DataType>>>()?;
795
796
0
            let fields = types
797
0
                .into_iter()
798
0
                .enumerate()
799
0
                .map(|(i, datatype)| {
800
0
                    Arc::new(Field::new(format!("c{i}"), datatype, true))
801
0
                })
802
0
                .collect::<Vec<FieldRef>>();
803
0
            Some(Struct(fields.into()))
804
        }
805
0
        _ => None,
806
    }
807
0
}
808
809
/// Returns the output type of applying mathematics operations such as
810
/// `+` to arguments of `lhs_type` and `rhs_type`.
811
0
fn mathematics_numerical_coercion(
812
0
    lhs_type: &DataType,
813
0
    rhs_type: &DataType,
814
0
) -> Option<DataType> {
815
    use arrow::datatypes::DataType::*;
816
817
    // error on any non-numeric type
818
0
    if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) {
819
0
        return None;
820
0
    };
821
0
822
0
    // these are ordered from most informative to least informative so
823
0
    // that the coercion removes the least amount of information
824
0
    match (lhs_type, rhs_type) {
825
0
        (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
826
0
            mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
827
        }
828
0
        (Dictionary(_, value_type), _) => {
829
0
            mathematics_numerical_coercion(value_type, rhs_type)
830
        }
831
0
        (_, Dictionary(_, value_type)) => {
832
0
            mathematics_numerical_coercion(lhs_type, value_type)
833
        }
834
0
        (Float64, _) | (_, Float64) => Some(Float64),
835
0
        (_, Float32) | (Float32, _) => Some(Float32),
836
0
        (Int64, _) | (_, Int64) => Some(Int64),
837
0
        (Int32, _) | (_, Int32) => Some(Int32),
838
0
        (Int16, _) | (_, Int16) => Some(Int16),
839
0
        (Int8, _) | (_, Int8) => Some(Int8),
840
0
        (UInt64, _) | (_, UInt64) => Some(UInt64),
841
0
        (UInt32, _) | (_, UInt32) => Some(UInt32),
842
0
        (UInt16, _) | (_, UInt16) => Some(UInt16),
843
0
        (UInt8, _) | (_, UInt8) => Some(UInt8),
844
0
        _ => None,
845
    }
846
0
}
847
848
1
fn create_decimal_type(precision: u8, scale: i8) -> DataType {
849
1
    DataType::Decimal128(
850
1
        DECIMAL128_MAX_PRECISION.min(precision),
851
1
        DECIMAL128_MAX_SCALE.min(scale),
852
1
    )
853
1
}
854
855
0
fn create_decimal256_type(precision: u8, scale: i8) -> DataType {
856
0
    DataType::Decimal256(
857
0
        DECIMAL256_MAX_PRECISION.min(precision),
858
0
        DECIMAL256_MAX_SCALE.min(scale),
859
0
    )
860
0
}
861
862
/// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric
863
0
fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
864
    use arrow::datatypes::DataType::*;
865
0
    match (lhs_type, rhs_type) {
866
0
        (_, Null) => lhs_type.is_numeric(),
867
0
        (Null, _) => rhs_type.is_numeric(),
868
0
        (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
869
0
            lhs_value_type.is_numeric() && rhs_value_type.is_numeric()
870
        }
871
0
        (Dictionary(_, value_type), _) => {
872
0
            value_type.is_numeric() && rhs_type.is_numeric()
873
        }
874
0
        (_, Dictionary(_, value_type)) => {
875
0
            lhs_type.is_numeric() && value_type.is_numeric()
876
        }
877
0
        _ => lhs_type.is_numeric() && rhs_type.is_numeric(),
878
    }
879
0
}
880
881
/// Coercion rules for Dictionaries: the type that both lhs and rhs
882
/// can be casted to for the purpose of a computation.
883
///
884
/// Not all operators support dictionaries, if `preserve_dictionaries` is true
885
/// dictionaries will be preserved if possible
886
0
fn dictionary_comparison_coercion(
887
0
    lhs_type: &DataType,
888
0
    rhs_type: &DataType,
889
0
    preserve_dictionaries: bool,
890
0
) -> Option<DataType> {
891
    use arrow::datatypes::DataType::*;
892
0
    match (lhs_type, rhs_type) {
893
        (
894
0
            Dictionary(_lhs_index_type, lhs_value_type),
895
0
            Dictionary(_rhs_index_type, rhs_value_type),
896
0
        ) => comparison_coercion(lhs_value_type, rhs_value_type),
897
0
        (d @ Dictionary(_, value_type), other_type)
898
0
        | (other_type, d @ Dictionary(_, value_type))
899
0
            if preserve_dictionaries && value_type.as_ref() == other_type =>
900
        {
901
0
            Some(d.clone())
902
        }
903
0
        (Dictionary(_index_type, value_type), _) => {
904
0
            comparison_coercion(value_type, rhs_type)
905
        }
906
0
        (_, Dictionary(_index_type, value_type)) => {
907
0
            comparison_coercion(lhs_type, value_type)
908
        }
909
0
        _ => None,
910
    }
911
0
}
912
913
/// Coercion rules for string concat.
914
/// This is a union of string coercion rules and specified rules:
915
/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
916
/// 2. Data type of the other side should be able to cast to string type
917
0
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
918
    use arrow::datatypes::DataType::*;
919
0
    string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
920
0
        (Utf8View, from_type) | (from_type, Utf8View) => {
921
0
            string_concat_internal_coercion(from_type, &Utf8View)
922
        }
923
0
        (Utf8, from_type) | (from_type, Utf8) => {
924
0
            string_concat_internal_coercion(from_type, &Utf8)
925
        }
926
0
        (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
927
0
            string_concat_internal_coercion(from_type, &LargeUtf8)
928
        }
929
0
        (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
930
0
            string_coercion(lhs_value_type, rhs_value_type).or(None)
931
        }
932
0
        _ => None,
933
    })
934
0
}
935
936
0
fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
937
0
    if lhs_type.equals_datatype(rhs_type) {
938
0
        Some(lhs_type.to_owned())
939
    } else {
940
0
        None
941
    }
942
0
}
943
944
/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise
945
/// return `None`.
946
0
fn string_concat_internal_coercion(
947
0
    from_type: &DataType,
948
0
    to_type: &DataType,
949
0
) -> Option<DataType> {
950
0
    if can_cast_types(from_type, to_type) {
951
0
        Some(to_type.to_owned())
952
    } else {
953
0
        None
954
    }
955
0
}
956
957
/// Coercion rules for string view types (Utf8/LargeUtf8/Utf8View):
958
/// If at least one argument is a string view, we coerce to string view
959
/// based on the observation that StringArray to StringViewArray is cheap but not vice versa.
960
///
961
/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8.
962
0
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
963
    use arrow::datatypes::DataType::*;
964
0
    match (lhs_type, rhs_type) {
965
        // If Utf8View is in any side, we coerce to Utf8View.
966
        (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
967
0
            Some(Utf8View)
968
        }
969
        // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
970
0
        (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
971
        // Utf8 coerces to Utf8
972
0
        (Utf8, Utf8) => Some(Utf8),
973
0
        _ => None,
974
    }
975
0
}
976
977
/// This will be deprecated when binary operators native support
978
/// for Utf8View (use `string_coercion` instead).
979
0
fn regex_comparison_string_coercion(
980
0
    lhs_type: &DataType,
981
0
    rhs_type: &DataType,
982
0
) -> Option<DataType> {
983
    use arrow::datatypes::DataType::*;
984
0
    match (lhs_type, rhs_type) {
985
        // If Utf8View is in any side, we coerce to Utf8.
986
        (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
987
0
            Some(Utf8)
988
        }
989
        // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
990
0
        (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
991
        // Utf8 coerces to Utf8
992
0
        (Utf8, Utf8) => Some(Utf8),
993
0
        _ => None,
994
    }
995
0
}
996
997
0
fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
998
    use arrow::datatypes::DataType::*;
999
0
    match (lhs_type, rhs_type) {
1000
0
        (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8)
1001
0
            if other_type.is_numeric() =>
1002
        {
1003
0
            Some(other_type.clone())
1004
        }
1005
0
        _ => None,
1006
    }
1007
0
}
1008
1009
/// Coercion rules for list types.
1010
0
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1011
    use arrow::datatypes::DataType::*;
1012
0
    match (lhs_type, rhs_type) {
1013
0
        (List(_), List(_)) => Some(lhs_type.clone()),
1014
0
        (LargeList(_), List(_)) => Some(lhs_type.clone()),
1015
0
        (List(_), LargeList(_)) => Some(rhs_type.clone()),
1016
0
        (LargeList(_), LargeList(_)) => Some(lhs_type.clone()),
1017
0
        (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
1018
0
        (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()),
1019
        // Coerce to the left side FixedSizeList type if the list lengths are the same,
1020
        // otherwise coerce to list with the left type for dynamic length
1021
0
        (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => {
1022
0
            if ls == rs {
1023
0
                Some(lhs_type.clone())
1024
            } else {
1025
0
                Some(List(Arc::clone(lf)))
1026
            }
1027
        }
1028
0
        (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()),
1029
0
        (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()),
1030
0
        _ => None,
1031
    }
1032
0
}
1033
1034
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
1035
/// If one argument is binary and the other is a string then coerce to string
1036
/// (e.g. for `like`)
1037
0
fn binary_to_string_coercion(
1038
0
    lhs_type: &DataType,
1039
0
    rhs_type: &DataType,
1040
0
) -> Option<DataType> {
1041
    use arrow::datatypes::DataType::*;
1042
0
    match (lhs_type, rhs_type) {
1043
0
        (Binary, Utf8) => Some(Utf8),
1044
0
        (Binary, LargeUtf8) => Some(LargeUtf8),
1045
0
        (BinaryView, Utf8) => Some(Utf8View),
1046
0
        (BinaryView, LargeUtf8) => Some(LargeUtf8),
1047
0
        (LargeBinary, Utf8) => Some(LargeUtf8),
1048
0
        (LargeBinary, LargeUtf8) => Some(LargeUtf8),
1049
0
        (Utf8, Binary) => Some(Utf8),
1050
0
        (Utf8, LargeBinary) => Some(LargeUtf8),
1051
0
        (Utf8, BinaryView) => Some(Utf8View),
1052
0
        (LargeUtf8, Binary) => Some(LargeUtf8),
1053
0
        (LargeUtf8, LargeBinary) => Some(LargeUtf8),
1054
0
        (LargeUtf8, BinaryView) => Some(LargeUtf8),
1055
0
        _ => None,
1056
    }
1057
0
}
1058
1059
/// Coercion rules for binary types (Binary/LargeBinary/BinaryView): If at least one argument is
1060
/// a binary type and both arguments can be coerced into a binary type, coerce
1061
/// to binary type.
1062
0
fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1063
    use arrow::datatypes::DataType::*;
1064
0
    match (lhs_type, rhs_type) {
1065
        // If BinaryView is in any side, we coerce to BinaryView.
1066
        (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View)
1067
        | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => {
1068
0
            Some(BinaryView)
1069
        }
1070
        // Prefer LargeBinary over Binary
1071
        (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary)
1072
0
        | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary),
1073
1074
        // If Utf8View/LargeUtf8 presents need to be large Binary
1075
        (Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => {
1076
0
            Some(LargeBinary)
1077
        }
1078
0
        (Binary, Utf8) | (Utf8, Binary) => Some(Binary),
1079
0
        _ => None,
1080
    }
1081
0
}
1082
1083
/// coercion rules for like operations.
1084
/// This is a union of string coercion rules and dictionary coercion rules
1085
0
pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1086
0
    string_coercion(lhs_type, rhs_type)
1087
0
        .or_else(|| list_coercion(lhs_type, rhs_type))
1088
0
        .or_else(|| binary_to_string_coercion(lhs_type, rhs_type))
1089
0
        .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
1090
0
        .or_else(|| regex_null_coercion(lhs_type, rhs_type))
1091
0
        .or_else(|| null_coercion(lhs_type, rhs_type))
1092
0
}
1093
1094
/// coercion rules for regular expression comparison operations with NULL input.
1095
0
fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1096
    use arrow::datatypes::DataType::*;
1097
0
    match (lhs_type, rhs_type) {
1098
0
        (DataType::Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()),
1099
0
        (Utf8View | Utf8 | LargeUtf8, DataType::Null) => Some(lhs_type.clone()),
1100
0
        (DataType::Null, DataType::Null) => Some(Utf8),
1101
0
        _ => None,
1102
    }
1103
0
}
1104
1105
/// Coercion rules for regular expression comparison operations.
1106
/// This is a union of string coercion rules and dictionary coercion rules
1107
0
pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1108
0
    regex_comparison_string_coercion(lhs_type, rhs_type)
1109
0
        .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
1110
0
        .or_else(|| regex_null_coercion(lhs_type, rhs_type))
1111
0
}
1112
1113
/// Checks if the TimeUnit associated with a Time32 or Time64 type is consistent,
1114
/// as Time32 can only be used to Second and Millisecond accuracy, while Time64
1115
/// is exclusively used to Microsecond and Nanosecond accuracy
1116
0
fn is_time_with_valid_unit(datatype: DataType) -> bool {
1117
0
    matches!(
1118
0
        datatype,
1119
        DataType::Time32(TimeUnit::Second)
1120
            | DataType::Time32(TimeUnit::Millisecond)
1121
            | DataType::Time64(TimeUnit::Microsecond)
1122
            | DataType::Time64(TimeUnit::Nanosecond)
1123
    )
1124
0
}
1125
1126
/// Non-strict Timezone Coercion is useful in scenarios where we can guarantee
1127
/// a stable relationship between two timestamps of different timezones.
1128
///
1129
/// An example of this is binary comparisons (<, >, ==, etc). Arrow stores timestamps
1130
/// as relative to UTC epoch, and then adds the timezone as an offset. As a result, we can always
1131
/// do a binary comparison between the two times.
1132
///
1133
/// Timezone coercion is handled by the following rules:
1134
/// - If only one has a timezone, coerce the other to match
1135
/// - If both have a timezone, coerce to the left type
1136
/// - "UTC" and "+00:00" are considered equivalent
1137
0
fn temporal_coercion_nonstrict_timezone(
1138
0
    lhs_type: &DataType,
1139
0
    rhs_type: &DataType,
1140
0
) -> Option<DataType> {
1141
    use arrow::datatypes::DataType::*;
1142
1143
0
    match (lhs_type, rhs_type) {
1144
0
        (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
1145
0
            let tz = match (lhs_tz, rhs_tz) {
1146
                // If both have a timezone, use the left timezone.
1147
0
                (Some(lhs_tz), Some(_rhs_tz)) => Some(Arc::clone(lhs_tz)),
1148
0
                (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
1149
0
                (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
1150
0
                (None, None) => None,
1151
            };
1152
1153
0
            let unit = timeunit_coercion(lhs_unit, rhs_unit);
1154
0
1155
0
            Some(Timestamp(unit, tz))
1156
        }
1157
0
        _ => temporal_coercion(lhs_type, rhs_type),
1158
    }
1159
0
}
1160
1161
/// Strict Timezone coercion is useful in scenarios where we cannot guarantee a stable relationship
1162
/// between two timestamps with different timezones or do not want implicit coercion between them.
1163
///
1164
/// An example of this when attempting to coerce function arguments. Functions already have a mechanism
1165
/// for defining which timestamp types they want to support, so we do not want to do any further coercion.
1166
///
1167
/// Coercion rules for Temporal columns: the type that both lhs and rhs can be
1168
/// casted to for the purpose of a date computation
1169
/// For interval arithmetic, it doesn't handle datetime type +/- interval
1170
/// Timezone coercion is handled by the following rules:
1171
/// - If only one has a timezone, coerce the other to match
1172
/// - If both have a timezone, throw an error
1173
/// - "UTC" and "+00:00" are considered equivalent
1174
0
fn temporal_coercion_strict_timezone(
1175
0
    lhs_type: &DataType,
1176
0
    rhs_type: &DataType,
1177
0
) -> Option<DataType> {
1178
    use arrow::datatypes::DataType::*;
1179
1180
0
    match (lhs_type, rhs_type) {
1181
0
        (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
1182
0
            let tz = match (lhs_tz, rhs_tz) {
1183
0
                (Some(lhs_tz), Some(rhs_tz)) => {
1184
0
                    match (lhs_tz.as_ref(), rhs_tz.as_ref()) {
1185
0
                        // UTC and "+00:00" are the same by definition. Most other timezones
1186
0
                        // do not have a 1-1 mapping between timezone and an offset from UTC
1187
0
                        ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)),
1188
0
                        (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)),
1189
                        // can't cast across timezones
1190
                        _ => {
1191
0
                            return None;
1192
                        }
1193
                    }
1194
                }
1195
0
                (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
1196
0
                (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
1197
0
                (None, None) => None,
1198
            };
1199
1200
0
            let unit = timeunit_coercion(lhs_unit, rhs_unit);
1201
0
1202
0
            Some(Timestamp(unit, tz))
1203
        }
1204
0
        _ => temporal_coercion(lhs_type, rhs_type),
1205
    }
1206
0
}
1207
1208
0
fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1209
    use arrow::datatypes::DataType::*;
1210
    use arrow::datatypes::IntervalUnit::*;
1211
    use arrow::datatypes::TimeUnit::*;
1212
1213
0
    match (lhs_type, rhs_type) {
1214
        (Interval(_) | Duration(_), Interval(_) | Duration(_)) => {
1215
0
            Some(Interval(MonthDayNano))
1216
        }
1217
0
        (Date64, Date32) | (Date32, Date64) => Some(Date64),
1218
        (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => {
1219
0
            Some(Timestamp(Nanosecond, None))
1220
        }
1221
0
        (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => {
1222
0
            Some(Timestamp(Nanosecond, None))
1223
        }
1224
        (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
1225
0
            Some(Timestamp(Nanosecond, None))
1226
        }
1227
0
        (Timestamp(_, _tz), Date32) | (Date32, Timestamp(_, _tz)) => {
1228
0
            Some(Timestamp(Nanosecond, None))
1229
        }
1230
0
        _ => None,
1231
    }
1232
0
}
1233
1234
0
fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit {
1235
    use arrow::datatypes::TimeUnit::*;
1236
0
    match (lhs_unit, rhs_unit) {
1237
0
        (Second, Millisecond) => Second,
1238
0
        (Second, Microsecond) => Second,
1239
0
        (Second, Nanosecond) => Second,
1240
0
        (Millisecond, Second) => Second,
1241
0
        (Millisecond, Microsecond) => Millisecond,
1242
0
        (Millisecond, Nanosecond) => Millisecond,
1243
0
        (Microsecond, Second) => Second,
1244
0
        (Microsecond, Millisecond) => Millisecond,
1245
0
        (Microsecond, Nanosecond) => Microsecond,
1246
0
        (Nanosecond, Second) => Second,
1247
0
        (Nanosecond, Millisecond) => Millisecond,
1248
0
        (Nanosecond, Microsecond) => Microsecond,
1249
0
        (l, r) => {
1250
0
            assert_eq!(l, r);
1251
0
            *l
1252
        }
1253
    }
1254
0
}
1255
1256
/// coercion rules from NULL type. Since NULL can be casted to any other type in arrow,
1257
/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid.
1258
0
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1259
0
    match (lhs_type, rhs_type) {
1260
0
        (DataType::Null, other_type) | (other_type, DataType::Null) => {
1261
0
            if can_cast_types(&DataType::Null, other_type) {
1262
0
                Some(other_type.clone())
1263
            } else {
1264
0
                None
1265
            }
1266
        }
1267
0
        _ => None,
1268
    }
1269
0
}
1270
1271
#[cfg(test)]
1272
mod tests {
1273
    use super::*;
1274
1275
    use datafusion_common::assert_contains;
1276
1277
    #[test]
1278
    fn test_coercion_error() -> Result<()> {
1279
        let result_type =
1280
            get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
1281
1282
        let e = result_type.unwrap_err();
1283
        assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types");
1284
        Ok(())
1285
    }
1286
1287
    #[test]
1288
    fn test_decimal_binary_comparison_coercion() -> Result<()> {
1289
        let input_decimal = DataType::Decimal128(20, 3);
1290
        let input_types = [
1291
            DataType::Int8,
1292
            DataType::Int16,
1293
            DataType::Int32,
1294
            DataType::Int64,
1295
            DataType::Float32,
1296
            DataType::Float64,
1297
            DataType::Decimal128(38, 10),
1298
            DataType::Decimal128(20, 8),
1299
            DataType::Null,
1300
        ];
1301
        let result_types = [
1302
            DataType::Decimal128(20, 3),
1303
            DataType::Decimal128(20, 3),
1304
            DataType::Decimal128(20, 3),
1305
            DataType::Decimal128(23, 3),
1306
            DataType::Decimal128(24, 7),
1307
            DataType::Decimal128(32, 15),
1308
            DataType::Decimal128(38, 10),
1309
            DataType::Decimal128(25, 8),
1310
            DataType::Decimal128(20, 3),
1311
        ];
1312
        let comparison_op_types = [
1313
            Operator::NotEq,
1314
            Operator::Eq,
1315
            Operator::Gt,
1316
            Operator::GtEq,
1317
            Operator::Lt,
1318
            Operator::LtEq,
1319
        ];
1320
        for (i, input_type) in input_types.iter().enumerate() {
1321
            let expect_type = &result_types[i];
1322
            for op in comparison_op_types {
1323
                let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?;
1324
                assert_eq!(expect_type, &lhs);
1325
                assert_eq!(expect_type, &rhs);
1326
            }
1327
        }
1328
        // negative test
1329
        let result_type =
1330
            get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
1331
        assert!(result_type.is_err());
1332
        Ok(())
1333
    }
1334
1335
    #[test]
1336
    fn test_decimal_mathematics_op_type() {
1337
        assert_eq!(
1338
            coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(),
1339
            DataType::Decimal128(3, 0)
1340
        );
1341
        assert_eq!(
1342
            coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(),
1343
            DataType::Decimal128(5, 0)
1344
        );
1345
        assert_eq!(
1346
            coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(),
1347
            DataType::Decimal128(10, 0)
1348
        );
1349
        assert_eq!(
1350
            coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(),
1351
            DataType::Decimal128(20, 0)
1352
        );
1353
        assert_eq!(
1354
            coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(),
1355
            DataType::Decimal128(14, 7)
1356
        );
1357
        assert_eq!(
1358
            coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(),
1359
            DataType::Decimal128(30, 15)
1360
        );
1361
    }
1362
1363
    #[test]
1364
    fn test_dictionary_type_coercion() {
1365
        use DataType::*;
1366
1367
        let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
1368
        let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
1369
        assert_eq!(
1370
            dictionary_comparison_coercion(&lhs_type, &rhs_type, true),
1371
            Some(Int32)
1372
        );
1373
        assert_eq!(
1374
            dictionary_comparison_coercion(&lhs_type, &rhs_type, false),
1375
            Some(Int32)
1376
        );
1377
1378
        // Since we can coerce values of Int16 to Utf8 can support this
1379
        let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
1380
        let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
1381
        assert_eq!(
1382
            dictionary_comparison_coercion(&lhs_type, &rhs_type, true),
1383
            Some(Utf8)
1384
        );
1385
1386
        // Since we can coerce values of Utf8 to Binary can support this
1387
        let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
1388
        let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary));
1389
        assert_eq!(
1390
            dictionary_comparison_coercion(&lhs_type, &rhs_type, true),
1391
            Some(Binary)
1392
        );
1393
1394
        let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
1395
        let rhs_type = Utf8;
1396
        assert_eq!(
1397
            dictionary_comparison_coercion(&lhs_type, &rhs_type, false),
1398
            Some(Utf8)
1399
        );
1400
        assert_eq!(
1401
            dictionary_comparison_coercion(&lhs_type, &rhs_type, true),
1402
            Some(lhs_type.clone())
1403
        );
1404
1405
        let lhs_type = Utf8;
1406
        let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
1407
        assert_eq!(
1408
            dictionary_comparison_coercion(&lhs_type, &rhs_type, false),
1409
            Some(Utf8)
1410
        );
1411
        assert_eq!(
1412
            dictionary_comparison_coercion(&lhs_type, &rhs_type, true),
1413
            Some(rhs_type.clone())
1414
        );
1415
    }
1416
1417
    /// Test coercion rules for binary operators
1418
    ///
1419
    /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the
1420
    /// the result type is `$RESULT_TYPE`
1421
    macro_rules! test_coercion_binary_rule {
1422
        ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{
1423
            let (lhs, rhs) = get_input_types(&$LHS_TYPE, &$OP, &$RHS_TYPE)?;
1424
            assert_eq!(lhs, $RESULT_TYPE);
1425
            assert_eq!(rhs, $RESULT_TYPE);
1426
        }};
1427
    }
1428
1429
    /// Test coercion rules for like
1430
    ///
1431
    /// Applies coercion rules for both
1432
    /// * `$LHS_TYPE LIKE $RHS_TYPE`
1433
    /// * `$RHS_TYPE LIKE $LHS_TYPE`
1434
    ///
1435
    /// And asserts the result type is `$RESULT_TYPE`
1436
    macro_rules! test_like_rule {
1437
        ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{
1438
            println!("Coercing {} LIKE {}", $LHS_TYPE, $RHS_TYPE);
1439
            let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE);
1440
            assert_eq!(result, $RESULT_TYPE);
1441
            // reverse the order
1442
            let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE);
1443
            assert_eq!(result, $RESULT_TYPE);
1444
        }};
1445
    }
1446
1447
    #[test]
1448
    fn test_date_timestamp_arithmetic_error() -> Result<()> {
1449
        let (lhs, rhs) = get_input_types(
1450
            &DataType::Timestamp(TimeUnit::Nanosecond, None),
1451
            &Operator::Minus,
1452
            &DataType::Timestamp(TimeUnit::Millisecond, None),
1453
        )?;
1454
        assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)");
1455
        assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)");
1456
1457
        let err = get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
1458
            .unwrap_err()
1459
            .to_string();
1460
1461
        assert_contains!(
1462
            &err,
1463
            "Cannot get result type for temporal operation Date64 + Date64"
1464
        );
1465
1466
        Ok(())
1467
    }
1468
1469
    #[test]
1470
    fn test_like_coercion() {
1471
        // string coerce to strings
1472
        test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8));
1473
        test_like_rule!(
1474
            DataType::LargeUtf8,
1475
            DataType::Utf8,
1476
            Some(DataType::LargeUtf8)
1477
        );
1478
        test_like_rule!(
1479
            DataType::Utf8,
1480
            DataType::LargeUtf8,
1481
            Some(DataType::LargeUtf8)
1482
        );
1483
        test_like_rule!(
1484
            DataType::LargeUtf8,
1485
            DataType::LargeUtf8,
1486
            Some(DataType::LargeUtf8)
1487
        );
1488
1489
        // Also coerce binary to strings
1490
        test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8));
1491
        test_like_rule!(
1492
            DataType::LargeBinary,
1493
            DataType::Utf8,
1494
            Some(DataType::LargeUtf8)
1495
        );
1496
        test_like_rule!(
1497
            DataType::Binary,
1498
            DataType::LargeUtf8,
1499
            Some(DataType::LargeUtf8)
1500
        );
1501
        test_like_rule!(
1502
            DataType::LargeBinary,
1503
            DataType::LargeUtf8,
1504
            Some(DataType::LargeUtf8)
1505
        );
1506
    }
1507
1508
    #[test]
1509
    fn test_type_coercion() -> Result<()> {
1510
        test_coercion_binary_rule!(
1511
            DataType::Utf8,
1512
            DataType::Date32,
1513
            Operator::Eq,
1514
            DataType::Date32
1515
        );
1516
        test_coercion_binary_rule!(
1517
            DataType::Utf8,
1518
            DataType::Date64,
1519
            Operator::Lt,
1520
            DataType::Date64
1521
        );
1522
        test_coercion_binary_rule!(
1523
            DataType::Utf8,
1524
            DataType::Time32(TimeUnit::Second),
1525
            Operator::Eq,
1526
            DataType::Time32(TimeUnit::Second)
1527
        );
1528
        test_coercion_binary_rule!(
1529
            DataType::Utf8,
1530
            DataType::Time32(TimeUnit::Millisecond),
1531
            Operator::Eq,
1532
            DataType::Time32(TimeUnit::Millisecond)
1533
        );
1534
        test_coercion_binary_rule!(
1535
            DataType::Utf8,
1536
            DataType::Time64(TimeUnit::Microsecond),
1537
            Operator::Eq,
1538
            DataType::Time64(TimeUnit::Microsecond)
1539
        );
1540
        test_coercion_binary_rule!(
1541
            DataType::Utf8,
1542
            DataType::Time64(TimeUnit::Nanosecond),
1543
            Operator::Eq,
1544
            DataType::Time64(TimeUnit::Nanosecond)
1545
        );
1546
        test_coercion_binary_rule!(
1547
            DataType::Utf8,
1548
            DataType::Timestamp(TimeUnit::Second, None),
1549
            Operator::Lt,
1550
            DataType::Timestamp(TimeUnit::Nanosecond, None)
1551
        );
1552
        test_coercion_binary_rule!(
1553
            DataType::Utf8,
1554
            DataType::Timestamp(TimeUnit::Millisecond, None),
1555
            Operator::Lt,
1556
            DataType::Timestamp(TimeUnit::Nanosecond, None)
1557
        );
1558
        test_coercion_binary_rule!(
1559
            DataType::Utf8,
1560
            DataType::Timestamp(TimeUnit::Microsecond, None),
1561
            Operator::Lt,
1562
            DataType::Timestamp(TimeUnit::Nanosecond, None)
1563
        );
1564
        test_coercion_binary_rule!(
1565
            DataType::Utf8,
1566
            DataType::Timestamp(TimeUnit::Nanosecond, None),
1567
            Operator::Lt,
1568
            DataType::Timestamp(TimeUnit::Nanosecond, None)
1569
        );
1570
        test_coercion_binary_rule!(
1571
            DataType::Utf8,
1572
            DataType::Utf8,
1573
            Operator::RegexMatch,
1574
            DataType::Utf8
1575
        );
1576
        test_coercion_binary_rule!(
1577
            DataType::Utf8,
1578
            DataType::Utf8,
1579
            Operator::RegexNotMatch,
1580
            DataType::Utf8
1581
        );
1582
        test_coercion_binary_rule!(
1583
            DataType::Utf8,
1584
            DataType::Utf8,
1585
            Operator::RegexNotIMatch,
1586
            DataType::Utf8
1587
        );
1588
        test_coercion_binary_rule!(
1589
            DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
1590
            DataType::Utf8,
1591
            Operator::RegexMatch,
1592
            DataType::Utf8
1593
        );
1594
        test_coercion_binary_rule!(
1595
            DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
1596
            DataType::Utf8,
1597
            Operator::RegexIMatch,
1598
            DataType::Utf8
1599
        );
1600
        test_coercion_binary_rule!(
1601
            DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
1602
            DataType::Utf8,
1603
            Operator::RegexNotMatch,
1604
            DataType::Utf8
1605
        );
1606
        test_coercion_binary_rule!(
1607
            DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
1608
            DataType::Utf8,
1609
            Operator::RegexNotIMatch,
1610
            DataType::Utf8
1611
        );
1612
        test_coercion_binary_rule!(
1613
            DataType::Int16,
1614
            DataType::Int64,
1615
            Operator::BitwiseAnd,
1616
            DataType::Int64
1617
        );
1618
        test_coercion_binary_rule!(
1619
            DataType::UInt64,
1620
            DataType::UInt64,
1621
            Operator::BitwiseAnd,
1622
            DataType::UInt64
1623
        );
1624
        test_coercion_binary_rule!(
1625
            DataType::Int8,
1626
            DataType::UInt32,
1627
            Operator::BitwiseAnd,
1628
            DataType::Int64
1629
        );
1630
        test_coercion_binary_rule!(
1631
            DataType::UInt32,
1632
            DataType::Int32,
1633
            Operator::BitwiseAnd,
1634
            DataType::Int64
1635
        );
1636
        test_coercion_binary_rule!(
1637
            DataType::UInt16,
1638
            DataType::Int16,
1639
            Operator::BitwiseAnd,
1640
            DataType::Int32
1641
        );
1642
        test_coercion_binary_rule!(
1643
            DataType::UInt32,
1644
            DataType::UInt32,
1645
            Operator::BitwiseAnd,
1646
            DataType::UInt32
1647
        );
1648
        test_coercion_binary_rule!(
1649
            DataType::UInt16,
1650
            DataType::UInt32,
1651
            Operator::BitwiseAnd,
1652
            DataType::UInt32
1653
        );
1654
        Ok(())
1655
    }
1656
1657
    #[test]
1658
    fn test_type_coercion_arithmetic() -> Result<()> {
1659
        // integer
1660
        test_coercion_binary_rule!(
1661
            DataType::Int32,
1662
            DataType::UInt32,
1663
            Operator::Plus,
1664
            DataType::Int32
1665
        );
1666
        test_coercion_binary_rule!(
1667
            DataType::Int32,
1668
            DataType::UInt16,
1669
            Operator::Minus,
1670
            DataType::Int32
1671
        );
1672
        test_coercion_binary_rule!(
1673
            DataType::Int8,
1674
            DataType::Int64,
1675
            Operator::Multiply,
1676
            DataType::Int64
1677
        );
1678
        // float
1679
        test_coercion_binary_rule!(
1680
            DataType::Float32,
1681
            DataType::Int32,
1682
            Operator::Plus,
1683
            DataType::Float32
1684
        );
1685
        test_coercion_binary_rule!(
1686
            DataType::Float32,
1687
            DataType::Float64,
1688
            Operator::Multiply,
1689
            DataType::Float64
1690
        );
1691
        // TODO add other data type
1692
        Ok(())
1693
    }
1694
1695
    fn test_math_decimal_coercion_rule(
1696
        lhs_type: DataType,
1697
        rhs_type: DataType,
1698
        expected_lhs_type: DataType,
1699
        expected_rhs_type: DataType,
1700
    ) {
1701
        // The coerced types for lhs and rhs, if any of them is not decimal
1702
        let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap();
1703
        assert_eq!(lhs_type, expected_lhs_type);
1704
        assert_eq!(rhs_type, expected_rhs_type);
1705
    }
1706
1707
    #[test]
1708
    fn test_coercion_arithmetic_decimal() -> Result<()> {
1709
        test_math_decimal_coercion_rule(
1710
            DataType::Decimal128(10, 2),
1711
            DataType::Decimal128(10, 2),
1712
            DataType::Decimal128(10, 2),
1713
            DataType::Decimal128(10, 2),
1714
        );
1715
1716
        test_math_decimal_coercion_rule(
1717
            DataType::Int32,
1718
            DataType::Decimal128(10, 2),
1719
            DataType::Decimal128(10, 0),
1720
            DataType::Decimal128(10, 2),
1721
        );
1722
1723
        test_math_decimal_coercion_rule(
1724
            DataType::Int32,
1725
            DataType::Decimal128(10, 2),
1726
            DataType::Decimal128(10, 0),
1727
            DataType::Decimal128(10, 2),
1728
        );
1729
1730
        test_math_decimal_coercion_rule(
1731
            DataType::Int32,
1732
            DataType::Decimal128(10, 2),
1733
            DataType::Decimal128(10, 0),
1734
            DataType::Decimal128(10, 2),
1735
        );
1736
1737
        test_math_decimal_coercion_rule(
1738
            DataType::Int32,
1739
            DataType::Decimal128(10, 2),
1740
            DataType::Decimal128(10, 0),
1741
            DataType::Decimal128(10, 2),
1742
        );
1743
1744
        test_math_decimal_coercion_rule(
1745
            DataType::Int32,
1746
            DataType::Decimal128(10, 2),
1747
            DataType::Decimal128(10, 0),
1748
            DataType::Decimal128(10, 2),
1749
        );
1750
1751
        Ok(())
1752
    }
1753
1754
    #[test]
1755
    fn test_type_coercion_compare() -> Result<()> {
1756
        // boolean
1757
        test_coercion_binary_rule!(
1758
            DataType::Boolean,
1759
            DataType::Boolean,
1760
            Operator::Eq,
1761
            DataType::Boolean
1762
        );
1763
        // float
1764
        test_coercion_binary_rule!(
1765
            DataType::Float32,
1766
            DataType::Int64,
1767
            Operator::Eq,
1768
            DataType::Float32
1769
        );
1770
        test_coercion_binary_rule!(
1771
            DataType::Float32,
1772
            DataType::Float64,
1773
            Operator::GtEq,
1774
            DataType::Float64
1775
        );
1776
        // signed integer
1777
        test_coercion_binary_rule!(
1778
            DataType::Int8,
1779
            DataType::Int32,
1780
            Operator::LtEq,
1781
            DataType::Int32
1782
        );
1783
        test_coercion_binary_rule!(
1784
            DataType::Int64,
1785
            DataType::Int32,
1786
            Operator::LtEq,
1787
            DataType::Int64
1788
        );
1789
        // unsigned integer
1790
        test_coercion_binary_rule!(
1791
            DataType::UInt32,
1792
            DataType::UInt8,
1793
            Operator::Gt,
1794
            DataType::UInt32
1795
        );
1796
        // numeric/decimal
1797
        test_coercion_binary_rule!(
1798
            DataType::Int64,
1799
            DataType::Decimal128(10, 0),
1800
            Operator::Eq,
1801
            DataType::Decimal128(20, 0)
1802
        );
1803
        test_coercion_binary_rule!(
1804
            DataType::Int64,
1805
            DataType::Decimal128(10, 2),
1806
            Operator::Lt,
1807
            DataType::Decimal128(22, 2)
1808
        );
1809
        test_coercion_binary_rule!(
1810
            DataType::Float64,
1811
            DataType::Decimal128(10, 3),
1812
            Operator::Gt,
1813
            DataType::Decimal128(30, 15)
1814
        );
1815
        test_coercion_binary_rule!(
1816
            DataType::Int64,
1817
            DataType::Decimal128(10, 0),
1818
            Operator::Eq,
1819
            DataType::Decimal128(20, 0)
1820
        );
1821
        test_coercion_binary_rule!(
1822
            DataType::Decimal128(14, 2),
1823
            DataType::Decimal128(10, 3),
1824
            Operator::GtEq,
1825
            DataType::Decimal128(15, 3)
1826
        );
1827
1828
        // Binary
1829
        test_coercion_binary_rule!(
1830
            DataType::Binary,
1831
            DataType::Binary,
1832
            Operator::Eq,
1833
            DataType::Binary
1834
        );
1835
        test_coercion_binary_rule!(
1836
            DataType::Utf8,
1837
            DataType::Binary,
1838
            Operator::Eq,
1839
            DataType::Binary
1840
        );
1841
        test_coercion_binary_rule!(
1842
            DataType::Binary,
1843
            DataType::Utf8,
1844
            Operator::Eq,
1845
            DataType::Binary
1846
        );
1847
1848
        // LargeBinary
1849
        test_coercion_binary_rule!(
1850
            DataType::LargeBinary,
1851
            DataType::LargeBinary,
1852
            Operator::Eq,
1853
            DataType::LargeBinary
1854
        );
1855
        test_coercion_binary_rule!(
1856
            DataType::Binary,
1857
            DataType::LargeBinary,
1858
            Operator::Eq,
1859
            DataType::LargeBinary
1860
        );
1861
        test_coercion_binary_rule!(
1862
            DataType::LargeBinary,
1863
            DataType::Binary,
1864
            Operator::Eq,
1865
            DataType::LargeBinary
1866
        );
1867
        test_coercion_binary_rule!(
1868
            DataType::Utf8,
1869
            DataType::LargeBinary,
1870
            Operator::Eq,
1871
            DataType::LargeBinary
1872
        );
1873
        test_coercion_binary_rule!(
1874
            DataType::LargeBinary,
1875
            DataType::Utf8,
1876
            Operator::Eq,
1877
            DataType::LargeBinary
1878
        );
1879
        test_coercion_binary_rule!(
1880
            DataType::LargeUtf8,
1881
            DataType::LargeBinary,
1882
            Operator::Eq,
1883
            DataType::LargeBinary
1884
        );
1885
        test_coercion_binary_rule!(
1886
            DataType::LargeBinary,
1887
            DataType::LargeUtf8,
1888
            Operator::Eq,
1889
            DataType::LargeBinary
1890
        );
1891
1892
        // Timestamps
1893
        let utc: Option<Arc<str>> = Some("UTC".into());
1894
        test_coercion_binary_rule!(
1895
            DataType::Timestamp(TimeUnit::Second, utc.clone()),
1896
            DataType::Timestamp(TimeUnit::Second, utc.clone()),
1897
            Operator::Eq,
1898
            DataType::Timestamp(TimeUnit::Second, utc.clone())
1899
        );
1900
        test_coercion_binary_rule!(
1901
            DataType::Timestamp(TimeUnit::Second, utc.clone()),
1902
            DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())),
1903
            Operator::Eq,
1904
            DataType::Timestamp(TimeUnit::Second, utc.clone())
1905
        );
1906
        test_coercion_binary_rule!(
1907
            DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())),
1908
            DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())),
1909
            Operator::Eq,
1910
            DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into()))
1911
        );
1912
        test_coercion_binary_rule!(
1913
            DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())),
1914
            DataType::Timestamp(TimeUnit::Second, utc),
1915
            Operator::Eq,
1916
            DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into()))
1917
        );
1918
1919
        // list
1920
        let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
1921
        test_coercion_binary_rule!(
1922
            DataType::List(Arc::clone(&inner_field)),
1923
            DataType::List(Arc::clone(&inner_field)),
1924
            Operator::Eq,
1925
            DataType::List(Arc::clone(&inner_field))
1926
        );
1927
        test_coercion_binary_rule!(
1928
            DataType::List(Arc::clone(&inner_field)),
1929
            DataType::LargeList(Arc::clone(&inner_field)),
1930
            Operator::Eq,
1931
            DataType::LargeList(Arc::clone(&inner_field))
1932
        );
1933
        test_coercion_binary_rule!(
1934
            DataType::LargeList(Arc::clone(&inner_field)),
1935
            DataType::List(Arc::clone(&inner_field)),
1936
            Operator::Eq,
1937
            DataType::LargeList(Arc::clone(&inner_field))
1938
        );
1939
        test_coercion_binary_rule!(
1940
            DataType::LargeList(Arc::clone(&inner_field)),
1941
            DataType::LargeList(Arc::clone(&inner_field)),
1942
            Operator::Eq,
1943
            DataType::LargeList(Arc::clone(&inner_field))
1944
        );
1945
        test_coercion_binary_rule!(
1946
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1947
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1948
            Operator::Eq,
1949
            DataType::FixedSizeList(Arc::clone(&inner_field), 10)
1950
        );
1951
        test_coercion_binary_rule!(
1952
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1953
            DataType::LargeList(Arc::clone(&inner_field)),
1954
            Operator::Eq,
1955
            DataType::LargeList(Arc::clone(&inner_field))
1956
        );
1957
        test_coercion_binary_rule!(
1958
            DataType::LargeList(Arc::clone(&inner_field)),
1959
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1960
            Operator::Eq,
1961
            DataType::LargeList(Arc::clone(&inner_field))
1962
        );
1963
        test_coercion_binary_rule!(
1964
            DataType::List(Arc::clone(&inner_field)),
1965
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1966
            Operator::Eq,
1967
            DataType::List(Arc::clone(&inner_field))
1968
        );
1969
        test_coercion_binary_rule!(
1970
            DataType::FixedSizeList(Arc::clone(&inner_field), 10),
1971
            DataType::List(Arc::clone(&inner_field)),
1972
            Operator::Eq,
1973
            DataType::List(Arc::clone(&inner_field))
1974
        );
1975
1976
        // TODO add other data type
1977
        Ok(())
1978
    }
1979
1980
    #[test]
1981
    fn test_type_coercion_logical_op() -> Result<()> {
1982
        test_coercion_binary_rule!(
1983
            DataType::Boolean,
1984
            DataType::Boolean,
1985
            Operator::And,
1986
            DataType::Boolean
1987
        );
1988
1989
        test_coercion_binary_rule!(
1990
            DataType::Boolean,
1991
            DataType::Boolean,
1992
            Operator::Or,
1993
            DataType::Boolean
1994
        );
1995
        test_coercion_binary_rule!(
1996
            DataType::Boolean,
1997
            DataType::Null,
1998
            Operator::And,
1999
            DataType::Boolean
2000
        );
2001
        test_coercion_binary_rule!(
2002
            DataType::Boolean,
2003
            DataType::Null,
2004
            Operator::Or,
2005
            DataType::Boolean
2006
        );
2007
        test_coercion_binary_rule!(
2008
            DataType::Null,
2009
            DataType::Null,
2010
            Operator::Or,
2011
            DataType::Boolean
2012
        );
2013
        test_coercion_binary_rule!(
2014
            DataType::Null,
2015
            DataType::Null,
2016
            Operator::And,
2017
            DataType::Boolean
2018
        );
2019
        test_coercion_binary_rule!(
2020
            DataType::Null,
2021
            DataType::Boolean,
2022
            Operator::And,
2023
            DataType::Boolean
2024
        );
2025
        test_coercion_binary_rule!(
2026
            DataType::Null,
2027
            DataType::Boolean,
2028
            Operator::Or,
2029
            DataType::Boolean
2030
        );
2031
        Ok(())
2032
    }
2033
}