Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 6, 2023
1 parent 85e9392 commit f8d6814
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, os_client: OpenSearch):
self._model_execute = ModelExecute(os_client)
self.model_access_control = ModelAccessControl(os_client)
self.connector = Connector(os_client)
self.profile = ModelProfile(os_client)
self.model_profile = ModelProfile(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
"""
Expand Down
37 changes: 25 additions & 12 deletions opensearch_py_ml/ml_commons/model_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# GitHub history for details.

from opensearchpy import OpenSearch
from typing import Optional

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI

Expand All @@ -16,23 +17,35 @@ class ModelProfile:
def __init__(self, os_client: OpenSearch):
self.client = os_client

def get_profile(self, payload: dict):
if not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary")
def _validate_input(self, path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("payload needs to be a dictionary or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("path_parameter needs to be a string or None")

def get_profile(self, payload: Optional[dict] = None):
if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def get_models_profile(self, payload: dict):
if not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary")
def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None):

self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/models", body=payload
method="GET", url=url, body=payload
)

def get_tasks_profile(self, payload: dict):
if not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary")


def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None):

self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks", body=payload
method="GET", url=url, body=payload
)
45 changes: 45 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel
from tests import OPENSEARCH_TEST_CLIENT

Expand Down Expand Up @@ -573,3 +574,47 @@ def test_search():
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

# Model Profile Tests. These will need some model train/predict data. Hence, need to be
# at the end after the training/prediction tests are done.

def profile_client():
client = ModelProfile(OPENSEARCH_TEST_CLIENT)
return client

def test_get_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_profile("")

result = profile_client.get_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result


def test_get_models_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_models_profile("")

result = profile_client.get_models_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result
for node_id, node_val in result['nodes']:
assert "models" in node_val



def test_get_tasks_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_tasks_profile("")

result = profile_client.get_tasks_profile()
if len(result) > 0:
assert "nodes" in result
for node_id, node_val in result['nodes']:
assert "tasks" in node_val

0 comments on commit f8d6814

Please sign in to comment.