Skip to content

Commit

Permalink
Adding a default retry strategy in python submissions (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db authored Jan 16, 2024
2 parents 7421491 + 6c43ba7 commit 79e8707
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 26 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

### Fixes

- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547))
- Added python model specific connection handling to prevent using invalid sessions ([547](https://github.com/databricks/dbt-databricks/pull/547))
- Allow schema to be specified in testing (thanks @case-k-git!) ([538](https://github.com/databricks/dbt-databricks/pull/538))
- Fix dbt incremental_strategy behavior by fixing schema table existing check (thanks @case-k-git!) ([530](https://github.com/databricks/dbt-databricks/pull/530))

### Under the Hood

- Adding retries around API calls in python model submission ([549](https://github.com/databricks/dbt-databricks/pull/549))

## dbt-databricks 1.7.3 (Dec 12, 2023)

### Fixes
Expand Down
53 changes: 37 additions & 16 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from typing import Any, Dict, Tuple, Optional, Callable

from requests import Session

from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.connections import DatabricksCredentials
from dbt.adapters.databricks import utils

import base64
import time
import requests
import uuid

from urllib3.util.retry import Retry

from dbt.events import AdapterLogger
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import __version__
from databricks.sdk.core import CredentialsProvider
from requests.adapters import HTTPAdapter

logger = AdapterLogger("Databricks")

Expand All @@ -31,6 +35,13 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
self.parsed_model = parsed_model
self.timeout = self.get_timeout()
self.polling_interval = DEFAULT_POLLING_INTERVAL

# This should be passed in, but not sure where this is actually instantiated
retry_strategy = Retry(total=4, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session = Session()
self.session.mount("https://", adapter)

self.check_credentials()
self.auth_header = {
"Authorization": f"Bearer {self.credentials.token}",
Expand All @@ -53,7 +64,7 @@ def check_credentials(self) -> None:
)

def _create_work_dir(self, path: str) -> None:
response = requests.post(
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/mkdirs",
headers=self.auth_header,
json={
Expand All @@ -73,7 +84,7 @@ def _update_with_acls(self, cluster_dict: dict) -> dict:

def _upload_notebook(self, path: str, compiled_code: str) -> None:
b64_encoded_content = base64.b64encode(compiled_code.encode()).decode()
response = requests.post(
response = self.session.post(
f"https://{self.credentials.host}/api/2.0/workspace/import",
headers=self.auth_header,
json={
Expand Down Expand Up @@ -118,7 +129,7 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str:
libraries.append(lib)

job_spec.update({"libraries": libraries}) # type: ignore
submit_response = requests.post(
submit_response = self.session.post(
f"https://{self.credentials.host}/api/2.1/jobs/runs/submit",
headers=self.auth_header,
json=job_spec,
Expand All @@ -143,7 +154,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
run_id = self._submit_job(whole_file_path, cluster_spec)

self.polling(
status_func=requests.get,
status_func=self.session.get,
status_func_kwargs={
"url": f"https://{self.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}",
"headers": self.auth_header,
Expand All @@ -155,7 +166,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No
)

# get end state to return to user
run_output = requests.get(
run_output = self.session.get(
f"https://{self.credentials.host}" f"/api/2.1/jobs/runs/get-output?run_id={run_id}",
headers=self.auth_header,
)
Expand Down Expand Up @@ -217,11 +228,16 @@ def submit(self, compiled_code: str) -> None:

class DBContext:
def __init__(
self, credentials: DatabricksCredentials, cluster_id: str, auth_header: dict
self,
credentials: DatabricksCredentials,
cluster_id: str,
auth_header: dict,
session: Session,
) -> None:
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session

def create(self) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#create-an-execution-context
Expand All @@ -235,7 +251,7 @@ def create(self) -> str:
if current_status != "RUNNING":
self._wait_for_cluster_to_start()

response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/contexts/create",
headers=self.auth_header,
json={
Expand All @@ -251,7 +267,7 @@ def create(self) -> str:

def destroy(self, context_id: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context
response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/contexts/destroy",
headers=self.auth_header,
json={
Expand All @@ -268,7 +284,7 @@ def destroy(self, context_id: str) -> str:
def get_cluster_status(self) -> Dict:
# https://docs.databricks.com/dev-tools/api/latest/clusters.html#get

response = requests.get(
response = self.session.get(
f"https://{self.host}/api/2.0/clusters/get",
headers=self.auth_header,
json={"cluster_id": self.cluster_id},
Expand All @@ -291,7 +307,7 @@ def start_cluster(self) -> None:

logger.debug(f"Sending restart command for cluster id {self.cluster_id}")

response = requests.post(
response = self.session.post(
f"https://{self.host}/api/2.0/clusters/start",
headers=self.auth_header,
json={"cluster_id": self.cluster_id},
Expand Down Expand Up @@ -327,15 +343,20 @@ def get_elapsed() -> float:

class DBCommand:
def __init__(
self, credentials: DatabricksCredentials, cluster_id: str, auth_header: dict
self,
credentials: DatabricksCredentials,
cluster_id: str,
auth_header: dict,
session: Session,
) -> None:
self.auth_header = auth_header
self.cluster_id = cluster_id
self.host = credentials.host
self.session = session

def execute(self, context_id: str, command: str) -> str:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command
response = requests.post(
response = self.session.post(
f"https://{self.host}/api/1.2/commands/execute",
headers=self.auth_header,
json={
Expand All @@ -354,7 +375,7 @@ def execute(self, context_id: str, command: str) -> str:

def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
# https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command
response = requests.get(
response = self.session.get(
f"https://{self.host}/api/1.2/commands/status",
headers=self.auth_header,
params={
Expand Down Expand Up @@ -383,8 +404,8 @@ def submit(self, compiled_code: str) -> None:
config = {"existing_cluster_id": self.cluster_id}
self._submit_through_notebook(compiled_code, self._update_with_acls(config))
else:
context = DBContext(self.credentials, self.cluster_id, self.auth_header)
command = DBCommand(self.credentials, self.cluster_id, self.auth_header)
context = DBContext(self.credentials, self.cluster_id, self.auth_header, self.session)
command = DBCommand(self.credentials, self.cluster_id, self.auth_header, self.session)
context_id = context.create()
try:
command_id = command.execute(context_id, compiled_code)
Expand Down
21 changes: 12 additions & 9 deletions tests/unit/python/test_python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import unittest
from unittest.mock import patch, Mock
from unittest.mock import Mock

from dbt.adapters.databricks.connections import DatabricksCredentials

from dbt.adapters.databricks.python_submissions import DBContext, BaseDatabricksHelper


class TestDatabricksPythonSubmissions(unittest.TestCase):
@patch("requests.get")
@patch("requests.post")
def test_start_cluster_returns_on_receiving_running_state(self, mock_post, mock_get):
def test_start_cluster_returns_on_receiving_running_state(self):
session_mock = Mock()
# Mock the start command
mock_post.return_value.status_code = 200
post_mock = Mock()
post_mock.status_code = 200
session_mock.post.return_value = post_mock
# Mock the status command
mock_get.return_value.status_code = 200
mock_get.return_value.json = Mock(return_value={"state": "RUNNING"})
get_mock = Mock()
get_mock.status_code = 200
get_mock.json.return_value = {"state": "RUNNING"}
session_mock.get.return_value = get_mock

context = DBContext(Mock(), None, None)
context = DBContext(Mock(), None, None, session_mock)
context.start_cluster()

mock_get.assert_called_once()
session_mock.get.assert_called_once()


class DatabricksTestHelper(BaseDatabricksHelper):
Expand Down

0 comments on commit 79e8707

Please sign in to comment.