Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Nov 28, 2023
1 parent e5b1f06 commit 1fbb253
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
16 changes: 8 additions & 8 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def _create_work_dir(self, path: str) -> None:
f"Error creating work_dir for python notebooks\n {response.content!r}"
)

def _update_with_acls(self, cluster_dict: dict) -> dict:
acl = self.parsed_model["config"].get("access_control_list", None)
if acl:
cluster_dict.update({"access_control_list": acl})
return cluster_dict

def _upload_notebook(self, path: str, compiled_code: str) -> None:
b64_encoded_content = base64.b64encode(compiled_code.encode()).decode()
response = requests.post(
Expand Down Expand Up @@ -206,10 +212,7 @@ def check_credentials(self) -> None:

def submit(self, compiled_code: str) -> None:
cluster_spec = {"new_cluster": self.parsed_model["config"]["job_cluster_config"]}
acl = self.parsed_model["config"].get("access_control_list", None)
if acl:
cluster_spec.update({"access_control_list": acl})
self._submit_through_notebook(compiled_code, cluster_spec)
self._submit_through_notebook(compiled_code, self._update_with_acls(cluster_spec))


class DBContext:
Expand Down Expand Up @@ -378,10 +381,7 @@ def check_credentials(self) -> None:
def submit(self, compiled_code: str) -> None:
if self.parsed_model["config"].get("create_notebook", False):
config = {"existing_cluster_id": self.cluster_id}
acl = self.parsed_model["config"].get("access_control_list", None)
if acl:
config.update({"access_control_list": acl})
self._submit_through_notebook(compiled_code, config)
self._submit_through_notebook(compiled_code, self._update_with_acls(config))
else:
context = DBContext(self.credentials, self.cluster_id, self.auth_header)
command = DBCommand(self.credentials, self.cluster_id, self.auth_header)
Expand Down
41 changes: 40 additions & 1 deletion tests/unit/python/test_python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from unittest.mock import patch, Mock

from dbt.adapters.databricks.python_submissions import DBContext
from dbt.adapters.databricks.connections import DatabricksCredentials

from dbt.adapters.databricks.python_submissions import DBContext, BaseDatabricksHelper


class TestDatabricksPythonSubmissions(unittest.TestCase):
Expand All @@ -18,3 +20,40 @@ def test_start_cluster_returns_on_receiving_running_state(self, mock_post, mock_
context.start_cluster()

mock_get.assert_called_once()


class DatabricksTestHelper(BaseDatabricksHelper):
def __init__(self, parsed_model: dict, credentials: DatabricksCredentials):
self.parsed_model = parsed_model
self.credentials = credentials


class TestAclUpdate:
def test_empty_acl_empty_config(self):
helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials())
assert helper._update_with_acls({}) == {}

def test_empty_acl_non_empty_config(self):
helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials())
assert helper._update_with_acls({"a": "b"}) == {"a": "b"}

def test_non_empty_acl_empty_config(self):
expected_access_control = {
"access_control_list": [
{"user_name": "user2", "permission_level": "CAN_VIEW"},
]
}
helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials())
assert helper._update_with_acls({}) == expected_access_control

def test_non_empty_acl_non_empty_config(self):
expected_access_control = {
"access_control_list": [
{"user_name": "user2", "permission_level": "CAN_VIEW"},
]
}
helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials())
assert helper._update_with_acls({"a": "b"}) == {
"a": "b",
"access_control_list": expected_access_control["access_control_list"],
}

0 comments on commit 1fbb253

Please sign in to comment.