Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: refactor trim to clean up duplicated code #8434

Merged
merged 4 commits into from
Dec 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 69 additions & 100 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ use datafusion_common::{
};
use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use std::iter;
use std::sync::Arc;
use std::{
fmt::{Display, Formatter},
iter,
};
use uuid::Uuid;

/// applies a unary expression to `args[0]` that is expected to be downcastable to
Expand Down Expand Up @@ -133,53 +136,6 @@ pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| {
string.map(|string: &str| {
string.trim_start_matches(' ').trim_end_matches(' ')
})
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(None, _) => None,
(_, None) => None,
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(
string
.trim_start_matches(&chars[..])
.trim_end_matches(&chars[..]),
)
}
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"btrim was called with {other} arguments. It requires at least 1 and at most 2."
),
}
}

/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character.
/// chr(65) = 'A'
pub fn chr(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -346,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, |string| string.to_ascii_lowercase(), "lower")
}

/// Removes the longest string containing only characters in characters (a space by default) from the start of string.
/// ltrim('zzzytest', 'xyz') = 'test'
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
enum TrimType {
Left,
Right,
Both,
}

impl Display for TrimType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TrimType::Left => write!(f, "ltrim"),
TrimType::Right => write!(f, "rtrim"),
TrimType::Both => write!(f, "btrim"),
}
}
}

fn general_trim<T: OffsetSizeTrait>(
args: &[ArrayRef],
trim_type: TrimType,
) -> Result<ArrayRef> {
let func = match trim_type {
TrimType::Left => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_start_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Right => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Both => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(
str::trim_start_matches::<&[char]>(input, pattern.as_ref()),
pattern.as_ref(),
)
},
};

let string_array = as_generic_string_array::<T>(&args[0])?;

match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| string.map(|string: &str| string.trim_start_matches(' ')))
.map(|string| string.map(|string: &str| func(string, " ")))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(string.trim_start_matches(&chars[..]))
}
(Some(string), Some(characters)) => Some(func(string, characters)),
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"ltrim was called with {other} arguments. It requires at least 1 and at most 2."
),
other => {
internal_err!(
"{trim_type} was called with {other} arguments. It requires at least 1 and at most 2."
)
}
}
}

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might help to explain what characters is here -- specifically args[1] if present

pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Both)
}

/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed.
/// ltrim('zzzytest', 'xyz') = 'test'
pub fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Left)
}

/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed.
/// rtrim('testxxzx', 'xyz') = 'test'
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Right)
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
pub fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down Expand Up @@ -422,44 +429,6 @@ pub fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Removes the longest string containing only characters in characters (a space by default) from the end of string.
/// rtrim('testxxzx', 'xyz') = 'test'
pub fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
1 => {
let string_array = as_generic_string_array::<T>(&args[0])?;

let result = string_array
.iter()
.map(|string| string.map(|string: &str| string.trim_end_matches(' ')))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
2 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(Some(string), Some(characters)) => {
let chars: Vec<char> = characters.chars().collect();
Some(string.trim_end_matches(&chars[..]))
}
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"rtrim was called with {other} arguments. It requires at least 1 and at most 2."
),
}
}

/// Splits string at occurrences of delimiter and returns the n'th field (counting from one).
/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def'
pub fn split_part<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down