Skip to content

Commit

Permalink
Support 'starts_with' pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
rz-vastdata committed Aug 28, 2024
1 parent 7de597a commit 63a3161
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 2 deletions.
23 changes: 21 additions & 2 deletions vastdb/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Or,
)
from ibis.expr.operations.relations import Field
from ibis.expr.operations.strings import StringContains
from ibis.expr.operations.strings import StartsWith, StringContains
from ibis.expr.operations.structs import StructField

import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BinaryLiteral as fb_binary_lit
Expand Down Expand Up @@ -103,7 +103,7 @@
from vast_flatbuf.tabular.ListSchemasResponse import ListSchemasResponse as list_schemas
from vast_flatbuf.tabular.ListTablesResponse import ListTablesResponse as list_tables

from . import errors
from . import errors, util
from .config import BackoffConfig

UINT64_MAX = 18446744073709551615
Expand Down Expand Up @@ -168,6 +168,7 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'):
IsNull: self.build_is_null,
Not: self.build_is_not_null,
StringContains: self.build_match_substring,
StartsWith: self.build_starts_with,
Between: self.build_between,
}

Expand Down Expand Up @@ -207,6 +208,14 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'):
elif builder_func == self.build_between:
column, lower, upper = inner_op.args
literals = (None,)
elif builder_func == self.build_starts_with:
column, prefix = inner_op.args
literals = (None,)
if prefix.value:
lower_bytes, upper_bytes = util.prefix_to_range(prefix.value)
else:
# `col.starts_with('')` is equivalent to `col IS NOT NULL`
builder_func = self.build_is_not_null
else:
column, arg = inner_op.args
if isinstance(arg, tuple):
Expand Down Expand Up @@ -249,6 +258,9 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'):
if builder_func == self.build_between:
args_offsets.append(self.build_literal(field=node.field, value=lower.value))
args_offsets.append(self.build_literal(field=node.field, value=upper.value))
if builder_func == self.build_starts_with:
args_offsets.append(self.build_literal(field=node.field, value=lower_bytes))
args_offsets.append(self.build_literal(field=node.field, value=upper_bytes))

inner_offsets.append(builder_func(*args_offsets))

Expand Down Expand Up @@ -550,6 +562,13 @@ def build_between(self, column: int, lower: int, upper: int):
]
return self.build_and(offsets)

def build_starts_with(self, column: int, lower: int, upper: int):
offsets = [
self.build_greater_equal(column, lower),
self.build_less(column, upper),
]
return self.build_and(offsets)


class FieldNodesState:
def __init__(self) -> None:
Expand Down
45 changes: 45 additions & 0 deletions vastdb/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,48 @@ def test_catalog_snapshots_select(session, clean_bucket_name):
rows = t.select().read_all()
if not rows:
raise NotReady


def test_starts_with(session, clean_bucket_name):
columns = pa.schema([
('s', pa.utf8()),
('i', pa.int16()),
])

expected = pa.table(schema=columns, data=[
['a', 'ab', 'abc', None, 'abd', 'α', '', 'b'],
[0, 1, 2, 3, 4, 5, 6, 7],
])

with prepare_data(session, clean_bucket_name, 's', 't', expected) as table:
def select(prefix):
res = table.select(predicate=table['s'].startswith(prefix)).read_all()
return res.to_pydict()

assert select('')['s'] == ['a', 'ab', 'abc', 'abd', 'α', '', 'b']
assert select('a')['s'] == ['a', 'ab', 'abc', 'abd']
assert select('b')['s'] == ['b']
assert select('ab')['s'] == ['ab', 'abc', 'abd']
assert select('abc')['s'] == ['abc']
assert select('α')['s'] == ['α']

res = table.select(predicate=(table['s'].startswith('ab') | (table['s'].isnull()))).read_all()
assert res.to_pydict()['s'] == ['ab', 'abc', None, 'abd']

res = table.select(predicate=(table['s'].startswith('ab') | (table['s'] == 'b'))).read_all()
assert res.to_pydict()['s'] == ['ab', 'abc', 'abd', 'b']

res = table.select(predicate=((table['s'] == 'b') | table['s'].startswith('ab'))).read_all()
assert res.to_pydict()['s'] == ['ab', 'abc', 'abd', 'b']

res = table.select(predicate=(table['s'].startswith('ab') & (table['s'] != 'abc'))).read_all()
assert res.to_pydict()['s'] == ['ab', 'abd']

res = table.select(predicate=((table['s'] != 'abc') & table['s'].startswith('ab'))).read_all()
assert res.to_pydict()['s'] == ['ab', 'abd']

res = table.select(predicate=((table['i'] > 3) & table['s'].startswith('ab'))).read_all()
assert res.to_pydict() == {'i': [4], 's': ['abd']}

res = table.select(predicate=(table['s'].startswith('ab')) & (table['i'] > 3)).read_all()
assert res.to_pydict() == {'i': [4], 's': ['abd']}
13 changes: 13 additions & 0 deletions vastdb/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ def _parse(bufs):
for buf in bufs:
with pa.ipc.open_stream(buf) as reader:
yield from reader


def test_prefix():
assert util.prefix_to_range('a') == (b'a', b'b')
assert util.prefix_to_range('abc') == (b'abc', b'abd')
assert util.prefix_to_range('abc\x00') == (b'abc\x00', b'abc\x01')
assert util.prefix_to_range('abc\x7f') == (b'abc\x7f', b'abc\x80')
assert util.prefix_to_range('/a/b/c') == (b'/a/b/c', b'/a/b/d')
assert util.prefix_to_range('/123α') == (b'/123\xce\xb1', b'/123\xce\xb2')
assert util.prefix_to_range('/123αA') == (b'/123\xce\xb1A', b'/123\xce\xb1B')
assert util.prefix_to_range('\U0010ffff') == (b'\xf4\x8f\xbf\xbf', b'\xf4\x8f\xbf\xc0') # max unicode codepoint
with pytest.raises(AssertionError):
util.prefix_to_range('')
10 changes: 10 additions & 0 deletions vastdb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,13 @@ def sort_record_batch_if_needed(record_batch, sort_column):
return record_batch.sort_by(sort_column)
else:
return record_batch


def prefix_to_range(prefix: str):
"""Compute (L, U) such that `s.starts_with(prefix)` is equivalent to `L <= s.encode() < H`."""
assert prefix, "Empty prefix is not convertible to range predicate"
lower = prefix.encode()
upper = bytearray(lower)
# https://en.wikipedia.org/wiki/UTF-8#Encoding guarantees that the last byte is not 0xFF
upper[-1] = upper[-1] + 1
return (lower, bytes(upper))

0 comments on commit 63a3161

Please sign in to comment.