Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate references to base adapter #689

Merged
merged 8 commits into from
Jan 10, 2024
Merged
21 changes: 11 additions & 10 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from redshift_connector.utils.oids import get_datatype_name

from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.contracts.util import Replaceable
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ValidationError
from dbt.events import AdapterLogger
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.contracts.util import Replaceable
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError
from dbt.common.helper_types import Port
from dbt.exceptions import DbtRuntimeError, CompilationError
import dbt.flags
from dbt.helper_types import Port
import dbt.mp_context


class SSLConfigError(CompilationError):
Expand All @@ -33,7 +34,7 @@ def get_message(self) -> str:
logger = AdapterLogger("Redshift")


drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore
drop_lock: Lock = dbt.mp_context._MP_CONTEXT.Lock() # type: ignore
VersusFacit marked this conversation as resolved.
Show resolved Hide resolved


class RedshiftConnectionMethod(StrEnum):
Expand Down Expand Up @@ -185,7 +186,7 @@ def get_connect_method(self):
# this requirement is really annoying to encode into json schema,
# so validate it here
if self.credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
raise dbt.adapters.exceptions.FailedToConnectError(
"'password' field is required for 'database' credentials"
)

Expand All @@ -204,7 +205,7 @@ def connect():

elif method == RedshiftConnectionMethod.IAM:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise dbt.exceptions.FailedToConnectError(
raise dbt.adapters.exceptions.FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)
Expand All @@ -227,7 +228,7 @@ def connect():
return c

else:
raise dbt.exceptions.FailedToConnectError(
raise dbt.adapters.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)

Expand Down Expand Up @@ -349,7 +350,7 @@ def execute(
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
table = dbt.clients.agate_helper.empty_table()
table = dbt.common.clients.agate_helper.empty_table()
return response, table

def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport
from dbt.adapters.base.meta import available
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.connection import AdapterResponse
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ConstraintType
from dbt.events import AdapterLogger
from dbt.adapters.events.logging import AdapterLogger


import dbt.exceptions
import dbt.adapters.exceptions

from dbt.adapters.redshift import RedshiftConnectionManager, RedshiftRelation

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/redshift/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from dbt.context.providers import RuntimeConfigObject
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import RelationType
from dbt.adapters.base import RelationType
from dbt.exceptions import DbtRuntimeError

from dbt.adapters.redshift.relation_configs import (
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/redshift/relation_configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.contracts.relation import ComponentName
from dbt.adapters.relation_configs import (
RelationConfigBase,
RelationResults,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName

from dbt.adapters.redshift.relation_configs.policies import (
RedshiftIncludePolicy,
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/redshift/relation_configs/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
RelationConfigValidationRule,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.dataclass_schema import StrEnum
from dbt.common.dataclass_schema import StrEnum
from dbt.exceptions import DbtRuntimeError

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
RelationConfigValidationRule,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from dbt.adapters.contracts.relation import ComponentName
from dbt.exceptions import DbtRuntimeError

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/redshift/relation_configs/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
RelationConfigValidationRule,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.dataclass_schema import StrEnum
from dbt.common.dataclass_schema import StrEnum
from dbt.exceptions import DbtRuntimeError

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase
Expand Down
6 changes: 3 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# install latest changes in dbt-core + dbt-postgres
# TODO: how to switch from HEAD to x.y.latest branches after minor releases?
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-postgres&subdirectory=plugins/postgres
git+https://github.com/dbt-labs/dbt-core.git@feature/decouple-adapters-from-core#egg=dbt-core&subdirectory=core
git+https://github.com/dbt-labs/dbt-core.git@feature/decouple-adapters-from-core#egg=dbt-tests-adapter&subdirectory=tests/adapter
git+https://github.com/dbt-labs/dbt-core.git@feature/decouple-adapters-from-core#egg=dbt-postgres&subdirectory=plugins/postgres

# if version 1.x or greater -> pin to major version
# if version 0.x -> pin to minor
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pytest
import unittest

from multiprocessing import get_context
from unittest import mock

from .utils import config_from_parts_or_dicts, inject_adapter, clear_plugin
from .mock_adapter import adapter_factory
import dbt.exceptions
import dbt.adapters.exceptions

from dbt.adapters import (
redshift,
Expand Down Expand Up @@ -191,7 +192,7 @@ def manifest_extended(manifest_fx):

@pytest.fixture
def redshift_adapter(config, get_adapter):
adapter = redshift.RedshiftAdapter(config)
adapter = redshift.RedshiftAdapter(config, get_context("spawn"))
inject_adapter(adapter, redshift.Plugin)
get_adapter.return_value = adapter
yield adapter
Expand Down
18 changes: 10 additions & 8 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import unittest

from multiprocessing import get_context
from unittest import mock
from unittest.mock import Mock, call

Expand All @@ -10,8 +12,8 @@
RedshiftAdapter,
Plugin as RedshiftPlugin,
)
from dbt.clients import agate_helper
from dbt.exceptions import FailedToConnectError
from dbt.common.clients import agate_helper
from dbt.adapters.exceptions import FailedToConnectError
from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig
from .utils import (
config_from_parts_or_dicts,
Expand Down Expand Up @@ -59,7 +61,7 @@ def setUp(self):
@property
def adapter(self):
if self._adapter is None:
self._adapter = RedshiftAdapter(self.config)
self._adapter = RedshiftAdapter(self.config, get_context("spawn"))
inject_adapter(self._adapter, RedshiftPlugin)
return self._adapter

Expand Down Expand Up @@ -235,7 +237,7 @@ def test_explicit_region_failure(self):
region=None,
)

with self.assertRaises(dbt.exceptions.FailedToConnectError):
with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError):
connection = self.adapter.acquire_connection("dummy")
connection.handle
redshift_connector.connect.assert_called_once_with(
Expand Down Expand Up @@ -264,7 +266,7 @@ def test_explicit_invalid_region(self):
region=None,
)

with self.assertRaises(dbt.exceptions.FailedToConnectError):
with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError):
connection = self.adapter.acquire_connection("dummy")
connection.handle
redshift_connector.connect.assert_called_once_with(
Expand Down Expand Up @@ -385,7 +387,7 @@ def test_serverless_iam_failure(self):
iam_profile="test",
host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com",
)
with self.assertRaises(dbt.exceptions.FailedToConnectError) as context:
with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError) as context:
connection = self.adapter.acquire_connection("dummy")
connection.handle
redshift_connector.connect.assert_called_once_with(
Expand Down Expand Up @@ -507,12 +509,12 @@ def test_dbname_verification_is_case_insensitive(self):
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.adapter.cleanup_connections()
self._adapter = RedshiftAdapter(self.config)
self._adapter = RedshiftAdapter(self.config, get_context("spawn"))
self.adapter.verify_database("redshift")

def test_execute_with_fetch(self):
cursor = mock.Mock()
table = dbt.clients.agate_helper.empty_table()
table = dbt.common.clients.agate_helper.empty_table()
with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query:
mock_add_query.return_value = (
None,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import agate
import pytest
from dbt.dataclass_schema import ValidationError
from dbt.common.dataclass_schema import ValidationError
from dbt.config.project import PartialProject


Expand Down Expand Up @@ -233,7 +233,7 @@ def assert_fails_validation(dct, cls):
class TestAdapterConversions(TestCase):
@staticmethod
def _get_tester_for(column_type):
from dbt.clients import agate_helper
from dbt.common.clients import agate_helper

if column_type is agate.TimeDelta: # dbt never makes this!
return agate.TimeDelta()
Expand Down
Loading