Skip to content

Commit

Permalink
fix python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Shchederkin committed Dec 18, 2024
1 parent 27ba236 commit dba0a51
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 44 deletions.
9 changes: 7 additions & 2 deletions aws-lambda/src/databricks_cdk/resources/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional

import cfnresponse
from databricks_cdk.resources.service_principals.service_principal_secrets import ServicePrincipalSecretsProperties, create_or_update_service_principal_secrets, delete_service_principal_secrets
from pydantic import BaseModel, ValidationError

from databricks_cdk.resources.account.credentials import (
Expand Down Expand Up @@ -97,6 +96,11 @@
create_or_update_service_principal,
delete_service_principal,
)
from databricks_cdk.resources.service_principals.service_principal_secrets import (
ServicePrincipalSecretsProperties,
create_or_update_service_principal_secrets,
delete_service_principal_secrets,
)
from databricks_cdk.resources.sql_warehouses.sql_warehouses import (
SQLWarehouseProperties,
create_or_update_warehouse,
Expand Down Expand Up @@ -231,7 +235,8 @@ def create_or_update_resource(event: DatabricksEvent) -> CnfResponse:
return create_or_update_service_principal(ServicePrincipalProperties(**event.ResourceProperties))
elif action == "service-principal-secrets":
return create_or_update_service_principal_secrets(
ServicePrincipalSecretsProperties(**event.ResourceProperties), event.PhysicalResourceId,
ServicePrincipalSecretsProperties(**event.ResourceProperties),
event.PhysicalResourceId,
)
else:
raise RuntimeError(f"Unknown action: {action}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from databricks.sdk.service.iam import ServicePrincipal
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_workspace_client, get_account_client
from databricks_cdk.utils import CnfResponse, get_account_client, get_workspace_client

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Optional

import boto3
from databricks.sdk.service.oauth2 import SecretInfo
from databricks.sdk import AccountClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.oauth2 import SecretInfo
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_account_client
from databricks_cdk.resources.service_principals.service_principal import get_service_principal
from databricks_cdk.utils import CnfResponse, get_account_client

logger = logging.getLogger(__name__)

Expand All @@ -23,34 +23,29 @@ class ServicePrincipalSecretsProperties(BaseModel):


def create_or_update_service_principal_secrets(
properties: ServicePrincipalSecretsProperties,
physical_resource_id: Optional[str] = None
) -> CnfResponse:
properties: ServicePrincipalSecretsProperties, physical_resource_id: Optional[str] = None
) -> CnfResponse:
"""
Create or update service principal secrets on databricks.
If service principal secrets already exist, it will return the existing service principal secrets.
If service principal secrets doesn't exist, it will create a new one.
"""
account_client = get_account_client()

if physical_resource_id:
existing_service_principal_secrets = get_service_principal_secrets(
service_principal_id=properties.service_principal_id,
physical_resource_id=physical_resource_id,
account_client=account_client
account_client=account_client,
)
return CnfResponse(
physical_resource_id=existing_service_principal_secrets.id
)

return CnfResponse(physical_resource_id=existing_service_principal_secrets.id)

return create_service_principal_secrets(properties, account_client)


def get_service_principal_secrets(
service_principal_id: int,
physical_resource_id: str,
account_client: AccountClient
) -> SecretInfo:
service_principal_id: int, physical_resource_id: str, account_client: AccountClient
) -> SecretInfo:
"""Get service principal secrets on databricks based on physical resource id and service principal id."""
existing_service_principal_secrets = account_client.service_principal_secrets.list(
service_principal_id=service_principal_id
Expand All @@ -63,7 +58,9 @@ def get_service_principal_secrets(
raise NotFound(f"Service principal secrets with id {physical_resource_id} not found")


def create_service_principal_secrets(properties: ServicePrincipalSecretsProperties, account_client: AccountClient) -> CnfResponse:
def create_service_principal_secrets(
properties: ServicePrincipalSecretsProperties, account_client: AccountClient
) -> CnfResponse:
"""
Create service principal secrets on databricks.
It will create a new service principal secrets and store it in secrets manager.
Expand All @@ -80,25 +77,24 @@ def create_service_principal_secrets(properties: ServicePrincipalSecretsProperti
add_to_secrets_manager(
secret_name=secret_name,
client_id=service_principal.application_id,
client_secret=created_service_principal_secrets.secret
)
return CnfResponse(
physical_resource_id=created_service_principal_secrets.id
client_secret=created_service_principal_secrets.secret,
)
return CnfResponse(physical_resource_id=created_service_principal_secrets.id)


def delete_service_principal_secrets(properties: ServicePrincipalSecretsProperties, physical_resource_id: str) -> CnfResponse:
def delete_service_principal_secrets(
properties: ServicePrincipalSecretsProperties, physical_resource_id: str
) -> CnfResponse:
"""Delete service pricncipal secrets on databricks."""
account_client = get_account_client()

try:
account_client.service_principal_secrets.delete(
service_principal_id=properties.service_principal_id,
secret_id=physical_resource_id
service_principal_id=properties.service_principal_id, secret_id=physical_resource_id
)
except NotFound:
logger.warning("Service principal secrets with id %s not found", physical_resource_id)

service_principal = get_service_principal(properties.service_principal_id, account_client)
secret_name = f"{service_principal.display_name}/{service_principal.id}"
delete_from_secrets_manager(secret_name)
Expand Down
3 changes: 2 additions & 1 deletion aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from databricks.sdk import AccountClient, CredentialsAPI, ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient
from databricks.sdk.service.iam import ServicePrincipalsAPI
from databricks.sdk.service.iam import ServicePrincipalsAPI, AccountServicePrincipalsAPI
from databricks.sdk.service.oauth2 import ServicePrincipalSecretsAPI


Expand Down Expand Up @@ -38,5 +38,6 @@ def account_client():
# mock all of the underlying service api's
account_client.credentials = MagicMock(spec=CredentialsAPI)
account_client.service_principal_secrets = MagicMock(spec=ServicePrincipalSecretsAPI)
account_client.service_principals = MagicMock(spec=AccountServicePrincipalsAPI)

return account_client
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,15 @@ def test_update_service_principal(workspace_client):


@patch("databricks_cdk.resources.service_principals.service_principal.get_workspace_client")
def test_delete_service_principal(patched_get_workspace_client, workspace_client):
@patch("databricks_cdk.resources.service_principals.service_principal.get_account_client")
def test_delete_service_principal(
patched_get_account_client,
patched_get_workspace_client,
workspace_client,
account_client,
):
patched_get_workspace_client.return_value = workspace_client
patched_get_account_client.return_value = account_client
mock_properties = ServicePrincipalProperties(
workspace_url="https://test.cloud.databricks.com",
service_principal=ServicePrincipal(
Expand All @@ -223,3 +230,4 @@ def test_delete_service_principal(patched_get_workspace_client, workspace_client

assert response == CnfResponse(physical_resource_id="some_id")
workspace_client.service_principals.delete.assert_called_once_with(id="some_id")
account_client.service_principals.delete.assert_called_once_with(id="some_id")
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@

import pytest
from databricks.sdk.errors import NotFound
from databricks.sdk.service.oauth2 import SecretInfo, CreateServicePrincipalSecretResponse
from databricks.sdk.service.iam import ServicePrincipal
from databricks.sdk.service.oauth2 import CreateServicePrincipalSecretResponse, SecretInfo

from databricks_cdk.resources.service_principals.service_principal_secrets import (
ServicePrincipalSecretsCreationError,
ServicePrincipalSecretsProperties,
create_or_update_service_principal_secrets,
create_service_principal_secrets,
delete_service_principal_secrets,
get_service_principal_secrets,
create_or_update_service_principal_secrets,
)
from databricks_cdk.utils import CnfResponse


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal_secrets")
def test_create_or_update_service_principal_secrets_create(
patched_get_service_principal_secrets,
patched_create_service_principal_secrets,
Expand All @@ -36,9 +37,9 @@ def test_create_or_update_service_principal_secrets_create(
patched_get_service_principal_secrets.assert_not_called()


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.create_service_principal_secrets")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal_secrets")
def test_create_or_update_service_principal_secrets_update(
patched_get_service_principal_secrets,
patched_create_service_principal_secrets,
Expand All @@ -57,9 +58,9 @@ def test_create_or_update_service_principal_secrets_update(

assert response == CnfResponse(physical_resource_id=mock_physical_resource_id)
patched_get_service_principal_secrets.assert_called_once_with(
service_principal_id=mock_properties.service_principal_id,
physical_resource_id=mock_physical_resource_id,
account_client=account_client
service_principal_id=mock_properties.service_principal_id,
physical_resource_id=mock_physical_resource_id,
account_client=account_client,
)
patched_create_service_principal_secrets.assert_not_called()

Expand Down Expand Up @@ -90,15 +91,34 @@ def test_get_service_principal_secrets_error_no_secrets(account_client):
get_service_principal_secrets(1, "some_id", account_client)


def test_create_service_principal_secrets(account_client):
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.add_to_secrets_manager")
def test_create_service_principal_secrets(
patched_add_to_secrets_manager,
patched_get_service_principal,
account_client,
):
patched_get_service_principal.return_value = ServicePrincipal(
application_id="some_client_id",
display_name="mock_name",
id=1,
)
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)
account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse(id="some_id")
account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse(
id="some_id",
secret="some_secret_id",
)
response = create_service_principal_secrets(mock_properties, account_client)

assert response == CnfResponse(physical_resource_id="some_id")
account_client.service_principal_secrets.create.assert_called_once_with(
service_principal_id=1,
)
patched_add_to_secrets_manager.assert_called_once_with(
secret_name="mock_name/1",
client_id="some_client_id",
client_secret="some_secret_id",
)


def test_create_service_principal_secrets_error(account_client):
Expand All @@ -109,8 +129,20 @@ def test_create_service_principal_secrets_error(account_client):
create_service_principal_secrets(mock_properties, account_client)


@patch("databricks_cdk.resources.service_principals.secrets.get_account_client")
def test_delete_service_principal(patched_get_account_client, account_client):
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.delete_from_secrets_manager")
@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal")
def test_delete_service_principal(
patched_get_service_principal,
patched_delete_from_secrets_manager,
patched_get_account_client,
account_client,
):
patched_get_service_principal.return_value = ServicePrincipal(
application_id="some_id",
display_name="mock_name",
id=1,
)
patched_get_account_client.return_value = account_client
mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1)
response = delete_service_principal_secrets(mock_properties, "some_id")
Expand All @@ -120,3 +152,4 @@ def test_delete_service_principal(patched_get_account_client, account_client):
service_principal_id=1,
secret_id="some_id",
)
patched_delete_from_secrets_manager.assert_called_once_with("mock_name/1")

0 comments on commit dba0a51

Please sign in to comment.