/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/cast.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::any::Any; |
19 | | use std::fmt; |
20 | | use std::hash::{Hash, Hasher}; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; |
24 | | |
25 | | use arrow::compute::{can_cast_types, CastOptions}; |
26 | | use arrow::datatypes::{DataType, DataType::*, Schema}; |
27 | | use arrow::record_batch::RecordBatch; |
28 | | use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; |
29 | | use datafusion_common::{not_impl_err, Result}; |
30 | | use datafusion_expr_common::columnar_value::ColumnarValue; |
31 | | use datafusion_expr_common::interval_arithmetic::Interval; |
32 | | use datafusion_expr_common::sort_properties::ExprProperties; |
33 | | |
34 | | const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { |
35 | | safe: false, |
36 | | format_options: DEFAULT_FORMAT_OPTIONS, |
37 | | }; |
38 | | |
39 | | const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { |
40 | | safe: true, |
41 | | format_options: DEFAULT_FORMAT_OPTIONS, |
42 | | }; |
43 | | |
44 | | /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast |
45 | | #[derive(Debug, Clone)] |
46 | | pub struct CastExpr { |
47 | | /// The expression to cast |
48 | | pub expr: Arc<dyn PhysicalExpr>, |
49 | | /// The data type to cast to |
50 | | cast_type: DataType, |
51 | | /// Cast options |
52 | | cast_options: CastOptions<'static>, |
53 | | } |
54 | | |
55 | | impl CastExpr { |
56 | | /// Create a new CastExpr |
57 | 85 | pub fn new( |
58 | 85 | expr: Arc<dyn PhysicalExpr>, |
59 | 85 | cast_type: DataType, |
60 | 85 | cast_options: Option<CastOptions<'static>>, |
61 | 85 | ) -> Self { |
62 | 85 | Self { |
63 | 85 | expr, |
64 | 85 | cast_type, |
65 | 85 | cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), |
66 | 85 | } |
67 | 85 | } |
68 | | |
69 | | /// The expression to cast |
70 | 0 | pub fn expr(&self) -> &Arc<dyn PhysicalExpr> { |
71 | 0 | &self.expr |
72 | 0 | } |
73 | | |
74 | | /// The data type to cast to |
75 | 0 | pub fn cast_type(&self) -> &DataType { |
76 | 0 | &self.cast_type |
77 | 0 | } |
78 | | |
79 | | /// The cast options |
80 | 0 | pub fn cast_options(&self) -> &CastOptions<'static> { |
81 | 0 | &self.cast_options |
82 | 0 | } |
83 | 0 | pub fn is_bigger_cast(&self, src: DataType) -> bool { |
84 | 0 | if src == self.cast_type { |
85 | 0 | return true; |
86 | 0 | } |
87 | 0 | matches!( |
88 | 0 | (src, &self.cast_type), |
89 | | (Int8, Int16 | Int32 | Int64) |
90 | | | (Int16, Int32 | Int64) |
91 | | | (Int32, Int64) |
92 | | | (UInt8, UInt16 | UInt32 | UInt64) |
93 | | | (UInt16, UInt32 | UInt64) |
94 | | | (UInt32, UInt64) |
95 | | | ( |
96 | | Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32, |
97 | | Float32 | Float64 |
98 | | ) |
99 | | | (Int64 | UInt64, Float64) |
100 | | | (Utf8, LargeUtf8) |
101 | | ) |
102 | 0 | } |
103 | | } |
104 | | |
105 | | impl fmt::Display for CastExpr { |
106 | 0 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
107 | 0 | write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) |
108 | 0 | } |
109 | | } |
110 | | |
111 | | impl PhysicalExpr for CastExpr { |
112 | | /// Return a reference to Any that can be used for downcasting |
113 | 2.31k | fn as_any(&self) -> &dyn Any { |
114 | 2.31k | self |
115 | 2.31k | } |
116 | | |
117 | 3.24k | fn data_type(&self, _input_schema: &Schema) -> Result<DataType> { |
118 | 3.24k | Ok(self.cast_type.clone()) |
119 | 3.24k | } |
120 | | |
121 | 0 | fn nullable(&self, input_schema: &Schema) -> Result<bool> { |
122 | 0 | self.expr.nullable(input_schema) |
123 | 0 | } |
124 | | |
125 | 2.38k | fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { |
126 | 2.38k | let value = self.expr.evaluate(batch)?0 ; |
127 | 2.38k | value.cast_to(&self.cast_type, Some(&self.cast_options)) |
128 | 2.38k | } |
129 | | |
130 | 808 | fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { |
131 | 808 | vec![&self.expr] |
132 | 808 | } |
133 | | |
134 | 0 | fn with_new_children( |
135 | 0 | self: Arc<Self>, |
136 | 0 | children: Vec<Arc<dyn PhysicalExpr>>, |
137 | 0 | ) -> Result<Arc<dyn PhysicalExpr>> { |
138 | 0 | Ok(Arc::new(CastExpr::new( |
139 | 0 | Arc::clone(&children[0]), |
140 | 0 | self.cast_type.clone(), |
141 | 0 | Some(self.cast_options.clone()), |
142 | 0 | ))) |
143 | 0 | } |
144 | | |
145 | 776 | fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> { |
146 | 776 | // Cast current node's interval to the right type: |
147 | 776 | children[0].cast_to(&self.cast_type, &self.cast_options) |
148 | 776 | } |
149 | | |
150 | 776 | fn propagate_constraints( |
151 | 776 | &self, |
152 | 776 | interval: &Interval, |
153 | 776 | children: &[&Interval], |
154 | 776 | ) -> Result<Option<Vec<Interval>>> { |
155 | 776 | let child_interval = children[0]; |
156 | 776 | // Get child's datatype: |
157 | 776 | let cast_type = child_interval.data_type(); |
158 | 776 | Ok(Some(vec![ |
159 | 776 | interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?0 |
160 | | ])) |
161 | 776 | } |
162 | | |
163 | 0 | fn dyn_hash(&self, state: &mut dyn Hasher) { |
164 | 0 | let mut s = state; |
165 | 0 | self.expr.hash(&mut s); |
166 | 0 | self.cast_type.hash(&mut s); |
167 | 0 | self.cast_options.hash(&mut s); |
168 | 0 | } |
169 | | |
170 | | /// A [`CastExpr`] preserves the ordering of its child if the cast is done |
171 | | /// under the same datatype family. |
172 | 0 | fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> { |
173 | 0 | let source_datatype = children[0].range.data_type(); |
174 | 0 | let target_type = &self.cast_type; |
175 | | |
176 | 0 | let unbounded = Interval::make_unbounded(target_type)?; |
177 | 0 | if (source_datatype.is_numeric() || source_datatype == Boolean) |
178 | 0 | && target_type.is_numeric() |
179 | 0 | || source_datatype.is_temporal() && target_type.is_temporal() |
180 | 0 | || source_datatype.eq(target_type) |
181 | | { |
182 | 0 | Ok(children[0].clone().with_range(unbounded)) |
183 | | } else { |
184 | 0 | Ok(ExprProperties::new_unknown().with_range(unbounded)) |
185 | | } |
186 | 0 | } |
187 | | } |
188 | | |
189 | | impl PartialEq<dyn Any> for CastExpr { |
190 | 2.85k | fn eq(&self, other: &dyn Any) -> bool { |
191 | 2.85k | down_cast_any_ref(other) |
192 | 2.85k | .downcast_ref::<Self>() |
193 | 2.85k | .map(|x| { |
194 | 721 | self.expr.eq(&x.expr) |
195 | 289 | && self.cast_type == x.cast_type |
196 | 289 | && self.cast_options == x.cast_options |
197 | 2.85k | }721 ) |
198 | 2.85k | .unwrap_or(false) |
199 | 2.85k | } |
200 | | } |
201 | | |
202 | | /// Return a PhysicalExpression representing `expr` casted to |
203 | | /// `cast_type`, if any casting is needed. |
204 | | /// |
205 | | /// Note that such casts may lose type information |
206 | 85 | pub fn cast_with_options( |
207 | 85 | expr: Arc<dyn PhysicalExpr>, |
208 | 85 | input_schema: &Schema, |
209 | 85 | cast_type: DataType, |
210 | 85 | cast_options: Option<CastOptions<'static>>, |
211 | 85 | ) -> Result<Arc<dyn PhysicalExpr>> { |
212 | 85 | let expr_type = expr.data_type(input_schema)?0 ; |
213 | 85 | if expr_type == cast_type { |
214 | 0 | Ok(Arc::clone(&expr)) |
215 | 85 | } else if can_cast_types(&expr_type, &cast_type) { |
216 | 85 | Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) |
217 | | } else { |
218 | 0 | not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") |
219 | | } |
220 | 85 | } |
221 | | |
222 | | /// Return a PhysicalExpression representing `expr` casted to |
223 | | /// `cast_type`, if any casting is needed. |
224 | | /// |
225 | | /// Note that such casts may lose type information |
226 | 85 | pub fn cast( |
227 | 85 | expr: Arc<dyn PhysicalExpr>, |
228 | 85 | input_schema: &Schema, |
229 | 85 | cast_type: DataType, |
230 | 85 | ) -> Result<Arc<dyn PhysicalExpr>> { |
231 | 85 | cast_with_options(expr, input_schema, cast_type, None) |
232 | 85 | } |
233 | | |
234 | | #[cfg(test)] |
235 | | mod tests { |
236 | | use super::*; |
237 | | |
238 | | use crate::expressions::column::col; |
239 | | |
240 | | use arrow::{ |
241 | | array::{ |
242 | | Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, |
243 | | Int64Array, Int8Array, StringArray, Time64NanosecondArray, |
244 | | TimestampNanosecondArray, UInt32Array, |
245 | | }, |
246 | | datatypes::*, |
247 | | }; |
248 | | |
249 | | // runs an end-to-end test of physical type cast |
250 | | // 1. construct a record batch with a column "a" of type A |
251 | | // 2. construct a physical expression of CAST(a AS B) |
252 | | // 3. evaluate the expression |
253 | | // 4. verify that the resulting expression is of type B |
254 | | // 5. verify that the resulting values are downcastable and correct |
255 | | macro_rules! generic_decimal_to_other_test_cast { |
256 | | ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{ |
257 | | let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]); |
258 | | let batch = RecordBatch::try_new( |
259 | | Arc::new(schema.clone()), |
260 | | vec![Arc::new($DECIMAL_ARRAY)], |
261 | | )?; |
262 | | // verify that we can construct the expression |
263 | | let expression = |
264 | | cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; |
265 | | |
266 | | // verify that its display is correct |
267 | | assert_eq!( |
268 | | format!("CAST(a@0 AS {:?})", $TYPE), |
269 | | format!("{}", expression) |
270 | | ); |
271 | | |
272 | | // verify that the expression's type is correct |
273 | | assert_eq!(expression.data_type(&schema)?, $TYPE); |
274 | | |
275 | | // compute |
276 | | let result = expression |
277 | | .evaluate(&batch)? |
278 | | .into_array(batch.num_rows()) |
279 | | .expect("Failed to convert to array"); |
280 | | |
281 | | // verify that the array's data_type is correct |
282 | | assert_eq!(*result.data_type(), $TYPE); |
283 | | |
284 | | // verify that the data itself is downcastable |
285 | | let result = result |
286 | | .as_any() |
287 | | .downcast_ref::<$TYPEARRAY>() |
288 | | .expect("failed to downcast"); |
289 | | |
290 | | // verify that the result itself is correct |
291 | | for (i, x) in $VEC.iter().enumerate() { |
292 | | match x { |
293 | | Some(x) => assert_eq!(result.value(i), *x), |
294 | | None => assert!(!result.is_valid(i)), |
295 | | } |
296 | | } |
297 | | }}; |
298 | | } |
299 | | |
300 | | // runs an end-to-end test of physical type cast |
301 | | // 1. construct a record batch with a column "a" of type A |
302 | | // 2. construct a physical expression of CAST(a AS B) |
303 | | // 3. evaluate the expression |
304 | | // 4. verify that the resulting expression is of type B |
305 | | // 5. verify that the resulting values are downcastable and correct |
306 | | macro_rules! generic_test_cast { |
307 | | ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{ |
308 | | let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]); |
309 | | let a_vec_len = $A_VEC.len(); |
310 | | let a = $A_ARRAY::from($A_VEC); |
311 | | let batch = |
312 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
313 | | |
314 | | // verify that we can construct the expression |
315 | | let expression = |
316 | | cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; |
317 | | |
318 | | // verify that its display is correct |
319 | | assert_eq!( |
320 | | format!("CAST(a@0 AS {:?})", $TYPE), |
321 | | format!("{}", expression) |
322 | | ); |
323 | | |
324 | | // verify that the expression's type is correct |
325 | | assert_eq!(expression.data_type(&schema)?, $TYPE); |
326 | | |
327 | | // compute |
328 | | let result = expression |
329 | | .evaluate(&batch)? |
330 | | .into_array(batch.num_rows()) |
331 | | .expect("Failed to convert to array"); |
332 | | |
333 | | // verify that the array's data_type is correct |
334 | | assert_eq!(*result.data_type(), $TYPE); |
335 | | |
336 | | // verify that the len is correct |
337 | | assert_eq!(result.len(), a_vec_len); |
338 | | |
339 | | // verify that the data itself is downcastable |
340 | | let result = result |
341 | | .as_any() |
342 | | .downcast_ref::<$TYPEARRAY>() |
343 | | .expect("failed to downcast"); |
344 | | |
345 | | // verify that the result itself is correct |
346 | | for (i, x) in $VEC.iter().enumerate() { |
347 | | match x { |
348 | | Some(x) => assert_eq!(result.value(i), *x), |
349 | | None => assert!(!result.is_valid(i)), |
350 | | } |
351 | | } |
352 | | }}; |
353 | | } |
354 | | |
355 | | #[test] |
356 | | fn test_cast_decimal_to_decimal() -> Result<()> { |
357 | | let array = vec![ |
358 | | Some(1234), |
359 | | Some(2222), |
360 | | Some(3), |
361 | | Some(4000), |
362 | | Some(5000), |
363 | | None, |
364 | | ]; |
365 | | |
366 | | let decimal_array = array |
367 | | .clone() |
368 | | .into_iter() |
369 | | .collect::<Decimal128Array>() |
370 | | .with_precision_and_scale(10, 3)?; |
371 | | |
372 | | generic_decimal_to_other_test_cast!( |
373 | | decimal_array, |
374 | | Decimal128(10, 3), |
375 | | Decimal128Array, |
376 | | Decimal128(20, 6), |
377 | | [ |
378 | | Some(1_234_000), |
379 | | Some(2_222_000), |
380 | | Some(3_000), |
381 | | Some(4_000_000), |
382 | | Some(5_000_000), |
383 | | None |
384 | | ], |
385 | | None |
386 | | ); |
387 | | |
388 | | let decimal_array = array |
389 | | .into_iter() |
390 | | .collect::<Decimal128Array>() |
391 | | .with_precision_and_scale(10, 3)?; |
392 | | |
393 | | generic_decimal_to_other_test_cast!( |
394 | | decimal_array, |
395 | | Decimal128(10, 3), |
396 | | Decimal128Array, |
397 | | Decimal128(10, 2), |
398 | | [Some(123), Some(222), Some(0), Some(400), Some(500), None], |
399 | | None |
400 | | ); |
401 | | |
402 | | Ok(()) |
403 | | } |
404 | | |
405 | | #[test] |
406 | | fn test_cast_decimal_to_numeric() -> Result<()> { |
407 | | let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None]; |
408 | | // decimal to i8 |
409 | | let decimal_array = array |
410 | | .clone() |
411 | | .into_iter() |
412 | | .collect::<Decimal128Array>() |
413 | | .with_precision_and_scale(10, 0)?; |
414 | | generic_decimal_to_other_test_cast!( |
415 | | decimal_array, |
416 | | Decimal128(10, 0), |
417 | | Int8Array, |
418 | | Int8, |
419 | | [ |
420 | | Some(1_i8), |
421 | | Some(2_i8), |
422 | | Some(3_i8), |
423 | | Some(4_i8), |
424 | | Some(5_i8), |
425 | | None |
426 | | ], |
427 | | None |
428 | | ); |
429 | | |
430 | | // decimal to i16 |
431 | | let decimal_array = array |
432 | | .clone() |
433 | | .into_iter() |
434 | | .collect::<Decimal128Array>() |
435 | | .with_precision_and_scale(10, 0)?; |
436 | | generic_decimal_to_other_test_cast!( |
437 | | decimal_array, |
438 | | Decimal128(10, 0), |
439 | | Int16Array, |
440 | | Int16, |
441 | | [ |
442 | | Some(1_i16), |
443 | | Some(2_i16), |
444 | | Some(3_i16), |
445 | | Some(4_i16), |
446 | | Some(5_i16), |
447 | | None |
448 | | ], |
449 | | None |
450 | | ); |
451 | | |
452 | | // decimal to i32 |
453 | | let decimal_array = array |
454 | | .clone() |
455 | | .into_iter() |
456 | | .collect::<Decimal128Array>() |
457 | | .with_precision_and_scale(10, 0)?; |
458 | | generic_decimal_to_other_test_cast!( |
459 | | decimal_array, |
460 | | Decimal128(10, 0), |
461 | | Int32Array, |
462 | | Int32, |
463 | | [ |
464 | | Some(1_i32), |
465 | | Some(2_i32), |
466 | | Some(3_i32), |
467 | | Some(4_i32), |
468 | | Some(5_i32), |
469 | | None |
470 | | ], |
471 | | None |
472 | | ); |
473 | | |
474 | | // decimal to i64 |
475 | | let decimal_array = array |
476 | | .into_iter() |
477 | | .collect::<Decimal128Array>() |
478 | | .with_precision_and_scale(10, 0)?; |
479 | | generic_decimal_to_other_test_cast!( |
480 | | decimal_array, |
481 | | Decimal128(10, 0), |
482 | | Int64Array, |
483 | | Int64, |
484 | | [ |
485 | | Some(1_i64), |
486 | | Some(2_i64), |
487 | | Some(3_i64), |
488 | | Some(4_i64), |
489 | | Some(5_i64), |
490 | | None |
491 | | ], |
492 | | None |
493 | | ); |
494 | | |
495 | | // decimal to float32 |
496 | | let array = vec![ |
497 | | Some(1234), |
498 | | Some(2222), |
499 | | Some(3), |
500 | | Some(4000), |
501 | | Some(5000), |
502 | | None, |
503 | | ]; |
504 | | let decimal_array = array |
505 | | .clone() |
506 | | .into_iter() |
507 | | .collect::<Decimal128Array>() |
508 | | .with_precision_and_scale(10, 3)?; |
509 | | generic_decimal_to_other_test_cast!( |
510 | | decimal_array, |
511 | | Decimal128(10, 3), |
512 | | Float32Array, |
513 | | Float32, |
514 | | [ |
515 | | Some(1.234_f32), |
516 | | Some(2.222_f32), |
517 | | Some(0.003_f32), |
518 | | Some(4.0_f32), |
519 | | Some(5.0_f32), |
520 | | None |
521 | | ], |
522 | | None |
523 | | ); |
524 | | |
525 | | // decimal to float64 |
526 | | let decimal_array = array |
527 | | .into_iter() |
528 | | .collect::<Decimal128Array>() |
529 | | .with_precision_and_scale(20, 6)?; |
530 | | generic_decimal_to_other_test_cast!( |
531 | | decimal_array, |
532 | | Decimal128(20, 6), |
533 | | Float64Array, |
534 | | Float64, |
535 | | [ |
536 | | Some(0.001234_f64), |
537 | | Some(0.002222_f64), |
538 | | Some(0.000003_f64), |
539 | | Some(0.004_f64), |
540 | | Some(0.005_f64), |
541 | | None |
542 | | ], |
543 | | None |
544 | | ); |
545 | | Ok(()) |
546 | | } |
547 | | |
548 | | #[test] |
549 | | fn test_cast_numeric_to_decimal() -> Result<()> { |
550 | | // int8 |
551 | | generic_test_cast!( |
552 | | Int8Array, |
553 | | Int8, |
554 | | vec![1, 2, 3, 4, 5], |
555 | | Decimal128Array, |
556 | | Decimal128(3, 0), |
557 | | [Some(1), Some(2), Some(3), Some(4), Some(5)], |
558 | | None |
559 | | ); |
560 | | |
561 | | // int16 |
562 | | generic_test_cast!( |
563 | | Int16Array, |
564 | | Int16, |
565 | | vec![1, 2, 3, 4, 5], |
566 | | Decimal128Array, |
567 | | Decimal128(5, 0), |
568 | | [Some(1), Some(2), Some(3), Some(4), Some(5)], |
569 | | None |
570 | | ); |
571 | | |
572 | | // int32 |
573 | | generic_test_cast!( |
574 | | Int32Array, |
575 | | Int32, |
576 | | vec![1, 2, 3, 4, 5], |
577 | | Decimal128Array, |
578 | | Decimal128(10, 0), |
579 | | [Some(1), Some(2), Some(3), Some(4), Some(5)], |
580 | | None |
581 | | ); |
582 | | |
583 | | // int64 |
584 | | generic_test_cast!( |
585 | | Int64Array, |
586 | | Int64, |
587 | | vec![1, 2, 3, 4, 5], |
588 | | Decimal128Array, |
589 | | Decimal128(20, 0), |
590 | | [Some(1), Some(2), Some(3), Some(4), Some(5)], |
591 | | None |
592 | | ); |
593 | | |
594 | | // int64 to different scale |
595 | | generic_test_cast!( |
596 | | Int64Array, |
597 | | Int64, |
598 | | vec![1, 2, 3, 4, 5], |
599 | | Decimal128Array, |
600 | | Decimal128(20, 2), |
601 | | [Some(100), Some(200), Some(300), Some(400), Some(500)], |
602 | | None |
603 | | ); |
604 | | |
605 | | // float32 |
606 | | generic_test_cast!( |
607 | | Float32Array, |
608 | | Float32, |
609 | | vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], |
610 | | Decimal128Array, |
611 | | Decimal128(10, 2), |
612 | | [Some(150), Some(250), Some(300), Some(112), Some(550)], |
613 | | None |
614 | | ); |
615 | | |
616 | | // float64 |
617 | | generic_test_cast!( |
618 | | Float64Array, |
619 | | Float64, |
620 | | vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], |
621 | | Decimal128Array, |
622 | | Decimal128(20, 4), |
623 | | [ |
624 | | Some(15000), |
625 | | Some(25000), |
626 | | Some(30000), |
627 | | Some(11235), |
628 | | Some(55000) |
629 | | ], |
630 | | None |
631 | | ); |
632 | | Ok(()) |
633 | | } |
634 | | |
635 | | #[test] |
636 | | fn test_cast_i32_u32() -> Result<()> { |
637 | | generic_test_cast!( |
638 | | Int32Array, |
639 | | Int32, |
640 | | vec![1, 2, 3, 4, 5], |
641 | | UInt32Array, |
642 | | UInt32, |
643 | | [ |
644 | | Some(1_u32), |
645 | | Some(2_u32), |
646 | | Some(3_u32), |
647 | | Some(4_u32), |
648 | | Some(5_u32) |
649 | | ], |
650 | | None |
651 | | ); |
652 | | Ok(()) |
653 | | } |
654 | | |
655 | | #[test] |
656 | | fn test_cast_i32_utf8() -> Result<()> { |
657 | | generic_test_cast!( |
658 | | Int32Array, |
659 | | Int32, |
660 | | vec![1, 2, 3, 4, 5], |
661 | | StringArray, |
662 | | Utf8, |
663 | | [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], |
664 | | None |
665 | | ); |
666 | | Ok(()) |
667 | | } |
668 | | |
669 | | #[test] |
670 | | fn test_cast_i64_t64() -> Result<()> { |
671 | | let original = vec![1, 2, 3, 4, 5]; |
672 | | let expected: Vec<Option<i64>> = original |
673 | | .iter() |
674 | | .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0))) |
675 | | .collect(); |
676 | | generic_test_cast!( |
677 | | Int64Array, |
678 | | Int64, |
679 | | original, |
680 | | TimestampNanosecondArray, |
681 | | Timestamp(TimeUnit::Nanosecond, None), |
682 | | expected, |
683 | | None |
684 | | ); |
685 | | Ok(()) |
686 | | } |
687 | | |
688 | | #[test] |
689 | | fn invalid_cast() { |
690 | | // Ensure a useful error happens at plan time if invalid casts are used |
691 | | let schema = Schema::new(vec![Field::new("a", Int32, false)]); |
692 | | |
693 | | let result = cast( |
694 | | col("a", &schema).unwrap(), |
695 | | &schema, |
696 | | DataType::Interval(IntervalUnit::MonthDayNano), |
697 | | ); |
698 | | result.expect_err("expected Invalid CAST"); |
699 | | } |
700 | | |
701 | | #[test] |
702 | | fn invalid_cast_with_options_error() -> Result<()> { |
703 | | // Ensure a useful error happens at plan time if invalid casts are used |
704 | | let schema = Schema::new(vec![Field::new("a", Utf8, false)]); |
705 | | let a = StringArray::from(vec!["9.1"]); |
706 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
707 | | let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?; |
708 | | let result = expression.evaluate(&batch); |
709 | | |
710 | | match result { |
711 | | Ok(_) => panic!("expected error"), |
712 | | Err(e) => { |
713 | | assert!(e |
714 | | .to_string() |
715 | | .contains("Cannot cast string '9.1' to value of Int32 type")) |
716 | | } |
717 | | } |
718 | | Ok(()) |
719 | | } |
720 | | |
721 | | #[test] |
722 | | #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396 |
723 | | fn test_cast_decimal() -> Result<()> { |
724 | | let schema = Schema::new(vec![Field::new("a", Int64, false)]); |
725 | | let a = Int64Array::from(vec![100]); |
726 | | let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; |
727 | | let expression = |
728 | | cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?; |
729 | | expression.evaluate(&batch)?; |
730 | | Ok(()) |
731 | | } |
732 | | } |