From 6692fe207365af14207547c52a0158cb6a35dd28 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 6 Dec 2023 08:27:48 +0530 Subject: [PATCH 01/18] init Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 2 + opensearch_py_ml/ml_commons/model_profile.py | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 opensearch_py_ml/ml_commons/model_profile.py diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 72e2e158b..a50a7d4bb 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -24,6 +24,7 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_commons.model_profile import ModelProfile class MLCommonClient: @@ -37,6 +38,7 @@ def __init__(self, os_client: OpenSearch): self._model_uploader = ModelUploader(os_client) self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) + self.profile = ModelProfile(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py new file mode 100644 index 000000000..ae0fc770d --- /dev/null +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +from opensearchpy import OpenSearch + +from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI + + +class ModelProfile: + API_ENDPOINT = "profile" + + def __init__(self, os_client: OpenSearch): + self.client = os_client + + def get_profile(self, payload: dict): + if not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary") + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload + ) + + def get_models_profile(self, payload: dict): + if not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary") + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/models", body=payload + ) + + def get_tasks_profile(self, payload: dict): + if not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary") + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks", body=payload + ) \ No newline at end of file From 85e93925135db31a5985b7984e520da1077b237d Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 6 Dec 2023 08:32:30 +0530 Subject: [PATCH 02/18] update changelog Signed-off-by: kalyanr --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eb78bf78..1d99fcb7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) - Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) - Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345)) +- Add support for model profiles by @rawwar in ([#350](https://github.com/opensearch-project/opensearch-py-ml/pull/350)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) From f8d68142650dd7bde12fa7965bc8b745598b3bfa Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 05:06:30 +0530 Subject: [PATCH 03/18] update Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 2 +- opensearch_py_ml/ml_commons/model_profile.py | 37 ++++++++++----- tests/ml_commons/test_ml_commons_client.py | 45 +++++++++++++++++++ 3 files changed, 71 insertions(+), 13 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 9d228c1a0..af4161b35 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -40,7 +40,7 @@ def __init__(self, os_client: OpenSearch): self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) self.connector = Connector(os_client) - self.profile = ModelProfile(os_client) + self.model_profile = ModelProfile(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py index ae0fc770d..18eee2989 100644 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -6,6 +6,7 @@ # GitHub history for details. from opensearchpy import OpenSearch +from typing import Optional from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI @@ -16,23 +17,35 @@ class ModelProfile: def __init__(self, os_client: OpenSearch): self.client = os_client - def get_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + def _validate_input(self, path_parameter, payload): + if path_parameter is not None and not isinstance(path_parameter, str): + raise ValueError("payload needs to be a dictionary or None") + + if payload is not None and not isinstance(payload, dict): + raise ValueError("path_parameter needs to be a string or None") + + def get_profile(self, payload: Optional[dict] = None): + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None") return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) - def get_models_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): + + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/models", body=payload + method="GET", url=url, body=payload ) - - def get_tasks_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + + + def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): + + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks", body=payload + method="GET", url=url, body=payload ) \ No newline at end of file diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 27cd79dc9..4995ef276 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -18,6 +18,7 @@ from opensearch_py_ml.ml_commons import MLCommonClient from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_commons.model_profile import ModelProfile from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT @@ -573,3 +574,47 @@ def test_search(): except: # noqa: E722 raised = True assert raised == False, "Raised Exception in searching model" + +# Model Profile Tests. These will need some model train/predict data. Hence, need to be +# at the end after the training/prediction tests are done. + +def profile_client(): + client = ModelProfile(OPENSEARCH_TEST_CLIENT) + return client + +def test_get_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_profile("") + + result = profile_client.get_profile() + assert isinstance(result, dict) + if len(result) > 0: + assert "nodes" in result + + +def test_get_models_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_models_profile("") + + result = profile_client.get_models_profile() + assert isinstance(result, dict) + if len(result) > 0: + assert "nodes" in result + for node_id, node_val in result['nodes']: + assert "models" in node_val + + + +def test_get_tasks_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_tasks_profile("") + + result = profile_client.get_tasks_profile() + if len(result) > 0: + assert "nodes" in result + for node_id, node_val in result['nodes']: + assert "tasks" in node_val + \ No newline at end of file From dbd776c1971576bff5ebed2e1784c58aaa3b3be4 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 05:21:16 +0530 Subject: [PATCH 04/18] fix Signed-off-by: kalyanr --- tests/ml_commons/test_ml_commons_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 4995ef276..24980a327 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -602,7 +602,7 @@ def test_get_models_profile(profile_client): assert isinstance(result, dict) if len(result) > 0: assert "nodes" in result - for node_id, node_val in result['nodes']: + for _, node_val in result['nodes']: assert "models" in node_val @@ -615,6 +615,6 @@ def test_get_tasks_profile(profile_client): result = profile_client.get_tasks_profile() if len(result) > 0: assert "nodes" in result - for node_id, node_val in result['nodes']: + for _, node_val in result['nodes']: assert "tasks" in node_val \ No newline at end of file From bc941fa6fd5000cf14ddfc1847fe728a0a383817 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 07:07:52 +0530 Subject: [PATCH 05/18] fix Signed-off-by: kalyanr --- tests/ml_commons/test_ml_commons_client.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 24980a327..170a3ea65 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -578,6 +578,7 @@ def test_search(): # Model Profile Tests. These will need some model train/predict data. Hence, need to be # at the end after the training/prediction tests are done. +@pytest.fixture def profile_client(): client = ModelProfile(OPENSEARCH_TEST_CLIENT) return client @@ -596,13 +597,16 @@ def test_get_profile(profile_client): def test_get_models_profile(profile_client): with pytest.raises(ValueError): - profile_client.get_models_profile("") + profile_client.get_models_profile(10) + + with pytest.raises(ValueError): + profile_client.get_models_profile("", 10) result = profile_client.get_models_profile() assert isinstance(result, dict) if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes']: + for _, node_val in result['nodes'].items(): assert "models" in node_val @@ -610,11 +614,14 @@ def test_get_models_profile(profile_client): def test_get_tasks_profile(profile_client): with pytest.raises(ValueError): - profile_client.get_tasks_profile("") + profile_client.get_tasks_profile(10) + + with pytest.raises(ValueError): + profile_client.get_tasks_profile("", 10) result = profile_client.get_tasks_profile() if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes']: + for _, node_val in result['nodes'].items(): assert "tasks" in node_val \ No newline at end of file From 1b81c8422f1c61151f864e3914130711fa4cd172 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 07:08:55 +0530 Subject: [PATCH 06/18] lint fix Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 2 +- opensearch_py_ml/ml_commons/model_profile.py | 28 ++++++++++--------- tests/ml_commons/test_ml_commons_client.py | 24 ++++++++-------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index af4161b35..34529d059 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -24,8 +24,8 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute -from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_commons.model_profile import ModelProfile +from opensearch_py_ml.ml_commons.model_uploader import ModelUploader class MLCommonClient: diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py index 18eee2989..fecdddab0 100644 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -5,47 +5,49 @@ # Any modifications Copyright OpenSearch Contributors. See # GitHub history for details. -from opensearchpy import OpenSearch from typing import Optional +from opensearchpy import OpenSearch + from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI class ModelProfile: API_ENDPOINT = "profile" - + def __init__(self, os_client: OpenSearch): self.client = os_client - + def _validate_input(self, path_parameter, payload): if path_parameter is not None and not isinstance(path_parameter, str): raise ValueError("payload needs to be a dictionary or None") if payload is not None and not isinstance(payload, dict): raise ValueError("path_parameter needs to be a string or None") - + def get_profile(self, payload: Optional[dict] = None): if payload is not None and not isinstance(payload, dict): raise ValueError("payload needs to be a dictionary or None") return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) - - def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): - + + def get_models_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): self._validate_input(path_parameter, payload) - + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) - - def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): - + def get_tasks_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): self._validate_input(path_parameter, payload) - + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload - ) \ No newline at end of file + ) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 170a3ea65..20193a03a 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -17,8 +17,8 @@ from sklearn.datasets import load_iris from opensearch_py_ml.ml_commons import MLCommonClient -from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_commons.model_profile import ModelProfile +from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT @@ -575,19 +575,21 @@ def test_search(): raised = True assert raised == False, "Raised Exception in searching model" + # Model Profile Tests. These will need some model train/predict data. Hence, need to be # at the end after the training/prediction tests are done. + @pytest.fixture def profile_client(): client = ModelProfile(OPENSEARCH_TEST_CLIENT) return client + def test_get_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_profile("") - + result = profile_client.get_profile() assert isinstance(result, dict) if len(result) > 0: @@ -595,33 +597,29 @@ def test_get_profile(profile_client): def test_get_models_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_models_profile(10) - + with pytest.raises(ValueError): profile_client.get_models_profile("", 10) - + result = profile_client.get_models_profile() assert isinstance(result, dict) if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes'].items(): + for _, node_val in result["nodes"].items(): assert "models" in node_val - def test_get_tasks_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_tasks_profile(10) - + with pytest.raises(ValueError): profile_client.get_tasks_profile("", 10) - + result = profile_client.get_tasks_profile() if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes'].items(): + for _, node_val in result["nodes"].items(): assert "tasks" in node_val - \ No newline at end of file From bd82bf1ebb7393981f973b0d301c3c264ca6e042 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 07:17:08 +0530 Subject: [PATCH 07/18] reuse validate input Signed-off-by: kalyanr --- opensearch_py_ml/ml_commons/model_profile.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py index fecdddab0..d1b385491 100644 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -20,14 +20,13 @@ def __init__(self, os_client: OpenSearch): def _validate_input(self, path_parameter, payload): if path_parameter is not None and not isinstance(path_parameter, str): - raise ValueError("payload needs to be a dictionary or None") - - if payload is not None and not isinstance(payload, dict): raise ValueError("path_parameter needs to be a string or None") - def get_profile(self, payload: Optional[dict] = None): if payload is not None and not isinstance(payload, dict): raise ValueError("payload needs to be a dictionary or None") + + def get_profile(self, payload: Optional[dict] = None): + self._validate_input(None, payload) return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) From ad0077ad24245054102eeb578287c4feb6ec6f51 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Thu, 7 Dec 2023 09:39:47 +0530 Subject: [PATCH 08/18] update comment Signed-off-by: kalyanr --- tests/ml_commons/test_ml_commons_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 20193a03a..8469f2d3f 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -576,8 +576,8 @@ def test_search(): assert raised == False, "Raised Exception in searching model" -# Model Profile Tests. These will need some model train/predict data. Hence, need to be -# at the end after the training/prediction tests are done. +# Model Profile Tests. These tests will need some model train/predict run data. Hence, need +# to be run at the end after the training/prediction tests are done. @pytest.fixture From af4de7b40f88c2c45043a919c2a9edf9685858ec Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 11:41:23 +0530 Subject: [PATCH 09/18] change Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 39 +++++++++++++++++++ .../ml_commons/validators/profile.py | 16 ++++++++ 2 files changed, 55 insertions(+) create mode 100644 opensearch_py_ml/ml_commons/validators/profile.py diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 34529d059..138bd9c59 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -21,6 +21,7 @@ MODEL_VERSION_FIELD, TIMEOUT, ) +from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute @@ -608,3 +609,41 @@ def delete_task(self, task_id: str) -> object: method="DELETE", url=API_URL, ) + + def _get_profile(self, payload: Optional[dict] = None): + validate_profile_input(None, payload) + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload + ) + + def _get_models_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" + return self.client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def _get_tasks_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" + return self.client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def get_profile(self, profile_type="all", id=None, request_body=None): + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + return self._get_models_profile(id, request_body) + elif profile_type == "task": + pass + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + ) diff --git a/opensearch_py_ml/ml_commons/validators/profile.py b/opensearch_py_ml/ml_commons/validators/profile.py new file mode 100644 index 000000000..602edb0f9 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/profile.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +"""Module for validating Profile API parameters """ + + +def validate_profile_input(path_parameter, payload): + if path_parameter is not None and not isinstance(path_parameter, str): + raise ValueError("path_parameter needs to be a string or None") + + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None") From b0191a83aa718746a45fbdd7a4c351fedb24ab3e Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 12:19:10 +0530 Subject: [PATCH 10/18] fix Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 84 +++++++++++++++---- 1 file changed, 67 insertions(+), 17 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 138bd9c59..2ce4d7bda 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -611,39 +611,89 @@ def delete_task(self, task_id: str) -> object: ) def _get_profile(self, payload: Optional[dict] = None): + """ + Get the profile using the given payload. + + :param payload: The payload to be used for getting the profile. Defaults to None. + :type payload: Optional[dict] + :return: The response from the server after performing the request. + :rtype: Any + """ validate_profile_input(None, payload) return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) def _get_models_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + self, model_id: Optional[str] = "", payload: Optional[dict] = None ): - self._validate_input(path_parameter, payload) + """ + Get the profile of a model. + + Args: + model_id (str, optional): The ID of the model. Defaults to "". + payload (dict, optional): Additional payload for the request. Defaults to None. - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" + Returns: + dict: The response from the API. + """ + self._validate_input(model_id, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{model_id if model_id else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) def _get_tasks_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + self, task_id: Optional[str] = "", payload: Optional[dict] = None ): - self._validate_input(path_parameter, payload) + """ + Retrieves the profile of a task from the API. + + Parameters: + task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string. + payload (dict, optional): Additional payload for the request. Defaults to None. + + Returns: + dict: The profile of the task. - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" + Raises: + ValueError: If the input validation fails. + + """ + self._validate_input(task_id, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{task_id if task_id else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) - def get_profile(self, profile_type="all", id=None, request_body=None): - if profile_type == "all": - return self._get_profile(request_body) - elif profile_type == "model": - return self._get_models_profile(id, request_body) - elif profile_type == "task": - pass - else: - raise ValueError( - "Invalid profile type. Profile type must be 'all', 'model' or 'task'." - ) +def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict: + """ + Get profile information based on the profile type. + + Args: + profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. + ids: A list of profile IDs to retrieve. Default is None. + request_body: The request body containing additional information. Default is None. + + Returns: + The profile information. + + Raises: + ValueError: If the profile_type is not 'all', 'model', or 'task'. + """ + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + if ids: + ids = ",".join(ids) + return self._get_models_profile(ids, request_body) + elif profile_type == "task": + if ids: + ids = ",".join(ids) + return self._get_tasks_profile(ids, request_body) + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + ) From dc19f2450ca5376ff396a91ef592e671392c4c20 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 12:20:55 +0530 Subject: [PATCH 11/18] update changelog Signed-off-by: kalyanr --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d99fcb7f..86f270110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) - Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) - Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345)) -- Add support for model profiles by @rawwar in ([#350](https://github.com/opensearch-project/opensearch-py-ml/pull/350)) +- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) From 0657f74f5200040c75398ca47c0d078ba4ca33d1 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 12:21:46 +0530 Subject: [PATCH 12/18] fix Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 2ce4d7bda..8104e4527 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -668,32 +668,32 @@ def _get_tasks_profile( method="GET", url=url, body=payload ) -def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict: - """ - Get profile information based on the profile type. + def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict: + """ + Get profile information based on the profile type. - Args: - profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. - ids: A list of profile IDs to retrieve. Default is None. - request_body: The request body containing additional information. Default is None. + Args: + profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. + ids: A list of profile IDs to retrieve. Default is None. + request_body: The request body containing additional information. Default is None. - Returns: - The profile information. + Returns: + The profile information. - Raises: - ValueError: If the profile_type is not 'all', 'model', or 'task'. - """ - if profile_type == "all": - return self._get_profile(request_body) - elif profile_type == "model": - if ids: - ids = ",".join(ids) - return self._get_models_profile(ids, request_body) - elif profile_type == "task": - if ids: - ids = ",".join(ids) - return self._get_tasks_profile(ids, request_body) - else: - raise ValueError( - "Invalid profile type. Profile type must be 'all', 'model' or 'task'." - ) + Raises: + ValueError: If the profile_type is not 'all', 'model', or 'task'. + """ + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + if ids: + ids = ",".join(ids) + return self._get_models_profile(ids, request_body) + elif profile_type == "task": + if ids: + ids = ",".join(ids) + return self._get_tasks_profile(ids, request_body) + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + ) From 2120e442da6fe85150df296c5827cebf49d3ddb4 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Tue, 19 Dec 2023 21:50:11 +0530 Subject: [PATCH 13/18] remove separate model profile module Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 1 - opensearch_py_ml/ml_commons/model_profile.py | 52 ------------------- 2 files changed, 53 deletions(-) delete mode 100644 opensearch_py_ml/ml_commons/model_profile.py diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 8104e4527..b70b551a4 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -41,7 +41,6 @@ def __init__(self, os_client: OpenSearch): self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) self.connector = Connector(os_client) - self.model_profile = ModelProfile(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py deleted file mode 100644 index d1b385491..000000000 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# The OpenSearch Contributors require contributions made to -# this file be licensed under the Apache-2.0 license or a -# compatible open source license. -# Any modifications Copyright OpenSearch Contributors. See -# GitHub history for details. - -from typing import Optional - -from opensearchpy import OpenSearch - -from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI - - -class ModelProfile: - API_ENDPOINT = "profile" - - def __init__(self, os_client: OpenSearch): - self.client = os_client - - def _validate_input(self, path_parameter, payload): - if path_parameter is not None and not isinstance(path_parameter, str): - raise ValueError("path_parameter needs to be a string or None") - - if payload is not None and not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary or None") - - def get_profile(self, payload: Optional[dict] = None): - self._validate_input(None, payload) - return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload - ) - - def get_models_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None - ): - self._validate_input(path_parameter, payload) - - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" - return self.client.transport.perform_request( - method="GET", url=url, body=payload - ) - - def get_tasks_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None - ): - self._validate_input(path_parameter, payload) - - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" - return self.client.transport.perform_request( - method="GET", url=url, body=payload - ) From 4a3101465ee6e4e7a8cfed5ab276d67e2757310e Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 20 Dec 2023 00:09:25 +0530 Subject: [PATCH 14/18] fix tests Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 24 ++--- tests/ml_commons/test_ml_commons_client.py | 95 +++++++++++-------- 2 files changed, 68 insertions(+), 51 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index b70b551a4..9073c4a65 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -25,7 +25,6 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute -from opensearch_py_ml.ml_commons.model_profile import ModelProfile from opensearch_py_ml.ml_commons.model_uploader import ModelUploader @@ -619,8 +618,8 @@ def _get_profile(self, payload: Optional[dict] = None): :rtype: Any """ validate_profile_input(None, payload) - return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload + return self._client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/profile", body=payload ) def _get_models_profile( @@ -636,10 +635,10 @@ def _get_models_profile( Returns: dict: The response from the API. """ - self._validate_input(model_id, payload) + validate_profile_input(model_id, payload) - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{model_id if model_id else ''}" - return self.client.transport.perform_request( + url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}" + return self._client.transport.perform_request( method="GET", url=url, body=payload ) @@ -660,10 +659,10 @@ def _get_tasks_profile( ValueError: If the input validation fails. """ - self._validate_input(task_id, payload) + validate_profile_input(task_id, payload) - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{task_id if task_id else ''}" - return self.client.transport.perform_request( + url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}" + return self._client.transport.perform_request( method="GET", url=url, body=payload ) @@ -682,14 +681,15 @@ def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None Raises: ValueError: If the profile_type is not 'all', 'model', or 'task'. """ + if profile_type == "all": return self._get_profile(request_body) elif profile_type == "model": - if ids: - ids = ",".join(ids) + if ids and isinstance(ids, list): + ids = ",".join(ids) return self._get_models_profile(ids, request_body) elif profile_type == "task": - if ids: + if ids and isinstance(ids, list): ids = ",".join(ids) return self._get_tasks_profile(ids, request_body) else: diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 8469f2d3f..b91b7c87b 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -17,7 +17,6 @@ from sklearn.datasets import load_iris from opensearch_py_ml.ml_commons import MLCommonClient -from opensearch_py_ml.ml_commons.model_profile import ModelProfile from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT @@ -580,46 +579,64 @@ def test_search(): # to be run at the end after the training/prediction tests are done. -@pytest.fixture -def profile_client(): - client = ModelProfile(OPENSEARCH_TEST_CLIENT) - return client - - -def test_get_profile(profile_client): +def test_get_profile(): + res = ml_client.get_profile() + assert isinstance(res, dict) + assert "nodes" in res + test_model_id = None + test_task_id = None + for node_id, val in res['nodes'].items(): + if test_model_id is None and "models" in val: + for model_id, model_val in val['models'].items(): + test_model_id = {"node_id":node_id, "model_id":model_id} + break + if test_task_id is None and "tasks" in val: + for task_id, task_val in val['tasks'].items(): + test_task_id = {"node_id":node_id, "task_id":task_id} + break + + res = ml_client.get_profile(profile_type='model') + assert isinstance(res, dict) + assert "nodes" in res + for node_id, node_val in res['nodes'].items(): + assert "models" in node_val + + + res = ml_client.get_profile(profile_type='model', ids=[test_model_id['model_id']]) + assert isinstance(res, dict) + assert "nodes" in res + assert test_model_id['model_id'] in res['nodes'][test_model_id["node_id"]]['models'] + + res = ml_client.get_profile(profile_type='model', ids=['randomid1', 'random_id2']) + assert isinstance(res, dict) + assert len(res) == 0 + + + res = ml_client.get_profile(profile_type='task') + assert isinstance(res, dict) + if len(res) > 0: + assert "nodes" in res + for node_id, node_val in res['nodes'].items(): + assert "tasks" in node_val + + res = ml_client.get_profile(profile_type='task', ids=['random1', 'random2']) + assert isinstance(res, dict) + assert len(res) == 0 + + with pytest.raises(ValueError): - profile_client.get_profile("") - - result = profile_client.get_profile() - assert isinstance(result, dict) - if len(result) > 0: - assert "nodes" in result - - -def test_get_models_profile(profile_client): + ml_client.get_profile(profile_type='test') + with pytest.raises(ValueError): - profile_client.get_models_profile(10) - + ml_client.get_profile(profile_type='model', ids=1) + with pytest.raises(ValueError): - profile_client.get_models_profile("", 10) - - result = profile_client.get_models_profile() - assert isinstance(result, dict) - if len(result) > 0: - assert "nodes" in result - for _, node_val in result["nodes"].items(): - assert "models" in node_val - - -def test_get_tasks_profile(profile_client): + ml_client.get_profile(profile_type='model', request_body=10) + with pytest.raises(ValueError): - profile_client.get_tasks_profile(10) - + ml_client.get_profile(profile_type='task', ids=1) + with pytest.raises(ValueError): - profile_client.get_tasks_profile("", 10) - - result = profile_client.get_tasks_profile() - if len(result) > 0: - assert "nodes" in result - for _, node_val in result["nodes"].items(): - assert "tasks" in node_val + ml_client.get_profile(profile_type='task', request_body=10) + + \ No newline at end of file From bcbcb4ba0fcb5d5150d720682cc304eb30feabc2 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 20 Dec 2023 00:10:02 +0530 Subject: [PATCH 15/18] fix lint Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 11 +++- tests/ml_commons/test_ml_commons_client.py | 63 +++++++++---------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 9073c4a65..81d70116b 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -666,7 +666,12 @@ def _get_tasks_profile( method="GET", url=url, body=payload ) - def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict: + def get_profile( + self, + profile_type: str = "all", + ids: Optional[List[str]] = None, + request_body: Optional[dict] = None, + ) -> dict: """ Get profile information based on the profile type. @@ -681,12 +686,12 @@ def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None Raises: ValueError: If the profile_type is not 'all', 'model', or 'task'. """ - + if profile_type == "all": return self._get_profile(request_body) elif profile_type == "model": if ids and isinstance(ids, list): - ids = ",".join(ids) + ids = ",".join(ids) return self._get_models_profile(ids, request_body) elif profile_type == "task": if ids and isinstance(ids, list): diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index b91b7c87b..bb1adcde1 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -585,58 +585,53 @@ def test_get_profile(): assert "nodes" in res test_model_id = None test_task_id = None - for node_id, val in res['nodes'].items(): + for node_id, val in res["nodes"].items(): if test_model_id is None and "models" in val: - for model_id, model_val in val['models'].items(): - test_model_id = {"node_id":node_id, "model_id":model_id} + for model_id, model_val in val["models"].items(): + test_model_id = {"node_id": node_id, "model_id": model_id} break if test_task_id is None and "tasks" in val: - for task_id, task_val in val['tasks'].items(): - test_task_id = {"node_id":node_id, "task_id":task_id} + for task_id, task_val in val["tasks"].items(): + test_task_id = {"node_id": node_id, "task_id": task_id} break - - res = ml_client.get_profile(profile_type='model') + + res = ml_client.get_profile(profile_type="model") assert isinstance(res, dict) assert "nodes" in res - for node_id, node_val in res['nodes'].items(): - assert "models" in node_val - - - res = ml_client.get_profile(profile_type='model', ids=[test_model_id['model_id']]) + for node_id, node_val in res["nodes"].items(): + assert "models" in node_val + + res = ml_client.get_profile(profile_type="model", ids=[test_model_id["model_id"]]) assert isinstance(res, dict) assert "nodes" in res - assert test_model_id['model_id'] in res['nodes'][test_model_id["node_id"]]['models'] - - res = ml_client.get_profile(profile_type='model', ids=['randomid1', 'random_id2']) + assert test_model_id["model_id"] in res["nodes"][test_model_id["node_id"]]["models"] + + res = ml_client.get_profile(profile_type="model", ids=["randomid1", "random_id2"]) assert isinstance(res, dict) assert len(res) == 0 - - - res = ml_client.get_profile(profile_type='task') + + res = ml_client.get_profile(profile_type="task") assert isinstance(res, dict) if len(res) > 0: assert "nodes" in res - for node_id, node_val in res['nodes'].items(): + for node_id, node_val in res["nodes"].items(): assert "tasks" in node_val - - res = ml_client.get_profile(profile_type='task', ids=['random1', 'random2']) + + res = ml_client.get_profile(profile_type="task", ids=["random1", "random2"]) assert isinstance(res, dict) assert len(res) == 0 - - + with pytest.raises(ValueError): - ml_client.get_profile(profile_type='test') - + ml_client.get_profile(profile_type="test") + with pytest.raises(ValueError): - ml_client.get_profile(profile_type='model', ids=1) - + ml_client.get_profile(profile_type="model", ids=1) + with pytest.raises(ValueError): - ml_client.get_profile(profile_type='model', request_body=10) - + ml_client.get_profile(profile_type="model", request_body=10) + with pytest.raises(ValueError): - ml_client.get_profile(profile_type='task', ids=1) - + ml_client.get_profile(profile_type="task", ids=1) + with pytest.raises(ValueError): - ml_client.get_profile(profile_type='task', request_body=10) - - \ No newline at end of file + ml_client.get_profile(profile_type="task", request_body=10) From fefca752d9cf250d06f67d16fe34d5dda7a8a772 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 20 Dec 2023 00:10:28 +0530 Subject: [PATCH 16/18] fix lint Signed-off-by: kalyanr --- opensearch_py_ml/ml_commons/ml_commons_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 81d70116b..4991583cd 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -21,11 +21,11 @@ MODEL_VERSION_FIELD, TIMEOUT, ) -from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input class MLCommonClient: From 0d7b0966f510afedf02e6045fb73178c5689f99e Mon Sep 17 00:00:00 2001 From: kalyanr Date: Wed, 20 Dec 2023 00:33:00 +0530 Subject: [PATCH 17/18] fix Signed-off-by: kalyanr --- opensearch_py_ml/ml_commons/ml_commons_client.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 4991583cd..a82971b54 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -669,7 +669,7 @@ def _get_tasks_profile( def get_profile( self, profile_type: str = "all", - ids: Optional[List[str]] = None, + ids: Optional[Union[str, List[str]]] = None, request_body: Optional[dict] = None, ) -> dict: """ @@ -677,7 +677,7 @@ def get_profile( Args: profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. - ids: A list of profile IDs to retrieve. Default is None. + ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None. request_body: The request body containing additional information. Default is None. Returns: @@ -685,6 +685,17 @@ def get_profile( Raises: ValueError: If the profile_type is not 'all', 'model', or 'task'. + + Example: + get_profile() + + get_profile(profile_type='model', ids='model1') + + get_profile(profile_type='model', ids=['model1', 'model2']) + + get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true}) + + get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'}) """ if profile_type == "all": From a393c597d6f24882ca78108debacfd2c3ea120a5 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Thu, 21 Dec 2023 13:31:10 +0530 Subject: [PATCH 18/18] Update ml_commons_client.py Signed-off-by: Kalyan --- opensearch_py_ml/ml_commons/ml_commons_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index a82971b54..1f81c967c 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -677,6 +677,9 @@ def get_profile( Args: profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. + 'all': Retrieves all profiles available. + 'model': Retrieves the profile(s) of the specified model(s). The model(s) to retrieve are specified by the 'ids' parameter. + 'task': Retrieves the profile(s) of the specified task(s). The task(s) to retrieve are specified by the 'ids' parameter. ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None. request_body: The request body containing additional information. Default is None.