diff --git a/.changes/unreleased/Under the Hood-20240516-105757.yaml b/.changes/unreleased/Under the Hood-20240516-105757.yaml new file mode 100644 index 00000000..a5a47c8c --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240516-105757.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Improve memory efficiency of process_results() +time: 2024-05-16T10:57:57.480672-04:00 +custom: + Author: peterallenwebb + Issue: "217" diff --git a/dbt/adapters/sql/connections.py b/dbt/adapters/sql/connections.py index 9adaafce..d8699fd3 100644 --- a/dbt/adapters/sql/connections.py +++ b/dbt/adapters/sql/connections.py @@ -1,6 +1,6 @@ import abc import time -from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -112,27 +112,24 @@ def get_response(cls, cursor: Any) -> AdapterResponse: @classmethod def process_results( cls, column_names: Iterable[str], rows: Iterable[Any] - ) -> List[Dict[str, Any]]: - # TODO CT-211 + ) -> Iterator[Dict[str, Any]]: unique_col_names = dict() # type: ignore[var-annotated] - # TODO CT-211 for idx in range(len(column_names)): # type: ignore[arg-type] - # TODO CT-211 col_name = column_names[idx] # type: ignore[index] if col_name in unique_col_names: unique_col_names[col_name] += 1 - # TODO CT-211 column_names[idx] = f"{col_name}_{unique_col_names[col_name]}" # type: ignore[index] # noqa else: - # TODO CT-211 unique_col_names[column_names[idx]] = 1 # type: ignore[index] - return [dict(zip(column_names, row)) for row in rows] + + for row in rows: + yield dict(zip(column_names, row)) @classmethod def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> "agate.Table": from dbt_common.clients.agate_helper import table_from_data_flat - data: List[Any] = [] + data: Iterable[Any] = [] column_names: List[str] = [] if cursor.description is not None: diff --git a/tests/unit/test_sql_result.py b/tests/unit/test_sql_result.py index 12c173cb..454e6572 100644 --- a/tests/unit/test_sql_result.py +++ b/tests/unit/test_sql_result.py @@ -8,13 +8,13 @@ def test_duplicated_columns(self): cols_with_one_dupe = ["a", "b", "a", "d"] rows = [(1, 2, 3, 4)] self.assertEqual( - SQLConnectionManager.process_results(cols_with_one_dupe, rows), + list(SQLConnectionManager.process_results(cols_with_one_dupe, rows)), [{"a": 1, "b": 2, "a_2": 3, "d": 4}], ) cols_with_more_dupes = ["a", "a", "a", "b"] rows = [(1, 2, 3, 4)] self.assertEqual( - SQLConnectionManager.process_results(cols_with_more_dupes, rows), + list(SQLConnectionManager.process_results(cols_with_more_dupes, rows)), [{"a": 1, "a_2": 2, "a_3": 3, "b": 4}], )