/Users/andrewlamb/Software/datafusion/datafusion/expr/src/type_coercion/functions.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use super::binary::{binary_numeric_coercion, comparison_coercion}; |
19 | | use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; |
20 | | use arrow::{ |
21 | | compute::can_cast_types, |
22 | | datatypes::{DataType, TimeUnit}, |
23 | | }; |
24 | | use datafusion_common::{ |
25 | | exec_err, internal_datafusion_err, internal_err, plan_err, |
26 | | utils::{coerced_fixed_size_list_to_list, list_ndims}, |
27 | | Result, |
28 | | }; |
29 | | use datafusion_expr_common::signature::{ |
30 | | ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, |
31 | | }; |
32 | | use std::sync::Arc; |
33 | | |
34 | | /// Performs type coercion for scalar function arguments. |
35 | | /// |
36 | | /// Returns the data types to which each argument must be coerced to |
37 | | /// match `signature`. |
38 | | /// |
39 | | /// For more details on coercion in general, please see the |
40 | | /// [`type_coercion`](crate::type_coercion) module. |
41 | 0 | pub fn data_types_with_scalar_udf( |
42 | 0 | current_types: &[DataType], |
43 | 0 | func: &ScalarUDF, |
44 | 0 | ) -> Result<Vec<DataType>> { |
45 | 0 | let signature = func.signature(); |
46 | 0 |
|
47 | 0 | if current_types.is_empty() { |
48 | 0 | if signature.type_signature.supports_zero_argument() { |
49 | 0 | return Ok(vec![]); |
50 | | } else { |
51 | 0 | return plan_err!("{} does not support zero arguments.", func.name()); |
52 | | } |
53 | 0 | } |
54 | | |
55 | 0 | let valid_types = |
56 | 0 | get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; |
57 | | |
58 | 0 | if valid_types |
59 | 0 | .iter() |
60 | 0 | .any(|data_type| data_type == current_types) |
61 | | { |
62 | 0 | return Ok(current_types.to_vec()); |
63 | 0 | } |
64 | 0 |
|
65 | 0 | try_coerce_types(valid_types, current_types, &signature.type_signature) |
66 | 0 | } |
67 | | |
68 | | /// Performs type coercion for aggregate function arguments. |
69 | | /// |
70 | | /// Returns the data types to which each argument must be coerced to |
71 | | /// match `signature`. |
72 | | /// |
73 | | /// For more details on coercion in general, please see the |
74 | | /// [`type_coercion`](crate::type_coercion) module. |
75 | 0 | pub fn data_types_with_aggregate_udf( |
76 | 0 | current_types: &[DataType], |
77 | 0 | func: &AggregateUDF, |
78 | 0 | ) -> Result<Vec<DataType>> { |
79 | 0 | let signature = func.signature(); |
80 | 0 |
|
81 | 0 | if current_types.is_empty() { |
82 | 0 | if signature.type_signature.supports_zero_argument() { |
83 | 0 | return Ok(vec![]); |
84 | | } else { |
85 | 0 | return plan_err!("{} does not support zero arguments.", func.name()); |
86 | | } |
87 | 0 | } |
88 | | |
89 | 0 | let valid_types = get_valid_types_with_aggregate_udf( |
90 | 0 | &signature.type_signature, |
91 | 0 | current_types, |
92 | 0 | func, |
93 | 0 | )?; |
94 | 0 | if valid_types |
95 | 0 | .iter() |
96 | 0 | .any(|data_type| data_type == current_types) |
97 | | { |
98 | 0 | return Ok(current_types.to_vec()); |
99 | 0 | } |
100 | 0 |
|
101 | 0 | try_coerce_types(valid_types, current_types, &signature.type_signature) |
102 | 0 | } |
103 | | |
104 | | /// Performs type coercion for window function arguments. |
105 | | /// |
106 | | /// Returns the data types to which each argument must be coerced to |
107 | | /// match `signature`. |
108 | | /// |
109 | | /// For more details on coercion in general, please see the |
110 | | /// [`type_coercion`](crate::type_coercion) module. |
111 | 0 | pub fn data_types_with_window_udf( |
112 | 0 | current_types: &[DataType], |
113 | 0 | func: &WindowUDF, |
114 | 0 | ) -> Result<Vec<DataType>> { |
115 | 0 | let signature = func.signature(); |
116 | 0 |
|
117 | 0 | if current_types.is_empty() { |
118 | 0 | if signature.type_signature.supports_zero_argument() { |
119 | 0 | return Ok(vec![]); |
120 | | } else { |
121 | 0 | return plan_err!("{} does not support zero arguments.", func.name()); |
122 | | } |
123 | 0 | } |
124 | | |
125 | 0 | let valid_types = |
126 | 0 | get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; |
127 | 0 | if valid_types |
128 | 0 | .iter() |
129 | 0 | .any(|data_type| data_type == current_types) |
130 | | { |
131 | 0 | return Ok(current_types.to_vec()); |
132 | 0 | } |
133 | 0 |
|
134 | 0 | try_coerce_types(valid_types, current_types, &signature.type_signature) |
135 | 0 | } |
136 | | |
137 | | /// Performs type coercion for function arguments. |
138 | | /// |
139 | | /// Returns the data types to which each argument must be coerced to |
140 | | /// match `signature`. |
141 | | /// |
142 | | /// For more details on coercion in general, please see the |
143 | | /// [`type_coercion`](crate::type_coercion) module. |
144 | 0 | pub fn data_types( |
145 | 0 | current_types: &[DataType], |
146 | 0 | signature: &Signature, |
147 | 0 | ) -> Result<Vec<DataType>> { |
148 | 0 | if current_types.is_empty() { |
149 | 0 | if signature.type_signature.supports_zero_argument() { |
150 | 0 | return Ok(vec![]); |
151 | | } else { |
152 | 0 | return plan_err!( |
153 | 0 | "signature {:?} does not support zero arguments.", |
154 | 0 | &signature.type_signature |
155 | 0 | ); |
156 | | } |
157 | 0 | } |
158 | | |
159 | 0 | let valid_types = get_valid_types(&signature.type_signature, current_types)?; |
160 | 0 | if valid_types |
161 | 0 | .iter() |
162 | 0 | .any(|data_type| data_type == current_types) |
163 | | { |
164 | 0 | return Ok(current_types.to_vec()); |
165 | 0 | } |
166 | 0 |
|
167 | 0 | try_coerce_types(valid_types, current_types, &signature.type_signature) |
168 | 0 | } |
169 | | |
170 | 0 | fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { |
171 | 0 | if let TypeSignature::OneOf(signatures) = type_signature { |
172 | 0 | return signatures.iter().all(is_well_supported_signature); |
173 | 0 | } |
174 | | |
175 | 0 | matches!( |
176 | 0 | type_signature, |
177 | | TypeSignature::UserDefined |
178 | | | TypeSignature::Numeric(_) |
179 | | | TypeSignature::Coercible(_) |
180 | | | TypeSignature::Any(_) |
181 | | ) |
182 | 0 | } |
183 | | |
184 | 0 | fn try_coerce_types( |
185 | 0 | valid_types: Vec<Vec<DataType>>, |
186 | 0 | current_types: &[DataType], |
187 | 0 | type_signature: &TypeSignature, |
188 | 0 | ) -> Result<Vec<DataType>> { |
189 | 0 | let mut valid_types = valid_types; |
190 | 0 |
|
191 | 0 | // Well-supported signature that returns exact valid types. |
192 | 0 | if !valid_types.is_empty() && is_well_supported_signature(type_signature) { |
193 | | // exact valid types |
194 | 0 | assert_eq!(valid_types.len(), 1); |
195 | 0 | let valid_types = valid_types.swap_remove(0); |
196 | 0 | if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) { |
197 | 0 | return Ok(t); |
198 | 0 | } |
199 | | } else { |
200 | | // Try and coerce the argument types to match the signature, returning the |
201 | | // coerced types from the first matching signature. |
202 | 0 | for valid_types in valid_types { |
203 | 0 | if let Some(types) = maybe_data_types(&valid_types, current_types) { |
204 | 0 | return Ok(types); |
205 | 0 | } |
206 | | } |
207 | | } |
208 | | |
209 | | // none possible -> Error |
210 | 0 | plan_err!( |
211 | 0 | "Coercion from {:?} to the signature {:?} failed.", |
212 | 0 | current_types, |
213 | 0 | type_signature |
214 | 0 | ) |
215 | 0 | } |
216 | | |
217 | 0 | fn get_valid_types_with_scalar_udf( |
218 | 0 | signature: &TypeSignature, |
219 | 0 | current_types: &[DataType], |
220 | 0 | func: &ScalarUDF, |
221 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
222 | 0 | let valid_types = match signature { |
223 | 0 | TypeSignature::UserDefined => match func.coerce_types(current_types) { |
224 | 0 | Ok(coerced_types) => vec![coerced_types], |
225 | 0 | Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), |
226 | | }, |
227 | 0 | TypeSignature::OneOf(signatures) => signatures |
228 | 0 | .iter() |
229 | 0 | .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) |
230 | 0 | .flatten() |
231 | 0 | .collect::<Vec<_>>(), |
232 | 0 | _ => get_valid_types(signature, current_types)?, |
233 | | }; |
234 | | |
235 | 0 | Ok(valid_types) |
236 | 0 | } |
237 | | |
238 | 0 | fn get_valid_types_with_aggregate_udf( |
239 | 0 | signature: &TypeSignature, |
240 | 0 | current_types: &[DataType], |
241 | 0 | func: &AggregateUDF, |
242 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
243 | 0 | let valid_types = match signature { |
244 | 0 | TypeSignature::UserDefined => match func.coerce_types(current_types) { |
245 | 0 | Ok(coerced_types) => vec![coerced_types], |
246 | 0 | Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), |
247 | | }, |
248 | 0 | TypeSignature::OneOf(signatures) => signatures |
249 | 0 | .iter() |
250 | 0 | .filter_map(|t| { |
251 | 0 | get_valid_types_with_aggregate_udf(t, current_types, func).ok() |
252 | 0 | }) |
253 | 0 | .flatten() |
254 | 0 | .collect::<Vec<_>>(), |
255 | 0 | _ => get_valid_types(signature, current_types)?, |
256 | | }; |
257 | | |
258 | 0 | Ok(valid_types) |
259 | 0 | } |
260 | | |
261 | 0 | fn get_valid_types_with_window_udf( |
262 | 0 | signature: &TypeSignature, |
263 | 0 | current_types: &[DataType], |
264 | 0 | func: &WindowUDF, |
265 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
266 | 0 | let valid_types = match signature { |
267 | 0 | TypeSignature::UserDefined => match func.coerce_types(current_types) { |
268 | 0 | Ok(coerced_types) => vec![coerced_types], |
269 | 0 | Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), |
270 | | }, |
271 | 0 | TypeSignature::OneOf(signatures) => signatures |
272 | 0 | .iter() |
273 | 0 | .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok()) |
274 | 0 | .flatten() |
275 | 0 | .collect::<Vec<_>>(), |
276 | 0 | _ => get_valid_types(signature, current_types)?, |
277 | | }; |
278 | | |
279 | 0 | Ok(valid_types) |
280 | 0 | } |
281 | | |
282 | | /// Returns a Vec of all possible valid argument types for the given signature. |
283 | 0 | fn get_valid_types( |
284 | 0 | signature: &TypeSignature, |
285 | 0 | current_types: &[DataType], |
286 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
287 | 0 | fn array_element_and_optional_index( |
288 | 0 | current_types: &[DataType], |
289 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
290 | 0 | // make sure there's 2 or 3 arguments |
291 | 0 | if !(current_types.len() == 2 || current_types.len() == 3) { |
292 | 0 | return Ok(vec![vec![]]); |
293 | 0 | } |
294 | 0 |
|
295 | 0 | let first_two_types = ¤t_types[0..2]; |
296 | 0 | let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?; |
297 | | |
298 | | // Early return if there are only 2 arguments |
299 | 0 | if current_types.len() == 2 { |
300 | 0 | return Ok(valid_types); |
301 | 0 | } |
302 | 0 |
|
303 | 0 | let valid_types_with_index = valid_types |
304 | 0 | .iter() |
305 | 0 | .map(|t| { |
306 | 0 | let mut t = t.clone(); |
307 | 0 | t.push(DataType::Int64); |
308 | 0 | t |
309 | 0 | }) |
310 | 0 | .collect::<Vec<_>>(); |
311 | 0 |
|
312 | 0 | valid_types.extend(valid_types_with_index); |
313 | 0 |
|
314 | 0 | Ok(valid_types) |
315 | 0 | } |
316 | | |
317 | 0 | fn array_append_or_prepend_valid_types( |
318 | 0 | current_types: &[DataType], |
319 | 0 | is_append: bool, |
320 | 0 | ) -> Result<Vec<Vec<DataType>>> { |
321 | 0 | if current_types.len() != 2 { |
322 | 0 | return Ok(vec![vec![]]); |
323 | 0 | } |
324 | | |
325 | 0 | let (array_type, elem_type) = if is_append { |
326 | 0 | (¤t_types[0], ¤t_types[1]) |
327 | | } else { |
328 | 0 | (¤t_types[1], ¤t_types[0]) |
329 | | }; |
330 | | |
331 | | // We follow Postgres on `array_append(Null, T)`, which is not valid. |
332 | 0 | if array_type.eq(&DataType::Null) { |
333 | 0 | return Ok(vec![vec![]]); |
334 | 0 | } |
335 | 0 |
|
336 | 0 | // We need to find the coerced base type, mainly for cases like: |
337 | 0 | // `array_append(List(null), i64)` -> `List(i64)` |
338 | 0 | let array_base_type = datafusion_common::utils::base_type(array_type); |
339 | 0 | let elem_base_type = datafusion_common::utils::base_type(elem_type); |
340 | 0 | let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); |
341 | | |
342 | 0 | let new_base_type = new_base_type.ok_or_else(|| { |
343 | 0 | internal_datafusion_err!( |
344 | 0 | "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." |
345 | 0 | ) |
346 | 0 | })?; |
347 | | |
348 | 0 | let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( |
349 | 0 | array_type, |
350 | 0 | &new_base_type, |
351 | 0 | ); |
352 | 0 |
|
353 | 0 | match new_array_type { |
354 | 0 | DataType::List(ref field) |
355 | 0 | | DataType::LargeList(ref field) |
356 | 0 | | DataType::FixedSizeList(ref field, _) => { |
357 | 0 | let new_elem_type = field.data_type(); |
358 | 0 | if is_append { |
359 | 0 | Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]]) |
360 | | } else { |
361 | 0 | Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]]) |
362 | | } |
363 | | } |
364 | 0 | _ => Ok(vec![vec![]]), |
365 | | } |
366 | 0 | } |
367 | 0 | fn array(array_type: &DataType) -> Option<DataType> { |
368 | 0 | match array_type { |
369 | | DataType::List(_) |
370 | | | DataType::LargeList(_) |
371 | | | DataType::FixedSizeList(_, _) => { |
372 | 0 | let array_type = coerced_fixed_size_list_to_list(array_type); |
373 | 0 | Some(array_type) |
374 | | } |
375 | 0 | _ => None, |
376 | | } |
377 | 0 | } |
378 | | |
379 | 0 | let valid_types = match signature { |
380 | 0 | TypeSignature::Variadic(valid_types) => valid_types |
381 | 0 | .iter() |
382 | 0 | .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) |
383 | 0 | .collect(), |
384 | 0 | TypeSignature::Numeric(number) => { |
385 | 0 | if *number < 1 { |
386 | 0 | return plan_err!( |
387 | 0 | "The signature expected at least one argument but received {}", |
388 | 0 | current_types.len() |
389 | 0 | ); |
390 | 0 | } |
391 | 0 | if *number != current_types.len() { |
392 | 0 | return plan_err!( |
393 | 0 | "The signature expected {} arguments but received {}", |
394 | 0 | number, |
395 | 0 | current_types.len() |
396 | 0 | ); |
397 | 0 | } |
398 | 0 |
|
399 | 0 | let mut valid_type = current_types.first().unwrap().clone(); |
400 | 0 | for t in current_types.iter().skip(1) { |
401 | 0 | if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { |
402 | 0 | valid_type = coerced_type; |
403 | 0 | } else { |
404 | 0 | return plan_err!( |
405 | 0 | "{} and {} are not coercible to a common numeric type", |
406 | 0 | valid_type, |
407 | 0 | t |
408 | 0 | ); |
409 | | } |
410 | | } |
411 | | |
412 | 0 | vec![vec![valid_type; *number]] |
413 | | } |
414 | 0 | TypeSignature::Coercible(target_types) => { |
415 | 0 | if target_types.is_empty() { |
416 | 0 | return plan_err!( |
417 | 0 | "The signature expected at least one argument but received {}", |
418 | 0 | current_types.len() |
419 | 0 | ); |
420 | 0 | } |
421 | 0 | if target_types.len() != current_types.len() { |
422 | 0 | return plan_err!( |
423 | 0 | "The signature expected {} arguments but received {}", |
424 | 0 | target_types.len(), |
425 | 0 | current_types.len() |
426 | 0 | ); |
427 | 0 | } |
428 | | |
429 | 0 | for (data_type, target_type) in current_types.iter().zip(target_types.iter()) |
430 | | { |
431 | 0 | if !can_cast_types(data_type, target_type) { |
432 | 0 | return plan_err!("{data_type} is not coercible to {target_type}"); |
433 | 0 | } |
434 | | } |
435 | | |
436 | 0 | vec![target_types.to_owned()] |
437 | | } |
438 | 0 | TypeSignature::Uniform(number, valid_types) => valid_types |
439 | 0 | .iter() |
440 | 0 | .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) |
441 | 0 | .collect(), |
442 | | TypeSignature::UserDefined => { |
443 | 0 | return internal_err!( |
444 | 0 | "User-defined signature should be handled by function-specific coerce_types." |
445 | 0 | ) |
446 | | } |
447 | | TypeSignature::VariadicAny => { |
448 | 0 | vec![current_types.to_vec()] |
449 | | } |
450 | 0 | TypeSignature::Exact(valid_types) => vec![valid_types.clone()], |
451 | 0 | TypeSignature::ArraySignature(ref function_signature) => match function_signature |
452 | | { |
453 | | ArrayFunctionSignature::ArrayAndElement => { |
454 | 0 | array_append_or_prepend_valid_types(current_types, true)? |
455 | | } |
456 | | ArrayFunctionSignature::ElementAndArray => { |
457 | 0 | array_append_or_prepend_valid_types(current_types, false)? |
458 | | } |
459 | | ArrayFunctionSignature::ArrayAndIndex => { |
460 | 0 | if current_types.len() != 2 { |
461 | 0 | return Ok(vec![vec![]]); |
462 | 0 | } |
463 | 0 | array(¤t_types[0]).map_or_else( |
464 | 0 | || vec![vec![]], |
465 | 0 | |array_type| vec![vec![array_type, DataType::Int64]], |
466 | 0 | ) |
467 | | } |
468 | | ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { |
469 | 0 | array_element_and_optional_index(current_types)? |
470 | | } |
471 | | ArrayFunctionSignature::Array => { |
472 | 0 | if current_types.len() != 1 { |
473 | 0 | return Ok(vec![vec![]]); |
474 | 0 | } |
475 | 0 |
|
476 | 0 | array(¤t_types[0]) |
477 | 0 | .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) |
478 | | } |
479 | | ArrayFunctionSignature::MapArray => { |
480 | 0 | if current_types.len() != 1 { |
481 | 0 | return Ok(vec![vec![]]); |
482 | 0 | } |
483 | 0 |
|
484 | 0 | match ¤t_types[0] { |
485 | 0 | DataType::Map(_, _) => vec![vec![current_types[0].clone()]], |
486 | 0 | _ => vec![vec![]], |
487 | | } |
488 | | } |
489 | | }, |
490 | 0 | TypeSignature::Any(number) => { |
491 | 0 | if current_types.len() != *number { |
492 | 0 | return plan_err!( |
493 | 0 | "The function expected {} arguments but received {}", |
494 | 0 | number, |
495 | 0 | current_types.len() |
496 | 0 | ); |
497 | 0 | } |
498 | 0 | vec![(0..*number).map(|i| current_types[i].clone()).collect()] |
499 | | } |
500 | 0 | TypeSignature::OneOf(types) => types |
501 | 0 | .iter() |
502 | 0 | .filter_map(|t| get_valid_types(t, current_types).ok()) |
503 | 0 | .flatten() |
504 | 0 | .collect::<Vec<_>>(), |
505 | | }; |
506 | | |
507 | 0 | Ok(valid_types) |
508 | 0 | } |
509 | | |
510 | | /// Try to coerce the current argument types to match the given `valid_types`. |
511 | | /// |
512 | | /// For example, if a function `func` accepts arguments of `(int64, int64)`, |
513 | | /// but was called with `(int32, int64)`, this function could match the |
514 | | /// valid_types by coercing the first argument to `int64`, and would return |
515 | | /// `Some([int64, int64])`. |
516 | 0 | fn maybe_data_types( |
517 | 0 | valid_types: &[DataType], |
518 | 0 | current_types: &[DataType], |
519 | 0 | ) -> Option<Vec<DataType>> { |
520 | 0 | if valid_types.len() != current_types.len() { |
521 | 0 | return None; |
522 | 0 | } |
523 | 0 |
|
524 | 0 | let mut new_type = Vec::with_capacity(valid_types.len()); |
525 | 0 | for (i, valid_type) in valid_types.iter().enumerate() { |
526 | 0 | let current_type = ¤t_types[i]; |
527 | 0 |
|
528 | 0 | if current_type == valid_type { |
529 | 0 | new_type.push(current_type.clone()) |
530 | | } else { |
531 | | // attempt to coerce. |
532 | | // TODO: Replace with `can_cast_types` after failing cases are resolved |
533 | | // (they need new signature that returns exactly valid types instead of list of possible valid types). |
534 | 0 | if let Some(coerced_type) = coerced_from(valid_type, current_type) { |
535 | 0 | new_type.push(coerced_type) |
536 | | } else { |
537 | | // not possible |
538 | 0 | return None; |
539 | | } |
540 | | } |
541 | | } |
542 | 0 | Some(new_type) |
543 | 0 | } |
544 | | |
545 | | /// Check if the current argument types can be coerced to match the given `valid_types` |
546 | | /// unlike `maybe_data_types`, this function does not coerce the types. |
547 | | /// TODO: I think this function should replace `maybe_data_types` after signature are well-supported. |
548 | 0 | fn maybe_data_types_without_coercion( |
549 | 0 | valid_types: &[DataType], |
550 | 0 | current_types: &[DataType], |
551 | 0 | ) -> Option<Vec<DataType>> { |
552 | 0 | if valid_types.len() != current_types.len() { |
553 | 0 | return None; |
554 | 0 | } |
555 | 0 |
|
556 | 0 | let mut new_type = Vec::with_capacity(valid_types.len()); |
557 | 0 | for (i, valid_type) in valid_types.iter().enumerate() { |
558 | 0 | let current_type = ¤t_types[i]; |
559 | 0 |
|
560 | 0 | if current_type == valid_type { |
561 | 0 | new_type.push(current_type.clone()) |
562 | 0 | } else if can_cast_types(current_type, valid_type) { |
563 | | // validate the valid type is castable from the current type |
564 | 0 | new_type.push(valid_type.clone()) |
565 | | } else { |
566 | 0 | return None; |
567 | | } |
568 | | } |
569 | 0 | Some(new_type) |
570 | 0 | } |
571 | | |
572 | | /// Return true if a value of type `type_from` can be coerced |
573 | | /// (losslessly converted) into a value of `type_to` |
574 | | /// |
575 | | /// See the module level documentation for more detail on coercion. |
576 | 0 | pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { |
577 | 0 | if type_into == type_from { |
578 | 0 | return true; |
579 | 0 | } |
580 | 0 | if let Some(coerced) = coerced_from(type_into, type_from) { |
581 | 0 | return coerced == *type_into; |
582 | 0 | } |
583 | 0 | false |
584 | 0 | } |
585 | | |
586 | | /// Find the coerced type for the given `type_into` and `type_from`. |
587 | | /// Returns `None` if coercion is not possible. |
588 | | /// |
589 | | /// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. |
590 | | /// |
591 | | /// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. |
592 | 0 | fn coerced_from<'a>( |
593 | 0 | type_into: &'a DataType, |
594 | 0 | type_from: &'a DataType, |
595 | 0 | ) -> Option<DataType> { |
596 | | use self::DataType::*; |
597 | | |
598 | | // match Dictionary first |
599 | 0 | match (type_into, type_from) { |
600 | | // coerced dictionary first |
601 | 0 | (_, Dictionary(_, value_type)) |
602 | 0 | if coerced_from(type_into, value_type).is_some() => |
603 | 0 | { |
604 | 0 | Some(type_into.clone()) |
605 | | } |
606 | 0 | (Dictionary(_, value_type), _) |
607 | 0 | if coerced_from(value_type, type_from).is_some() => |
608 | 0 | { |
609 | 0 | Some(type_into.clone()) |
610 | | } |
611 | | // coerced into type_into |
612 | 0 | (Int8, Null | Int8) => Some(type_into.clone()), |
613 | 0 | (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()), |
614 | 0 | (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()), |
615 | | (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => { |
616 | 0 | Some(type_into.clone()) |
617 | | } |
618 | 0 | (UInt8, Null | UInt8) => Some(type_into.clone()), |
619 | 0 | (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), |
620 | 0 | (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), |
621 | 0 | (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), |
622 | | ( |
623 | | Float32, |
624 | | Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 |
625 | | | Float32, |
626 | 0 | ) => Some(type_into.clone()), |
627 | | ( |
628 | | Float64, |
629 | | Null |
630 | | | Int8 |
631 | | | Int16 |
632 | | | Int32 |
633 | | | Int64 |
634 | | | UInt8 |
635 | | | UInt16 |
636 | | | UInt32 |
637 | | | UInt64 |
638 | | | Float32 |
639 | | | Float64 |
640 | | | Decimal128(_, _), |
641 | 0 | ) => Some(type_into.clone()), |
642 | | ( |
643 | | Timestamp(TimeUnit::Nanosecond, None), |
644 | | Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, |
645 | 0 | ) => Some(type_into.clone()), |
646 | 0 | (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), |
647 | | // We can go into a Utf8View from a Utf8 or LargeUtf8 |
648 | 0 | (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), |
649 | | // Any type can be coerced into strings |
650 | 0 | (Utf8 | LargeUtf8, _) => Some(type_into.clone()), |
651 | 0 | (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), |
652 | | |
653 | 0 | (List(_), FixedSizeList(_, _)) => Some(type_into.clone()), |
654 | | |
655 | | // Only accept list and largelist with the same number of dimensions unless the type is Null. |
656 | | // List or LargeList with different dimensions should be handled in TypeSignature or other places before this |
657 | | (List(_) | LargeList(_), _) |
658 | 0 | if datafusion_common::utils::base_type(type_from).eq(&Null) |
659 | 0 | || list_ndims(type_from) == list_ndims(type_into) => |
660 | | { |
661 | 0 | Some(type_into.clone()) |
662 | | } |
663 | | // should be able to coerce wildcard fixed size list to non wildcard fixed size list |
664 | | ( |
665 | 0 | FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), |
666 | 0 | FixedSizeList(f_from, size_from), |
667 | 0 | ) => match coerced_from(f_into.data_type(), f_from.data_type()) { |
668 | 0 | Some(data_type) if &data_type != f_into.data_type() => { |
669 | 0 | let new_field = |
670 | 0 | Arc::new(f_into.as_ref().clone().with_data_type(data_type)); |
671 | 0 | Some(FixedSizeList(new_field, *size_from)) |
672 | | } |
673 | 0 | Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), |
674 | 0 | _ => None, |
675 | | }, |
676 | 0 | (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { |
677 | 0 | match type_from { |
678 | 0 | Timestamp(_, Some(from_tz)) => { |
679 | 0 | Some(Timestamp(*unit, Some(Arc::clone(from_tz)))) |
680 | | } |
681 | | Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { |
682 | | // In the absence of any other information assume the time zone is "+00" (UTC). |
683 | 0 | Some(Timestamp(*unit, Some("+00".into()))) |
684 | | } |
685 | 0 | _ => None, |
686 | | } |
687 | | } |
688 | | (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { |
689 | 0 | Some(type_into.clone()) |
690 | | } |
691 | 0 | _ => None, |
692 | | } |
693 | 0 | } |
694 | | |
695 | | #[cfg(test)] |
696 | | mod tests { |
697 | | |
698 | | use crate::Volatility; |
699 | | |
700 | | use super::*; |
701 | | use arrow::datatypes::Field; |
702 | | |
703 | | #[test] |
704 | | fn test_string_conversion() { |
705 | | let cases = vec![ |
706 | | (DataType::Utf8View, DataType::Utf8, true), |
707 | | (DataType::Utf8View, DataType::LargeUtf8, true), |
708 | | ]; |
709 | | |
710 | | for case in cases { |
711 | | assert_eq!(can_coerce_from(&case.0, &case.1), case.2); |
712 | | } |
713 | | } |
714 | | |
715 | | #[test] |
716 | | fn test_maybe_data_types() { |
717 | | // this vec contains: arg1, arg2, expected result |
718 | | let cases = vec![ |
719 | | // 2 entries, same values |
720 | | ( |
721 | | vec![DataType::UInt8, DataType::UInt16], |
722 | | vec![DataType::UInt8, DataType::UInt16], |
723 | | Some(vec![DataType::UInt8, DataType::UInt16]), |
724 | | ), |
725 | | // 2 entries, can coerce values |
726 | | ( |
727 | | vec![DataType::UInt16, DataType::UInt16], |
728 | | vec![DataType::UInt8, DataType::UInt16], |
729 | | Some(vec![DataType::UInt16, DataType::UInt16]), |
730 | | ), |
731 | | // 0 entries, all good |
732 | | (vec![], vec![], Some(vec![])), |
733 | | // 2 entries, can't coerce |
734 | | ( |
735 | | vec![DataType::Boolean, DataType::UInt16], |
736 | | vec![DataType::UInt8, DataType::UInt16], |
737 | | None, |
738 | | ), |
739 | | // u32 -> u16 is possible |
740 | | ( |
741 | | vec![DataType::Boolean, DataType::UInt32], |
742 | | vec![DataType::Boolean, DataType::UInt16], |
743 | | Some(vec![DataType::Boolean, DataType::UInt32]), |
744 | | ), |
745 | | // UTF8 -> Timestamp |
746 | | ( |
747 | | vec![ |
748 | | DataType::Timestamp(TimeUnit::Nanosecond, None), |
749 | | DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())), |
750 | | DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), |
751 | | ], |
752 | | vec![DataType::Utf8, DataType::Utf8, DataType::Utf8], |
753 | | Some(vec![ |
754 | | DataType::Timestamp(TimeUnit::Nanosecond, None), |
755 | | DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())), |
756 | | DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), |
757 | | ]), |
758 | | ), |
759 | | ]; |
760 | | |
761 | | for case in cases { |
762 | | assert_eq!(maybe_data_types(&case.0, &case.1), case.2) |
763 | | } |
764 | | } |
765 | | |
766 | | #[test] |
767 | | fn test_get_valid_types_one_of() -> Result<()> { |
768 | | let signature = |
769 | | TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]); |
770 | | |
771 | | let invalid_types = get_valid_types( |
772 | | &signature, |
773 | | &[DataType::Int32, DataType::Int32, DataType::Int32], |
774 | | )?; |
775 | | assert_eq!(invalid_types.len(), 0); |
776 | | |
777 | | let args = vec![DataType::Int32, DataType::Int32]; |
778 | | let valid_types = get_valid_types(&signature, &args)?; |
779 | | assert_eq!(valid_types.len(), 1); |
780 | | assert_eq!(valid_types[0], args); |
781 | | |
782 | | let args = vec![DataType::Int32]; |
783 | | let valid_types = get_valid_types(&signature, &args)?; |
784 | | assert_eq!(valid_types.len(), 1); |
785 | | assert_eq!(valid_types[0], args); |
786 | | |
787 | | Ok(()) |
788 | | } |
789 | | |
790 | | #[test] |
791 | | fn test_fixed_list_wildcard_coerce() -> Result<()> { |
792 | | let inner = Arc::new(Field::new("item", DataType::Int32, false)); |
793 | | let current_types = vec![ |
794 | | DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size |
795 | | ]; |
796 | | |
797 | | let signature = Signature::exact( |
798 | | vec![DataType::FixedSizeList( |
799 | | Arc::clone(&inner), |
800 | | FIXED_SIZE_LIST_WILDCARD, |
801 | | )], |
802 | | Volatility::Stable, |
803 | | ); |
804 | | |
805 | | let coerced_data_types = data_types(¤t_types, &signature).unwrap(); |
806 | | assert_eq!(coerced_data_types, current_types); |
807 | | |
808 | | // make sure it can't coerce to a different size |
809 | | let signature = Signature::exact( |
810 | | vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], |
811 | | Volatility::Stable, |
812 | | ); |
813 | | let coerced_data_types = data_types(¤t_types, &signature); |
814 | | assert!(coerced_data_types.is_err()); |
815 | | |
816 | | // make sure it works with the same type. |
817 | | let signature = Signature::exact( |
818 | | vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], |
819 | | Volatility::Stable, |
820 | | ); |
821 | | let coerced_data_types = data_types(¤t_types, &signature).unwrap(); |
822 | | assert_eq!(coerced_data_types, current_types); |
823 | | |
824 | | Ok(()) |
825 | | } |
826 | | |
827 | | #[test] |
828 | | fn test_nested_wildcard_fixed_size_lists() -> Result<()> { |
829 | | let type_into = DataType::FixedSizeList( |
830 | | Arc::new(Field::new( |
831 | | "item", |
832 | | DataType::FixedSizeList( |
833 | | Arc::new(Field::new("item", DataType::Int32, false)), |
834 | | FIXED_SIZE_LIST_WILDCARD, |
835 | | ), |
836 | | false, |
837 | | )), |
838 | | FIXED_SIZE_LIST_WILDCARD, |
839 | | ); |
840 | | |
841 | | let type_from = DataType::FixedSizeList( |
842 | | Arc::new(Field::new( |
843 | | "item", |
844 | | DataType::FixedSizeList( |
845 | | Arc::new(Field::new("item", DataType::Int8, false)), |
846 | | 4, |
847 | | ), |
848 | | false, |
849 | | )), |
850 | | 3, |
851 | | ); |
852 | | |
853 | | assert_eq!( |
854 | | coerced_from(&type_into, &type_from), |
855 | | Some(DataType::FixedSizeList( |
856 | | Arc::new(Field::new( |
857 | | "item", |
858 | | DataType::FixedSizeList( |
859 | | Arc::new(Field::new("item", DataType::Int32, false)), |
860 | | 4, |
861 | | ), |
862 | | false, |
863 | | )), |
864 | | 3, |
865 | | )) |
866 | | ); |
867 | | |
868 | | Ok(()) |
869 | | } |
870 | | |
871 | | #[test] |
872 | | fn test_coerced_from_dictionary() { |
873 | | let type_into = |
874 | | DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); |
875 | | let type_from = DataType::Int64; |
876 | | assert_eq!(coerced_from(&type_into, &type_from), None); |
877 | | |
878 | | let type_from = |
879 | | DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32)); |
880 | | let type_into = DataType::Int64; |
881 | | assert_eq!( |
882 | | coerced_from(&type_into, &type_from), |
883 | | Some(type_into.clone()) |
884 | | ); |
885 | | } |
886 | | } |