diff --git a/tests/unit/db/vertica/test_vertica_db_driver.py b/tests/unit/db/vertica/test_vertica_db_driver.py index 21f1434d0..ef6415a33 100644 --- a/tests/unit/db/vertica/test_vertica_db_driver.py +++ b/tests/unit/db/vertica/test_vertica_db_driver.py @@ -1,5 +1,5 @@ from .base_test_vertica_db_driver import BaseTestVerticaDBDriver -from mock import Mock +from mock import Mock, MagicMock from ...records.format_hints import (vertica_format_hints) import sqlalchemy @@ -8,7 +8,10 @@ class TestVerticaDBDriver(BaseTestVerticaDBDriver): def test_unload(self): mock_result = Mock(name='result') mock_result.rows = 579 - self.mock_db_engine.execute.return_value.fetchall.return_value = [mock_result] + mock_connection = MagicMock(name='connection') + self.mock_db_engine.connect.return_value \ + .__enter__.return_value = mock_connection + mock_connection.execute.return_value.fetchall.return_value = [mock_result] self.mock_records_unload_plan.processing_instructions.fail_if_dont_understand = True self.mock_records_unload_plan.processing_instructions.fail_if_cant_handle_hint = True @@ -25,7 +28,10 @@ def test_unload(self): def test_unload_to_non_s3(self): mock_result = Mock(name='result') mock_result.rows = 579 - self.mock_db_engine.execute.return_value.fetchall.return_value = [mock_result] + mock_connection = MagicMock(name='connection') + self.mock_db_engine.connect.return_value \ + .__enter__.return_value = mock_connection + mock_connection.execute.return_value.fetchall.return_value = [mock_result] self.mock_records_unload_plan.processing_instructions.fail_if_dont_understand = True self.mock_records_unload_plan.processing_instructions.fail_if_cant_handle_hint = True