Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible cluster configuration #467

Merged
merged 5 commits into from
Sep 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changes/unreleased/Features-20220923-101248.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
kind: Features
body: Support job cluster in notebook submission method, remove requirement for user
for python model submission
time: 2022-09-23T10:12:48.288911-07:00
custom:
Author: ChenyuLInx
Issue: "444"
PR: "467"
4 changes: 4 additions & 0 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,10 @@ def __pre_deserialize__(cls, data):
data["database"] = None
return data

@property
def cluster_id(self):
return self.cluster

def __post_init__(self):
# spark classifies database and schema as the same thing
if self.database is not None and self.database != self.schema:
10 changes: 5 additions & 5 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -18,8 +18,8 @@
from dbt.adapters.spark import SparkRelation
from dbt.adapters.spark import SparkColumn
from dbt.adapters.spark.python_submissions import (
DBNotebookPythonJobHelper,
DBCommandsApiPythonJobHelper,
JobClusterPythonJobHelper,
AllPurposeClusterPythonJobHelper,
)
from dbt.adapters.base import BaseRelation
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER
@@ -377,13 +377,13 @@ def generate_python_submission_response(self, submission_result: Any) -> Adapter

@property
def default_python_submission_method(self) -> str:
return "commands"
return "all_purpose_cluster"

@property
def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
return {
"notebook": DBNotebookPythonJobHelper,
"commands": DBCommandsApiPythonJobHelper,
"job_cluster": JobClusterPythonJobHelper,
"all_purpose_cluster": AllPurposeClusterPythonJobHelper,
}

def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
210 changes: 113 additions & 97 deletions dbt/adapters/spark/python_submissions.py
Original file line number Diff line number Diff line change
@@ -7,78 +7,43 @@
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import SparkCredentials
from dbt.adapters.spark import __version__

DEFAULT_POLLING_INTERVAL = 5
DEFAULT_POLLING_INTERVAL = 10
SUBMISSION_LANGUAGE = "python"
DEFAULT_TIMEOUT = 60 * 60 * 24
DBT_SPARK_VERSION = __version__.version


class BaseDatabricksHelper(PythonJobHelper):
def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None:
self.check_credentials(credentials)
self.credentials = credentials
self.identifier = parsed_model["alias"]
self.schema = getattr(parsed_model, "schema", self.credentials.schema)
self.schema = parsed_model["schema"]
self.parsed_model = parsed_model
self.timeout = self.get_timeout()
self.polling_interval = DEFAULT_POLLING_INTERVAL
self.check_credentials()
self.auth_header = {
"Authorization": f"Bearer {self.credentials.token}",
"User-Agent": f"dbt-labs-dbt-spark/{DBT_SPARK_VERSION} (Databricks)",
}

@property
def cluster_id(self) -> str:
return self.parsed_model.get("cluster_id", self.credentials.cluster_id)
Copy link
Contributor

@ueshin ueshin Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChenyuLInx I'm updating and testing dbt-databricks, and found out that this is a mistake.

I think this should be:

self.parsed_model["config"].get("cluster_id", self.credentials.cluster_id)


def get_timeout(self) -> int:
timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT)
if timeout <= 0:
raise ValueError("Timeout must be a positive integer")
return timeout

def check_credentials(self, credentials: SparkCredentials) -> None:
def check_credentials(self) -> None:
raise NotImplementedError(
"Overwrite this method to check specific requirement for current submission method"
)

def submit(self, compiled_code: str) -> None:
raise NotImplementedError(
"BasePythonJobHelper is an abstract class and you should implement submit method."
)

def polling(
self,
status_func,
status_func_kwargs,
get_state_func,
terminal_states,
expected_end_state,
get_state_msg_func,
) -> Dict:
state = None
start = time.time()
exceeded_timeout = False
response = {}
while state not in terminal_states:
if time.time() - start > self.timeout:
exceeded_timeout = True
break
# should we do exponential backoff?
time.sleep(self.polling_interval)
response = status_func(**status_func_kwargs)
state = get_state_func(response)
if exceeded_timeout:
raise dbt.exceptions.RuntimeException("python model run timed out")
if state != expected_end_state:
raise dbt.exceptions.RuntimeException(
"python model run ended in state"
f"{state} with state_message\n{get_state_msg_func(response)}"
)
return response


