Skip to content

Commit

Permalink
add python code for creating secrects
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Shchederkin committed Dec 17, 2024
1 parent 428b8d0 commit 29b95ec
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 2 deletions.
9 changes: 9 additions & 0 deletions aws-lambda/src/databricks_cdk/resources/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

import cfnresponse
from databricks_cdk.resources.service_principals.secrets import create_or_update_service_principal_secrets, delete_service_principal_secrets
from pydantic import BaseModel, ValidationError

from databricks_cdk.resources.account.credentials import (
Expand Down Expand Up @@ -228,6 +229,10 @@ def create_or_update_resource(event: DatabricksEvent) -> CnfResponse:
return create_or_update_volume(VolumeProperties(**event.ResourceProperties), event.PhysicalResourceId)
elif action == "service-principal":
return create_or_update_service_principal(ServicePrincipalProperties(**event.ResourceProperties))
elif action == "service-principal-secrets":
return create_or_update_service_principal_secrets(
ServicePrincipalProperties(**event.ResourceProperties), event.PhysicalResourceId,
)
else:
raise RuntimeError(f"Unknown action: {action}")

Expand Down Expand Up @@ -336,6 +341,10 @@ def delete_resource(event: DatabricksEvent) -> CnfResponse:
return delete_service_principal(
ServicePrincipalProperties(**event.ResourceProperties), event.PhysicalResourceId
)
elif action == "service-principal-secrets":
return delete_service_principal_secrets(
ServicePrincipalProperties(**event.ResourceProperties), event.PhysicalResourceId
)
else:
raise RuntimeError(f"Unknown action: {action}")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from typing import Optional

from databricks.sdk.service.oauth2 import SecretInfo
from databricks.sdk import AccountClient
from databricks.sdk.errors import NotFound
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_account_client

logger = logging.getLogger(__name__)


class ServicePrincipalSecretsCreationError(Exception):
pass


class ServicePrincipalSecretsProperties(BaseModel):
service_principal_id: int


def create_or_update_service_principal_secrets(
properties: ServicePrincipalSecretsProperties,
physical_resource_id: Optional[str] = None
) -> CnfResponse:
"""
Create or update service principal secrets on databricks.
If service principal secrets already exist, it will return the existing service principal secrets.
If service principal secrets doesn't exist, it will create a new one.
"""
account_client = get_account_client()

if physical_resource_id:
existing_service_principal_secrets = get_service_principal_secrets(
service_principal_id=properties.service_principal_id,
physical_resource_id=physical_resource_id,
account_client=account_client
)
return CnfResponse(
physical_resource_id=existing_service_principal_secrets.id
)

return create_service_principal_secrets(properties, account_client)


def get_service_principal_secrets(
service_principal_id: int,
physical_resource_id: str,
account_client: AccountClient
) -> SecretInfo:
"""Get service principal secrets on databricks based on physical resource id and service principal id."""
existing_service_principal_secrets = account_client.service_principal_secrets.list(
service_principal_id=service_principal_id
)

for secret_info in existing_service_principal_secrets:
if secret_info is not None and secret_info.id == physical_resource_id:
return secret_info
else:
raise NotFound(f"Service principal secrets with id {physical_resource_id} not found")


def create_service_principal_secrets(properties: ServicePrincipalSecretsProperties, account_client: AccountClient) -> CnfResponse:
"""Create service principal secrets on databricks."""
created_service_principal_secrets = account_client.service_principal_secrets.create(
service_principal_id=properties.service_principal_id
)

if created_service_principal_secrets.id is None:
raise ServicePrincipalSecretsCreationError("Failed to create service principal secrets")

return CnfResponse(
physical_resource_id=created_service_principal_secrets.id
)


def delete_service_principal_secrets(properties: ServicePrincipalSecretsProperties, physical_resource_id: str) -> CnfResponse:
"""Delete service pricncipal secrets on databricks."""
account_client = get_account_client()
try:
account_client.service_principal_secrets.delete(
service_principal_id=properties.service_principal_id,
secret_id=physical_resource_id
)
except NotFound:
logger.warning("Service principal secrets with id %s not found", physical_resource_id)

return CnfResponse(physical_resource_id=physical_resource_id)
2 changes: 2 additions & 0 deletions aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from databricks.sdk import AccountClient, CredentialsAPI, ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient
from databricks.sdk.service.iam import ServicePrincipalsAPI
from databricks.sdk.service.oauth2 import ServicePrincipalSecretsAPI


@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -36,5 +37,6 @@ def account_client():

# mock all of the underlying service api's
account_client.credentials = MagicMock(spec=CredentialsAPI)
account_client.service_principal_secrets = MagicMock(spec=ServicePrincipalSecretsAPI)

return account_client
122 changes: 122 additions & 0 deletions aws-lambda/tests/resources/service_principals/test_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from unittest.mock import patch

import pytest
from databricks.sdk.errors import NotFound
from databricks.sdk.service.oauth2 import SecretInfo, CreateServicePrincipalSecretResponse

from databricks_cdk.resources.service_principals.secrets import (
ServicePrincipalSecretsCreationError,
ServicePrincipalSecretsProperties,
create_service_principal_secrets,
delete_service_principal_secrets,
get_service_principal_secrets,
create_or_update_service_principal_secrets,
)
from databricks_cdk.utils import CnfResponse


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets")
def test_create_or_update_service_principal_secrets_create(
patched_get_service_principal_secrets,
patched_create_service_principal_secrets,
patched_get_account_client,
account_client,
):
patched_get_account_client.return_value = account_client
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)

