diff --git a/records_mover/records/sources/table.py b/records_mover/records/sources/table.py index a279fb78a..426e70ca2 100644 --- a/records_mover/records/sources/table.py +++ b/records_mover/records/sources/table.py @@ -14,7 +14,7 @@ from ...url.resolver import UrlResolver from records_mover.url.base import BaseDirectoryUrl import logging -from typing import Iterator, List, TYPE_CHECKING +from typing import Generator, Iterator, List, TYPE_CHECKING if TYPE_CHECKING: from .dataframes import DataframesRecordsSource # noqa from pandas import DataFrame # noqa @@ -89,13 +89,17 @@ def to_dataframes_source(self, logger.info(f"Exporting in chunks of up to {chunksize} rows by {num_columns} columns") quoted_table = quote_schema_and_table(db, self.schema_name, self.table_name) - chunks: Iterator['DataFrame'] = \ + chunks: Generator['DataFrame', None, None] = \ pandas.read_sql(f"SELECT * FROM {quoted_table}", con=db, chunksize=chunksize) - yield DataframesRecordsSource(dfs=self.with_cast_dataframe_types(records_schema, chunks), - records_schema=records_schema, - processing_instructions=processing_instructions) + try: + yield DataframesRecordsSource(dfs=self.with_cast_dataframe_types(records_schema, + chunks), + records_schema=records_schema, + processing_instructions=processing_instructions) + finally: + chunks.close() def with_cast_dataframe_types(self, records_schema: RecordsSchema, diff --git a/tests/unit/records/sources/test_table.py b/tests/unit/records/sources/test_table.py index d9b9576d6..027c6c8d8 100644 --- a/tests/unit/records/sources/test_table.py +++ b/tests/unit/records/sources/test_table.py @@ -35,6 +35,7 @@ def test_to_dataframes_source(self, mock_columns = [mock_column] mock_db.dialect.get_columns.return_value = mock_columns mock_quoted_table = mock_quote_schema_and_table.return_value + mock_chunks = mock_read_sql.return_value with self.table_records_source.to_dataframes_source(mock_processing_instructions) as\ df_source: self.assertEqual(df_source, mock_DataframesRecordsSource.return_value) @@ -51,6 +52,8 @@ def test_to_dataframes_source(self, assert_called_with(dfs=ANY, processing_instructions=mock_processing_instructions, records_schema=mock_records_schema) + mock_chunks.close.assert_not_called() + mock_chunks.close.assert_called() @patch('records_mover.records.sources.table.RecordsUnloadPlan') @patch('records_mover.records.sources.table.MoveResult')