Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: substr() on StringView column's behavior is inconsistent with the old version #12383

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 63 additions & 56 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,37 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

// Return the exact byte index for [start, end), set count to -1 to ignore count
fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) {
// Convert the given `start` and `count` to valid byte indices within `input` string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍 for the comments

// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
// `start` is 1-based, if `count` is not provided count to the end of the string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
let start = start - 1;
let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
};
let count_to_end = count.is_some();

let start = start.clamp(0, input.len() as i64) as usize;
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;

let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
if char_cnt == start {
st = byte_cnt;
if count != -1 {
if count_to_end {
start_counting = true;
} else {
break;
Expand All @@ -153,20 +175,15 @@ fn make_and_append_view(
start: u32,
) {
let substr_len = substr.len();
if substr_len == 0 {
null_builder.append_null();
views_buffer.push(0);
let sub_view = if substr_len > 12 {
let view = ByteView::from(*raw);
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
} else {
let sub_view = if substr_len > 12 {
let view = ByteView::from(*raw);
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
} else {
// inline value does not need block id or offset
make_view(substr.as_bytes(), 0, 0)
};
views_buffer.push(sub_view);
null_builder.append_non_null();
}
// inline value does not need block id or offset
make_view(substr.as_bytes(), 0, 0)
Comment on lines +178 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the inline check here is duplicated with the one within make_view?
I guess inline len(12) may be the internal logic in string view, and we should not expose and use it in datafusion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say the point we are checking size here is to help with the operation of "modifying the views directly". But I agree that maybe we could add more API upstream in arrow to hide the logic.

Copy link
Contributor

@Rachelint Rachelint Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say the point we are checking size here is to help with the operation of "modifying the views directly". But I agree that maybe we could add more API upstream in arrow to hide the logic.

Maybe we can just let it run as same as the > 12 branch? ByteView::from(*raw) seems cheap, maybe it is not more expensive than a if else? I am not sure.

Copy link
Contributor

@Kev1n8 Kev1n8 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean. I think ByteVIew is not equal to inlined_view. The former one indicates the str whose size > 12, while the inlined str is represented by a simple u128.

Maybe we can just let it run as same as the > 12 branch?

It might actually work since from<u128> for ByteView would load meaningless(which are actually the str) prefix, buffer_index, offset, which does no harm. But I think it would be ambiguous. It's a good suggestion tho.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean. I think ByteVIew is not equal to inlined_view. The former one indicates the str whose size > 12, while the inlined str is represented by a simple u128.

Maybe we can just let it run as same as the > 12 branch?

It might actually work since from<u128> for ByteView would load meaningless prefix, buffer_index, offset, which does no harm. But I think it would be ambiguous. It's a good suggestion tho.

Ok, it is indeed more clear to distinguish inlined and not inlned.

};
views_buffer.push(sub_view);
null_builder.append_non_null();
}

// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
Expand All @@ -180,32 +197,26 @@ fn string_view_substr(

let start_array = as_int64_array(&args[0])?;

// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
match args.len() {
1 => {
for (idx, (raw, start)) in string_view_array
.views()
for ((str_opt, raw_view), start_opt) in string_view_array
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old implementation will treat empty string in view as NULL (within make_and_append_view()), so it handled NULL case correctly but handled the empty string case wrong.
Since it's not possible to use view to check if the current element is null, so here we also have to iterate on the string array to check whether each element is NULL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another potential thing that might be faster is to iterate on the null array directly (rather than the Option<&str>) so the code doesn't have to figure out where the &str comes from

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for correcting my mistakes. I really appreciate it! I'll be more careful with my code next time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another potential thing that might be faster is to iterate on the null array directly (rather than the Option<&str>) so the code doesn't have to figure out where the &str comes from

Great point, there are also other potential optimizations like specializing not null columns, scalar literal arguments, and ASCII case. So I prefer only to make a simple and correct implementation in this PR, and evaluate those optimizations as a future task

.iter()
.zip(string_view_array.views().iter())
.zip(start_array.iter())
.enumerate()
{
if let Some(start) = start {
let start = (start - 1).max(0) as usize;

// Safety:
// idx is always smaller or equal to string_view_array.views.len()
unsafe {
let str = string_view_array.value_unchecked(idx);
let (start, end) = get_true_start_end(str, start, -1);
let substr = &str[start..end];
if let (Some(str), Some(start)) = (str_opt, start_opt) {
let (start, end) = get_true_start_end(str, start, None);
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw,
substr,
start as u32,
);
}
make_and_append_view(
&mut views_buf,
&mut null_builder,
raw_view,
substr,
start as u32,
);
} else {
null_builder.append_null();
views_buf.push(0);
Expand All @@ -214,35 +225,31 @@ fn string_view_substr(
}
2 => {
let count_array = as_int64_array(&args[1])?;
for (idx, ((raw, start), count)) in string_view_array
.views()
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
.iter()
.zip(string_view_array.views().iter())
.zip(start_array.iter())
.zip(count_array.iter())
.enumerate()
{
if let (Some(start), Some(count)) = (start, count) {
let start = (start - 1).max(0) as usize;
if let (Some(str), Some(start), Some(count)) =
(str_opt, start_opt, count_opt)
{
if count < 0 {
return exec_err!(
"negative substring length not allowed: substr(<str>, {start}, {count})"
);
} else {
// Safety:
// idx is always smaller or equal to string_view_array.views.len()
unsafe {
let str = string_view_array.value_unchecked(idx);
let (start, end) = get_true_start_end(str, start, count);
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw,
substr,
start as u32,
);
}
let (start, end) =
get_true_start_end(str, start, Some(count as u64));
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw_view,
substr,
start as u32,
);
}
} else {
null_builder.append_null();
Expand Down
124 changes: 115 additions & 9 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,6 @@ SELECT substr('alphabet', 30)
----
(empty)

query T
SELECT substr('alphabet', CAST(NULL AS int))
----
NULL

query T
SELECT substr('alphabet', 3, 2)
----
Expand All @@ -487,22 +482,133 @@ SELECT substr('alphabet', 3, 20)
----
phabet

# test range ouside of string length
query TTTTTTTTTTTT
SELECT
substr('hi🌏', 1, 3),
substr('hi🌏', 1, 4),
substr('hi🌏', 1, 100),
substr('hi🌏', 0, 1),
substr('hi🌏', 0, 2),
substr('hi🌏', 0, 4),
substr('hi🌏', 0, 5),
substr('hi🌏', -10, 100),
substr('hi🌏', -10, 12),
substr('hi🌏', -10, 5),
substr('hi🌏', 10, 0),
substr('hi🌏', 10, 10);
----
hi🌏 hi🌏 hi🌏 (empty) h hi🌏 hi🌏 hi🌏 h (empty) (empty) (empty)

query TTTTTTTTTTTT
SELECT
substr('', 1, 3),
substr('', 1, 4),
substr('', 1, 100),
substr('', 0, 1),
substr('', 0, 2),
substr('', 0, 4),
substr('', 0, 5),
substr('', -10, 100),
substr('', -10, 12),
substr('', -10, 5),
substr('', 10, 0),
substr('', 10, 10);
----
(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty)

# Nulls
query TTTTTTTTTT
SELECT
substr('alphabet', NULL),
substr(NULL, 1),
substr(NULL, NULL),
substr('alphabet', CAST(NULL AS int), -20),
substr('alphabet', 3, CAST(NULL AS int)),
substr(NULL, 3, -4),
substr(NULL, NULL, 4),
substr(NULL, 1, NULL),
substr('', NULL, NULL),
substr(NULL, NULL, NULL);
----
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

query T
SELECT substr('alphabet', CAST(NULL AS int), 20)
SELECT substr('Hello🌏世界', 5)
----
NULL
o🌏世界

query T
SELECT substr('alphabet', 3, CAST(NULL AS int))
SELECT substr('Hello🌏世界', 5, 3)
----
NULL
o🌏世

statement ok
create table test_substr (
c1 VARCHAR
) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL);

statement ok
create table test_substr_stringview as
select c1 as c1, arrow_cast(c1, 'Utf8View') as c1_view from test_substr;

# `substr()` on `StringViewArray`'s implementation operates directly on view's
# logical pointers, so check it's consistent with `StringArray`
query BBBBBBBBBBBBBB
select
substr(c1, 1) = substr(c1_view, 1),
substr(c1, 3) = substr(c1_view, 3),
substr(c1, 100) = substr(c1_view, 100),
substr(c1, -1) = substr(c1_view, -1),
substr(c1, 0, 0) = substr(c1_view, 0, 0),
substr(c1, -1, 2) = substr(c1_view, -1, 2),
substr(c1, -2, 10) = substr(c1_view, -2, 10),
substr(c1, -100, 200) = substr(c1_view, -100, 200),
substr(c1, -10, 10) = substr(c1_view, -10, 10),
substr(c1, -100, 10) = substr(c1_view, -100, 10),
substr(c1, 1, 100) = substr(c1_view, 1, 100),
substr(c1, 5, 3) = substr(c1_view, 5, 3),
substr(c1, 100, 200) = substr(c1_view, 100, 200),
substr(c1, 8, 0) = substr(c1_view, 8, 0)
from test_substr_stringview;
----
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

# Check for non-ASCII strings
query TT
select substr(c1_view, 1), substr(c1_view, 5,3) from test_substr_stringview;
----
foo (empty)
hello🌏世界 o🌏世
💩 (empty)
ThisIsAVeryLongASCIIString IsA
(empty) (empty)
NULL NULL

statement ok
drop table test_substr;

statement ok
drop table test_substr_stringview;


statement error The SUBSTR function can only accept strings, but got Int64.
SELECT substr(1, 3)

statement error The SUBSTR function can only accept strings, but got Int64.
SELECT substr(1, 3, 4)

statement error Execution error: negative substring length not allowed
select substr(arrow_cast('foo', 'Utf8View'), 1, -1);

statement error Execution error: negative substring length not allowed
select substr('', 1, -1);

query T
SELECT translate('12345', '143', 'ax')
----
Expand Down