create_or_update_service_principal_secrets(properties=mock_properties)

patched_create_service_principal_secrets.assert_called_once_with(
mock_properties,
account_client,
)
patched_get_service_principal_secrets.assert_not_called()


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets")
def test_create_or_update_service_principal_secrets_update(
patched_get_service_principal_secrets,
patched_create_service_principal_secrets,
patched_get_account_client,
account_client,
):
patched_get_account_client.return_value = account_client
mock_physical_resource_id = "some_id"
existing_service_principal_secrets = SecretInfo(id=mock_physical_resource_id)
patched_get_service_principal_secrets.return_value = existing_service_principal_secrets
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)

response = create_or_update_service_principal_secrets(
properties=mock_properties, physical_resource_id=mock_physical_resource_id
)

assert response == CnfResponse(physical_resource_id=mock_physical_resource_id)
patched_get_service_principal_secrets.assert_called_once_with(
service_principal_id=mock_properties.service_principal_id,
physical_resource_id=mock_physical_resource_id,
account_client=account_client
)
patched_create_service_principal_secrets.assert_not_called()


def test_get_service_principal_secrets(account_client):
service_principal_id = 1
service_principal_secrets_id = "some_id"
service_principal_info = SecretInfo(id=service_principal_secrets_id)
account_client.service_principal_secrets.list.return_value = [service_principal_info]

response = get_service_principal_secrets(service_principal_id, service_principal_secrets_id, account_client)
assert response == service_principal_info
account_client.service_principal_secrets.list.assert_called_once_with(service_principal_id=service_principal_id)


def test_get_service_principal_secrets_error(account_client):
existing_service_principal_secrets = SecretInfo(id="some_different_id")
account_client.service_principal_secrets.list.return_value = [existing_service_principal_secrets]

with pytest.raises(NotFound):
get_service_principal_secrets(1, "some_id", account_client)


def test_get_service_principal_secrets_error_no_secrets(account_client):
account_client.service_principal_secrets.list.return_value = []

with pytest.raises(NotFound):
get_service_principal_secrets(1, "some_id", account_client)


def test_create_service_principal_secrets(account_client):
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)
account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse(id="some_id")
response = create_service_principal_secrets(mock_properties, account_client)

assert response == CnfResponse(physical_resource_id="some_id")
account_client.service_principal_secrets.create.assert_called_once_with(
service_principal_id=1,
)


def test_create_service_principal_secrets_error(account_client):
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)
account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse(id=None)

with pytest.raises(ServicePrincipalSecretsCreationError):
create_service_principal_secrets(mock_properties, account_client)


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
def test_delete_service_principal(patched_get_account_client, account_client):
patched_get_account_client.return_value = account_client
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)
response = delete_service_principal_secrets(mock_properties, "some_id")

assert response == CnfResponse(physical_resource_id="some_id")
account_client.service_principal_secrets.delete.assert_called_once_with(
service_principal_id=1,
secret_id="some_id",
)
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_create_or_update_service_principal_update(
id=mock_physical_resource_id,
roles=[ComplexValue(value="role")],
)
workspace_client.service_principals.get.return_value = existing_service_principal
patched_get_service_principal.return_value = existing_service_principal
mock_properties = ServicePrincipalProperties(
workspace_url="https://test.cloud.databricks.com",
service_principal=ServicePrincipal(
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_get_service_principal_error(workspace_client):
get_service_principal("some_id", workspace_client)


def test_create_service_principle(workspace_client):
def test_create_service_principal(workspace_client):
mock_properties = ServicePrincipalProperties(
workspace_url="https://test.cloud.databricks.com",
service_principal=ServicePrincipal(
Expand Down

0 comments on commit 29b95ec

Please sign in to comment.