Skip to content

Commit

Permalink
Minor: refactor trim to clean up duplicated code (apache#8434)
Browse files Browse the repository at this point in the history
* refactor trim

* add fmt for TrimType

* fix closure

* update comment
  • Loading branch information
Weijun-H authored Dec 9, 2023
1 parent 182a37e commit 62ee8fb
Showing 1 changed file with 69 additions and 100 deletions.
169 changes: 69 additions & 100 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ use datafusion_common::{
};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use std::iter;
use std::sync::Arc;
use std::{
fmt::{Display, Formatter},
iter,
};
use uuid::Uuid;

/// applies a unary expression to `args[0]` that is expected to be downcastable to
Expand Down Expand Up @@ -133,53 +136,6 @@ pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| {
string.map(|string: &str| {
string.trim_start_matches(' ').trim_end_matches(' ')
})
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(None, _) => None,
(_, None) => None,
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(
string
.trim_start_matches(&chars[..])
.trim_end_matches(&chars[..]),
)
}
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"btrim was called with {other} arguments. It requires at least 1 and at most 2."
),
}
}

/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character.
/// chr(65) = 'A'
pub fn chr(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -346,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, |string| string.to_ascii_lowercase(), "lower")
}

/// Removes the longest string containing only characters in characters (a space by default) from the start of string.
/// ltrim('zzzytest', 'xyz') = 'test'
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
enum TrimType {
Left,
Right,
Both,
}

impl Display for TrimType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TrimType::Left => write!(f, "ltrim"),
TrimType::Right => write!(f, "rtrim"),
TrimType::Both => write!(f, "btrim"),
}
}
}

fn general_trim<T: OffsetSizeTrait>(
args: &[ArrayRef],
trim_type: TrimType,
) -> Result<ArrayRef> {
let func = match trim_type {
TrimType::Left => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_start_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Right => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Both => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(
str::trim_start_matches::<&[char]>(input, pattern.as_ref()),
pattern.as_ref(),
)
},
};

let string_array = as_generic_string_array::<T>(&args[0])?;

match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| string.map(|string: &str| string.trim_start_matches(' ')))
.map(|string| string.map(|string: &str| func(string, " ")))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(string.trim_start_matches(&chars[..]))
}
(Some(string), Some(characters)) => Some(func(string, characters)),
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"ltrim was called with {other} arguments. It requires at least 1 and at most 2."
),
other => {
internal_err!(
"{trim_type} was called with {other} arguments. It requires at least 1 and at most 2."
)
}
}
}

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Both)
}

/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed.
/// ltrim('zzzytest', 'xyz') = 'test'
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Left)
}

/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed.
/// rtrim('testxxzx', 'xyz') = 'test'
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Right)
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -422,44 +429,6 @@ pub fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Removes the longest string containing only characters in characters (a space by default) from the end of string.
/// rtrim('testxxzx', 'xyz') = 'test'
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| string.map(|string: &str| string.trim_end_matches(' ')))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(string.trim_end_matches(&chars[..]))
}
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"rtrim was called with {other} arguments. It requires at least 1 and at most 2."
),
}
}

/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down

0 comments on commit 62ee8fb

Please sign in to comment.