Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Clean up the database pagination code (#5007)
Browse files Browse the repository at this point in the history
* rewrite & simplify

* changelog

* cleanup potential sql injection
  • Loading branch information
hawkowl authored Apr 4, 2019
1 parent 616e6a1 commit a33a5ab
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/5007.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor synapse.storage._base._simple_select_list_paginate.
20 changes: 13 additions & 7 deletions synapse/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import logging
import time

from twisted.internet import defer

from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
Expand Down Expand Up @@ -453,6 +455,7 @@ def get_users(self):
desc="get_users",
)

@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
Expand All @@ -465,16 +468,19 @@ def get_users_paginate(self, order, start, limit):
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
is_guest = 0
i_start = (int)(start)
i_limit = (int)(limit)
return self.get_user_list_paginate(
users = yield self.runInteraction(
"get_users_paginate",
self._simple_select_list_paginate_txn,
table="users",
keyvalues={"is_guest": is_guest},
pagevalues=[order, i_limit, i_start],
keyvalues={"is_guest": False},
orderby=order,
start=start,
limit=limit,
retcols=["name", "password_hash", "is_guest", "admin"],
desc="get_users_paginate",
)
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
defer.returnValue(retval)

def search_users(self, term):
"""Function to search users list for one or more users with
Expand Down
110 changes: 50 additions & 60 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _simple_upsert(
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
Expand Down Expand Up @@ -627,7 +627,7 @@ def _simple_upsert(

# presumably we raced with another transaction: let's retry.
logger.warn(
"%s when upserting into %s; retrying: %s", e.__name__, table, e
"IntegrityError when upserting into %s; retrying: %s", table, e
)

def _simple_upsert_txn(
Expand Down Expand Up @@ -1398,21 +1398,31 @@ def get_cache_stream_token(self):
return 0

def _simple_select_list_paginate(
self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
self,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
desc="_simple_select_list_paginate",
):
"""Executes a SELECT query on the named table with start and limit,
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
Expand All @@ -1421,15 +1431,27 @@ def _simple_select_list_paginate(
self._simple_select_list_paginate_txn,
table,
keyvalues,
pagevalues,
orderby,
start,
limit,
retcols,
order_direction=order_direction,
)

@classmethod
def _simple_select_list_paginate_txn(
cls, txn, table, keyvalues, pagevalues, retcols
cls,
txn,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
):
"""Executes a SELECT query on the named table with start and limit,
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Expand All @@ -1439,64 +1461,32 @@ def _simple_select_list_paginate_txn(
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")

if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?",
)
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
table,
" ? ASC LIMIT ? OFFSET ?",
)
txn.execute(sql, pagevalues)
where_clause = ""

return cls.cursor_to_dict(txn)

@defer.inlineCallbacks
def get_user_list_paginate(
self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
):
"""Get a list of users from start row to a limit number of rows. This will
return a json object with users and total number of users in users list.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
users = yield self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols),
table,
keyvalues,
pagevalues,
retcols,
where_clause,
orderby,
order_direction,
)
count = yield self.runInteraction(desc, self.get_user_count_txn)
retval = {"users": users, "total": count}
defer.returnValue(retval)
txn.execute(sql, list(keyvalues.values()) + [limit, start])

return cls.cursor_to_dict(txn)

def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list.
Expand Down

0 comments on commit a33a5ab

Please sign in to comment.