Skip to content

Commit

Permalink
Merge pull request #1054 from DaanRademaker/change_registered_models_…
Browse files Browse the repository at this point in the history
…to_databricks_sdk

change api model registry to databricks-sdk
  • Loading branch information
dan1elt0m authored Nov 10, 2023
2 parents 3c201b1 + e6dcdff commit bbf737b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 276 deletions.
152 changes: 31 additions & 121 deletions aws-lambda/src/databricks_cdk/resources/mlflow/registered_model.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,27 @@
from enum import Enum
from typing import List, Optional

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.ml import ModelTag
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, delete_request, get_request, patch_request, post_request


class RegisteredModelTag(BaseModel):
key: str
value: str


class ModelVersionTag(BaseModel):
key: str
value: str
from databricks_cdk.utils import CnfResponse, get_workspace_client


class RegisteredModelProperties(BaseModel):
name: str
tags: Optional[List[RegisteredModelTag]] = []
tags: Optional[List[ModelTag]] = []
description: Optional[str]
workspace_url: str


class ModelVersionStatus(str, Enum):
PENDING_REGISTRATION = "PENDING_REGISTRATION"
FAILED_REGISTRATION = "FAILED_REGISTRATION"
READY = "READY"


class ModelVersion(BaseModel):
name: str
version: str
creation_timestamp: int
last_updated_timestamp: int
user_id: str
current_stage: str
description: Optional[str]
source: str
run_id: str
status: ModelVersionStatus
status_message: Optional[str]
tags: Optional[List[ModelVersionTag]]
run_link: Optional[str]


class RegisteredModel(BaseModel):
name: str
creation_timestamp: int
last_updated_timestamp: int
description: Optional[str]
latest_versions: Optional[List[ModelVersion]]
tags: Optional[List[RegisteredModelTag]]
user_id: Optional[str] # Currently not returned


class RegisteredModelCreateResponse(CnfResponse):
physical_resource_id: str


def get_registered_model_url(workspace_url: str):
"""Get the mlflow registered-models url"""
return f"{workspace_url}/api/2.0/mlflow/registered-models"


def _create_registered_model(registered_model_url: str, properties: RegisteredModelProperties) -> str:
"""Creates a registered model"""
response = post_request(
f"{registered_model_url}/create",
{
"name": properties.name,
"tags": [{"key": t.key, "value": t.value} for t in properties.tags],
"description": properties.description,
},
)
return response["registered_model"]["name"]


def _get_registered_model(registered_model_url: str, name: str) -> Optional[RegisteredModel]:
"""Gets the registered model"""
response = get_request(f"{registered_model_url}/get?name={name}")
if response:
return RegisteredModel.parse_obj(response["registered_model"])

return None


def _update_registered_model_description(registered_model_url: str, registered_model_name: str, description: str):
"""Updates the registered model description"""
return patch_request(
f"{registered_model_url}/update",
body={"name": registered_model_name, "description": description},
)


def _update_registered_model_name(registered_model_url: str, current_name: str, new_name: str) -> str:
"""Updates the registered model name"""
return post_request(f"{registered_model_url}/rename", {"name": current_name, "new_name": new_name})[
"registered_model"
]["name"]


def _update_registered_model_tags(
registered_model_url: str,
workspace_client: WorkspaceClient,
properties: RegisteredModelProperties,
current_tags: List[RegisteredModelTag],
current_tags: List[ModelTag],
):
"""Updates the registered model tags"""
tags_to_delete = []
Expand All @@ -116,22 +33,15 @@ def _update_registered_model_tags(
tags_to_delete = [t for t in current_tags if t.key not in new_keys]

if tags_to_delete:
[
delete_request(
f"{registered_model_url}/delete-tag",
body={"name": properties.name, "key": t.key},
)
for t in tags_to_delete
]
# delete tags that are not on the cdk object anymore
[workspace_client.model_registry.delete_model_tag(properties.name, t.key) for t in tags_to_delete]

if properties.tags:
# Overwrites / updates existing tags
[
post_request(
f"{registered_model_url}/set-tag",
{"name": properties.name, "key": t.key, "value": t.value},
)
workspace_client.model_registry.set_model_tag(properties.name, t.key, t.value)
for t in properties.tags
if t.key and t.value is not None
]


Expand All @@ -146,27 +56,28 @@ def create_or_update_registered_model(
:param physical_resource_id: CDK Physical Resource Id belonging to the Registered Model (if exists). Defaults to None
:return:physical_resource_id of the Registered Model, which equals the name of the Registered Model
"""
registered_model_url = get_registered_model_url(properties.workspace_url)

if not physical_resource_id:
registered_model_name = _create_registered_model(registered_model_url, properties)
return RegisteredModelCreateResponse(physical_resource_id=registered_model_name)
workspace_client = get_workspace_client(properties.workspace_url)

if physical_resource_id is None:
response = workspace_client.model_registry.create_model(
name=properties.name, description=properties.description, tags=properties.tags
)

registered_model_url = get_registered_model_url(properties.workspace_url)
registered_model = _get_registered_model(registered_model_url, physical_resource_id)
name = response.registered_model.name if response.registered_model else None
if name is not None:
return RegisteredModelCreateResponse(physical_resource_id=name)

if not registered_model:
registered_model = workspace_client.model_registry.get_model(name=physical_resource_id)
if registered_model is None:
raise ValueError(f"Registered model cannot be found but physical_resouce_id {physical_resource_id} is provided")

if properties.name != registered_model.name:
physical_resource_id = _update_registered_model_name(
registered_model_url,
current_name=physical_resource_id,
new_name=properties.name,
)

if properties.description != registered_model.description:
_update_registered_model_description(registered_model_url, physical_resource_id, properties.description)
if registered_model.registered_model_databricks is not None and (
properties.name != registered_model.registered_model_databricks.name
or properties.description != registered_model.registered_model_databricks.description
):
workspace_client.model_registry.update_model(name=properties.name, description=properties.description)
return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id)

new_tags = properties.tags
if properties.tags is not None:
Expand All @@ -177,14 +88,13 @@ def create_or_update_registered_model(
current_tags = sorted(current_tags, key=lambda t: t.key)

if new_tags != current_tags:
_update_registered_model_tags(registered_model_url, properties, registered_model.tags)
_update_registered_model_tags(workspace_client, properties, registered_model.tags)

return RegisteredModelCreateResponse(physical_resource_id=physical_resource_id)


def delete_registered_model(properties: RegisteredModelProperties, physical_resource_id: str):
"""Deletes an existing registered model"""
delete_request(
f"{get_registered_model_url(properties.workspace_url)}/delete",
body={"name": physical_resource_id},
)
workspace_client = get_workspace_client(properties.workspace_url)
workspace_client.model_registry.delete_model(name=physical_resource_id)
return CnfResponse(physical_resource_id=physical_resource_id)
12 changes: 12 additions & 0 deletions aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from unittest.mock import MagicMock

import pytest
from databricks.sdk import ModelRegistryAPI, WorkspaceClient


@pytest.fixture(scope="function", autouse=True)
Expand All @@ -12,3 +14,13 @@ def aws_credentials():
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"


@pytest.fixture(scope="function")
def workspace_client():
workspace_client = MagicMock(spec=WorkspaceClient)

# mock all of the underlying service api's
workspace_client.model_registry = MagicMock(spec=ModelRegistryAPI)

return workspace_client
Loading

0 comments on commit bbf737b

Please sign in to comment.