Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Search: Add prefix matching support #412

Merged
merged 4 commits into from
Dec 2, 2015
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions synapse/storage/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def search_msgs(self, room_ids, search_term, keys):
list of dicts
"""
clauses = []
args = []

search_query = search_query = _parse_query(self.database_engine, search_term)

args = [search_query]

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand All @@ -162,7 +165,7 @@ def search_msgs(self, room_ids, search_term, keys):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" FROM to_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
Expand All @@ -183,7 +186,7 @@ def search_msgs(self, room_ids, search_term, keys):
sql += " ORDER BY rank DESC LIMIT 500"

results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
"search_msgs", self.cursor_to_dict, sql, *args
)

results = filter(lambda row: row["room_id"] in room_ids, results)
Expand All @@ -197,7 +200,7 @@ def search_msgs(self, room_ids, search_term, keys):

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_term, events)
highlights = yield self._find_highlights_in_postgres(search_query, events)

defer.returnValue({
"results": [
Expand Down Expand Up @@ -226,7 +229,10 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None
list of dicts
"""
clauses = []
args = [search_term]

search_query = search_query = _parse_query(self.database_engine, search_term)

args = [search_query]

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand Down Expand Up @@ -263,7 +269,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" FROM to_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND "
)
Expand Down Expand Up @@ -313,7 +319,7 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None

highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_term, events)
highlights = yield self._find_highlights_in_postgres(search_query, events)

defer.returnValue({
"results": [
Expand All @@ -330,15 +336,15 @@ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None
"highlights": highlights,
})

def _find_highlights_in_postgres(self, search_term, events):
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
This is used to give a list of words that clients can match against to
highlight the matching parts.
Args:
search_term (str)
search_query (str)
events (list): A list of events
Returns:
Expand Down Expand Up @@ -370,14 +376,14 @@ def f(txn):
while stop_sel in value:
stop_sel += ">"

query = "SELECT ts_headline(?, plainto_tsquery('english', ?), %s)" % (
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
_to_postgres_options({
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
})
)
txn.execute(query, (value, search_term,))
txn.execute(query, (value, search_query,))
headline, = txn.fetchall()[0]

# Now we need to pick the possible highlights out of the haedline
Expand All @@ -399,3 +405,22 @@ def _to_postgres_options(options_dict):
return "'%s'" % (
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)


def _parse_query(database_engine, search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""

# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)

if isinstance(database_engine, PostgresEngine):
return " & ".join(result + ":*" for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join(result + "*" for result in results)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")