diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 3d9d0ee3d920..ebfd97488b28 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -161,17 +161,16 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, // Utf8 to decimal (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, - (Decimal128(_, _) | Decimal256(_, _), _) => false, - (_, Decimal128(_, _) | Decimal256(_, _)) => false, (Struct(_), _) => false, (_, Struct(_)) => false, (_, Boolean) => { - DataType::is_numeric(from_type) + DataType::is_integer(from_type) || + DataType::is_floating(from_type) || from_type == &Utf8 || from_type == &LargeUtf8 } (Boolean, _) => { - DataType::is_numeric(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 + DataType::is_integer(to_type) || DataType::is_floating(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 } (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true, @@ -222,8 +221,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Time64(_), Time32(to_unit)) => { matches!(to_unit, Second | Millisecond) } - (Timestamp(_, _), _) if to_type.is_integer() => true, - (_, Timestamp(_, _)) if from_type.is_integer() => true, + (Timestamp(_, _), _) if to_type.is_numeric() && to_type != &Float16 => true, + (_, Timestamp(_, _)) if from_type.is_numeric() && from_type != &Float16 => true, (Date64, Timestamp(_, None)) => true, (Date32, Timestamp(_, None)) => true, ( @@ -849,7 +848,7 @@ pub fn cast_with_options( cast_options, ) } - (Decimal128(_, scale), _) => { + (Decimal128(_, scale), _) if !to_type.is_temporal() => { // cast decimal to other type match to_type { UInt8 => cast_decimal_to_integer::( @@ -914,7 +913,7 @@ pub fn cast_with_options( ))), } } - (Decimal256(_, scale), _) => { + (Decimal256(_, scale), _) if !to_type.is_temporal() => { // cast decimal to other type match to_type { UInt8 => cast_decimal_to_integer::( @@ -979,7 +978,7 @@ pub fn cast_with_options( ))), } } - (_, Decimal128(precision, scale)) => { + (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { // cast data to decimal match from_type { UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>( @@ -1068,7 +1067,7 @@ pub fn cast_with_options( ))), } } - (_, Decimal256(precision, scale)) => { + (_, Decimal256(precision, scale)) if !from_type.is_temporal() => { // cast data to decimal match from_type { UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>( @@ -1607,24 +1606,25 @@ pub fn cast_with_options( .unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)), )), - (Timestamp(TimeUnit::Second, _), _) if to_type.is_integer() => { + // Timestamp to integer/floating/decimals + (Timestamp(TimeUnit::Second, _), _) if to_type.is_numeric() => { let array = cast_reinterpret_arrays::(array)?; cast_with_options(&array, to_type, cast_options) } - (Timestamp(TimeUnit::Millisecond, _), _) if to_type.is_integer() => { + (Timestamp(TimeUnit::Millisecond, _), _) if to_type.is_numeric() => { let array = cast_reinterpret_arrays::(array)?; cast_with_options(&array, to_type, cast_options) } - (Timestamp(TimeUnit::Microsecond, _), _) if to_type.is_integer() => { + (Timestamp(TimeUnit::Microsecond, _), _) if to_type.is_numeric() => { let array = cast_reinterpret_arrays::(array)?; cast_with_options(&array, to_type, cast_options) } - (Timestamp(TimeUnit::Nanosecond, _), _) if to_type.is_integer() => { + (Timestamp(TimeUnit::Nanosecond, _), _) if to_type.is_numeric() => { let array = cast_reinterpret_arrays::(array)?; cast_with_options(&array, to_type, cast_options) } - (_, Timestamp(unit, tz)) if from_type.is_integer() => { + (_, Timestamp(unit, tz)) if from_type.is_numeric() => { let array = cast_with_options(array, &Int64, cast_options)?; Ok(make_timestamp_array( array.as_primitive(), @@ -4652,6 +4652,80 @@ mod tests { assert_eq!(&actual, &expected); } + #[test] + fn test_cast_floating_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Float64Array::from(vec![Some(2.1), Some(10.2), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_floating() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Float64).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_decimal_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Decimal128Array::from(vec![Some(200), Some(1000), None]) + .with_precision_and_scale(4, 2) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(2000)), + Some(i256::from_i128(10000)), + None, + ]) + .with_precision_and_scale(5, 3) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_decimal() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast( + &cast(&array, &DataType::Decimal128(5, 2)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast( + &cast(&array, &DataType::Decimal256(10, 5)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + } + #[test] fn test_cast_list_i32_to_list_u16() { let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data();