From 477da77b463baa9c2326c763911ecf8b46e1d84b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Dec 2015 11:38:51 +0000 Subject: [PATCH] Search: Add prefix matching support --- synapse/storage/search.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 20a62d07ffd3..0dfd7b9fb5a6 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -140,7 +140,10 @@ def search_msgs(self, room_ids, search_term, keys): list of dicts """ clauses = [] - args = [] + if isinstance(self.database_engine, PostgresEngine): + args = [_postgres_parse_query(search_term)] + else: + args = [_sqlite_parse_query(search_term)] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -162,7 +165,7 @@ def search_msgs(self, room_ids, search_term, keys): if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " FROM to_tsquery('english', ?) as query, event_search" " WHERE vector @@ query" ) elif isinstance(self.database_engine, Sqlite3Engine): @@ -183,7 +186,7 @@ def search_msgs(self, room_ids, search_term, keys): sql += " ORDER BY rank DESC LIMIT 500" results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) + "search_msgs", self.cursor_to_dict, sql, *args ) results = filter(lambda row: row["room_id"] in room_ids, results) @@ -226,7 +229,11 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None list of dicts """ clauses = [] - args = [search_term] + + if isinstance(self.database_engine, PostgresEngine): + args = [_postgres_parse_query(search_term)] + else: + args = [_sqlite_parse_query(search_term)] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -263,7 +270,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None sql = ( "SELECT ts_rank_cd(vector, query) as rank," " origin_server_ts, stream_ordering, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " FROM to_tsquery('english', ?) as query, event_search" " NATURAL JOIN events" " WHERE vector @@ query AND " ) @@ -399,3 +406,23 @@ def _to_postgres_options(options_dict): return "'%s'" % ( ",".join("%s=%s" % (k, v) for k, v in options_dict.items()), ) + + +def _postgres_parse_query(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to `to_tsquery(..)` postgres func. We use this so that + we can add prefix matching, which isn't something `plainto_tsquery` supports. + """ + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + return " & ".join(result + ":*" for result in results) + + +def _sqlite_parse_query(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqlite `MATCH`. We use this so that we can do prefix + matching. + """ + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + return " & ".join(result + "*" for result in results)