class DBNotebookPythonJobHelper(BaseDatabricksHelper):
def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None:
super().__init__(parsed_model, credentials)
self.auth_header = {"Authorization": f"Bearer {self.credentials.token}"}

def check_credentials(self, credentials) -> None:
if not credentials.user:
raise ValueError("Databricks user is required for notebook submission method.")

def _create_work_dir(self, path: str) -> None:
response = requests.post(
f"https://{self.credentials.host}/api/2.0/workspace/mkdirs",
@@ -110,35 +75,35 @@ def _upload_notebook(self, path: str, compiled_code: str) -> None:
f"Error creating python notebook.\n {response.content!r}"
)

def _submit_notebook(self, path: str) -> str:
def _submit_job(self, path: str, cluster_spec: dict) -> str:
job_spec = {
"run_name": f"{self.schema}-{self.identifier}-{uuid.uuid4()}",
"notebook_task": {
"notebook_path": path,
},
}
job_spec.update(cluster_spec)
submit_response = requests.post(
f"https://{self.credentials.host}/api/2.1/jobs/runs/submit",
headers=self.auth_header,
json={
"run_name": f"{self.schema}-{self.identifier}-{uuid.uuid4()}",
"existing_cluster_id": self.credentials.cluster,
"notebook_task": {
"notebook_path": path,
},
},
json=job_spec,
)
if submit_response.status_code != 200:
raise dbt.exceptions.RuntimeException(
f"Error creating python run.\n {submit_response.content!r}"
)
return submit_response.json()["run_id"]

def submit(self, compiled_code: str) -> None:
def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> None:
# it is safe to call mkdirs even if dir already exists and have content inside
work_dir = f"/Users/{self.credentials.user}/{self.schema}/"
work_dir = f"/dbt_python_model/{self.schema}/"
self._create_work_dir(work_dir)

# add notebook
whole_file_path = f"{work_dir}{self.identifier}"
self._upload_notebook(whole_file_path, compiled_code)

# submit job
run_id = self._submit_notebook(whole_file_path)
run_id = self._submit_job(whole_file_path, cluster_spec)

self.polling(
status_func=requests.get,
@@ -167,11 +132,56 @@ def submit(self, compiled_code: str) -> None:
f"{json_run_output['error_trace']}"
)

def submit(self, compiled_code: str) -> None:
raise NotImplementedError(
"BasePythonJobHelper is an abstract class and you should implement submit method."
)

def polling(
self,
status_func,
status_func_kwargs,
get_state_func,
terminal_states,
expected_end_state,
get_state_msg_func,
) -> Dict:
state = None
start = time.time()
exceeded_timeout = False
response = {}
while state not in terminal_states:
if time.time() - start > self.timeout:
exceeded_timeout = True
break
# should we do exponential backoff?
time.sleep(self.polling_interval)
response = status_func(**status_func_kwargs)
state = get_state_func(response)
if exceeded_timeout:
raise dbt.exceptions.RuntimeException("python model run timed out")
if state != expected_end_state:
raise dbt.exceptions.RuntimeException(
"python model run ended in state"
f"{state} with state_message\n{get_state_msg_func(response)}"
)
return response


class JobClusterPythonJobHelper(BaseDatabricksHelper):
def check_credentials(self) -> None:
if not self.parsed_model["config"].get("job_cluster_config", None):
raise ValueError("job_cluster_config is required for commands submission method.")

def submit(self, compiled_code: str) -> None:
cluster_spec = {"new_cluster": self.parsed_model["config"]["job_cluster_config"]}
self._submit_through_notebook(compiled_code, cluster_spec)


class DBContext:
def __init__(self, credentials: SparkCredentials) -> None:
self.auth_header = {"Authorization": f"Bearer {credentials.token}"}
self.cluster = credentials.cluster
def __init__(self, credentials: SparkCredentials, cluster_id: str, auth_header: dict) -> None:
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host

