Skip to content

Commit

Permalink
✨ add inventorization for ML-related objects (#56)
Browse files Browse the repository at this point in the history
* update readme

* add tests

* add impl for models

* switch to getattr function

* switch to getattr
  • Loading branch information
renardeinside authored Jul 26, 2023
1 parent ccd8b4a commit 6a18748
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 9 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ Workflows:

ML:

- [ ] MLflow experiments
- [ ] MLflow registry
- [ ] Legacy Mlflow model endpoints (?)
- [x] MLflow experiments
- [x] MLflow registry

SQL:

Expand Down
7 changes: 3 additions & 4 deletions src/uc_migration_toolkit/managers/inventory/inventorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ def preload(self):
logger.info(f"Object metadata prepared for {len(self._objects)} objects.")

def _process_single_object(self, _object: InventoryObject) -> PermissionsInventoryItem:
permissions = self._permissions_function(
self._request_object_type, _object.__getattribute__(self._id_attribute)
)
object_id = str(getattr(_object, self._id_attribute))
permissions = self._permissions_function(self._request_object_type, object_id)
inventory_item = PermissionsInventoryItem(
object_id=str(_object.__getattribute__(self._id_attribute)),
object_id=object_id,
logical_object_type=self._logical_object_type,
request_object_type=self._request_object_type,
object_permissions=permissions.as_dict(),
Expand Down
17 changes: 17 additions & 0 deletions src/uc_migration_toolkit/managers/inventory/listing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from collections.abc import Iterator

from databricks.sdk.service.ml import ModelDatabricks

from uc_migration_toolkit.providers.client import provider


class CustomListing:
"""
Provides utility functions for custom listing operations
"""

@staticmethod
def list_models() -> Iterator[ModelDatabricks]:
for model in provider.ws.model_registry.list_models():
model_with_id = provider.ws.model_registry.get_model(model.name).registered_model_databricks
yield model_with_id
13 changes: 13 additions & 0 deletions src/uc_migration_toolkit/managers/inventory/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from uc_migration_toolkit.managers.group import MigrationGroupsProvider
from uc_migration_toolkit.managers.inventory.inventorizer import StandardInventorizer
from uc_migration_toolkit.managers.inventory.listing import CustomListing
from uc_migration_toolkit.managers.inventory.table import InventoryTableManager
from uc_migration_toolkit.managers.inventory.types import (
LogicalObjectType,
Expand Down Expand Up @@ -70,6 +71,18 @@ def get_inventorizers():
listing_function=provider.ws.jobs.list,
id_attribute="job_id",
),
StandardInventorizer(
logical_object_type=LogicalObjectType.EXPERIMENT,
request_object_type=RequestObjectType.EXPERIMENTS,
listing_function=provider.ws.experiments.list_experiments,
id_attribute="experiment_id",
),
StandardInventorizer(
logical_object_type=LogicalObjectType.MODEL,
request_object_type=RequestObjectType.REGISTERED_MODELS,
listing_function=CustomListing.list_models,
id_attribute="id",
),
]

def inventorize_permissions(self):
Expand Down
2 changes: 2 additions & 0 deletions src/uc_migration_toolkit/managers/inventory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __repr__(self):


class LogicalObjectType(StrEnum):
MODEL = "MODEL"
EXPERIMENT = "EXPERIMENT"
JOB = "JOB"
PIPELINE = "PIPELINE"
CLUSTER = "CLUSTER"
Expand Down
73 changes: 72 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import pytest
from _pytest.fixtures import SubRequest
from databricks.sdk import AccountClient
from databricks.sdk.core import DatabricksError
from databricks.sdk.service.compute import (
ClusterDetails,
CreateInstancePoolResponse,
CreatePolicyResponse,
)
from databricks.sdk.service.iam import PermissionLevel
from databricks.sdk.service.jobs import CreateResponse
from databricks.sdk.service.ml import CreateExperimentResponse, ModelDatabricks
from databricks.sdk.service.ml import PermissionLevel as ModelPermissionLevel
from databricks.sdk.service.pipelines import (
CreatePipelineResponse,
NotebookLibrary,
Expand Down Expand Up @@ -50,6 +53,8 @@
NUM_TEST_CLUSTER_POLICIES = os.environ.get("NUM_TEST_CLUSTER_POLICIES", 3)
NUM_TEST_PIPELINES = os.environ.get("NUM_TEST_PIPELINES", 3)
NUM_TEST_JOBS = os.environ.get("NUM_TEST_JOBS", 3)
NUM_TEST_EXPERIMENTS = os.environ.get("NUM_TEST_EXPERIMENTS", 3)
NUM_TEST_MODELS = os.environ.get("NUM_TEST_MODELS", 3)

NUM_THREADS = os.environ.get("NUM_TEST_THREADS", 20)
DB_CONNECT_CLUSTER_NAME = os.environ.get("DB_CONNECT_CLUSTER_NAME", "ucx-integration-testing")
Expand Down Expand Up @@ -332,16 +337,82 @@ def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterD
logger.debug("Test clusters deleted")


@pytest.fixture(scope="session", autouse=True)
def experiments(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[CreateExperimentResponse]:
logger.debug("Creating test experiments")

try:
ws.workspace.mkdirs("/experiments")
except DatabricksError:
pass

test_experiments = [
ws.experiments.create_experiment(name=f"/experiments/{env.test_uid}-test-{i}")
for i in range(NUM_TEST_EXPERIMENTS)
]

_set_random_permissions(
test_experiments,
"experiment_id",
RequestObjectType.EXPERIMENTS,
env,
ws,
permission_levels=[PermissionLevel.CAN_MANAGE, PermissionLevel.CAN_READ, PermissionLevel.CAN_EDIT],
)

yield test_experiments

logger.debug("Deleting test experiments")
executables = [partial(ws.experiments.delete_experiment, e.experiment_id) for e in test_experiments]
Threader(executables).run()
logger.debug("Test experiments deleted")


@pytest.fixture(scope="session", autouse=True)
def models(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[ModelDatabricks]:
logger.debug("Creating models")

test_models: list[ModelDatabricks] = [
ws.model_registry.get_model(
ws.model_registry.create_model(f"{env.test_uid}-test-{i}").registered_model.name
).registered_model_databricks
for i in range(NUM_TEST_MODELS)
]

_set_random_permissions(
test_models,
"id",
RequestObjectType.REGISTERED_MODELS,
env,
ws,
permission_levels=[
ModelPermissionLevel.CAN_READ,
ModelPermissionLevel.CAN_MANAGE,
ModelPermissionLevel.CAN_MANAGE_PRODUCTION_VERSIONS,
ModelPermissionLevel.CAN_MANAGE_STAGING_VERSIONS,
],
)

yield test_models

logger.debug("Deleting test models")
executables = [partial(provider.ws.model_registry.delete_model, m.name) for m in test_models]
Threader(executables).run()
logger.debug("Test models deleted")


@pytest.fixture(scope="session", autouse=True)
def verifiable_objects(
clusters, instance_pools, cluster_policies, pipelines, jobs
clusters, instance_pools, cluster_policies, pipelines, jobs, experiments, models
) -> tuple[list, str, RequestObjectType]:
_verifiable_objects = [
(clusters, "cluster_id", RequestObjectType.CLUSTERS),
(instance_pools, "instance_pool_id", RequestObjectType.INSTANCE_POOLS),
(cluster_policies, "policy_id", RequestObjectType.CLUSTER_POLICIES),
(pipelines, "pipeline_id", RequestObjectType.PIPELINES),
(jobs, "job_id", RequestObjectType.JOBS),
(experiments, "experiment_id", RequestObjectType.EXPERIMENTS),
(models, "id", RequestObjectType.REGISTERED_MODELS),
]
yield _verifiable_objects

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _verify_group_permissions(
toolkit: GroupMigrationToolkit,
target: Literal["backup", "account"],
):
logger.debug("Verifying that the permissions were applied to backup groups")
logger.debug(f"Verifying that the permissions of object {request_object_type} were applied to {target} groups")

for _object in objects:
_object_permissions = ws.permissions.get(request_object_type, getattr(_object, id_attribute))
Expand Down

0 comments on commit 6a18748

Please sign in to comment.