Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added lower support (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuanwo authored Dec 6, 2021
1 parent 021a8e3 commit 998882e
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ compute_substring = []
compute_take = []
compute_temporal = []
compute_window = ["compute_concatenate"]
compute_lower = []
compute = [
"compute_aggregate",
"compute_arithmetics",
Expand All @@ -196,6 +197,7 @@ compute = [
"compute_take",
"compute_temporal",
"compute_window",
"compute_lower",
]
# base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format.
io_parquet = ["parquet2", "io_ipc", "base64", "futures"]
Expand Down Expand Up @@ -298,4 +300,4 @@ harness = false

[[bench]]
name = "bitwise"
harness = false
harness = false
67 changes: 67 additions & 0 deletions src/compute/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines kernel to extract a lower case of a \[Large\]StringArray
use super::utils::utf8_apply;
use crate::array::*;
use crate::{
datatypes::DataType,
error::{ArrowError, Result},
};

/// Returns a new `Array` where each of each of the elements is lower-cased.
/// this function errors when the passed array is not a \[Large\]String array.
pub fn lower(array: &dyn Array) -> Result<Box<dyn Array>> {
match array.data_type() {
DataType::LargeUtf8 => Ok(Box::new(utf8_apply(
str::to_lowercase,
array
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.expect("A large string is expected"),
))),
DataType::Utf8 => Ok(Box::new(utf8_apply(
str::to_lowercase,
array
.as_any()
.downcast_ref::<Utf8Array<i32>>()
.expect("A string is expected"),
))),
_ => Err(ArrowError::InvalidArgumentError(format!(
"lower does not support type {:?}",
array.data_type()
))),
}
}

/// Checks if an array of type `datatype` can perform lower operation
///
/// # Examples
/// ```
/// use arrow2::compute::lower::can_lower;
/// use arrow2::datatypes::{DataType};
///
/// let data_type = DataType::Utf8;
/// assert_eq!(can_lower(&data_type), true);
///
/// let data_type = DataType::Null;
/// assert_eq!(can_lower(&data_type), false);
/// ```
pub fn can_lower(data_type: &DataType) -> bool {
matches!(data_type, DataType::LargeUtf8 | DataType::Utf8)
}
3 changes: 3 additions & 0 deletions src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ pub mod like;
#[cfg(feature = "compute_limit")]
#[cfg_attr(docsrs, doc(cfg(feature = "compute_limit")))]
pub mod limit;
#[cfg(feature = "compute_lower")]
#[cfg_attr(docsrs, doc(cfg(feature = "compute_lower")))]
pub mod lower;
#[cfg(feature = "compute_merge_sort")]
#[cfg_attr(docsrs, doc(cfg(feature = "compute_merge_sort")))]
pub mod merge_sort;
Expand Down
8 changes: 8 additions & 0 deletions src/compute/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ pub fn unary_utf8_boolean<O: Offset, F: Fn(&str) -> bool>(
BooleanArray::from_data(DataType::Boolean, values, validity)
}

/// utf8_apply will apply `Fn(&str) -> String` to every value in Utf8Array.
pub fn utf8_apply<O: Offset, F: Fn(&str) -> String>(f: F, array: &Utf8Array<O>) -> Utf8Array<O> {
let iter = array.values_iter().map(f);

let new = Utf8Array::<O>::from_trusted_len_values_iter(iter);
new.with_validity(array.validity().cloned())
}

// Errors iff the two arrays have a different length.
#[inline]
pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> Result<()> {
Expand Down
186 changes: 186 additions & 0 deletions tests/it/compute/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
use arrow2::{array::*, compute::lower::*, error::Result};

fn with_nulls_utf8<O: Offset>() -> Result<()> {
let cases = vec![
// identity
(
vec![Some("hello"), None, Some("world")],
vec![Some("hello"), None, Some("world")],
),
// part of input
(
vec![Some("Hello"), None, Some("wOrld")],
vec![Some("hello"), None, Some("world")],
),
// all input
(
vec![Some("HELLO"), None, Some("WORLD")],
vec![Some("hello"), None, Some("world")],
),
// UTF8 characters
(
vec![
None,
Some("السلام عليكم"),
Some("Dobrý den"),
Some("שָׁלוֹם"),
Some("नमस्ते"),
Some("こんにちは"),
Some("안녕하세요"),
Some("你好"),
Some("Olá"),
Some("Здравствуйте"),
Some("Hola"),
],
vec![
None,
Some("السلام عليكم"),
Some("dobrý den"),
Some("שָׁלוֹם"),
Some("नमस्ते"),
Some("こんにちは"),
Some("안녕하세요"),
Some("你好"),
Some("olá"),
Some("здравствуйте"),
Some("hola"),
],
),
];

cases
.into_iter()
.try_for_each::<_, Result<()>>(|(array, expected)| {
let array = Utf8Array::<O>::from(&array);
let result = lower(&array)?;
assert_eq!(array.len(), result.len());

let result = result.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
let expected = Utf8Array::<O>::from(&expected);

assert_eq!(&expected, result);
Ok(())
})?;

Ok(())
}

#[test]
fn with_nulls_string() -> Result<()> {
with_nulls_utf8::<i32>()
}

#[test]
fn with_nulls_large_string() -> Result<()> {
with_nulls_utf8::<i64>()
}

fn without_nulls_utf8<O: Offset>() -> Result<()> {
let cases = vec![
// identity
(vec!["hello", "world"], vec!["hello", "world"]),
// part of input
(vec!["Hello", "wOrld"], vec!["hello", "world"]),
// all input
(vec!["HELLO", "WORLD"], vec!["hello", "world"]),
// UTF8 characters
(
vec![
"السلام عليكم",
"Dobrý den",
"שָׁלוֹם",
"नमस्ते",
"こんにちは",
"안녕하세요",
"你好",
"Olá",
"Здравствуйте",
"Hola",
],
vec![
"السلام عليكم",
"dobrý den",
"שָׁלוֹם",
"नमस्ते",
"こんにちは",
"안녕하세요",
"你好",
"olá",
"здравствуйте",
"hola",
],
),
];

cases
.into_iter()
.try_for_each::<_, Result<()>>(|(array, expected)| {
let array = Utf8Array::<O>::from_slice(&array);
let result = lower(&array)?;
assert_eq!(array.len(), result.len());

let result = result.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
let expected = Utf8Array::<O>::from_slice(&expected);
assert_eq!(&expected, result);
Ok(())
})?;

Ok(())
}

#[test]
fn without_nulls_string() -> Result<()> {
without_nulls_utf8::<i32>()
}

#[test]
fn without_nulls_large_string() -> Result<()> {
without_nulls_utf8::<i64>()
}

#[test]
fn consistency() {
use arrow2::datatypes::DataType::*;
use arrow2::datatypes::TimeUnit;
let datatypes = vec![
Null,
Boolean,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Timestamp(TimeUnit::Second, None),
Timestamp(TimeUnit::Millisecond, None),
Timestamp(TimeUnit::Microsecond, None),
Timestamp(TimeUnit::Nanosecond, None),
Time64(TimeUnit::Microsecond),
Time64(TimeUnit::Nanosecond),
Date32,
Time32(TimeUnit::Second),
Time32(TimeUnit::Millisecond),
Date64,
Utf8,
LargeUtf8,
Binary,
LargeBinary,
Duration(TimeUnit::Second),
Duration(TimeUnit::Millisecond),
Duration(TimeUnit::Microsecond),
Duration(TimeUnit::Nanosecond),
];

datatypes.into_iter().for_each(|d1| {
let array = new_null_array(d1.clone(), 10);
if can_lower(&d1) {
assert!(lower(array.as_ref()).is_ok());
} else {
assert!(lower(array.as_ref()).is_err());
}
});
}
2 changes: 2 additions & 0 deletions tests/it/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ mod length;
mod like;
#[cfg(feature = "compute_limit")]
mod limit;
#[cfg(feature = "compute_lower")]
mod lower;
#[cfg(feature = "compute_merge_sort")]
mod merge_sort;
#[cfg(feature = "compute_partition")]
Expand Down

0 comments on commit 998882e

Please sign in to comment.