diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 20a62d07ffd3..39f600f53ce9 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -140,7 +140,10 @@ def search_msgs(self, room_ids, search_term, keys): list of dicts """ clauses = [] - args = [] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [search_query] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -162,7 +165,7 @@ def search_msgs(self, room_ids, search_term, keys): if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " FROM to_tsquery('english', ?) as query, event_search" " WHERE vector @@ query" ) elif isinstance(self.database_engine, Sqlite3Engine): @@ -183,7 +186,7 @@ def search_msgs(self, room_ids, search_term, keys): sql += " ORDER BY rank DESC LIMIT 500" results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) + "search_msgs", self.cursor_to_dict, sql, *args ) results = filter(lambda row: row["room_id"] in room_ids, results) @@ -197,7 +200,7 @@ def search_msgs(self, room_ids, search_term, keys): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_term, events) + highlights = yield self._find_highlights_in_postgres(search_query, events) defer.returnValue({ "results": [ @@ -226,7 +229,10 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None list of dicts """ clauses = [] - args = [search_term] + + search_query = search_query = _parse_query(self.database_engine, search_term) + + args = [search_query] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -263,7 +269,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None sql = ( "SELECT ts_rank_cd(vector, query) as rank," " origin_server_ts, stream_ordering, room_id, event_id" - " FROM plainto_tsquery('english', ?) as query, event_search" + " FROM to_tsquery('english', ?) as query, event_search" " NATURAL JOIN events" " WHERE vector @@ query AND " ) @@ -313,7 +319,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_term, events) + highlights = yield self._find_highlights_in_postgres(search_query, events) defer.returnValue({ "results": [ @@ -330,7 +336,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None "highlights": highlights, }) - def _find_highlights_in_postgres(self, search_term, events): + def _find_highlights_in_postgres(self, search_query, events): """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -338,7 +344,7 @@ def _find_highlights_in_postgres(self, search_term, events): highlight the matching parts. Args: - search_term (str) + search_query (str) events (list): A list of events Returns: @@ -370,14 +376,14 @@ def f(txn): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, plainto_tsquery('english', ?), %s)" % ( + query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( _to_postgres_options({ "StartSel": start_sel, "StopSel": stop_sel, "MaxFragments": "50", }) ) - txn.execute(query, (value, search_term,)) + txn.execute(query, (value, search_query,)) headline, = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline @@ -399,3 +405,22 @@ def _to_postgres_options(options_dict): return "'%s'" % ( ",".join("%s=%s" % (k, v) for k, v in options_dict.items()), ) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join(result + ":*" for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join(result + "*" for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine")