diff --git a/changelog.d/14311.feature b/changelog.d/14311.feature new file mode 100644 index 000000000000..94c8a83212d1 --- /dev/null +++ b/changelog.d/14311.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 594b935614f7..e9588d175518 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -80,11 +80,11 @@ def store_search_entries_txn( if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) + sql = """ + INSERT INTO event_search + (event_id, room_id, key, vector, stream_ordering, origin_server_ts) + VALUES (?,?,?,to_tsvector('english', ?),?,?) + """ args1 = ( ( @@ -101,20 +101,20 @@ def store_search_entries_txn( txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args2 = ( - ( - entry.event_id, - entry.room_id, - entry.key, - _clean_value_for_search(entry.value), - ) - for entry in entries + self.db_pool.simple_insert_many_txn( + txn, + table="event_search", + keys=("event_id", "room_id", "key", "value"), + values=( + ( + entry.event_id, + entry.room_id, + entry.key, + _clean_value_for_search(entry.value), + ) + for entry in entries + ), ) - txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -162,15 +162,17 @@ async def _background_reindex_search( TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn: LoggingTransaction) -> int: - sql = ( - "SELECT stream_ordering, event_id, room_id, type, json, " - " origin_server_ts FROM events" - " JOIN event_json USING (room_id, event_id)" - " WHERE ? <= stream_ordering AND stream_ordering < ?" - " AND (%s)" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) + sql = """ + SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts + FROM events + JOIN event_json USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND (%s) + ORDER BY stream_ordering DESC + LIMIT ? + """ % ( + " OR ".join("type = '%s'" % (t,) for t in TYPES), + ) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -284,8 +286,10 @@ def create_index(conn: LoggingDatabaseConnection) -> None: try: c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx" - " ON event_search USING GIN (vector)" + """ + CREATE INDEX CONCURRENTLY event_search_fts_idx + ON event_search USING GIN (vector) + """ ) except psycopg2.ProgrammingError as e: logger.warning( @@ -323,12 +327,16 @@ def create_index(conn: LoggingDatabaseConnection) -> None: # We create with NULLS FIRST so that when we search *backwards* # we get the ones with non null origin_server_ts *first* c.execute( - "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(" - "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_room_order + ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) c.execute( - "CREATE INDEX CONCURRENTLY event_search_order ON event_search(" - "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)" + """ + CREATE INDEX CONCURRENTLY event_search_order + ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST) + """ ) conn.set_session(autocommit=False) @@ -345,14 +353,14 @@ def create_index(conn: LoggingDatabaseConnection) -> None: ) def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: - sql = ( - "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," - " origin_server_ts = e.origin_server_ts" - " FROM events AS e" - " WHERE e.event_id = es.event_id" - " AND ? <= e.stream_ordering AND e.stream_ordering < ?" - " RETURNING es.stream_ordering" - ) + sql = """ + UPDATE event_search AS es + SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts + FROM events AS e + WHERE e.event_id = es.event_id + AND ? <= e.stream_ordering AND e.stream_ordering < ? + RETURNING es.stream_ordering + """ min_stream_id = max_stream_id - batch_size txn.execute(sql, (min_stream_id, max_stream_id)) @@ -456,33 +464,33 @@ async def search_msgs( if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," - " room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_for_sqlite(search_term) - sql = ( - "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" - " FROM event_search" - " WHERE value MATCH ?" - ) + sql = """ + SELECT rank(matchinfo(event_search)) as rank, room_id, event_id + FROM event_search + WHERE value MATCH ? + """ args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ?" - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -588,26 +596,27 @@ async def search_rooms( raise SynapseError(400, "Invalid pagination token") clauses.append( - "(origin_server_ts < ?" - " OR (origin_server_ts = ? AND stream_ordering < ?))" + """ + (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?)) + """ ) args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): search_query = search_term tsquery_func = self.database_engine.tsquery_func - sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," - " origin_server_ts, stream_ordering, room_id, event_id" - " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + sql = f""" + SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + origin_server_ts, stream_ordering, room_id, event_id + FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ args = [search_query, search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " - ) + count_sql = f""" + SELECT room_id, count(*) as count FROM event_search + WHERE vector @@ {tsquery_func}('english', ?) AND + """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -619,23 +628,24 @@ async def search_rooms( # in the events table to get the topological ordering. We need # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index - sql = ( - "SELECT rank(matchinfo) as rank, room_id, event_id," - " origin_server_ts, stream_ordering" - " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" - " FROM event_search" - " WHERE value MATCH ?" - " )" - " CROSS JOIN events USING (event_id)" - " WHERE " + sql = """ + SELECT + rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering + FROM ( + SELECT key, event_id, matchinfo(event_search) as matchinfo + FROM event_search + WHERE value MATCH ? ) + CROSS JOIN events USING (event_id) + WHERE + """ search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args - count_sql = ( - "SELECT room_id, count(*) as count FROM event_search" - " WHERE value MATCH ? AND " - ) + count_sql = """ + SELECT room_id, count(*) as count FROM event_search + WHERE value MATCH ? AND + """ count_args = [search_query] + count_args else: # This should be unreachable. @@ -647,10 +657,10 @@ async def search_rooms( # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. if isinstance(self.database_engine, PostgresEngine): - sql += ( - " ORDER BY origin_server_ts DESC NULLS LAST," - " stream_ordering DESC NULLS LAST LIMIT ?" - ) + sql += """ + ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST + LIMIT ? + """ elif isinstance(self.database_engine, Sqlite3Engine): sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" else: