Skip to content

Commit

Permalink
Follow "More flexible cluster configuration". (#194)
Browse files Browse the repository at this point in the history
### Description

Follows "More flexible cluster configuration" at dbt-labs/dbt-spark#467.

- Reuse `dbt-spark`'s implementation
- Remove the dependency on `databricks-cli`
- Internal refactorings

Co-authored-by: allisonwang-db <[email protected]>
  • Loading branch information
ueshin and allisonwang-db authored Sep 28, 2022
1 parent 3a41729 commit dbd58fb
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 205 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
### Features
- Support python model through run command API, currently supported materializations are table and incremental. ([dbt-labs/dbt-spark#377](https://github.com/dbt-labs/dbt-spark/pull/377), [#126](https://github.com/databricks/dbt-databricks/pull/126))
- Enable Pandas and Pandas-on-Spark DataFrames for dbt python models ([dbt-labs/dbt-spark#469](https://github.com/dbt-labs/dbt-spark/pull/469), [#181](https://github.com/databricks/dbt-databricks/pull/181))
- Support job cluster in notebook submission method ([dbt-labs/dbt-spark#467](https://github.com/dbt-labs/dbt-spark/pull/467), [#194](https://github.com/databricks/dbt-databricks/pull/194))
- In `all_purpose_cluster` submission method, a config `http_path` can be specified in Python model config to switch the cluster where Python model runs.
```py
def model(dbt, _):
dbt.config(
materialized='table',
http_path='...'
)
...
```
- Use builtin timestampadd and timestampdiff functions for dateadd/datediff macros if available ([#185](https://github.com/databricks/dbt-databricks/pull/185))
- Implement testing for a test for various Python models ([#189](https://github.com/databricks/dbt-databricks/pull/189))
- Implement testing for `type_boolean` in Databricks ([dbt-labs/dbt-spark#471](https://github.com/dbt-labs/dbt-spark/pull/471), [#188](https://github.com/databricks/dbt-databricks/pull/188))
Expand Down
87 changes: 0 additions & 87 deletions dbt/adapters/databricks/api_client.py

This file was deleted.

108 changes: 56 additions & 52 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,50 @@ def __post_init__(self) -> None:
)
self.connection_parameters = connection_parameters

def validate_creds(self) -> None:
for key in ["host", "http_path", "token"]:
if not getattr(self, key):
raise dbt.exceptions.DbtProfileError(
"The config '{}' is required to connect to Databricks".format(key)
)

@classmethod
def get_invocation_env(cls) -> Optional[str]:
invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
if invocation_env:
# Thrift doesn't allow nested () so we need to ensure
# that the passed user agent is valid.
if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env):
raise dbt.exceptions.ValidationException(
f"Invalid invocation environment: {invocation_env}"
)
return invocation_env

@classmethod
def get_all_http_headers(cls, user_http_session_headers: Dict[str, str]) -> Dict[str, str]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)

http_session_headers_dict: Dict[str, str] = (
{k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()}
if http_session_headers_str is not None
else {}
)

intersect_http_header_keys = (
user_http_session_headers.keys() & http_session_headers_dict.keys()
)

if len(intersect_http_header_keys) > 0:
raise dbt.exceptions.ValidationException(
f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}"
)

http_session_headers_dict.update(user_http_session_headers)

return http_session_headers_dict

@property
def type(self) -> str:
return "databricks"
Expand Down Expand Up @@ -165,14 +209,18 @@ def _connection_keys(self, *, with_aliases: bool = False) -> Tuple[str, ...]:
connection_keys.append("session_properties")
return tuple(connection_keys)

@property
def cluster_id(self) -> Optional[str]:
m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(self.http_path) # type: ignore[arg-type]
@classmethod
def extract_cluster_id(cls, http_path: str) -> Optional[str]:
m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(http_path)
if m:
return m.group(1).strip()
else:
return None

@property
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]


class DatabricksSQLConnectionWrapper:
"""Wrap a Databricks SQL connector in a way that no-ops transactions"""
Expand Down Expand Up @@ -437,69 +485,25 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema),
)

@classmethod
def validate_creds(cls, creds: DatabricksCredentials, required: List[str]) -> None:
for key in required:
if not getattr(creds, key):
raise dbt.exceptions.DbtProfileError(
"The config '{}' is required to connect to Databricks".format(key)
)

@classmethod
def validate_invocation_env(cls, invocation_env: str) -> None:
# Thrift doesn't allow nested () so we need to ensure that the passed user agent is valid
if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env):
raise dbt.exceptions.ValidationException(
f"Invalid invocation environment: {invocation_env}"
)

@classmethod
def get_all_http_headers(
cls, user_http_session_headers: Dict[str, str]
) -> List[Tuple[str, str]]:
http_session_headers_str: Optional[str] = os.environ.get(
DBT_DATABRICKS_HTTP_SESSION_HEADERS
)

http_session_headers_dict: Dict[str, str] = (
{k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()}
if http_session_headers_str is not None
else {}
)

intersect_http_header_keys = (
user_http_session_headers.keys() & http_session_headers_dict.keys()
)

if len(intersect_http_header_keys) > 0:
raise dbt.exceptions.ValidationException(
f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}"
)

http_session_headers_dict.update(user_http_session_headers)

return list(http_session_headers_dict.items())

@classmethod
def open(cls, connection: Connection) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection

creds: DatabricksCredentials = connection.credentials
cls.validate_creds(creds, ["host", "http_path", "token"])
creds.validate_creds()

user_agent_entry = f"dbt-databricks/{__version__}"

invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV)
if invocation_env is not None and len(invocation_env) > 0:
cls.validate_invocation_env(invocation_env)
invocation_env = creds.get_invocation_env()
if invocation_env:
user_agent_entry = f"{user_agent_entry}; {invocation_env}"

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = cls.get_all_http_headers(
connection_parameters.pop("http_headers", {})
http_headers: List[Tuple[str, str]] = list(
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

exc: Optional[Exception] = None
Expand Down
14 changes: 8 additions & 6 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import DatabricksConnectionManager
from dbt.adapters.databricks.python_submissions import CommandApiPythonJobHelper
from dbt.adapters.databricks.python_submissions import (
DbtDatabricksAllPurposeClusterPythonJobHelper,
DbtDatabricksJobClusterPythonJobHelper,
)
from dbt.adapters.databricks.relation import DatabricksRelation
from dbt.adapters.databricks.utils import undefined_proof

Expand Down Expand Up @@ -264,13 +267,12 @@ def run_sql_for_tests(
def valid_incremental_strategies(self) -> List[str]:
return ["append", "merge", "insert_overwrite"]

@property
def default_python_submission_method(self) -> str:
return "commands"

@property
def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
return {"commands": CommandApiPythonJobHelper}
return {
"job_cluster": DbtDatabricksJobClusterPythonJobHelper,
"all_purpose_cluster": DbtDatabricksAllPurposeClusterPythonJobHelper,
}

@contextmanager
def _catalog(self, catalog: Optional[str]) -> Iterator[None]:
Expand Down
Loading

0 comments on commit dbd58fb

Please sign in to comment.