def create(self) -> str:
@@ -180,7 +190,7 @@ def create(self) -> str:
f"https://{self.host}/api/1.2/contexts/create",
headers=self.auth_header,
json={
"clusterId": self.cluster,
"clusterId": self.cluster_id,
"language": SUBMISSION_LANGUAGE,
},
)
@@ -196,7 +206,7 @@ def destroy(self, context_id: str) -> str:
f"https://{self.host}/api/1.2/contexts/destroy",
headers=self.auth_header,
json={
"clusterId": self.cluster,
"clusterId": self.cluster_id,
"contextId": context_id,
},
)
@@ -208,9 +218,9 @@ def destroy(self, context_id: str) -> str:


class DBCommand:
def __init__(self, credentials: SparkCredentials) -> None:
self.auth_header = {"Authorization": f"Bearer {credentials.token}"}
self.cluster = credentials.cluster
def __init__(self, credentials: SparkCredentials, cluster_id: str, auth_header: dict) -> None:
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host

def execute(self, context_id: str, command: str) -> str:
@@ -219,7 +229,7 @@ def execute(self, context_id: str, command: str) -> str:
f"https://{self.host}/api/1.2/commands/execute",
headers=self.auth_header,
json={
"clusterId": self.cluster,
"clusterId": self.cluster_id,
"contextId": context_id,
"language": SUBMISSION_LANGUAGE,
"command": command,
@@ -237,7 +247,7 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
f"https://{self.host}/api/1.2/commands/status",
headers=self.auth_header,
params={
"clusterId": self.cluster,
"clusterId": self.cluster_id,
"contextId": context_id,
"commandId": command_id,
},
@@ -249,32 +259,38 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
return response.json()


class DBCommandsApiPythonJobHelper(BaseDatabricksHelper):
def check_credentials(self, credentials: SparkCredentials) -> None:
if not credentials.cluster:
raise ValueError("Databricks cluster is required for commands submission method.")
class AllPurposeClusterPythonJobHelper(BaseDatabricksHelper):
def check_credentials(self) -> None:
if not self.cluster_id:
raise ValueError(
"Databricks cluster_id is required for all_purpose_cluster submission method with running with notebook."
)

def submit(self, compiled_code: str) -> None:
context = DBContext(self.credentials)
command = DBCommand(self.credentials)
context_id = context.create()
try:
command_id = command.execute(context_id, compiled_code)
# poll until job finish
response = self.polling(
status_func=command.status,
status_func_kwargs={
"context_id": context_id,
"command_id": command_id,
},
get_state_func=lambda response: response["status"],
terminal_states=("Cancelled", "Error", "Finished"),
expected_end_state="Finished",
get_state_msg_func=lambda response: response.json()["results"]["data"],
)
if response["results"]["resultType"] == "error":
raise dbt.exceptions.RuntimeException(
f"Python model failed with traceback as:\n" f"{response['results']['cause']}"
if self.parsed_model["config"].get("create_notebook", False):
self._submit_through_notebook(compiled_code, {"existing_cluster_id": self.cluster_id})
else:
context = DBContext(self.credentials, self.cluster_id, self.auth_header)
command = DBCommand(self.credentials, self.cluster_id, self.auth_header)
context_id = context.create()
try:
command_id = command.execute(context_id, compiled_code)
# poll until job finish
response = self.polling(
status_func=command.status,
status_func_kwargs={
"context_id": context_id,
"command_id": command_id,
},
get_state_func=lambda response: response["status"],
terminal_states=("Cancelled", "Error", "Finished"),
expected_end_state="Finished",
get_state_msg_func=lambda response: response.json()["results"]["data"],
)
finally:
context.destroy(context_id)
if response["results"]["resultType"] == "error":
raise dbt.exceptions.RuntimeException(
f"Python model failed with traceback as:\n"
f"{response['results']['cause']}"
)
finally:
context.destroy(context_id)
6 changes: 5 additions & 1 deletion tests/functional/adapter/test_python_model.py
Original file line number Diff line number Diff line change
@@ -2,11 +2,15 @@
import pytest
from dbt.tests.util import run_dbt, write_file, run_dbt_and_capture
from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests, BasePythonIncrementalTests

from dbt.tests.adapter.python_model.test_spark import BasePySparkTests
@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
class TestPythonModelSpark(BasePythonModelTests):
pass

@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
class TestPySpark(BasePySparkTests):
pass

@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
class TestPythonIncrementalModelSpark(BasePythonIncrementalTests):
@pytest.fixture(scope="class")