Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Search: Add prefix matching support
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Dec 2, 2015
1 parent 27c5e1b commit 477da77
Showing 1 changed file with 32 additions and 5 deletions.
37 changes: 32 additions & 5 deletions synapse/storage/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def search_msgs(self, room_ids, search_term, keys):
list of dicts
"""
clauses = []
args = []
if isinstance(self.database_engine, PostgresEngine):
args = [_postgres_parse_query(search_term)]
else:
args = [_sqlite_parse_query(search_term)]

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand All @@ -162,7 +165,7 @@ def search_msgs(self, room_ids, search_term, keys):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" FROM to_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
Expand All @@ -183,7 +186,7 @@ def search_msgs(self, room_ids, search_term, keys):
sql += " ORDER BY rank DESC LIMIT 500"

results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
"search_msgs", self.cursor_to_dict, sql, *args
)

results = filter(lambda row: row["room_id"] in room_ids, results)
Expand Down Expand Up @@ -226,7 +229,11 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None
list of dicts
"""
clauses = []
args = [search_term]

if isinstance(self.database_engine, PostgresEngine):
args = [_postgres_parse_query(search_term)]
else:
args = [_sqlite_parse_query(search_term)]

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand Down Expand Up @@ -263,7 +270,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" FROM to_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND "
)
Expand Down Expand Up @@ -399,3 +406,23 @@ def _to_postgres_options(options_dict):
return "'%s'" % (
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)


def _postgres_parse_query(search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to `to_tsquery(..)` postgres func. We use this so that
we can add prefix matching, which isn't something `plainto_tsquery` supports.
"""
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)

return " & ".join(result + ":*" for result in results)


def _sqlite_parse_query(search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to sqlite `MATCH`. We use this so that we can do prefix
matching.
"""
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)

return " & ".join(result + "*" for result in results)

0 comments on commit 477da77

Please sign in to comment.