diff --git a/vastdb/_internal.py b/vastdb/_internal.py index 272cec8..ae54e14 100644 --- a/vastdb/_internal.py +++ b/vastdb/_internal.py @@ -34,7 +34,7 @@ Or, ) from ibis.expr.operations.relations import Field -from ibis.expr.operations.strings import StringContains +from ibis.expr.operations.strings import StartsWith, StringContains from ibis.expr.operations.structs import StructField import vast_flatbuf.org.apache.arrow.computeir.flatbuf.BinaryLiteral as fb_binary_lit @@ -103,7 +103,7 @@ from vast_flatbuf.tabular.ListSchemasResponse import ListSchemasResponse as list_schemas from vast_flatbuf.tabular.ListTablesResponse import ListTablesResponse as list_tables -from . import errors +from . import errors, util from .config import BackoffConfig UINT64_MAX = 18446744073709551615 @@ -168,6 +168,7 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'): IsNull: self.build_is_null, Not: self.build_is_not_null, StringContains: self.build_match_substring, + StartsWith: self.build_starts_with, Between: self.build_between, } @@ -207,6 +208,14 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'): elif builder_func == self.build_between: column, lower, upper = inner_op.args literals = (None,) + elif builder_func == self.build_starts_with: + column, prefix = inner_op.args + literals = (None,) + if prefix.value: + lower_bytes, upper_bytes = util.prefix_to_range(prefix.value) + else: + # `col.starts_with('')` is equivalent to `col IS NOT NULL` + builder_func = self.build_is_not_null else: column, arg = inner_op.args if isinstance(arg, tuple): @@ -249,6 +258,9 @@ def serialize(self, builder: 'flatbuffers.builder.Builder'): if builder_func == self.build_between: args_offsets.append(self.build_literal(field=node.field, value=lower.value)) args_offsets.append(self.build_literal(field=node.field, value=upper.value)) + if builder_func == self.build_starts_with: + args_offsets.append(self.build_literal(field=node.field, value=lower_bytes)) + args_offsets.append(self.build_literal(field=node.field, value=upper_bytes)) inner_offsets.append(builder_func(*args_offsets)) @@ -550,6 +562,13 @@ def build_between(self, column: int, lower: int, upper: int): ] return self.build_and(offsets) + def build_starts_with(self, column: int, lower: int, upper: int): + offsets = [ + self.build_greater_equal(column, lower), + self.build_less(column, upper), + ] + return self.build_and(offsets) + class FieldNodesState: def __init__(self) -> None: diff --git a/vastdb/tests/test_tables.py b/vastdb/tests/test_tables.py index 9d77da5..7fc7bf9 100644 --- a/vastdb/tests/test_tables.py +++ b/vastdb/tests/test_tables.py @@ -837,3 +837,48 @@ def test_catalog_snapshots_select(session, clean_bucket_name): rows = t.select().read_all() if not rows: raise NotReady + + +def test_starts_with(session, clean_bucket_name): + columns = pa.schema([ + ('s', pa.utf8()), + ('i', pa.int16()), + ]) + + expected = pa.table(schema=columns, data=[ + ['a', 'ab', 'abc', None, 'abd', 'α', '', 'b'], + [0, 1, 2, 3, 4, 5, 6, 7], + ]) + + with prepare_data(session, clean_bucket_name, 's', 't', expected) as table: + def select(prefix): + res = table.select(predicate=table['s'].startswith(prefix)).read_all() + return res.to_pydict() + + assert select('')['s'] == ['a', 'ab', 'abc', 'abd', 'α', '', 'b'] + assert select('a')['s'] == ['a', 'ab', 'abc', 'abd'] + assert select('b')['s'] == ['b'] + assert select('ab')['s'] == ['ab', 'abc', 'abd'] + assert select('abc')['s'] == ['abc'] + assert select('α')['s'] == ['α'] + + res = table.select(predicate=(table['s'].startswith('ab') | (table['s'].isnull()))).read_all() + assert res.to_pydict()['s'] == ['ab', 'abc', None, 'abd'] + + res = table.select(predicate=(table['s'].startswith('ab') | (table['s'] == 'b'))).read_all() + assert res.to_pydict()['s'] == ['ab', 'abc', 'abd', 'b'] + + res = table.select(predicate=((table['s'] == 'b') | table['s'].startswith('ab'))).read_all() + assert res.to_pydict()['s'] == ['ab', 'abc', 'abd', 'b'] + + res = table.select(predicate=(table['s'].startswith('ab') & (table['s'] != 'abc'))).read_all() + assert res.to_pydict()['s'] == ['ab', 'abd'] + + res = table.select(predicate=((table['s'] != 'abc') & table['s'].startswith('ab'))).read_all() + assert res.to_pydict()['s'] == ['ab', 'abd'] + + res = table.select(predicate=((table['i'] > 3) & table['s'].startswith('ab'))).read_all() + assert res.to_pydict() == {'i': [4], 's': ['abd']} + + res = table.select(predicate=(table['s'].startswith('ab')) & (table['i'] > 3)).read_all() + assert res.to_pydict() == {'i': [4], 's': ['abd']} diff --git a/vastdb/tests/test_util.py b/vastdb/tests/test_util.py index 679de02..660c73f 100644 --- a/vastdb/tests/test_util.py +++ b/vastdb/tests/test_util.py @@ -43,3 +43,16 @@ def _parse(bufs): for buf in bufs: with pa.ipc.open_stream(buf) as reader: yield from reader + + +def test_prefix(): + assert util.prefix_to_range('a') == (b'a', b'b') + assert util.prefix_to_range('abc') == (b'abc', b'abd') + assert util.prefix_to_range('abc\x00') == (b'abc\x00', b'abc\x01') + assert util.prefix_to_range('abc\x7f') == (b'abc\x7f', b'abc\x80') + assert util.prefix_to_range('/a/b/c') == (b'/a/b/c', b'/a/b/d') + assert util.prefix_to_range('/123α') == (b'/123\xce\xb1', b'/123\xce\xb2') + assert util.prefix_to_range('/123αA') == (b'/123\xce\xb1A', b'/123\xce\xb1B') + assert util.prefix_to_range('\U0010ffff') == (b'\xf4\x8f\xbf\xbf', b'\xf4\x8f\xbf\xc0') # max unicode codepoint + with pytest.raises(AssertionError): + util.prefix_to_range('') diff --git a/vastdb/util.py b/vastdb/util.py index f38de45..b702e72 100644 --- a/vastdb/util.py +++ b/vastdb/util.py @@ -157,3 +157,13 @@ def sort_record_batch_if_needed(record_batch, sort_column): return record_batch.sort_by(sort_column) else: return record_batch + + +def prefix_to_range(prefix: str): + """Compute (L, U) such that `s.starts_with(prefix)` is equivalent to `L <= s.encode() < H`.""" + assert prefix, "Empty prefix is not convertible to range predicate" + lower = prefix.encode() + upper = bytearray(lower) + # https://en.wikipedia.org/wiki/UTF-8#Encoding guarantees that the last byte is not 0xFF + upper[-1] = upper[-1] + 1 + return (lower, bytes(upper))