Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#285: Add support for Model profile #358

Merged
merged 20 commits into from
Jan 3, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))
- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
91 changes: 91 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
MODEL_VERSION_FIELD,
TIMEOUT,
)
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader


Expand All @@ -39,6 +41,7 @@
self._model_execute = ModelExecute(os_client)
self.model_access_control = ModelAccessControl(os_client)
self.connector = Connector(os_client)
self.model_profile = ModelProfile(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
"""
Expand Down Expand Up @@ -606,3 +609,91 @@
method="DELETE",
url=API_URL,
)

def _get_profile(self, payload: Optional[dict] = None):
"""
Get the profile using the given payload.

:param payload: The payload to be used for getting the profile. Defaults to None.
:type payload: Optional[dict]
:return: The response from the server after performing the request.
:rtype: Any
"""
validate_profile_input(None, payload)
return self.client.transport.perform_request(

Check warning on line 623 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L622-L623

Added lines #L622 - L623 were not covered by tests
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def _get_models_profile(
self, model_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Get the profile of a model.

Args:
model_id (str, optional): The ID of the model. Defaults to "".
payload (dict, optional): Additional payload for the request. Defaults to None.

Returns:
dict: The response from the API.
"""
self._validate_input(model_id, payload)

Check warning on line 640 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L640

Added line #L640 was not covered by tests

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{model_id if model_id else ''}"
return self.client.transport.perform_request(

Check warning on line 643 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L642-L643

Added lines #L642 - L643 were not covered by tests
method="GET", url=url, body=payload
)

def _get_tasks_profile(
self, task_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Retrieves the profile of a task from the API.

Parameters:
task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string.
payload (dict, optional): Additional payload for the request. Defaults to None.

Returns:
dict: The profile of the task.

Raises:
ValueError: If the input validation fails.

"""
self._validate_input(task_id, payload)

Check warning on line 664 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L664

Added line #L664 was not covered by tests

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{task_id if task_id else ''}"
return self.client.transport.perform_request(

Check warning on line 667 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L666-L667

Added lines #L666 - L667 were not covered by tests
method="GET", url=url, body=payload
)

def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict:
rawwar marked this conversation as resolved.
Show resolved Hide resolved
"""
Get profile information based on the profile type.

Args:
profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'.
rawwar marked this conversation as resolved.
Show resolved Hide resolved
ids: A list of profile IDs to retrieve. Default is None.
request_body: The request body containing additional information. Default is None.

Returns:
The profile information.

Raises:
ValueError: If the profile_type is not 'all', 'model', or 'task'.
"""
if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
if ids:
ids = ",".join(ids)
return self._get_models_profile(ids, request_body)
elif profile_type == "task":
if ids:
ids = ",".join(ids)
return self._get_tasks_profile(ids, request_body)

Check warning on line 695 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L686-L695

Added lines #L686 - L695 were not covered by tests
else:
raise ValueError(

Check warning on line 697 in opensearch_py_ml/ml_commons/ml_commons_client.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/ml_commons_client.py#L697

Added line #L697 was not covered by tests
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)
52 changes: 52 additions & 0 deletions opensearch_py_ml/ml_commons/model_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from typing import Optional

from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI


class ModelProfile:
API_ENDPOINT = "profile"

def __init__(self, os_client: OpenSearch):
self.client = os_client

def _validate_input(self, path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("path_parameter needs to be a string or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")

def get_profile(self, payload: Optional[dict] = None):
self._validate_input(None, payload)
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def get_models_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)

def get_tasks_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)
16 changes: 16 additions & 0 deletions opensearch_py_ml/ml_commons/validators/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

"""Module for validating Profile API parameters """


def validate_profile_input(path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("path_parameter needs to be a string or None")

Check warning on line 13 in opensearch_py_ml/ml_commons/validators/profile.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/validators/profile.py#L12-L13

Added lines #L12 - L13 were not covered by tests

if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")

Check warning on line 16 in opensearch_py_ml/ml_commons/validators/profile.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_commons/validators/profile.py#L15-L16

Added lines #L15 - L16 were not covered by tests
50 changes: 50 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.datasets import load_iris

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel
from tests import OPENSEARCH_TEST_CLIENT
Expand Down Expand Up @@ -573,3 +574,52 @@ def test_search():
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"


# Model Profile Tests. These tests will need some model train/predict run data. Hence, need
# to be run at the end after the training/prediction tests are done.


@pytest.fixture
def profile_client():
client = ModelProfile(OPENSEARCH_TEST_CLIENT)
return client


def test_get_profile(profile_client):
with pytest.raises(ValueError):
profile_client.get_profile("")

result = profile_client.get_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result


def test_get_models_profile(profile_client):
with pytest.raises(ValueError):
profile_client.get_models_profile(10)

with pytest.raises(ValueError):
profile_client.get_models_profile("", 10)

result = profile_client.get_models_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result
for _, node_val in result["nodes"].items():
assert "models" in node_val


def test_get_tasks_profile(profile_client):
with pytest.raises(ValueError):
profile_client.get_tasks_profile(10)

with pytest.raises(ValueError):
profile_client.get_tasks_profile("", 10)

result = profile_client.get_tasks_profile()
if len(result) > 0:
assert "nodes" in result
for _, node_val in result["nodes"].items():
assert "tasks" in node_val
Loading