diff --git a/aws-lambda/src/databricks_cdk/resources/handler.py b/aws-lambda/src/databricks_cdk/resources/handler.py index 59c304ae..c3467952 100644 --- a/aws-lambda/src/databricks_cdk/resources/handler.py +++ b/aws-lambda/src/databricks_cdk/resources/handler.py @@ -2,7 +2,6 @@ from typing import Optional import cfnresponse -from databricks_cdk.resources.service_principals.service_principal_secrets import ServicePrincipalSecretsProperties, create_or_update_service_principal_secrets, delete_service_principal_secrets from pydantic import BaseModel, ValidationError from databricks_cdk.resources.account.credentials import ( @@ -97,6 +96,11 @@ create_or_update_service_principal, delete_service_principal, ) +from databricks_cdk.resources.service_principals.service_principal_secrets import ( + ServicePrincipalSecretsProperties, + create_or_update_service_principal_secrets, + delete_service_principal_secrets, +) from databricks_cdk.resources.sql_warehouses.sql_warehouses import ( SQLWarehouseProperties, create_or_update_warehouse, @@ -231,7 +235,8 @@ def create_or_update_resource(event: DatabricksEvent) -> CnfResponse: return create_or_update_service_principal(ServicePrincipalProperties(**event.ResourceProperties)) elif action == "service-principal-secrets": return create_or_update_service_principal_secrets( - ServicePrincipalSecretsProperties(**event.ResourceProperties), event.PhysicalResourceId, + ServicePrincipalSecretsProperties(**event.ResourceProperties), + event.PhysicalResourceId, ) else: raise RuntimeError(f"Unknown action: {action}") diff --git a/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal.py b/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal.py index 0fe9d4e5..18b50990 100644 --- a/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal.py +++ b/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal.py @@ -5,7 +5,7 @@ from databricks.sdk.service.iam import ServicePrincipal from pydantic import BaseModel -from databricks_cdk.utils import CnfResponse, get_workspace_client, get_account_client +from databricks_cdk.utils import CnfResponse, get_account_client, get_workspace_client logger = logging.getLogger(__name__) diff --git a/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal_secrets.py b/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal_secrets.py index 3535c56a..eb6095f6 100644 --- a/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal_secrets.py +++ b/aws-lambda/src/databricks_cdk/resources/service_principals/service_principal_secrets.py @@ -3,13 +3,13 @@ from typing import Optional import boto3 -from databricks.sdk.service.oauth2 import SecretInfo from databricks.sdk import AccountClient from databricks.sdk.errors import NotFound +from databricks.sdk.service.oauth2 import SecretInfo from pydantic import BaseModel -from databricks_cdk.utils import CnfResponse, get_account_client from databricks_cdk.resources.service_principals.service_principal import get_service_principal +from databricks_cdk.utils import CnfResponse, get_account_client logger = logging.getLogger(__name__) @@ -23,34 +23,29 @@ class ServicePrincipalSecretsProperties(BaseModel): def create_or_update_service_principal_secrets( - properties: ServicePrincipalSecretsProperties, - physical_resource_id: Optional[str] = None - ) -> CnfResponse: + properties: ServicePrincipalSecretsProperties, physical_resource_id: Optional[str] = None +) -> CnfResponse: """ Create or update service principal secrets on databricks. If service principal secrets already exist, it will return the existing service principal secrets. If service principal secrets doesn't exist, it will create a new one. """ account_client = get_account_client() - + if physical_resource_id: existing_service_principal_secrets = get_service_principal_secrets( service_principal_id=properties.service_principal_id, physical_resource_id=physical_resource_id, - account_client=account_client + account_client=account_client, ) - return CnfResponse( - physical_resource_id=existing_service_principal_secrets.id - ) - + return CnfResponse(physical_resource_id=existing_service_principal_secrets.id) + return create_service_principal_secrets(properties, account_client) def get_service_principal_secrets( - service_principal_id: int, - physical_resource_id: str, - account_client: AccountClient - ) -> SecretInfo: + service_principal_id: int, physical_resource_id: str, account_client: AccountClient +) -> SecretInfo: """Get service principal secrets on databricks based on physical resource id and service principal id.""" existing_service_principal_secrets = account_client.service_principal_secrets.list( service_principal_id=service_principal_id @@ -63,7 +58,9 @@ def get_service_principal_secrets( raise NotFound(f"Service principal secrets with id {physical_resource_id} not found") -def create_service_principal_secrets(properties: ServicePrincipalSecretsProperties, account_client: AccountClient) -> CnfResponse: +def create_service_principal_secrets( + properties: ServicePrincipalSecretsProperties, account_client: AccountClient +) -> CnfResponse: """ Create service principal secrets on databricks. It will create a new service principal secrets and store it in secrets manager. @@ -80,25 +77,24 @@ def create_service_principal_secrets(properties: ServicePrincipalSecretsProperti add_to_secrets_manager( secret_name=secret_name, client_id=service_principal.application_id, - client_secret=created_service_principal_secrets.secret - ) - return CnfResponse( - physical_resource_id=created_service_principal_secrets.id + client_secret=created_service_principal_secrets.secret, ) + return CnfResponse(physical_resource_id=created_service_principal_secrets.id) -def delete_service_principal_secrets(properties: ServicePrincipalSecretsProperties, physical_resource_id: str) -> CnfResponse: +def delete_service_principal_secrets( + properties: ServicePrincipalSecretsProperties, physical_resource_id: str +) -> CnfResponse: """Delete service pricncipal secrets on databricks.""" account_client = get_account_client() try: account_client.service_principal_secrets.delete( - service_principal_id=properties.service_principal_id, - secret_id=physical_resource_id + service_principal_id=properties.service_principal_id, secret_id=physical_resource_id ) except NotFound: logger.warning("Service principal secrets with id %s not found", physical_resource_id) - + service_principal = get_service_principal(properties.service_principal_id, account_client) secret_name = f"{service_principal.display_name}/{service_principal.id}" delete_from_secrets_manager(secret_name) diff --git a/aws-lambda/tests/conftest.py b/aws-lambda/tests/conftest.py index 6e8fa128..b0baf064 100644 --- a/aws-lambda/tests/conftest.py +++ b/aws-lambda/tests/conftest.py @@ -3,7 +3,7 @@ import pytest from databricks.sdk import AccountClient, CredentialsAPI, ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient -from databricks.sdk.service.iam import ServicePrincipalsAPI +from databricks.sdk.service.iam import ServicePrincipalsAPI, AccountServicePrincipalsAPI from databricks.sdk.service.oauth2 import ServicePrincipalSecretsAPI @@ -38,5 +38,6 @@ def account_client(): # mock all of the underlying service api's account_client.credentials = MagicMock(spec=CredentialsAPI) account_client.service_principal_secrets = MagicMock(spec=ServicePrincipalSecretsAPI) + account_client.service_principals = MagicMock(spec=AccountServicePrincipalsAPI) return account_client diff --git a/aws-lambda/tests/resources/service_principals/test_service_principal.py b/aws-lambda/tests/resources/service_principals/test_service_principal.py index 5a06dc50..330177e2 100644 --- a/aws-lambda/tests/resources/service_principals/test_service_principal.py +++ b/aws-lambda/tests/resources/service_principals/test_service_principal.py @@ -209,8 +209,15 @@ def test_update_service_principal(workspace_client): @patch("databricks_cdk.resources.service_principals.service_principal.get_workspace_client") -def test_delete_service_principal(patched_get_workspace_client, workspace_client): +@patch("databricks_cdk.resources.service_principals.service_principal.get_account_client") +def test_delete_service_principal( + patched_get_account_client, + patched_get_workspace_client, + workspace_client, + account_client, +): patched_get_workspace_client.return_value = workspace_client + patched_get_account_client.return_value = account_client mock_properties = ServicePrincipalProperties( workspace_url="https://test.cloud.databricks.com", service_principal=ServicePrincipal( @@ -223,3 +230,4 @@ def test_delete_service_principal(patched_get_workspace_client, workspace_client assert response == CnfResponse(physical_resource_id="some_id") workspace_client.service_principals.delete.assert_called_once_with(id="some_id") + account_client.service_principals.delete.assert_called_once_with(id="some_id") diff --git a/aws-lambda/tests/resources/service_principals/test_service_principal_secrets.py b/aws-lambda/tests/resources/service_principals/test_service_principal_secrets.py index 59c4871c..54da5232 100644 --- a/aws-lambda/tests/resources/service_principals/test_service_principal_secrets.py +++ b/aws-lambda/tests/resources/service_principals/test_service_principal_secrets.py @@ -2,22 +2,23 @@ import pytest from databricks.sdk.errors import NotFound -from databricks.sdk.service.oauth2 import SecretInfo, CreateServicePrincipalSecretResponse +from databricks.sdk.service.iam import ServicePrincipal +from databricks.sdk.service.oauth2 import CreateServicePrincipalSecretResponse, SecretInfo from databricks_cdk.resources.service_principals.service_principal_secrets import ( ServicePrincipalSecretsCreationError, ServicePrincipalSecretsProperties, + create_or_update_service_principal_secrets, create_service_principal_secrets, delete_service_principal_secrets, get_service_principal_secrets, - create_or_update_service_principal_secrets, ) from databricks_cdk.utils import CnfResponse -@patch("databricks_cdk.resources.service_principals.secrets.get_account_client") -@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets") -@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.create_service_principal_secrets") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal_secrets") def test_create_or_update_service_principal_secrets_create( patched_get_service_principal_secrets, patched_create_service_principal_secrets, @@ -36,9 +37,9 @@ def test_create_or_update_service_principal_secrets_create( patched_get_service_principal_secrets.assert_not_called() -@patch("databricks_cdk.resources.service_principals.secrets.get_account_client") -@patch("databricks_cdk.resources.service_principals.secrets.create_service_principal_secrets") -@patch("databricks_cdk.resources.service_principals.secrets.get_service_principal_secrets") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.create_service_principal_secrets") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal_secrets") def test_create_or_update_service_principal_secrets_update( patched_get_service_principal_secrets, patched_create_service_principal_secrets, @@ -57,9 +58,9 @@ def test_create_or_update_service_principal_secrets_update( assert response == CnfResponse(physical_resource_id=mock_physical_resource_id) patched_get_service_principal_secrets.assert_called_once_with( - service_principal_id=mock_properties.service_principal_id, - physical_resource_id=mock_physical_resource_id, - account_client=account_client + service_principal_id=mock_properties.service_principal_id, + physical_resource_id=mock_physical_resource_id, + account_client=account_client, ) patched_create_service_principal_secrets.assert_not_called() @@ -90,15 +91,34 @@ def test_get_service_principal_secrets_error_no_secrets(account_client): get_service_principal_secrets(1, "some_id", account_client) -def test_create_service_principal_secrets(account_client): +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.add_to_secrets_manager") +def test_create_service_principal_secrets( + patched_add_to_secrets_manager, + patched_get_service_principal, + account_client, +): + patched_get_service_principal.return_value = ServicePrincipal( + application_id="some_client_id", + display_name="mock_name", + id=1, + ) mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1) - account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse(id="some_id") + account_client.service_principal_secrets.create.return_value = CreateServicePrincipalSecretResponse( + id="some_id", + secret="some_secret_id", + ) response = create_service_principal_secrets(mock_properties, account_client) assert response == CnfResponse(physical_resource_id="some_id") account_client.service_principal_secrets.create.assert_called_once_with( service_principal_id=1, ) + patched_add_to_secrets_manager.assert_called_once_with( + secret_name="mock_name/1", + client_id="some_client_id", + client_secret="some_secret_id", + ) def test_create_service_principal_secrets_error(account_client): @@ -109,8 +129,20 @@ def test_create_service_principal_secrets_error(account_client): create_service_principal_secrets(mock_properties, account_client) -@patch("databricks_cdk.resources.service_principals.secrets.get_account_client") -def test_delete_service_principal(patched_get_account_client, account_client): +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_account_client") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.delete_from_secrets_manager") +@patch("databricks_cdk.resources.service_principals.service_principal_secrets.get_service_principal") +def test_delete_service_principal( + patched_get_service_principal, + patched_delete_from_secrets_manager, + patched_get_account_client, + account_client, +): + patched_get_service_principal.return_value = ServicePrincipal( + application_id="some_id", + display_name="mock_name", + id=1, + ) patched_get_account_client.return_value = account_client mock_properties = ServicePrincipalSecretsProperties(service_principal_id=1) response = delete_service_principal_secrets(mock_properties, "some_id") @@ -120,3 +152,4 @@ def test_delete_service_principal(patched_get_account_client, account_client): service_principal_id=1, secret_id="some_id", ) + patched_delete_from_secrets_manager.assert_called_once_with("mock_name/1")