Skip to content

Commit

Permalink
Fix timestamp handling in cast kernel (#1936) (#4033) (#4034)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Apr 7, 2023
1 parent d946cc4 commit bc15cbd
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 52 deletions.
26 changes: 25 additions & 1 deletion arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow_schema::{
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
DECIMAL_DEFAULT_SCALE,
};
use chrono::{Duration, NaiveDate};
use chrono::{Duration, NaiveDate, NaiveDateTime};
use half::f16;
use std::marker::PhantomData;
use std::ops::{Add, Sub};
Expand Down Expand Up @@ -311,19 +311,43 @@ pub trait ArrowTimestampType: ArrowTemporalType<Native = i64> {
fn get_time_unit() -> TimeUnit {
Self::UNIT
}

/// Creates a ArrowTimestampType::Native from the provided [`NaiveDateTime`]
///
/// See [`DataType::Timestamp`] for more information on timezone handling
fn make_value(naive: NaiveDateTime) -> Option<i64>;
}

impl ArrowTimestampType for TimestampSecondType {
const UNIT: TimeUnit = TimeUnit::Second;

fn make_value(naive: NaiveDateTime) -> Option<i64> {
Some(naive.timestamp())
}
}
impl ArrowTimestampType for TimestampMillisecondType {
const UNIT: TimeUnit = TimeUnit::Millisecond;

fn make_value(naive: NaiveDateTime) -> Option<i64> {
let millis = naive.timestamp().checked_mul(1_000)?;
millis.checked_add(naive.timestamp_subsec_millis() as i64)
}
}
impl ArrowTimestampType for TimestampMicrosecondType {
const UNIT: TimeUnit = TimeUnit::Microsecond;

fn make_value(naive: NaiveDateTime) -> Option<i64> {
let micros = naive.timestamp().checked_mul(1_000_000)?;
micros.checked_add(naive.timestamp_subsec_micros() as i64)
}
}
impl ArrowTimestampType for TimestampNanosecondType {
const UNIT: TimeUnit = TimeUnit::Nanosecond;

fn make_value(naive: NaiveDateTime) -> Option<i64> {
let nanos = naive.timestamp().checked_mul(1_000_000_000)?;
nanos.checked_add(naive.timestamp_subsec_nanos() as i64)
}
}

impl IntervalYearMonthType {
Expand Down
139 changes: 88 additions & 51 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
//! assert_eq!(7.0, c.value(2));
//! ```
use chrono::{NaiveTime, Timelike};
use chrono::{NaiveTime, TimeZone, Timelike, Utc};
use std::cmp::Ordering;
use std::sync::Arc;

use crate::display::{array_value_to_string, ArrayFormatter, FormatOptions};
use crate::parse::{
parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
string_to_timestamp_nanos,
string_to_datetime,
};
use arrow_array::{
builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *,
Expand Down Expand Up @@ -1233,16 +1233,16 @@ pub fn cast_with_options(
cast_string_to_time64nanosecond::<i64>(array, cast_options)
}
Timestamp(TimeUnit::Second, to_tz) => {
cast_string_to_timestamp::<i64, TimestampSecondType>(array, to_tz,cast_options)
cast_string_to_timestamp::<i64, TimestampSecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Millisecond, to_tz) => {
cast_string_to_timestamp::<i64, TimestampMillisecondType>(array, to_tz,cast_options)
cast_string_to_timestamp::<i64, TimestampMillisecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Microsecond, to_tz) => {
cast_string_to_timestamp::<i64, TimestampMicrosecondType>(array, to_tz,cast_options)
cast_string_to_timestamp::<i64, TimestampMicrosecondType>(array, to_tz, cast_options)
}
Timestamp(TimeUnit::Nanosecond, to_tz) => {
cast_string_to_timestamp::<i64, TimestampNanosecondType>(array, to_tz,cast_options)
cast_string_to_timestamp::<i64, TimestampNanosecondType>(array, to_tz, cast_options)
}
Interval(IntervalUnit::YearMonth) => {
cast_string_to_year_month_interval::<i64>(array, cast_options)
Expand Down Expand Up @@ -2653,59 +2653,67 @@ fn cast_string_to_time64nanosecond<Offset: OffsetSizeTrait>(
}

/// Casts generic string arrays to an ArrowTimestampType (TimeStampNanosecondArray, etc.)
fn cast_string_to_timestamp<
Offset: OffsetSizeTrait,
TimestampType: ArrowTimestampType<Native = i64>,
>(
fn cast_string_to_timestamp<O: OffsetSizeTrait, T: ArrowTimestampType>(
array: &dyn Array,
to_tz: &Option<Arc<str>>,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();

let scale_factor = match TimestampType::UNIT {
TimeUnit::Second => 1_000_000_000,
TimeUnit::Millisecond => 1_000_000,
TimeUnit::Microsecond => 1_000,
TimeUnit::Nanosecond => 1,
let array = array.as_string::<O>();
let out: PrimitiveArray<T> = match to_tz {
Some(tz) => {
let tz: Tz = tz.as_ref().parse()?;
cast_string_to_timestamp_impl(array, &tz, cast_options)?
}
None => cast_string_to_timestamp_impl(array, &Utc, cast_options)?,
};
Ok(Arc::new(out.with_timezone_opt(to_tz.clone())))
}

let array = if cast_options.safe {
let iter = string_array.iter().map(|v| {
v.and_then(|v| string_to_timestamp_nanos(v).ok().map(|t| t / scale_factor))
fn cast_string_to_timestamp_impl<
O: OffsetSizeTrait,
T: ArrowTimestampType,
Tz: TimeZone,
>(
array: &GenericStringArray<O>,
tz: &Tz,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError> {
if cast_options.safe {
let iter = array.iter().map(|v| {
v.and_then(|v| {
let naive = string_to_datetime(tz, v).ok()?.naive_utc();
T::make_value(naive)
})
});
// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.

unsafe {
PrimitiveArray::<TimestampType>::from_trusted_len_iter(iter)
.with_timezone_opt(to_tz.clone())
}
Ok(unsafe { PrimitiveArray::from_trusted_len_iter(iter) })
} else {
let vec = string_array
let vec = array
.iter()
.map(|v| {
v.map(|v| string_to_timestamp_nanos(v).map(|t| t / scale_factor))
.transpose()
v.map(|v| {
let naive = string_to_datetime(tz, v)?.naive_utc();
T::make_value(naive).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow converting {naive} to {:?}",
T::UNIT
))
})
})
.transpose()
})
.collect::<Result<Vec<Option<i64>>, _>>()?;

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
unsafe {
PrimitiveArray::<TimestampType>::from_trusted_len_iter(vec.iter())
.with_timezone_opt(to_tz.clone())
}
};

Ok(Arc::new(array) as ArrayRef)
Ok(unsafe { PrimitiveArray::from_trusted_len_iter(vec.iter()) })
}
}

fn cast_string_to_year_month_interval<Offset: OffsetSizeTrait>(
Expand Down Expand Up @@ -5018,6 +5026,14 @@ mod tests {
}
}

#[test]
fn test_cast_string_to_timestamp_overflow() {
let array = StringArray::from(vec!["9800-09-08T12:00:00.123456789"]);
let result = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap();
let result = result.as_primitive::<TimestampSecondType>();
assert_eq!(result.values(), &[247112596800]);
}

#[test]
fn test_cast_string_to_date32() {
let a1 = Arc::new(StringArray::from(vec![
Expand Down Expand Up @@ -8079,24 +8095,45 @@ mod tests {
let array = Arc::new(valid) as ArrayRef;
let b = cast_with_options(
&array,
&DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)),
&DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.clone())),
&CastOptions { safe: false },
)
.unwrap();

let c = b
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
assert_eq!(1672574706789000000, c.value(0));
assert_eq!(1672571106789000000, c.value(1));
assert_eq!(1672574706789000000, c.value(2));
assert_eq!(1672574706789000000, c.value(3));
assert_eq!(1672518906000000000, c.value(4));
assert_eq!(1672518906000000000, c.value(5));
assert_eq!(1672545906789000000, c.value(6));
assert_eq!(1672545906000000000, c.value(7));
assert_eq!(1672531200000000000, c.value(8));
let tz = tz.as_ref().parse().unwrap();

let as_tz = |v: i64| {
as_datetime_with_timezone::<TimestampNanosecondType>(v, tz).unwrap()
};

let as_utc = |v: &i64| as_tz(*v).naive_utc().to_string();
let as_local = |v: &i64| as_tz(*v).naive_local().to_string();

let values = b.as_primitive::<TimestampNanosecondType>().values();
let utc_results: Vec<_> = values.iter().map(as_utc).collect();
let local_results: Vec<_> = values.iter().map(as_local).collect();

// Absolute timestamps should be parsed preserving the same UTC instant
assert_eq!(
&utc_results[..6],
&[
"2023-01-01 12:05:06.789".to_string(),
"2023-01-01 11:05:06.789".to_string(),
"2023-01-01 12:05:06.789".to_string(),
"2023-01-01 12:05:06.789".to_string(),
"2022-12-31 20:35:06".to_string(),
"2022-12-31 20:35:06".to_string(),
]
);
// Non-absolute timestamps should be parsed preserving the same local instant
assert_eq!(
&local_results[6..],
&[
"2023-01-01 04:05:06.789".to_string(),
"2023-01-01 04:05:06".to_string(),
"2023-01-01 00:00:00".to_string()
]
)
}

test_tz("+00:00".into());
Expand Down

0 comments on commit bc15cbd

Please sign in to comment.