diff --git a/Cargo.toml b/Cargo.toml index 22f15b2bf98..ef685f338f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -171,6 +171,7 @@ compute_substring = [] compute_take = [] compute_temporal = [] compute_window = ["compute_concatenate"] +compute_lower = [] compute = [ "compute_aggregate", "compute_arithmetics", @@ -196,6 +197,7 @@ compute = [ "compute_take", "compute_temporal", "compute_window", + "compute_lower", ] # base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. io_parquet = ["parquet2", "io_ipc", "base64", "futures"] @@ -298,4 +300,4 @@ harness = false [[bench]] name = "bitwise" -harness = false \ No newline at end of file +harness = false diff --git a/src/compute/lower.rs b/src/compute/lower.rs new file mode 100644 index 00000000000..5a9978179b4 --- /dev/null +++ b/src/compute/lower.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernel to extract a lower case of a \[Large\]StringArray + +use super::utils::utf8_apply; +use crate::array::*; +use crate::{ + datatypes::DataType, + error::{ArrowError, Result}, +}; + +/// Returns a new `Array` where each of each of the elements is lower-cased. +/// this function errors when the passed array is not a \[Large\]String array. +pub fn lower(array: &dyn Array) -> Result> { + match array.data_type() { + DataType::LargeUtf8 => Ok(Box::new(utf8_apply( + str::to_lowercase, + array + .as_any() + .downcast_ref::>() + .expect("A large string is expected"), + ))), + DataType::Utf8 => Ok(Box::new(utf8_apply( + str::to_lowercase, + array + .as_any() + .downcast_ref::>() + .expect("A string is expected"), + ))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "lower does not support type {:?}", + array.data_type() + ))), + } +} + +/// Checks if an array of type `datatype` can perform lower operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::lower::can_lower; +/// use arrow2::datatypes::{DataType}; +/// +/// let data_type = DataType::Utf8; +/// assert_eq!(can_lower(&data_type), true); +/// +/// let data_type = DataType::Null; +/// assert_eq!(can_lower(&data_type), false); +/// ``` +pub fn can_lower(data_type: &DataType) -> bool { + matches!(data_type, DataType::LargeUtf8 | DataType::Utf8) +} diff --git a/src/compute/mod.rs b/src/compute/mod.rs index b44433b9ef7..5608e3ebfe5 100644 --- a/src/compute/mod.rs +++ b/src/compute/mod.rs @@ -58,6 +58,9 @@ pub mod like; #[cfg(feature = "compute_limit")] #[cfg_attr(docsrs, doc(cfg(feature = "compute_limit")))] pub mod limit; +#[cfg(feature = "compute_lower")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_lower")))] +pub mod lower; #[cfg(feature = "compute_merge_sort")] #[cfg_attr(docsrs, doc(cfg(feature = "compute_merge_sort")))] pub mod merge_sort; diff --git a/src/compute/utils.rs b/src/compute/utils.rs index 864eb27d40d..69ed5b7a5a8 100644 --- a/src/compute/utils.rs +++ b/src/compute/utils.rs @@ -30,6 +30,14 @@ pub fn unary_utf8_boolean bool>( BooleanArray::from_data(DataType::Boolean, values, validity) } +/// utf8_apply will apply `Fn(&str) -> String` to every value in Utf8Array. +pub fn utf8_apply String>(f: F, array: &Utf8Array) -> Utf8Array { + let iter = array.values_iter().map(f); + + let new = Utf8Array::::from_trusted_len_values_iter(iter); + new.with_validity(array.validity().cloned()) +} + // Errors iff the two arrays have a different length. #[inline] pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> Result<()> { diff --git a/tests/it/compute/lower.rs b/tests/it/compute/lower.rs new file mode 100644 index 00000000000..d8f594174fb --- /dev/null +++ b/tests/it/compute/lower.rs @@ -0,0 +1,186 @@ +use arrow2::{array::*, compute::lower::*, error::Result}; + +fn with_nulls_utf8() -> Result<()> { + let cases = vec![ + // identity + ( + vec![Some("hello"), None, Some("world")], + vec![Some("hello"), None, Some("world")], + ), + // part of input + ( + vec![Some("Hello"), None, Some("wOrld")], + vec![Some("hello"), None, Some("world")], + ), + // all input + ( + vec![Some("HELLO"), None, Some("WORLD")], + vec![Some("hello"), None, Some("world")], + ), + // UTF8 characters + ( + vec![ + None, + Some("السلام عليكم"), + Some("Dobrý den"), + Some("שָׁלוֹם"), + Some("नमस्ते"), + Some("こんにちは"), + Some("안녕하세요"), + Some("你好"), + Some("Olá"), + Some("Здравствуйте"), + Some("Hola"), + ], + vec![ + None, + Some("السلام عليكم"), + Some("dobrý den"), + Some("שָׁלוֹם"), + Some("नमस्ते"), + Some("こんにちは"), + Some("안녕하세요"), + Some("你好"), + Some("olá"), + Some("здравствуйте"), + Some("hola"), + ], + ), + ]; + + cases + .into_iter() + .try_for_each::<_, Result<()>>(|(array, expected)| { + let array = Utf8Array::::from(&array); + let result = lower(&array)?; + assert_eq!(array.len(), result.len()); + + let result = result.as_any().downcast_ref::>().unwrap(); + let expected = Utf8Array::::from(&expected); + + assert_eq!(&expected, result); + Ok(()) + })?; + + Ok(()) +} + +#[test] +fn with_nulls_string() -> Result<()> { + with_nulls_utf8::() +} + +#[test] +fn with_nulls_large_string() -> Result<()> { + with_nulls_utf8::() +} + +fn without_nulls_utf8() -> Result<()> { + let cases = vec![ + // identity + (vec!["hello", "world"], vec!["hello", "world"]), + // part of input + (vec!["Hello", "wOrld"], vec!["hello", "world"]), + // all input + (vec!["HELLO", "WORLD"], vec!["hello", "world"]), + // UTF8 characters + ( + vec![ + "السلام عليكم", + "Dobrý den", + "שָׁלוֹם", + "नमस्ते", + "こんにちは", + "안녕하세요", + "你好", + "Olá", + "Здравствуйте", + "Hola", + ], + vec![ + "السلام عليكم", + "dobrý den", + "שָׁלוֹם", + "नमस्ते", + "こんにちは", + "안녕하세요", + "你好", + "olá", + "здравствуйте", + "hola", + ], + ), + ]; + + cases + .into_iter() + .try_for_each::<_, Result<()>>(|(array, expected)| { + let array = Utf8Array::::from_slice(&array); + let result = lower(&array)?; + assert_eq!(array.len(), result.len()); + + let result = result.as_any().downcast_ref::>().unwrap(); + let expected = Utf8Array::::from_slice(&expected); + assert_eq!(&expected, result); + Ok(()) + })?; + + Ok(()) +} + +#[test] +fn without_nulls_string() -> Result<()> { + without_nulls_utf8::() +} + +#[test] +fn without_nulls_large_string() -> Result<()> { + without_nulls_utf8::() +} + +#[test] +fn consistency() { + use arrow2::datatypes::DataType::*; + use arrow2::datatypes::TimeUnit; + let datatypes = vec![ + Null, + Boolean, + UInt8, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64, + Float32, + Float64, + Timestamp(TimeUnit::Second, None), + Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Nanosecond, None), + Time64(TimeUnit::Microsecond), + Time64(TimeUnit::Nanosecond), + Date32, + Time32(TimeUnit::Second), + Time32(TimeUnit::Millisecond), + Date64, + Utf8, + LargeUtf8, + Binary, + LargeBinary, + Duration(TimeUnit::Second), + Duration(TimeUnit::Millisecond), + Duration(TimeUnit::Microsecond), + Duration(TimeUnit::Nanosecond), + ]; + + datatypes.into_iter().for_each(|d1| { + let array = new_null_array(d1.clone(), 10); + if can_lower(&d1) { + assert!(lower(array.as_ref()).is_ok()); + } else { + assert!(lower(array.as_ref()).is_err()); + } + }); +} diff --git a/tests/it/compute/mod.rs b/tests/it/compute/mod.rs index ff55ab74d7b..d4bd0b008eb 100644 --- a/tests/it/compute/mod.rs +++ b/tests/it/compute/mod.rs @@ -28,6 +28,8 @@ mod length; mod like; #[cfg(feature = "compute_limit")] mod limit; +#[cfg(feature = "compute_lower")] +mod lower; #[cfg(feature = "compute_merge_sort")] mod merge_sort; #[cfg(feature = "compute_partition")]