Skip to content

Commit

Permalink
fix #7390: push down limit filtering to adapter (#7545)
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke authored May 9, 2023
1 parent 881437e commit 078a836
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 12 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230508-060926.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: push down limit filtering to adapter
time: 2023-05-08T06:09:26.455524-07:00
custom:
Author: aranke
Issue: "7390"
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def connection_for(self, node: ResultNode) -> Iterator[None]:

@available.parse(lambda *a, **k: ("", empty_table()))
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.
Expand All @@ -283,10 +283,11 @@ def execute(
:param bool auto_begin: If set, and dbt is not currently inside a
transaction, automatically begin one.
:param bool fetch: If set, fetch results.
:param Optional[int] limit: If set, only fetch n number of rows
:return: A tuple of the query status and results (empty if fetch=False).
:rtype: Tuple[AdapterResponse, agate.Table]
"""
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit)

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ def process_results(
return [dict(zip(column_names, row)) for row in rows]

@classmethod
def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table:
data: List[Any] = []
column_names: List[str] = []

if cursor.description is not None:
column_names = [col[0] for col in cursor.description]
rows = cursor.fetchall()
if limit:
rows = cursor.fetchmany(limit)
else:
rows = cursor.fetchall()
data = cls.process_results(column_names, rows)

return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
Expand All @@ -138,13 +141,13 @@ def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor)
table = self.get_result_from_cursor(cursor, limit)
else:
table = dbt.clients.agate_helper.empty_table()
return response, table
Expand Down
11 changes: 5 additions & 6 deletions core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def __init__(self, config, adapter, node, node_index, num_nodes):

def execute(self, compiled_node, manifest):
start_time = time.time()

# Allow passing in -1 (or any negative number) to get all rows
limit = None if self.config.args.limit < 0 else self.config.args.limit

adapter_response, execute_result = self.adapter.execute(
compiled_node.compiled_code, fetch=True
compiled_node.compiled_code, fetch=True, limit=limit
)
end_time = time.time()

Expand Down Expand Up @@ -66,13 +70,8 @@ def task_end_messages(self, results):
)

for result in matched_results:
# Allow passing in -1 (or any negative number) to get all rows
table = result.agate_table

if self.args.limit >= 0:
table = table.limit(self.args.limit)
result.agate_table = table

# Hack to get Agate table output as string
output = io.StringIO()
if self.args.output == "json":
Expand Down
14 changes: 14 additions & 0 deletions tests/functional/show/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ def test_second_ephemeral_model(self, project):
)
assert "col_hundo" in log_output

@pytest.mark.parametrize(
"args,expected",
[
([], 5), # default limit
(["--limit", 3], 3), # fetch 3 rows
(["--limit", -1], 7), # fetch all rows
],
)
def test_limit(self, project, args, expected):
run_dbt(["build"])
dbt_args = ["show", "--inline", models__second_ephemeral_model, *args]
results, log_output = run_dbt_and_capture(dbt_args)
assert len(results.results[0].agate_table) == expected

def test_seed(self, project):
(results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"])
assert "Previewing node 'sample_seed'" in log_output
Expand Down

0 comments on commit 078a836

Please sign in to comment.