/Users/andrewlamb/Software/datafusion/datafusion/expr-common/src/type_coercion/binary.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Coercion rules for matching argument types for binary operators |
19 | | |
20 | | use std::collections::HashSet; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use crate::operator::Operator; |
24 | | |
25 | | use arrow::array::{new_empty_array, Array}; |
26 | | use arrow::compute::can_cast_types; |
27 | | use arrow::datatypes::{ |
28 | | DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, |
29 | | DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, |
30 | | }; |
31 | | use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; |
32 | | |
33 | | /// The type signature of an instantiation of binary operator expression such as |
34 | | /// `lhs + rhs` |
35 | | /// |
36 | | /// Note this is different than [`crate::signature::Signature`] which |
37 | | /// describes the type signature of a function. |
38 | | struct Signature { |
39 | | /// The type to coerce the left argument to |
40 | | lhs: DataType, |
41 | | /// The type to coerce the right argument to |
42 | | rhs: DataType, |
43 | | /// The return type of the expression |
44 | | ret: DataType, |
45 | | } |
46 | | |
47 | | impl Signature { |
48 | | /// A signature where the inputs are the same type as the output |
49 | 23.2k | fn uniform(t: DataType) -> Self { |
50 | 23.2k | Self { |
51 | 23.2k | lhs: t.clone(), |
52 | 23.2k | rhs: t.clone(), |
53 | 23.2k | ret: t, |
54 | 23.2k | } |
55 | 23.2k | } |
56 | | |
57 | | /// A signature where the inputs are the same type with a boolean output |
58 | 48.7k | fn comparison(t: DataType) -> Self { |
59 | 48.7k | Self { |
60 | 48.7k | lhs: t.clone(), |
61 | 48.7k | rhs: t, |
62 | 48.7k | ret: DataType::Boolean, |
63 | 48.7k | } |
64 | 48.7k | } |
65 | | } |
66 | | |
67 | | /// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` |
68 | 214k | fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result<Signature> { |
69 | | use arrow::datatypes::DataType::*; |
70 | | use Operator::*; |
71 | 214k | match op { |
72 | | Eq | |
73 | | NotEq | |
74 | | Lt | |
75 | | LtEq | |
76 | | Gt | |
77 | | GtEq | |
78 | | IsDistinctFrom | |
79 | | IsNotDistinctFrom => { |
80 | 48.7k | comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { |
81 | 0 | plan_datafusion_err!( |
82 | 0 | "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}" |
83 | 0 | ) |
84 | 48.7k | }) |
85 | | } |
86 | 23.2k | And | Or => if matches!0 ((lhs, rhs), (Boolean | Null, Boolean | Null)) { |
87 | | // Logical binary boolean operators can only be evaluated for |
88 | | // boolean or null arguments. |
89 | 23.2k | Ok(Signature::uniform(DataType::Boolean)) |
90 | | } else { |
91 | 0 | plan_err!( |
92 | 0 | "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" |
93 | 0 | ) |
94 | | } |
95 | | RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { |
96 | 0 | regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { |
97 | 0 | plan_datafusion_err!( |
98 | 0 | "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" |
99 | 0 | ) |
100 | 0 | }) |
101 | | } |
102 | | LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => { |
103 | 0 | regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { |
104 | 0 | plan_datafusion_err!( |
105 | 0 | "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" |
106 | 0 | ) |
107 | 0 | }) |
108 | | } |
109 | | BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { |
110 | 0 | bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { |
111 | 0 | plan_datafusion_err!( |
112 | 0 | "Cannot infer common type for bitwise operation {lhs} {op} {rhs}" |
113 | 0 | ) |
114 | 0 | }) |
115 | | } |
116 | | StringConcat => { |
117 | 0 | string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { |
118 | 0 | plan_datafusion_err!( |
119 | 0 | "Cannot infer common string type for string concat operation {lhs} {op} {rhs}" |
120 | 0 | ) |
121 | 0 | }) |
122 | | } |
123 | | AtArrow | ArrowAt => { |
124 | | // ArrowAt and AtArrow check for whether one array is contained in another. |
125 | | // The result type is boolean. Signature::comparison defines this signature. |
126 | | // Operation has nothing to do with comparison |
127 | 0 | array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { |
128 | 0 | plan_datafusion_err!( |
129 | 0 | "Cannot infer common array type for arrow operation {lhs} {op} {rhs}" |
130 | 0 | ) |
131 | 0 | }) |
132 | | } |
133 | | Plus | Minus | Multiply | Divide | Modulo => { |
134 | 142k | let get_result = |lhs, rhs| { |
135 | | use arrow::compute::kernels::numeric::*; |
136 | 142k | let l = new_empty_array(lhs); |
137 | 142k | let r = new_empty_array(rhs); |
138 | | |
139 | 142k | let result = match op { |
140 | 43.5k | Plus => add_wrapping(&l, &r), |
141 | 74.6k | Minus => sub_wrapping(&l, &r), |
142 | 0 | Multiply => mul_wrapping(&l, &r), |
143 | 0 | Divide => div(&l, &r), |
144 | 24.2k | Modulo => rem(&l, &r), |
145 | 0 | _ => unreachable!(), |
146 | | }; |
147 | 142k | result.map(|x| x.data_type().clone()) |
148 | 142k | }; |
149 | | |
150 | 142k | if let Ok(ret) = get_result(lhs, rhs) { |
151 | | // Temporal arithmetic, e.g. Date32 + Interval |
152 | 142k | Ok(Signature{ |
153 | 142k | lhs: lhs.clone(), |
154 | 142k | rhs: rhs.clone(), |
155 | 142k | ret, |
156 | 142k | }) |
157 | 0 | } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { |
158 | | // Temporal arithmetic by first coercing to a common time representation |
159 | | // e.g. Date32 - Timestamp |
160 | 0 | let ret = get_result(&coerced, &coerced).map_err(|e| { |
161 | 0 | plan_datafusion_err!( |
162 | 0 | "Cannot get result type for temporal operation {coerced} {op} {coerced}: {e}" |
163 | 0 | ) |
164 | 0 | })?; |
165 | 0 | Ok(Signature{ |
166 | 0 | lhs: coerced.clone(), |
167 | 0 | rhs: coerced, |
168 | 0 | ret, |
169 | 0 | }) |
170 | 0 | } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { |
171 | | // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) |
172 | 0 | let ret = get_result(&lhs, &rhs).map_err(|e| { |
173 | 0 | plan_datafusion_err!( |
174 | 0 | "Cannot get result type for decimal operation {lhs} {op} {rhs}: {e}" |
175 | 0 | ) |
176 | 0 | })?; |
177 | 0 | Ok(Signature{ |
178 | 0 | lhs, |
179 | 0 | rhs, |
180 | 0 | ret, |
181 | 0 | }) |
182 | 0 | } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { |
183 | | // Numeric arithmetic, e.g. Int32 + Int32 |
184 | 0 | Ok(Signature::uniform(numeric)) |
185 | | } else { |
186 | 0 | plan_err!( |
187 | 0 | "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types" |
188 | 0 | ) |
189 | | } |
190 | | } |
191 | | } |
192 | 214k | } |
193 | | |
194 | | /// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types |
195 | 214k | pub fn get_result_type( |
196 | 214k | lhs: &DataType, |
197 | 214k | op: &Operator, |
198 | 214k | rhs: &DataType, |
199 | 214k | ) -> Result<DataType> { |
200 | 214k | signature(lhs, op, rhs).map(|sig| sig.ret) |
201 | 214k | } |
202 | | |
203 | | /// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types |
204 | 0 | pub fn get_input_types( |
205 | 0 | lhs: &DataType, |
206 | 0 | op: &Operator, |
207 | 0 | rhs: &DataType, |
208 | 0 | ) -> Result<(DataType, DataType)> { |
209 | 0 | signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs)) |
210 | 0 | } |
211 | | |
212 | | /// Coercion rules for mathematics operators between decimal and non-decimal types. |
213 | 0 | fn math_decimal_coercion( |
214 | 0 | lhs_type: &DataType, |
215 | 0 | rhs_type: &DataType, |
216 | 0 | ) -> Option<(DataType, DataType)> { |
217 | | use arrow::datatypes::DataType::*; |
218 | | |
219 | 0 | match (lhs_type, rhs_type) { |
220 | 0 | (Dictionary(_, value_type), _) => { |
221 | 0 | let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?; |
222 | 0 | Some((value_type, rhs_type)) |
223 | | } |
224 | 0 | (_, Dictionary(_, value_type)) => { |
225 | 0 | let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; |
226 | 0 | Some((lhs_type, value_type)) |
227 | | } |
228 | 0 | (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { |
229 | 0 | Some((dec_type.clone(), dec_type.clone())) |
230 | | } |
231 | | (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { |
232 | 0 | Some((lhs_type.clone(), rhs_type.clone())) |
233 | | } |
234 | | // Unlike with comparison we don't coerce to a decimal in the case of floating point |
235 | | // numbers, instead falling back to floating point arithmetic instead |
236 | | (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => { |
237 | 0 | Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?)) |
238 | | } |
239 | | (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => { |
240 | 0 | Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) |
241 | | } |
242 | | (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some(( |
243 | 0 | lhs_type.clone(), |
244 | 0 | coerce_numeric_type_to_decimal256(rhs_type)?, |
245 | | )), |
246 | | (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some(( |
247 | 0 | coerce_numeric_type_to_decimal256(lhs_type)?, |
248 | 0 | rhs_type.clone(), |
249 | | )), |
250 | 0 | _ => None, |
251 | | } |
252 | 0 | } |
253 | | |
254 | | /// Returns the output type of applying bitwise operations such as |
255 | | /// `&`, `|`, or `xor`to arguments of `lhs_type` and `rhs_type`. |
256 | 0 | fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> { |
257 | | use arrow::datatypes::DataType::*; |
258 | | |
259 | 0 | if !both_numeric_or_null_and_numeric(left_type, right_type) { |
260 | 0 | return None; |
261 | 0 | } |
262 | 0 |
|
263 | 0 | if left_type == right_type { |
264 | 0 | return Some(left_type.clone()); |
265 | 0 | } |
266 | 0 |
|
267 | 0 | match (left_type, right_type) { |
268 | 0 | (UInt64, _) | (_, UInt64) => Some(UInt64), |
269 | | (Int64, _) |
270 | | | (_, Int64) |
271 | | | (UInt32, Int8) |
272 | | | (Int8, UInt32) |
273 | | | (UInt32, Int16) |
274 | | | (Int16, UInt32) |
275 | | | (UInt32, Int32) |
276 | 0 | | (Int32, UInt32) => Some(Int64), |
277 | | (Int32, _) |
278 | | | (_, Int32) |
279 | | | (UInt16, Int16) |
280 | | | (Int16, UInt16) |
281 | | | (UInt16, Int8) |
282 | 0 | | (Int8, UInt16) => Some(Int32), |
283 | 0 | (UInt32, _) | (_, UInt32) => Some(UInt32), |
284 | 0 | (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), |
285 | 0 | (UInt16, _) | (_, UInt16) => Some(UInt16), |
286 | 0 | (Int8, _) | (_, Int8) => Some(Int8), |
287 | 0 | (UInt8, _) | (_, UInt8) => Some(UInt8), |
288 | 0 | _ => None, |
289 | | } |
290 | 0 | } |
291 | | |
292 | | #[derive(Debug, PartialEq, Eq, Hash, Clone)] |
293 | | enum TypeCategory { |
294 | | Array, |
295 | | Boolean, |
296 | | Numeric, |
297 | | // String, well-defined type, but are considered as unknown type. |
298 | | DateTime, |
299 | | Composite, |
300 | | Unknown, |
301 | | NotSupported, |
302 | | } |
303 | | |
304 | | impl From<&DataType> for TypeCategory { |
305 | 0 | fn from(data_type: &DataType) -> Self { |
306 | 0 | match data_type { |
307 | | // Dict is a special type in arrow, we check the value type |
308 | 0 | DataType::Dictionary(_, v) => { |
309 | 0 | let v = v.as_ref(); |
310 | 0 | TypeCategory::from(v) |
311 | | } |
312 | | _ => { |
313 | 0 | if data_type.is_numeric() { |
314 | 0 | return TypeCategory::Numeric; |
315 | 0 | } |
316 | | |
317 | 0 | if matches!(data_type, DataType::Boolean) { |
318 | 0 | return TypeCategory::Boolean; |
319 | 0 | } |
320 | | |
321 | 0 | if matches!( |
322 | 0 | data_type, |
323 | | DataType::List(_) |
324 | | | DataType::FixedSizeList(_, _) |
325 | | | DataType::LargeList(_) |
326 | | ) { |
327 | 0 | return TypeCategory::Array; |
328 | 0 | } |
329 | | |
330 | | // String literal is possible to cast to many other types like numeric or datetime, |
331 | | // therefore, it is categorized as a unknown type |
332 | 0 | if matches!( |
333 | 0 | data_type, |
334 | | DataType::Utf8 | DataType::LargeUtf8 | DataType::Null |
335 | | ) { |
336 | 0 | return TypeCategory::Unknown; |
337 | 0 | } |
338 | | |
339 | 0 | if matches!( |
340 | 0 | data_type, |
341 | | DataType::Date32 |
342 | | | DataType::Date64 |
343 | | | DataType::Time32(_) |
344 | | | DataType::Time64(_) |
345 | | | DataType::Timestamp(_, _) |
346 | | | DataType::Interval(_) |
347 | | | DataType::Duration(_) |
348 | | ) { |
349 | 0 | return TypeCategory::DateTime; |
350 | 0 | } |
351 | | |
352 | 0 | if matches!( |
353 | 0 | data_type, |
354 | | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) |
355 | | ) { |
356 | 0 | return TypeCategory::Composite; |
357 | 0 | } |
358 | 0 |
|
359 | 0 | TypeCategory::NotSupported |
360 | | } |
361 | | } |
362 | 0 | } |
363 | | } |
364 | | |
365 | | /// Coerce dissimilar data types to a single data type. |
366 | | /// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are |
367 | | /// examples that has the similar resolution rules. |
368 | | /// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information. |
369 | | /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely |
370 | | /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules |
371 | | /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted |
372 | | /// decimal precision and scale when coercing decimal types. |
373 | 0 | pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> { |
374 | 0 | if data_types.is_empty() { |
375 | 0 | return None; |
376 | 0 | } |
377 | 0 |
|
378 | 0 | // if all the data_types is the same return first one |
379 | 0 | if data_types.iter().all(|t| t == &data_types[0]) { |
380 | 0 | return Some(data_types[0].clone()); |
381 | 0 | } |
382 | 0 |
|
383 | 0 | // if all the data_types are null, return string |
384 | 0 | if data_types.iter().all(|t| t == &DataType::Null) { |
385 | 0 | return Some(DataType::Utf8); |
386 | 0 | } |
387 | 0 |
|
388 | 0 | // Ignore Nulls, if any data_type category is not the same, return None |
389 | 0 | let data_types_category: Vec<TypeCategory> = data_types |
390 | 0 | .iter() |
391 | 0 | .filter(|&t| t != &DataType::Null) |
392 | 0 | .map(|t| t.into()) |
393 | 0 | .collect(); |
394 | 0 |
|
395 | 0 | if data_types_category |
396 | 0 | .iter() |
397 | 0 | .any(|t| t == &TypeCategory::NotSupported) |
398 | | { |
399 | 0 | return None; |
400 | 0 | } |
401 | 0 |
|
402 | 0 | // check if there is only one category excluding Unknown |
403 | 0 | let categories: HashSet<TypeCategory> = HashSet::from_iter( |
404 | 0 | data_types_category |
405 | 0 | .iter() |
406 | 0 | .filter(|&c| c != &TypeCategory::Unknown) |
407 | 0 | .cloned(), |
408 | 0 | ); |
409 | 0 | if categories.len() > 1 { |
410 | 0 | return None; |
411 | 0 | } |
412 | 0 |
|
413 | 0 | // Ignore Nulls |
414 | 0 | let mut candidate_type: Option<DataType> = None; |
415 | 0 | for data_type in data_types.iter() { |
416 | 0 | if data_type == &DataType::Null { |
417 | 0 | continue; |
418 | 0 | } |
419 | 0 | if let Some(ref candidate_t) = candidate_type { |
420 | | // Find candidate type that all the data types can be coerced to |
421 | | // Follows the behavior of Postgres and DuckDB |
422 | | // Coerced type may be different from the candidate and current data type |
423 | | // For example, |
424 | | // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2) |
425 | | // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2) |
426 | 0 | if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) { |
427 | 0 | candidate_type = Some(t); |
428 | 0 | } else { |
429 | 0 | return None; |
430 | | } |
431 | 0 | } else { |
432 | 0 | candidate_type = Some(data_type.clone()); |
433 | 0 | } |
434 | | } |
435 | | |
436 | 0 | candidate_type |
437 | 0 | } |
438 | | |
439 | | /// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution] |
440 | | /// See [type_union_resolution] for more information. |
441 | 0 | fn type_union_resolution_coercion( |
442 | 0 | lhs_type: &DataType, |
443 | 0 | rhs_type: &DataType, |
444 | 0 | ) -> Option<DataType> { |
445 | 0 | if lhs_type == rhs_type { |
446 | 0 | return Some(lhs_type.clone()); |
447 | 0 | } |
448 | 0 |
|
449 | 0 | match (lhs_type, rhs_type) { |
450 | | ( |
451 | 0 | DataType::Dictionary(lhs_index_type, lhs_value_type), |
452 | 0 | DataType::Dictionary(rhs_index_type, rhs_value_type), |
453 | 0 | ) => { |
454 | 0 | let new_index_type = |
455 | 0 | type_union_resolution_coercion(lhs_index_type, rhs_index_type); |
456 | 0 | let new_value_type = |
457 | 0 | type_union_resolution_coercion(lhs_value_type, rhs_value_type); |
458 | 0 | if let (Some(new_index_type), Some(new_value_type)) = |
459 | 0 | (new_index_type, new_value_type) |
460 | | { |
461 | 0 | Some(DataType::Dictionary( |
462 | 0 | Box::new(new_index_type), |
463 | 0 | Box::new(new_value_type), |
464 | 0 | )) |
465 | | } else { |
466 | 0 | None |
467 | | } |
468 | | } |
469 | 0 | (DataType::Dictionary(index_type, value_type), other_type) |
470 | 0 | | (other_type, DataType::Dictionary(index_type, value_type)) => { |
471 | 0 | let new_value_type = type_union_resolution_coercion(value_type, other_type); |
472 | 0 | new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) |
473 | | } |
474 | 0 | (DataType::List(lhs), DataType::List(rhs)) => { |
475 | 0 | let new_item_type = |
476 | 0 | type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); |
477 | 0 | new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) |
478 | | } |
479 | | _ => { |
480 | | // numeric coercion is the same as comparison coercion, both find the narrowest type |
481 | | // that can accommodate both types |
482 | 0 | binary_numeric_coercion(lhs_type, rhs_type) |
483 | 0 | .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) |
484 | 0 | .or_else(|| string_coercion(lhs_type, rhs_type)) |
485 | 0 | .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) |
486 | | } |
487 | | } |
488 | 0 | } |
489 | | |
490 | | /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a |
491 | | /// comparison operation |
492 | | /// |
493 | | /// Example comparison operations are `lhs = rhs` and `lhs > rhs` |
494 | | /// |
495 | | /// Binary comparison kernels require the two arguments to be the (exact) same |
496 | | /// data type. However, users can write queries where the two arguments are |
497 | | /// different data types. In such cases, the data types are automatically cast |
498 | | /// (coerced) to a single data type to pass to the kernels. |
499 | 48.7k | pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
500 | 48.7k | if lhs_type == rhs_type { |
501 | | // same type => equality is possible |
502 | 48.7k | return Some(lhs_type.clone()); |
503 | 1 | } |
504 | 1 | binary_numeric_coercion(lhs_type, rhs_type) |
505 | 1 | .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true)0 ) |
506 | 1 | .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)0 ) |
507 | 1 | .or_else(|| string_coercion(lhs_type, rhs_type)0 ) |
508 | 1 | .or_else(|| list_coercion(lhs_type, rhs_type)0 ) |
509 | 1 | .or_else(|| null_coercion(lhs_type, rhs_type)0 ) |
510 | 1 | .or_else(|| string_numeric_coercion(lhs_type, rhs_type)0 ) |
511 | 1 | .or_else(|| string_temporal_coercion(lhs_type, rhs_type)0 ) |
512 | 1 | .or_else(|| binary_coercion(lhs_type, rhs_type)0 ) |
513 | 1 | .or_else(|| struct_coercion(lhs_type, rhs_type)0 ) |
514 | 48.7k | } |
515 | | |
516 | | /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation |
517 | | /// where one is numeric and one is `Utf8`/`LargeUtf8`. |
518 | 0 | fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
519 | | use arrow::datatypes::DataType::*; |
520 | 0 | match (lhs_type, rhs_type) { |
521 | 0 | (Utf8, _) if rhs_type.is_numeric() => Some(Utf8), |
522 | 0 | (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8), |
523 | 0 | (_, Utf8) if lhs_type.is_numeric() => Some(Utf8), |
524 | 0 | (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8), |
525 | 0 | _ => None, |
526 | | } |
527 | 0 | } |
528 | | |
529 | | /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation |
530 | | /// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`. |
531 | | /// |
532 | | /// Note this cannot be performed in case of arithmetic as there is insufficient information |
533 | | /// to correctly determine the type of argument. Consider |
534 | | /// |
535 | | /// ```sql |
536 | | /// timestamp > now() - '1 month' |
537 | | /// interval > now() - '1970-01-2021' |
538 | | /// ``` |
539 | | /// |
540 | | /// In the absence of a full type inference system, we can't determine the correct type |
541 | | /// to parse the string argument |
542 | 0 | fn string_temporal_coercion( |
543 | 0 | lhs_type: &DataType, |
544 | 0 | rhs_type: &DataType, |
545 | 0 | ) -> Option<DataType> { |
546 | | use arrow::datatypes::DataType::*; |
547 | | |
548 | 0 | fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> { |
549 | 0 | match (l, r) { |
550 | | // Coerce Utf8View/Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp |
551 | 0 | (Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => { |
552 | 0 | match temporal { |
553 | 0 | Date32 | Date64 => Some(temporal.clone()), |
554 | | Time32(_) | Time64(_) => { |
555 | 0 | if is_time_with_valid_unit(temporal.to_owned()) { |
556 | 0 | Some(temporal.to_owned()) |
557 | | } else { |
558 | 0 | None |
559 | | } |
560 | | } |
561 | 0 | Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), |
562 | 0 | _ => None, |
563 | | } |
564 | | } |
565 | 0 | _ => None, |
566 | | } |
567 | 0 | } |
568 | | |
569 | 0 | match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) |
570 | 0 | } |
571 | | |
572 | | /// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric |
573 | 1 | pub fn binary_numeric_coercion( |
574 | 1 | lhs_type: &DataType, |
575 | 1 | rhs_type: &DataType, |
576 | 1 | ) -> Option<DataType> { |
577 | | use arrow::datatypes::DataType::*; |
578 | 1 | if !lhs_type.is_numeric() || !rhs_type.is_numeric() { |
579 | 0 | return None; |
580 | 1 | }; |
581 | 1 | |
582 | 1 | // same type => all good |
583 | 1 | if lhs_type == rhs_type { |
584 | 0 | return Some(lhs_type.clone()); |
585 | 1 | } |
586 | | |
587 | 1 | if let Some(t) = decimal_coercion(lhs_type, rhs_type) { |
588 | 1 | return Some(t); |
589 | 0 | } |
590 | 0 |
|
591 | 0 | // these are ordered from most informative to least informative so |
592 | 0 | // that the coercion does not lose information via truncation |
593 | 0 | match (lhs_type, rhs_type) { |
594 | 0 | (Float64, _) | (_, Float64) => Some(Float64), |
595 | 0 | (_, Float32) | (Float32, _) => Some(Float32), |
596 | | // The following match arms encode the following logic: Given the two |
597 | | // integral types, we choose the narrowest possible integral type that |
598 | | // accommodates all values of both types. Note that some information |
599 | | // loss is inevitable when we have a signed type and a `UInt64`, in |
600 | | // which case we use `Int64`;i.e. the widest signed integral type. |
601 | | |
602 | | // TODO: For i64 and u64, we can use decimal or float64 |
603 | | // Postgres has no unsigned type :( |
604 | | // DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes)) |
605 | | // for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer) |
606 | | (Int64, _) |
607 | | | (_, Int64) |
608 | | | (UInt64, Int8) |
609 | | | (Int8, UInt64) |
610 | | | (UInt64, Int16) |
611 | | | (Int16, UInt64) |
612 | | | (UInt64, Int32) |
613 | | | (Int32, UInt64) |
614 | | | (UInt32, Int8) |
615 | | | (Int8, UInt32) |
616 | | | (UInt32, Int16) |
617 | | | (Int16, UInt32) |
618 | | | (UInt32, Int32) |
619 | 0 | | (Int32, UInt32) => Some(Int64), |
620 | 0 | (UInt64, _) | (_, UInt64) => Some(UInt64), |
621 | | (Int32, _) |
622 | | | (_, Int32) |
623 | | | (UInt16, Int16) |
624 | | | (Int16, UInt16) |
625 | | | (UInt16, Int8) |
626 | 0 | | (Int8, UInt16) => Some(Int32), |
627 | 0 | (UInt32, _) | (_, UInt32) => Some(UInt32), |
628 | 0 | (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16), |
629 | 0 | (UInt16, _) | (_, UInt16) => Some(UInt16), |
630 | 0 | (Int8, _) | (_, Int8) => Some(Int8), |
631 | 0 | (UInt8, _) | (_, UInt8) => Some(UInt8), |
632 | 0 | _ => None, |
633 | | } |
634 | 1 | } |
635 | | |
636 | | /// Decimal coercion rules. |
637 | 1 | pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
638 | | use arrow::datatypes::DataType::*; |
639 | | |
640 | 1 | match (lhs_type, rhs_type) { |
641 | | // Prefer decimal data type over floating point for comparison operation |
642 | | (Decimal128(_, _), Decimal128(_, _)) => { |
643 | 1 | get_wider_decimal_type(lhs_type, rhs_type) |
644 | | } |
645 | 0 | (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), |
646 | 0 | (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), |
647 | | (Decimal256(_, _), Decimal256(_, _)) => { |
648 | 0 | get_wider_decimal_type(lhs_type, rhs_type) |
649 | | } |
650 | 0 | (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), |
651 | 0 | (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), |
652 | 0 | (_, _) => None, |
653 | | } |
654 | 1 | } |
655 | | |
656 | | /// Coerce `lhs_type` and `rhs_type` to a common type. |
657 | 0 | fn get_common_decimal_type( |
658 | 0 | decimal_type: &DataType, |
659 | 0 | other_type: &DataType, |
660 | 0 | ) -> Option<DataType> { |
661 | | use arrow::datatypes::DataType::*; |
662 | 0 | match decimal_type { |
663 | | Decimal128(_, _) => { |
664 | 0 | let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?; |
665 | 0 | get_wider_decimal_type(decimal_type, &other_decimal_type) |
666 | | } |
667 | | Decimal256(_, _) => { |
668 | 0 | let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?; |
669 | 0 | get_wider_decimal_type(decimal_type, &other_decimal_type) |
670 | | } |
671 | 0 | _ => None, |
672 | | } |
673 | 0 | } |
674 | | |
675 | | /// Returns a `DataType::Decimal128` that can store any value from either |
676 | | /// `lhs_decimal_type` and `rhs_decimal_type` |
677 | | /// |
678 | | /// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`. |
679 | 1 | fn get_wider_decimal_type( |
680 | 1 | lhs_decimal_type: &DataType, |
681 | 1 | rhs_type: &DataType, |
682 | 1 | ) -> Option<DataType> { |
683 | 1 | match (lhs_decimal_type, rhs_type) { |
684 | 1 | (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { |
685 | 1 | // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) |
686 | 1 | let s = *s1.max(s2); |
687 | 1 | let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); |
688 | 1 | Some(create_decimal_type((range + s) as u8, s)) |
689 | | } |
690 | 0 | (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => { |
691 | 0 | // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) |
692 | 0 | let s = *s1.max(s2); |
693 | 0 | let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); |
694 | 0 | Some(create_decimal256_type((range + s) as u8, s)) |
695 | | } |
696 | 0 | (_, _) => None, |
697 | | } |
698 | 1 | } |
699 | | |
700 | | /// Returns the wider type among arguments `lhs` and `rhs`. |
701 | | /// The wider type is the type that can safely represent values from both types |
702 | | /// without information loss. Returns an Error if types are incompatible. |
703 | 0 | pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> { |
704 | | use arrow::datatypes::DataType::*; |
705 | 0 | Ok(match (lhs, rhs) { |
706 | 0 | (lhs, rhs) if lhs == rhs => lhs.clone(), |
707 | | // Right UInt is larger than left UInt. |
708 | | (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) | |
709 | | // Right Int is larger than left Int. |
710 | | (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) | |
711 | | // Right Float is larger than left Float. |
712 | | (Float16, Float32 | Float64) | (Float32, Float64) | |
713 | | // Right String is larger than left String. |
714 | | (Utf8, LargeUtf8) | |
715 | | // Any right type is wider than a left hand side Null. |
716 | 0 | (Null, _) => rhs.clone(), |
717 | | // Left UInt is larger than right UInt. |
718 | | (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) | |
719 | | // Left Int is larger than right Int. |
720 | | (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | |
721 | | // Left Float is larger than right Float. |
722 | | (Float32 | Float64, Float16) | (Float64, Float32) | |
723 | | // Left String is larger than right String. |
724 | | (LargeUtf8, Utf8) | |
725 | | // Any left type is wider than a right hand side Null. |
726 | 0 | (_, Null) => lhs.clone(), |
727 | 0 | (List(lhs_field), List(rhs_field)) => { |
728 | 0 | let field_type = |
729 | 0 | get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; |
730 | 0 | if lhs_field.name() != rhs_field.name() { |
731 | 0 | return Err(exec_datafusion_err!( |
732 | 0 | "There is no wider type that can represent both {lhs} and {rhs}." |
733 | 0 | )); |
734 | 0 | } |
735 | 0 | assert_eq!(lhs_field.name(), rhs_field.name()); |
736 | 0 | let field_name = lhs_field.name(); |
737 | 0 | let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); |
738 | 0 | List(Arc::new(Field::new(field_name, field_type, nullable))) |
739 | | } |
740 | | (_, _) => { |
741 | 0 | return Err(exec_datafusion_err!( |
742 | 0 | "There is no wider type that can represent both {lhs} and {rhs}." |
743 | 0 | )); |
744 | | } |
745 | | }) |
746 | 0 | } |
747 | | |
748 | | /// Convert the numeric data type to the decimal data type. |
749 | | /// Now, we just support the signed integer type and floating-point type. |
750 | 0 | fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> { |
751 | | use arrow::datatypes::DataType::*; |
752 | | // This conversion rule is from spark |
753 | | // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 |
754 | 0 | match numeric_type { |
755 | 0 | Int8 => Some(Decimal128(3, 0)), |
756 | 0 | Int16 => Some(Decimal128(5, 0)), |
757 | 0 | Int32 => Some(Decimal128(10, 0)), |
758 | 0 | Int64 => Some(Decimal128(20, 0)), |
759 | | // TODO if we convert the floating-point data to the decimal type, it maybe overflow. |
760 | 0 | Float32 => Some(Decimal128(14, 7)), |
761 | 0 | Float64 => Some(Decimal128(30, 15)), |
762 | 0 | _ => None, |
763 | | } |
764 | 0 | } |
765 | | |
766 | | /// Convert the numeric data type to the decimal data type. |
767 | | /// Now, we just support the signed integer type and floating-point type. |
768 | 0 | fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> { |
769 | | use arrow::datatypes::DataType::*; |
770 | | // This conversion rule is from spark |
771 | | // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 |
772 | 0 | match numeric_type { |
773 | 0 | Int8 => Some(Decimal256(3, 0)), |
774 | 0 | Int16 => Some(Decimal256(5, 0)), |
775 | 0 | Int32 => Some(Decimal256(10, 0)), |
776 | 0 | Int64 => Some(Decimal256(20, 0)), |
777 | | // TODO if we convert the floating-point data to the decimal type, it maybe overflow. |
778 | 0 | Float32 => Some(Decimal256(14, 7)), |
779 | 0 | Float64 => Some(Decimal256(30, 15)), |
780 | 0 | _ => None, |
781 | | } |
782 | 0 | } |
783 | | |
784 | 0 | fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
785 | | use arrow::datatypes::DataType::*; |
786 | 0 | match (lhs_type, rhs_type) { |
787 | 0 | (Struct(lhs_fields), Struct(rhs_fields)) => { |
788 | 0 | if lhs_fields.len() != rhs_fields.len() { |
789 | 0 | return None; |
790 | 0 | } |
791 | | |
792 | 0 | let types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) |
793 | 0 | .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) |
794 | 0 | .collect::<Option<Vec<DataType>>>()?; |
795 | | |
796 | 0 | let fields = types |
797 | 0 | .into_iter() |
798 | 0 | .enumerate() |
799 | 0 | .map(|(i, datatype)| { |
800 | 0 | Arc::new(Field::new(format!("c{i}"), datatype, true)) |
801 | 0 | }) |
802 | 0 | .collect::<Vec<FieldRef>>(); |
803 | 0 | Some(Struct(fields.into())) |
804 | | } |
805 | 0 | _ => None, |
806 | | } |
807 | 0 | } |
808 | | |
809 | | /// Returns the output type of applying mathematics operations such as |
810 | | /// `+` to arguments of `lhs_type` and `rhs_type`. |
811 | 0 | fn mathematics_numerical_coercion( |
812 | 0 | lhs_type: &DataType, |
813 | 0 | rhs_type: &DataType, |
814 | 0 | ) -> Option<DataType> { |
815 | | use arrow::datatypes::DataType::*; |
816 | | |
817 | | // error on any non-numeric type |
818 | 0 | if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { |
819 | 0 | return None; |
820 | 0 | }; |
821 | 0 |
|
822 | 0 | // these are ordered from most informative to least informative so |
823 | 0 | // that the coercion removes the least amount of information |
824 | 0 | match (lhs_type, rhs_type) { |
825 | 0 | (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { |
826 | 0 | mathematics_numerical_coercion(lhs_value_type, rhs_value_type) |
827 | | } |
828 | 0 | (Dictionary(_, value_type), _) => { |
829 | 0 | mathematics_numerical_coercion(value_type, rhs_type) |
830 | | } |
831 | 0 | (_, Dictionary(_, value_type)) => { |
832 | 0 | mathematics_numerical_coercion(lhs_type, value_type) |
833 | | } |
834 | 0 | (Float64, _) | (_, Float64) => Some(Float64), |
835 | 0 | (_, Float32) | (Float32, _) => Some(Float32), |
836 | 0 | (Int64, _) | (_, Int64) => Some(Int64), |
837 | 0 | (Int32, _) | (_, Int32) => Some(Int32), |
838 | 0 | (Int16, _) | (_, Int16) => Some(Int16), |
839 | 0 | (Int8, _) | (_, Int8) => Some(Int8), |
840 | 0 | (UInt64, _) | (_, UInt64) => Some(UInt64), |
841 | 0 | (UInt32, _) | (_, UInt32) => Some(UInt32), |
842 | 0 | (UInt16, _) | (_, UInt16) => Some(UInt16), |
843 | 0 | (UInt8, _) | (_, UInt8) => Some(UInt8), |
844 | 0 | _ => None, |
845 | | } |
846 | 0 | } |
847 | | |
848 | 1 | fn create_decimal_type(precision: u8, scale: i8) -> DataType { |
849 | 1 | DataType::Decimal128( |
850 | 1 | DECIMAL128_MAX_PRECISION.min(precision), |
851 | 1 | DECIMAL128_MAX_SCALE.min(scale), |
852 | 1 | ) |
853 | 1 | } |
854 | | |
855 | 0 | fn create_decimal256_type(precision: u8, scale: i8) -> DataType { |
856 | 0 | DataType::Decimal256( |
857 | 0 | DECIMAL256_MAX_PRECISION.min(precision), |
858 | 0 | DECIMAL256_MAX_SCALE.min(scale), |
859 | 0 | ) |
860 | 0 | } |
861 | | |
862 | | /// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric |
863 | 0 | fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { |
864 | | use arrow::datatypes::DataType::*; |
865 | 0 | match (lhs_type, rhs_type) { |
866 | 0 | (_, Null) => lhs_type.is_numeric(), |
867 | 0 | (Null, _) => rhs_type.is_numeric(), |
868 | 0 | (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { |
869 | 0 | lhs_value_type.is_numeric() && rhs_value_type.is_numeric() |
870 | | } |
871 | 0 | (Dictionary(_, value_type), _) => { |
872 | 0 | value_type.is_numeric() && rhs_type.is_numeric() |
873 | | } |
874 | 0 | (_, Dictionary(_, value_type)) => { |
875 | 0 | lhs_type.is_numeric() && value_type.is_numeric() |
876 | | } |
877 | 0 | _ => lhs_type.is_numeric() && rhs_type.is_numeric(), |
878 | | } |
879 | 0 | } |
880 | | |
881 | | /// Coercion rules for Dictionaries: the type that both lhs and rhs |
882 | | /// can be casted to for the purpose of a computation. |
883 | | /// |
884 | | /// Not all operators support dictionaries, if `preserve_dictionaries` is true |
885 | | /// dictionaries will be preserved if possible |
886 | 0 | fn dictionary_comparison_coercion( |
887 | 0 | lhs_type: &DataType, |
888 | 0 | rhs_type: &DataType, |
889 | 0 | preserve_dictionaries: bool, |
890 | 0 | ) -> Option<DataType> { |
891 | | use arrow::datatypes::DataType::*; |
892 | 0 | match (lhs_type, rhs_type) { |
893 | | ( |
894 | 0 | Dictionary(_lhs_index_type, lhs_value_type), |
895 | 0 | Dictionary(_rhs_index_type, rhs_value_type), |
896 | 0 | ) => comparison_coercion(lhs_value_type, rhs_value_type), |
897 | 0 | (d @ Dictionary(_, value_type), other_type) |
898 | 0 | | (other_type, d @ Dictionary(_, value_type)) |
899 | 0 | if preserve_dictionaries && value_type.as_ref() == other_type => |
900 | | { |
901 | 0 | Some(d.clone()) |
902 | | } |
903 | 0 | (Dictionary(_index_type, value_type), _) => { |
904 | 0 | comparison_coercion(value_type, rhs_type) |
905 | | } |
906 | 0 | (_, Dictionary(_index_type, value_type)) => { |
907 | 0 | comparison_coercion(lhs_type, value_type) |
908 | | } |
909 | 0 | _ => None, |
910 | | } |
911 | 0 | } |
912 | | |
913 | | /// Coercion rules for string concat. |
914 | | /// This is a union of string coercion rules and specified rules: |
915 | | /// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) |
916 | | /// 2. Data type of the other side should be able to cast to string type |
917 | 0 | fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
918 | | use arrow::datatypes::DataType::*; |
919 | 0 | string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { |
920 | 0 | (Utf8View, from_type) | (from_type, Utf8View) => { |
921 | 0 | string_concat_internal_coercion(from_type, &Utf8View) |
922 | | } |
923 | 0 | (Utf8, from_type) | (from_type, Utf8) => { |
924 | 0 | string_concat_internal_coercion(from_type, &Utf8) |
925 | | } |
926 | 0 | (LargeUtf8, from_type) | (from_type, LargeUtf8) => { |
927 | 0 | string_concat_internal_coercion(from_type, &LargeUtf8) |
928 | | } |
929 | 0 | (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { |
930 | 0 | string_coercion(lhs_value_type, rhs_value_type).or(None) |
931 | | } |
932 | 0 | _ => None, |
933 | | }) |
934 | 0 | } |
935 | | |
936 | 0 | fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
937 | 0 | if lhs_type.equals_datatype(rhs_type) { |
938 | 0 | Some(lhs_type.to_owned()) |
939 | | } else { |
940 | 0 | None |
941 | | } |
942 | 0 | } |
943 | | |
944 | | /// If `from_type` can be casted to `to_type`, return `to_type`, otherwise |
945 | | /// return `None`. |
946 | 0 | fn string_concat_internal_coercion( |
947 | 0 | from_type: &DataType, |
948 | 0 | to_type: &DataType, |
949 | 0 | ) -> Option<DataType> { |
950 | 0 | if can_cast_types(from_type, to_type) { |
951 | 0 | Some(to_type.to_owned()) |
952 | | } else { |
953 | 0 | None |
954 | | } |
955 | 0 | } |
956 | | |
957 | | /// Coercion rules for string view types (Utf8/LargeUtf8/Utf8View): |
958 | | /// If at least one argument is a string view, we coerce to string view |
959 | | /// based on the observation that StringArray to StringViewArray is cheap but not vice versa. |
960 | | /// |
961 | | /// Between Utf8 and LargeUtf8, we coerce to LargeUtf8. |
962 | 0 | fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
963 | | use arrow::datatypes::DataType::*; |
964 | 0 | match (lhs_type, rhs_type) { |
965 | | // If Utf8View is in any side, we coerce to Utf8View. |
966 | | (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { |
967 | 0 | Some(Utf8View) |
968 | | } |
969 | | // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. |
970 | 0 | (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), |
971 | | // Utf8 coerces to Utf8 |
972 | 0 | (Utf8, Utf8) => Some(Utf8), |
973 | 0 | _ => None, |
974 | | } |
975 | 0 | } |
976 | | |
977 | | /// This will be deprecated when binary operators native support |
978 | | /// for Utf8View (use `string_coercion` instead). |
979 | 0 | fn regex_comparison_string_coercion( |
980 | 0 | lhs_type: &DataType, |
981 | 0 | rhs_type: &DataType, |
982 | 0 | ) -> Option<DataType> { |
983 | | use arrow::datatypes::DataType::*; |
984 | 0 | match (lhs_type, rhs_type) { |
985 | | // If Utf8View is in any side, we coerce to Utf8. |
986 | | (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { |
987 | 0 | Some(Utf8) |
988 | | } |
989 | | // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. |
990 | 0 | (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), |
991 | | // Utf8 coerces to Utf8 |
992 | 0 | (Utf8, Utf8) => Some(Utf8), |
993 | 0 | _ => None, |
994 | | } |
995 | 0 | } |
996 | | |
997 | 0 | fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
998 | | use arrow::datatypes::DataType::*; |
999 | 0 | match (lhs_type, rhs_type) { |
1000 | 0 | (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8) |
1001 | 0 | if other_type.is_numeric() => |
1002 | | { |
1003 | 0 | Some(other_type.clone()) |
1004 | | } |
1005 | 0 | _ => None, |
1006 | | } |
1007 | 0 | } |
1008 | | |
1009 | | /// Coercion rules for list types. |
1010 | 0 | fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1011 | | use arrow::datatypes::DataType::*; |
1012 | 0 | match (lhs_type, rhs_type) { |
1013 | 0 | (List(_), List(_)) => Some(lhs_type.clone()), |
1014 | 0 | (LargeList(_), List(_)) => Some(lhs_type.clone()), |
1015 | 0 | (List(_), LargeList(_)) => Some(rhs_type.clone()), |
1016 | 0 | (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), |
1017 | 0 | (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), |
1018 | 0 | (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), |
1019 | | // Coerce to the left side FixedSizeList type if the list lengths are the same, |
1020 | | // otherwise coerce to list with the left type for dynamic length |
1021 | 0 | (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { |
1022 | 0 | if ls == rs { |
1023 | 0 | Some(lhs_type.clone()) |
1024 | | } else { |
1025 | 0 | Some(List(Arc::clone(lf))) |
1026 | | } |
1027 | | } |
1028 | 0 | (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), |
1029 | 0 | (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), |
1030 | 0 | _ => None, |
1031 | | } |
1032 | 0 | } |
1033 | | |
1034 | | /// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8): |
1035 | | /// If one argument is binary and the other is a string then coerce to string |
1036 | | /// (e.g. for `like`) |
1037 | 0 | fn binary_to_string_coercion( |
1038 | 0 | lhs_type: &DataType, |
1039 | 0 | rhs_type: &DataType, |
1040 | 0 | ) -> Option<DataType> { |
1041 | | use arrow::datatypes::DataType::*; |
1042 | 0 | match (lhs_type, rhs_type) { |
1043 | 0 | (Binary, Utf8) => Some(Utf8), |
1044 | 0 | (Binary, LargeUtf8) => Some(LargeUtf8), |
1045 | 0 | (BinaryView, Utf8) => Some(Utf8View), |
1046 | 0 | (BinaryView, LargeUtf8) => Some(LargeUtf8), |
1047 | 0 | (LargeBinary, Utf8) => Some(LargeUtf8), |
1048 | 0 | (LargeBinary, LargeUtf8) => Some(LargeUtf8), |
1049 | 0 | (Utf8, Binary) => Some(Utf8), |
1050 | 0 | (Utf8, LargeBinary) => Some(LargeUtf8), |
1051 | 0 | (Utf8, BinaryView) => Some(Utf8View), |
1052 | 0 | (LargeUtf8, Binary) => Some(LargeUtf8), |
1053 | 0 | (LargeUtf8, LargeBinary) => Some(LargeUtf8), |
1054 | 0 | (LargeUtf8, BinaryView) => Some(LargeUtf8), |
1055 | 0 | _ => None, |
1056 | | } |
1057 | 0 | } |
1058 | | |
1059 | | /// Coercion rules for binary types (Binary/LargeBinary/BinaryView): If at least one argument is |
1060 | | /// a binary type and both arguments can be coerced into a binary type, coerce |
1061 | | /// to binary type. |
1062 | 0 | fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1063 | | use arrow::datatypes::DataType::*; |
1064 | 0 | match (lhs_type, rhs_type) { |
1065 | | // If BinaryView is in any side, we coerce to BinaryView. |
1066 | | (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) |
1067 | | | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => { |
1068 | 0 | Some(BinaryView) |
1069 | | } |
1070 | | // Prefer LargeBinary over Binary |
1071 | | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary) |
1072 | 0 | | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary), |
1073 | | |
1074 | | // If Utf8View/LargeUtf8 presents need to be large Binary |
1075 | | (Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => { |
1076 | 0 | Some(LargeBinary) |
1077 | | } |
1078 | 0 | (Binary, Utf8) | (Utf8, Binary) => Some(Binary), |
1079 | 0 | _ => None, |
1080 | | } |
1081 | 0 | } |
1082 | | |
1083 | | /// coercion rules for like operations. |
1084 | | /// This is a union of string coercion rules and dictionary coercion rules |
1085 | 0 | pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1086 | 0 | string_coercion(lhs_type, rhs_type) |
1087 | 0 | .or_else(|| list_coercion(lhs_type, rhs_type)) |
1088 | 0 | .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) |
1089 | 0 | .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) |
1090 | 0 | .or_else(|| regex_null_coercion(lhs_type, rhs_type)) |
1091 | 0 | .or_else(|| null_coercion(lhs_type, rhs_type)) |
1092 | 0 | } |
1093 | | |
1094 | | /// coercion rules for regular expression comparison operations with NULL input. |
1095 | 0 | fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1096 | | use arrow::datatypes::DataType::*; |
1097 | 0 | match (lhs_type, rhs_type) { |
1098 | 0 | (DataType::Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), |
1099 | 0 | (Utf8View | Utf8 | LargeUtf8, DataType::Null) => Some(lhs_type.clone()), |
1100 | 0 | (DataType::Null, DataType::Null) => Some(Utf8), |
1101 | 0 | _ => None, |
1102 | | } |
1103 | 0 | } |
1104 | | |
1105 | | /// Coercion rules for regular expression comparison operations. |
1106 | | /// This is a union of string coercion rules and dictionary coercion rules |
1107 | 0 | pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1108 | 0 | regex_comparison_string_coercion(lhs_type, rhs_type) |
1109 | 0 | .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) |
1110 | 0 | .or_else(|| regex_null_coercion(lhs_type, rhs_type)) |
1111 | 0 | } |
1112 | | |
1113 | | /// Checks if the TimeUnit associated with a Time32 or Time64 type is consistent, |
1114 | | /// as Time32 can only be used to Second and Millisecond accuracy, while Time64 |
1115 | | /// is exclusively used to Microsecond and Nanosecond accuracy |
1116 | 0 | fn is_time_with_valid_unit(datatype: DataType) -> bool { |
1117 | 0 | matches!( |
1118 | 0 | datatype, |
1119 | | DataType::Time32(TimeUnit::Second) |
1120 | | | DataType::Time32(TimeUnit::Millisecond) |
1121 | | | DataType::Time64(TimeUnit::Microsecond) |
1122 | | | DataType::Time64(TimeUnit::Nanosecond) |
1123 | | ) |
1124 | 0 | } |
1125 | | |
1126 | | /// Non-strict Timezone Coercion is useful in scenarios where we can guarantee |
1127 | | /// a stable relationship between two timestamps of different timezones. |
1128 | | /// |
1129 | | /// An example of this is binary comparisons (<, >, ==, etc). Arrow stores timestamps |
1130 | | /// as relative to UTC epoch, and then adds the timezone as an offset. As a result, we can always |
1131 | | /// do a binary comparison between the two times. |
1132 | | /// |
1133 | | /// Timezone coercion is handled by the following rules: |
1134 | | /// - If only one has a timezone, coerce the other to match |
1135 | | /// - If both have a timezone, coerce to the left type |
1136 | | /// - "UTC" and "+00:00" are considered equivalent |
1137 | 0 | fn temporal_coercion_nonstrict_timezone( |
1138 | 0 | lhs_type: &DataType, |
1139 | 0 | rhs_type: &DataType, |
1140 | 0 | ) -> Option<DataType> { |
1141 | | use arrow::datatypes::DataType::*; |
1142 | | |
1143 | 0 | match (lhs_type, rhs_type) { |
1144 | 0 | (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { |
1145 | 0 | let tz = match (lhs_tz, rhs_tz) { |
1146 | | // If both have a timezone, use the left timezone. |
1147 | 0 | (Some(lhs_tz), Some(_rhs_tz)) => Some(Arc::clone(lhs_tz)), |
1148 | 0 | (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), |
1149 | 0 | (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), |
1150 | 0 | (None, None) => None, |
1151 | | }; |
1152 | | |
1153 | 0 | let unit = timeunit_coercion(lhs_unit, rhs_unit); |
1154 | 0 |
|
1155 | 0 | Some(Timestamp(unit, tz)) |
1156 | | } |
1157 | 0 | _ => temporal_coercion(lhs_type, rhs_type), |
1158 | | } |
1159 | 0 | } |
1160 | | |
1161 | | /// Strict Timezone coercion is useful in scenarios where we cannot guarantee a stable relationship |
1162 | | /// between two timestamps with different timezones or do not want implicit coercion between them. |
1163 | | /// |
1164 | | /// An example of this when attempting to coerce function arguments. Functions already have a mechanism |
1165 | | /// for defining which timestamp types they want to support, so we do not want to do any further coercion. |
1166 | | /// |
1167 | | /// Coercion rules for Temporal columns: the type that both lhs and rhs can be |
1168 | | /// casted to for the purpose of a date computation |
1169 | | /// For interval arithmetic, it doesn't handle datetime type +/- interval |
1170 | | /// Timezone coercion is handled by the following rules: |
1171 | | /// - If only one has a timezone, coerce the other to match |
1172 | | /// - If both have a timezone, throw an error |
1173 | | /// - "UTC" and "+00:00" are considered equivalent |
1174 | 0 | fn temporal_coercion_strict_timezone( |
1175 | 0 | lhs_type: &DataType, |
1176 | 0 | rhs_type: &DataType, |
1177 | 0 | ) -> Option<DataType> { |
1178 | | use arrow::datatypes::DataType::*; |
1179 | | |
1180 | 0 | match (lhs_type, rhs_type) { |
1181 | 0 | (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { |
1182 | 0 | let tz = match (lhs_tz, rhs_tz) { |
1183 | 0 | (Some(lhs_tz), Some(rhs_tz)) => { |
1184 | 0 | match (lhs_tz.as_ref(), rhs_tz.as_ref()) { |
1185 | 0 | // UTC and "+00:00" are the same by definition. Most other timezones |
1186 | 0 | // do not have a 1-1 mapping between timezone and an offset from UTC |
1187 | 0 | ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)), |
1188 | 0 | (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)), |
1189 | | // can't cast across timezones |
1190 | | _ => { |
1191 | 0 | return None; |
1192 | | } |
1193 | | } |
1194 | | } |
1195 | 0 | (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), |
1196 | 0 | (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), |
1197 | 0 | (None, None) => None, |
1198 | | }; |
1199 | | |
1200 | 0 | let unit = timeunit_coercion(lhs_unit, rhs_unit); |
1201 | 0 |
|
1202 | 0 | Some(Timestamp(unit, tz)) |
1203 | | } |
1204 | 0 | _ => temporal_coercion(lhs_type, rhs_type), |
1205 | | } |
1206 | 0 | } |
1207 | | |
1208 | 0 | fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1209 | | use arrow::datatypes::DataType::*; |
1210 | | use arrow::datatypes::IntervalUnit::*; |
1211 | | use arrow::datatypes::TimeUnit::*; |
1212 | | |
1213 | 0 | match (lhs_type, rhs_type) { |
1214 | | (Interval(_) | Duration(_), Interval(_) | Duration(_)) => { |
1215 | 0 | Some(Interval(MonthDayNano)) |
1216 | | } |
1217 | 0 | (Date64, Date32) | (Date32, Date64) => Some(Date64), |
1218 | | (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { |
1219 | 0 | Some(Timestamp(Nanosecond, None)) |
1220 | | } |
1221 | 0 | (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => { |
1222 | 0 | Some(Timestamp(Nanosecond, None)) |
1223 | | } |
1224 | | (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { |
1225 | 0 | Some(Timestamp(Nanosecond, None)) |
1226 | | } |
1227 | 0 | (Timestamp(_, _tz), Date32) | (Date32, Timestamp(_, _tz)) => { |
1228 | 0 | Some(Timestamp(Nanosecond, None)) |
1229 | | } |
1230 | 0 | _ => None, |
1231 | | } |
1232 | 0 | } |
1233 | | |
1234 | 0 | fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit { |
1235 | | use arrow::datatypes::TimeUnit::*; |
1236 | 0 | match (lhs_unit, rhs_unit) { |
1237 | 0 | (Second, Millisecond) => Second, |
1238 | 0 | (Second, Microsecond) => Second, |
1239 | 0 | (Second, Nanosecond) => Second, |
1240 | 0 | (Millisecond, Second) => Second, |
1241 | 0 | (Millisecond, Microsecond) => Millisecond, |
1242 | 0 | (Millisecond, Nanosecond) => Millisecond, |
1243 | 0 | (Microsecond, Second) => Second, |
1244 | 0 | (Microsecond, Millisecond) => Millisecond, |
1245 | 0 | (Microsecond, Nanosecond) => Microsecond, |
1246 | 0 | (Nanosecond, Second) => Second, |
1247 | 0 | (Nanosecond, Millisecond) => Millisecond, |
1248 | 0 | (Nanosecond, Microsecond) => Microsecond, |
1249 | 0 | (l, r) => { |
1250 | 0 | assert_eq!(l, r); |
1251 | 0 | *l |
1252 | | } |
1253 | | } |
1254 | 0 | } |
1255 | | |
1256 | | /// coercion rules from NULL type. Since NULL can be casted to any other type in arrow, |
1257 | | /// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid. |
1258 | 0 | fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { |
1259 | 0 | match (lhs_type, rhs_type) { |
1260 | 0 | (DataType::Null, other_type) | (other_type, DataType::Null) => { |
1261 | 0 | if can_cast_types(&DataType::Null, other_type) { |
1262 | 0 | Some(other_type.clone()) |
1263 | | } else { |
1264 | 0 | None |
1265 | | } |
1266 | | } |
1267 | 0 | _ => None, |
1268 | | } |
1269 | 0 | } |
1270 | | |
1271 | | #[cfg(test)] |
1272 | | mod tests { |
1273 | | use super::*; |
1274 | | |
1275 | | use datafusion_common::assert_contains; |
1276 | | |
1277 | | #[test] |
1278 | | fn test_coercion_error() -> Result<()> { |
1279 | | let result_type = |
1280 | | get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); |
1281 | | |
1282 | | let e = result_type.unwrap_err(); |
1283 | | assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types"); |
1284 | | Ok(()) |
1285 | | } |
1286 | | |
1287 | | #[test] |
1288 | | fn test_decimal_binary_comparison_coercion() -> Result<()> { |
1289 | | let input_decimal = DataType::Decimal128(20, 3); |
1290 | | let input_types = [ |
1291 | | DataType::Int8, |
1292 | | DataType::Int16, |
1293 | | DataType::Int32, |
1294 | | DataType::Int64, |
1295 | | DataType::Float32, |
1296 | | DataType::Float64, |
1297 | | DataType::Decimal128(38, 10), |
1298 | | DataType::Decimal128(20, 8), |
1299 | | DataType::Null, |
1300 | | ]; |
1301 | | let result_types = [ |
1302 | | DataType::Decimal128(20, 3), |
1303 | | DataType::Decimal128(20, 3), |
1304 | | DataType::Decimal128(20, 3), |
1305 | | DataType::Decimal128(23, 3), |
1306 | | DataType::Decimal128(24, 7), |
1307 | | DataType::Decimal128(32, 15), |
1308 | | DataType::Decimal128(38, 10), |
1309 | | DataType::Decimal128(25, 8), |
1310 | | DataType::Decimal128(20, 3), |
1311 | | ]; |
1312 | | let comparison_op_types = [ |
1313 | | Operator::NotEq, |
1314 | | Operator::Eq, |
1315 | | Operator::Gt, |
1316 | | Operator::GtEq, |
1317 | | Operator::Lt, |
1318 | | Operator::LtEq, |
1319 | | ]; |
1320 | | for (i, input_type) in input_types.iter().enumerate() { |
1321 | | let expect_type = &result_types[i]; |
1322 | | for op in comparison_op_types { |
1323 | | let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?; |
1324 | | assert_eq!(expect_type, &lhs); |
1325 | | assert_eq!(expect_type, &rhs); |
1326 | | } |
1327 | | } |
1328 | | // negative test |
1329 | | let result_type = |
1330 | | get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean); |
1331 | | assert!(result_type.is_err()); |
1332 | | Ok(()) |
1333 | | } |
1334 | | |
1335 | | #[test] |
1336 | | fn test_decimal_mathematics_op_type() { |
1337 | | assert_eq!( |
1338 | | coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(), |
1339 | | DataType::Decimal128(3, 0) |
1340 | | ); |
1341 | | assert_eq!( |
1342 | | coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(), |
1343 | | DataType::Decimal128(5, 0) |
1344 | | ); |
1345 | | assert_eq!( |
1346 | | coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(), |
1347 | | DataType::Decimal128(10, 0) |
1348 | | ); |
1349 | | assert_eq!( |
1350 | | coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), |
1351 | | DataType::Decimal128(20, 0) |
1352 | | ); |
1353 | | assert_eq!( |
1354 | | coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), |
1355 | | DataType::Decimal128(14, 7) |
1356 | | ); |
1357 | | assert_eq!( |
1358 | | coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), |
1359 | | DataType::Decimal128(30, 15) |
1360 | | ); |
1361 | | } |
1362 | | |
1363 | | #[test] |
1364 | | fn test_dictionary_type_coercion() { |
1365 | | use DataType::*; |
1366 | | |
1367 | | let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); |
1368 | | let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); |
1369 | | assert_eq!( |
1370 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, true), |
1371 | | Some(Int32) |
1372 | | ); |
1373 | | assert_eq!( |
1374 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, false), |
1375 | | Some(Int32) |
1376 | | ); |
1377 | | |
1378 | | // Since we can coerce values of Int16 to Utf8 can support this |
1379 | | let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); |
1380 | | let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); |
1381 | | assert_eq!( |
1382 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, true), |
1383 | | Some(Utf8) |
1384 | | ); |
1385 | | |
1386 | | // Since we can coerce values of Utf8 to Binary can support this |
1387 | | let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); |
1388 | | let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); |
1389 | | assert_eq!( |
1390 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, true), |
1391 | | Some(Binary) |
1392 | | ); |
1393 | | |
1394 | | let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); |
1395 | | let rhs_type = Utf8; |
1396 | | assert_eq!( |
1397 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, false), |
1398 | | Some(Utf8) |
1399 | | ); |
1400 | | assert_eq!( |
1401 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, true), |
1402 | | Some(lhs_type.clone()) |
1403 | | ); |
1404 | | |
1405 | | let lhs_type = Utf8; |
1406 | | let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); |
1407 | | assert_eq!( |
1408 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, false), |
1409 | | Some(Utf8) |
1410 | | ); |
1411 | | assert_eq!( |
1412 | | dictionary_comparison_coercion(&lhs_type, &rhs_type, true), |
1413 | | Some(rhs_type.clone()) |
1414 | | ); |
1415 | | } |
1416 | | |
1417 | | /// Test coercion rules for binary operators |
1418 | | /// |
1419 | | /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the |
1420 | | /// the result type is `$RESULT_TYPE` |
1421 | | macro_rules! test_coercion_binary_rule { |
1422 | | ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ |
1423 | | let (lhs, rhs) = get_input_types(&$LHS_TYPE, &$OP, &$RHS_TYPE)?; |
1424 | | assert_eq!(lhs, $RESULT_TYPE); |
1425 | | assert_eq!(rhs, $RESULT_TYPE); |
1426 | | }}; |
1427 | | } |
1428 | | |
1429 | | /// Test coercion rules for like |
1430 | | /// |
1431 | | /// Applies coercion rules for both |
1432 | | /// * `$LHS_TYPE LIKE $RHS_TYPE` |
1433 | | /// * `$RHS_TYPE LIKE $LHS_TYPE` |
1434 | | /// |
1435 | | /// And asserts the result type is `$RESULT_TYPE` |
1436 | | macro_rules! test_like_rule { |
1437 | | ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{ |
1438 | | println!("Coercing {} LIKE {}", $LHS_TYPE, $RHS_TYPE); |
1439 | | let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE); |
1440 | | assert_eq!(result, $RESULT_TYPE); |
1441 | | // reverse the order |
1442 | | let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE); |
1443 | | assert_eq!(result, $RESULT_TYPE); |
1444 | | }}; |
1445 | | } |
1446 | | |
1447 | | #[test] |
1448 | | fn test_date_timestamp_arithmetic_error() -> Result<()> { |
1449 | | let (lhs, rhs) = get_input_types( |
1450 | | &DataType::Timestamp(TimeUnit::Nanosecond, None), |
1451 | | &Operator::Minus, |
1452 | | &DataType::Timestamp(TimeUnit::Millisecond, None), |
1453 | | )?; |
1454 | | assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)"); |
1455 | | assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)"); |
1456 | | |
1457 | | let err = get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) |
1458 | | .unwrap_err() |
1459 | | .to_string(); |
1460 | | |
1461 | | assert_contains!( |
1462 | | &err, |
1463 | | "Cannot get result type for temporal operation Date64 + Date64" |
1464 | | ); |
1465 | | |
1466 | | Ok(()) |
1467 | | } |
1468 | | |
1469 | | #[test] |
1470 | | fn test_like_coercion() { |
1471 | | // string coerce to strings |
1472 | | test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8)); |
1473 | | test_like_rule!( |
1474 | | DataType::LargeUtf8, |
1475 | | DataType::Utf8, |
1476 | | Some(DataType::LargeUtf8) |
1477 | | ); |
1478 | | test_like_rule!( |
1479 | | DataType::Utf8, |
1480 | | DataType::LargeUtf8, |
1481 | | Some(DataType::LargeUtf8) |
1482 | | ); |
1483 | | test_like_rule!( |
1484 | | DataType::LargeUtf8, |
1485 | | DataType::LargeUtf8, |
1486 | | Some(DataType::LargeUtf8) |
1487 | | ); |
1488 | | |
1489 | | // Also coerce binary to strings |
1490 | | test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8)); |
1491 | | test_like_rule!( |
1492 | | DataType::LargeBinary, |
1493 | | DataType::Utf8, |
1494 | | Some(DataType::LargeUtf8) |
1495 | | ); |
1496 | | test_like_rule!( |
1497 | | DataType::Binary, |
1498 | | DataType::LargeUtf8, |
1499 | | Some(DataType::LargeUtf8) |
1500 | | ); |
1501 | | test_like_rule!( |
1502 | | DataType::LargeBinary, |
1503 | | DataType::LargeUtf8, |
1504 | | Some(DataType::LargeUtf8) |
1505 | | ); |
1506 | | } |
1507 | | |
1508 | | #[test] |
1509 | | fn test_type_coercion() -> Result<()> { |
1510 | | test_coercion_binary_rule!( |
1511 | | DataType::Utf8, |
1512 | | DataType::Date32, |
1513 | | Operator::Eq, |
1514 | | DataType::Date32 |
1515 | | ); |
1516 | | test_coercion_binary_rule!( |
1517 | | DataType::Utf8, |
1518 | | DataType::Date64, |
1519 | | Operator::Lt, |
1520 | | DataType::Date64 |
1521 | | ); |
1522 | | test_coercion_binary_rule!( |
1523 | | DataType::Utf8, |
1524 | | DataType::Time32(TimeUnit::Second), |
1525 | | Operator::Eq, |
1526 | | DataType::Time32(TimeUnit::Second) |
1527 | | ); |
1528 | | test_coercion_binary_rule!( |
1529 | | DataType::Utf8, |
1530 | | DataType::Time32(TimeUnit::Millisecond), |
1531 | | Operator::Eq, |
1532 | | DataType::Time32(TimeUnit::Millisecond) |
1533 | | ); |
1534 | | test_coercion_binary_rule!( |
1535 | | DataType::Utf8, |
1536 | | DataType::Time64(TimeUnit::Microsecond), |
1537 | | Operator::Eq, |
1538 | | DataType::Time64(TimeUnit::Microsecond) |
1539 | | ); |
1540 | | test_coercion_binary_rule!( |
1541 | | DataType::Utf8, |
1542 | | DataType::Time64(TimeUnit::Nanosecond), |
1543 | | Operator::Eq, |
1544 | | DataType::Time64(TimeUnit::Nanosecond) |
1545 | | ); |
1546 | | test_coercion_binary_rule!( |
1547 | | DataType::Utf8, |
1548 | | DataType::Timestamp(TimeUnit::Second, None), |
1549 | | Operator::Lt, |
1550 | | DataType::Timestamp(TimeUnit::Nanosecond, None) |
1551 | | ); |
1552 | | test_coercion_binary_rule!( |
1553 | | DataType::Utf8, |
1554 | | DataType::Timestamp(TimeUnit::Millisecond, None), |
1555 | | Operator::Lt, |
1556 | | DataType::Timestamp(TimeUnit::Nanosecond, None) |
1557 | | ); |
1558 | | test_coercion_binary_rule!( |
1559 | | DataType::Utf8, |
1560 | | DataType::Timestamp(TimeUnit::Microsecond, None), |
1561 | | Operator::Lt, |
1562 | | DataType::Timestamp(TimeUnit::Nanosecond, None) |
1563 | | ); |
1564 | | test_coercion_binary_rule!( |
1565 | | DataType::Utf8, |
1566 | | DataType::Timestamp(TimeUnit::Nanosecond, None), |
1567 | | Operator::Lt, |
1568 | | DataType::Timestamp(TimeUnit::Nanosecond, None) |
1569 | | ); |
1570 | | test_coercion_binary_rule!( |
1571 | | DataType::Utf8, |
1572 | | DataType::Utf8, |
1573 | | Operator::RegexMatch, |
1574 | | DataType::Utf8 |
1575 | | ); |
1576 | | test_coercion_binary_rule!( |
1577 | | DataType::Utf8, |
1578 | | DataType::Utf8, |
1579 | | Operator::RegexNotMatch, |
1580 | | DataType::Utf8 |
1581 | | ); |
1582 | | test_coercion_binary_rule!( |
1583 | | DataType::Utf8, |
1584 | | DataType::Utf8, |
1585 | | Operator::RegexNotIMatch, |
1586 | | DataType::Utf8 |
1587 | | ); |
1588 | | test_coercion_binary_rule!( |
1589 | | DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), |
1590 | | DataType::Utf8, |
1591 | | Operator::RegexMatch, |
1592 | | DataType::Utf8 |
1593 | | ); |
1594 | | test_coercion_binary_rule!( |
1595 | | DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), |
1596 | | DataType::Utf8, |
1597 | | Operator::RegexIMatch, |
1598 | | DataType::Utf8 |
1599 | | ); |
1600 | | test_coercion_binary_rule!( |
1601 | | DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), |
1602 | | DataType::Utf8, |
1603 | | Operator::RegexNotMatch, |
1604 | | DataType::Utf8 |
1605 | | ); |
1606 | | test_coercion_binary_rule!( |
1607 | | DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), |
1608 | | DataType::Utf8, |
1609 | | Operator::RegexNotIMatch, |
1610 | | DataType::Utf8 |
1611 | | ); |
1612 | | test_coercion_binary_rule!( |
1613 | | DataType::Int16, |
1614 | | DataType::Int64, |
1615 | | Operator::BitwiseAnd, |
1616 | | DataType::Int64 |
1617 | | ); |
1618 | | test_coercion_binary_rule!( |
1619 | | DataType::UInt64, |
1620 | | DataType::UInt64, |
1621 | | Operator::BitwiseAnd, |
1622 | | DataType::UInt64 |
1623 | | ); |
1624 | | test_coercion_binary_rule!( |
1625 | | DataType::Int8, |
1626 | | DataType::UInt32, |
1627 | | Operator::BitwiseAnd, |
1628 | | DataType::Int64 |
1629 | | ); |
1630 | | test_coercion_binary_rule!( |
1631 | | DataType::UInt32, |
1632 | | DataType::Int32, |
1633 | | Operator::BitwiseAnd, |
1634 | | DataType::Int64 |
1635 | | ); |
1636 | | test_coercion_binary_rule!( |
1637 | | DataType::UInt16, |
1638 | | DataType::Int16, |
1639 | | Operator::BitwiseAnd, |
1640 | | DataType::Int32 |
1641 | | ); |
1642 | | test_coercion_binary_rule!( |
1643 | | DataType::UInt32, |
1644 | | DataType::UInt32, |
1645 | | Operator::BitwiseAnd, |
1646 | | DataType::UInt32 |
1647 | | ); |
1648 | | test_coercion_binary_rule!( |
1649 | | DataType::UInt16, |
1650 | | DataType::UInt32, |
1651 | | Operator::BitwiseAnd, |
1652 | | DataType::UInt32 |
1653 | | ); |
1654 | | Ok(()) |
1655 | | } |
1656 | | |
1657 | | #[test] |
1658 | | fn test_type_coercion_arithmetic() -> Result<()> { |
1659 | | // integer |
1660 | | test_coercion_binary_rule!( |
1661 | | DataType::Int32, |
1662 | | DataType::UInt32, |
1663 | | Operator::Plus, |
1664 | | DataType::Int32 |
1665 | | ); |
1666 | | test_coercion_binary_rule!( |
1667 | | DataType::Int32, |
1668 | | DataType::UInt16, |
1669 | | Operator::Minus, |
1670 | | DataType::Int32 |
1671 | | ); |
1672 | | test_coercion_binary_rule!( |
1673 | | DataType::Int8, |
1674 | | DataType::Int64, |
1675 | | Operator::Multiply, |
1676 | | DataType::Int64 |
1677 | | ); |
1678 | | // float |
1679 | | test_coercion_binary_rule!( |
1680 | | DataType::Float32, |
1681 | | DataType::Int32, |
1682 | | Operator::Plus, |
1683 | | DataType::Float32 |
1684 | | ); |
1685 | | test_coercion_binary_rule!( |
1686 | | DataType::Float32, |
1687 | | DataType::Float64, |
1688 | | Operator::Multiply, |
1689 | | DataType::Float64 |
1690 | | ); |
1691 | | // TODO add other data type |
1692 | | Ok(()) |
1693 | | } |
1694 | | |
1695 | | fn test_math_decimal_coercion_rule( |
1696 | | lhs_type: DataType, |
1697 | | rhs_type: DataType, |
1698 | | expected_lhs_type: DataType, |
1699 | | expected_rhs_type: DataType, |
1700 | | ) { |
1701 | | // The coerced types for lhs and rhs, if any of them is not decimal |
1702 | | let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); |
1703 | | assert_eq!(lhs_type, expected_lhs_type); |
1704 | | assert_eq!(rhs_type, expected_rhs_type); |
1705 | | } |
1706 | | |
1707 | | #[test] |
1708 | | fn test_coercion_arithmetic_decimal() -> Result<()> { |
1709 | | test_math_decimal_coercion_rule( |
1710 | | DataType::Decimal128(10, 2), |
1711 | | DataType::Decimal128(10, 2), |
1712 | | DataType::Decimal128(10, 2), |
1713 | | DataType::Decimal128(10, 2), |
1714 | | ); |
1715 | | |
1716 | | test_math_decimal_coercion_rule( |
1717 | | DataType::Int32, |
1718 | | DataType::Decimal128(10, 2), |
1719 | | DataType::Decimal128(10, 0), |
1720 | | DataType::Decimal128(10, 2), |
1721 | | ); |
1722 | | |
1723 | | test_math_decimal_coercion_rule( |
1724 | | DataType::Int32, |
1725 | | DataType::Decimal128(10, 2), |
1726 | | DataType::Decimal128(10, 0), |
1727 | | DataType::Decimal128(10, 2), |
1728 | | ); |
1729 | | |
1730 | | test_math_decimal_coercion_rule( |
1731 | | DataType::Int32, |
1732 | | DataType::Decimal128(10, 2), |
1733 | | DataType::Decimal128(10, 0), |
1734 | | DataType::Decimal128(10, 2), |
1735 | | ); |
1736 | | |
1737 | | test_math_decimal_coercion_rule( |
1738 | | DataType::Int32, |
1739 | | DataType::Decimal128(10, 2), |
1740 | | DataType::Decimal128(10, 0), |
1741 | | DataType::Decimal128(10, 2), |
1742 | | ); |
1743 | | |
1744 | | test_math_decimal_coercion_rule( |
1745 | | DataType::Int32, |
1746 | | DataType::Decimal128(10, 2), |
1747 | | DataType::Decimal128(10, 0), |
1748 | | DataType::Decimal128(10, 2), |
1749 | | ); |
1750 | | |
1751 | | Ok(()) |
1752 | | } |
1753 | | |
1754 | | #[test] |
1755 | | fn test_type_coercion_compare() -> Result<()> { |
1756 | | // boolean |
1757 | | test_coercion_binary_rule!( |
1758 | | DataType::Boolean, |
1759 | | DataType::Boolean, |
1760 | | Operator::Eq, |
1761 | | DataType::Boolean |
1762 | | ); |
1763 | | // float |
1764 | | test_coercion_binary_rule!( |
1765 | | DataType::Float32, |
1766 | | DataType::Int64, |
1767 | | Operator::Eq, |
1768 | | DataType::Float32 |
1769 | | ); |
1770 | | test_coercion_binary_rule!( |
1771 | | DataType::Float32, |
1772 | | DataType::Float64, |
1773 | | Operator::GtEq, |
1774 | | DataType::Float64 |
1775 | | ); |
1776 | | // signed integer |
1777 | | test_coercion_binary_rule!( |
1778 | | DataType::Int8, |
1779 | | DataType::Int32, |
1780 | | Operator::LtEq, |
1781 | | DataType::Int32 |
1782 | | ); |
1783 | | test_coercion_binary_rule!( |
1784 | | DataType::Int64, |
1785 | | DataType::Int32, |
1786 | | Operator::LtEq, |
1787 | | DataType::Int64 |
1788 | | ); |
1789 | | // unsigned integer |
1790 | | test_coercion_binary_rule!( |
1791 | | DataType::UInt32, |
1792 | | DataType::UInt8, |
1793 | | Operator::Gt, |
1794 | | DataType::UInt32 |
1795 | | ); |
1796 | | // numeric/decimal |
1797 | | test_coercion_binary_rule!( |
1798 | | DataType::Int64, |
1799 | | DataType::Decimal128(10, 0), |
1800 | | Operator::Eq, |
1801 | | DataType::Decimal128(20, 0) |
1802 | | ); |
1803 | | test_coercion_binary_rule!( |
1804 | | DataType::Int64, |
1805 | | DataType::Decimal128(10, 2), |
1806 | | Operator::Lt, |
1807 | | DataType::Decimal128(22, 2) |
1808 | | ); |
1809 | | test_coercion_binary_rule!( |
1810 | | DataType::Float64, |
1811 | | DataType::Decimal128(10, 3), |
1812 | | Operator::Gt, |
1813 | | DataType::Decimal128(30, 15) |
1814 | | ); |
1815 | | test_coercion_binary_rule!( |
1816 | | DataType::Int64, |
1817 | | DataType::Decimal128(10, 0), |
1818 | | Operator::Eq, |
1819 | | DataType::Decimal128(20, 0) |
1820 | | ); |
1821 | | test_coercion_binary_rule!( |
1822 | | DataType::Decimal128(14, 2), |
1823 | | DataType::Decimal128(10, 3), |
1824 | | Operator::GtEq, |
1825 | | DataType::Decimal128(15, 3) |
1826 | | ); |
1827 | | |
1828 | | // Binary |
1829 | | test_coercion_binary_rule!( |
1830 | | DataType::Binary, |
1831 | | DataType::Binary, |
1832 | | Operator::Eq, |
1833 | | DataType::Binary |
1834 | | ); |
1835 | | test_coercion_binary_rule!( |
1836 | | DataType::Utf8, |
1837 | | DataType::Binary, |
1838 | | Operator::Eq, |
1839 | | DataType::Binary |
1840 | | ); |
1841 | | test_coercion_binary_rule!( |
1842 | | DataType::Binary, |
1843 | | DataType::Utf8, |
1844 | | Operator::Eq, |
1845 | | DataType::Binary |
1846 | | ); |
1847 | | |
1848 | | // LargeBinary |
1849 | | test_coercion_binary_rule!( |
1850 | | DataType::LargeBinary, |
1851 | | DataType::LargeBinary, |
1852 | | Operator::Eq, |
1853 | | DataType::LargeBinary |
1854 | | ); |
1855 | | test_coercion_binary_rule!( |
1856 | | DataType::Binary, |
1857 | | DataType::LargeBinary, |
1858 | | Operator::Eq, |
1859 | | DataType::LargeBinary |
1860 | | ); |
1861 | | test_coercion_binary_rule!( |
1862 | | DataType::LargeBinary, |
1863 | | DataType::Binary, |
1864 | | Operator::Eq, |
1865 | | DataType::LargeBinary |
1866 | | ); |
1867 | | test_coercion_binary_rule!( |
1868 | | DataType::Utf8, |
1869 | | DataType::LargeBinary, |
1870 | | Operator::Eq, |
1871 | | DataType::LargeBinary |
1872 | | ); |
1873 | | test_coercion_binary_rule!( |
1874 | | DataType::LargeBinary, |
1875 | | DataType::Utf8, |
1876 | | Operator::Eq, |
1877 | | DataType::LargeBinary |
1878 | | ); |
1879 | | test_coercion_binary_rule!( |
1880 | | DataType::LargeUtf8, |
1881 | | DataType::LargeBinary, |
1882 | | Operator::Eq, |
1883 | | DataType::LargeBinary |
1884 | | ); |
1885 | | test_coercion_binary_rule!( |
1886 | | DataType::LargeBinary, |
1887 | | DataType::LargeUtf8, |
1888 | | Operator::Eq, |
1889 | | DataType::LargeBinary |
1890 | | ); |
1891 | | |
1892 | | // Timestamps |
1893 | | let utc: Option<Arc<str>> = Some("UTC".into()); |
1894 | | test_coercion_binary_rule!( |
1895 | | DataType::Timestamp(TimeUnit::Second, utc.clone()), |
1896 | | DataType::Timestamp(TimeUnit::Second, utc.clone()), |
1897 | | Operator::Eq, |
1898 | | DataType::Timestamp(TimeUnit::Second, utc.clone()) |
1899 | | ); |
1900 | | test_coercion_binary_rule!( |
1901 | | DataType::Timestamp(TimeUnit::Second, utc.clone()), |
1902 | | DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), |
1903 | | Operator::Eq, |
1904 | | DataType::Timestamp(TimeUnit::Second, utc.clone()) |
1905 | | ); |
1906 | | test_coercion_binary_rule!( |
1907 | | DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), |
1908 | | DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), |
1909 | | Operator::Eq, |
1910 | | DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())) |
1911 | | ); |
1912 | | test_coercion_binary_rule!( |
1913 | | DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), |
1914 | | DataType::Timestamp(TimeUnit::Second, utc), |
1915 | | Operator::Eq, |
1916 | | DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())) |
1917 | | ); |
1918 | | |
1919 | | // list |
1920 | | let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); |
1921 | | test_coercion_binary_rule!( |
1922 | | DataType::List(Arc::clone(&inner_field)), |
1923 | | DataType::List(Arc::clone(&inner_field)), |
1924 | | Operator::Eq, |
1925 | | DataType::List(Arc::clone(&inner_field)) |
1926 | | ); |
1927 | | test_coercion_binary_rule!( |
1928 | | DataType::List(Arc::clone(&inner_field)), |
1929 | | DataType::LargeList(Arc::clone(&inner_field)), |
1930 | | Operator::Eq, |
1931 | | DataType::LargeList(Arc::clone(&inner_field)) |
1932 | | ); |
1933 | | test_coercion_binary_rule!( |
1934 | | DataType::LargeList(Arc::clone(&inner_field)), |
1935 | | DataType::List(Arc::clone(&inner_field)), |
1936 | | Operator::Eq, |
1937 | | DataType::LargeList(Arc::clone(&inner_field)) |
1938 | | ); |
1939 | | test_coercion_binary_rule!( |
1940 | | DataType::LargeList(Arc::clone(&inner_field)), |
1941 | | DataType::LargeList(Arc::clone(&inner_field)), |
1942 | | Operator::Eq, |
1943 | | DataType::LargeList(Arc::clone(&inner_field)) |
1944 | | ); |
1945 | | test_coercion_binary_rule!( |
1946 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1947 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1948 | | Operator::Eq, |
1949 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10) |
1950 | | ); |
1951 | | test_coercion_binary_rule!( |
1952 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1953 | | DataType::LargeList(Arc::clone(&inner_field)), |
1954 | | Operator::Eq, |
1955 | | DataType::LargeList(Arc::clone(&inner_field)) |
1956 | | ); |
1957 | | test_coercion_binary_rule!( |
1958 | | DataType::LargeList(Arc::clone(&inner_field)), |
1959 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1960 | | Operator::Eq, |
1961 | | DataType::LargeList(Arc::clone(&inner_field)) |
1962 | | ); |
1963 | | test_coercion_binary_rule!( |
1964 | | DataType::List(Arc::clone(&inner_field)), |
1965 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1966 | | Operator::Eq, |
1967 | | DataType::List(Arc::clone(&inner_field)) |
1968 | | ); |
1969 | | test_coercion_binary_rule!( |
1970 | | DataType::FixedSizeList(Arc::clone(&inner_field), 10), |
1971 | | DataType::List(Arc::clone(&inner_field)), |
1972 | | Operator::Eq, |
1973 | | DataType::List(Arc::clone(&inner_field)) |
1974 | | ); |
1975 | | |
1976 | | // TODO add other data type |
1977 | | Ok(()) |
1978 | | } |
1979 | | |
1980 | | #[test] |
1981 | | fn test_type_coercion_logical_op() -> Result<()> { |
1982 | | test_coercion_binary_rule!( |
1983 | | DataType::Boolean, |
1984 | | DataType::Boolean, |
1985 | | Operator::And, |
1986 | | DataType::Boolean |
1987 | | ); |
1988 | | |
1989 | | test_coercion_binary_rule!( |
1990 | | DataType::Boolean, |
1991 | | DataType::Boolean, |
1992 | | Operator::Or, |
1993 | | DataType::Boolean |
1994 | | ); |
1995 | | test_coercion_binary_rule!( |
1996 | | DataType::Boolean, |
1997 | | DataType::Null, |
1998 | | Operator::And, |
1999 | | DataType::Boolean |
2000 | | ); |
2001 | | test_coercion_binary_rule!( |
2002 | | DataType::Boolean, |
2003 | | DataType::Null, |
2004 | | Operator::Or, |
2005 | | DataType::Boolean |
2006 | | ); |
2007 | | test_coercion_binary_rule!( |
2008 | | DataType::Null, |
2009 | | DataType::Null, |
2010 | | Operator::Or, |
2011 | | DataType::Boolean |
2012 | | ); |
2013 | | test_coercion_binary_rule!( |
2014 | | DataType::Null, |
2015 | | DataType::Null, |
2016 | | Operator::And, |
2017 | | DataType::Boolean |
2018 | | ); |
2019 | | test_coercion_binary_rule!( |
2020 | | DataType::Null, |
2021 | | DataType::Boolean, |
2022 | | Operator::And, |
2023 | | DataType::Boolean |
2024 | | ); |
2025 | | test_coercion_binary_rule!( |
2026 | | DataType::Null, |
2027 | | DataType::Boolean, |
2028 | | Operator::Or, |
2029 | | DataType::Boolean |
2030 | | ); |
2031 | | Ok(()) |
2032 | | } |
2033 | | } |