Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 7, 2023
1 parent bc941fa commit 1b81c84
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader


class MLCommonClient:
Expand Down
28 changes: 15 additions & 13 deletions opensearch_py_ml/ml_commons/model_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,49 @@
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from opensearchpy import OpenSearch
from typing import Optional

from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI


class ModelProfile:
API_ENDPOINT = "profile"

def __init__(self, os_client: OpenSearch):
self.client = os_client

def _validate_input(self, path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("payload needs to be a dictionary or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("path_parameter needs to be a string or None")

def get_profile(self, payload: Optional[dict] = None):
if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None):


def get_models_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)


def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None):

def get_tasks_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)
)
24 changes: 11 additions & 13 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from sklearn.datasets import load_iris

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel
from tests import OPENSEARCH_TEST_CLIENT

Expand Down Expand Up @@ -575,53 +575,51 @@ def test_search():
raised = True
assert raised == False, "Raised Exception in searching model"


# Model Profile Tests. These will need some model train/predict data. Hence, need to be
# at the end after the training/prediction tests are done.


@pytest.fixture
def profile_client():
client = ModelProfile(OPENSEARCH_TEST_CLIENT)
return client


def test_get_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_profile("")

result = profile_client.get_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result


def test_get_models_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_models_profile(10)

with pytest.raises(ValueError):
profile_client.get_models_profile("", 10)

result = profile_client.get_models_profile()
assert isinstance(result, dict)
if len(result) > 0:
assert "nodes" in result
for _, node_val in result['nodes'].items():
for _, node_val in result["nodes"].items():
assert "models" in node_val



def test_get_tasks_profile(profile_client):

with pytest.raises(ValueError):
profile_client.get_tasks_profile(10)

with pytest.raises(ValueError):
profile_client.get_tasks_profile("", 10)

result = profile_client.get_tasks_profile()
if len(result) > 0:
assert "nodes" in result
for _, node_val in result['nodes'].items():
for _, node_val in result["nodes"].items():
assert "tasks" in node_val

0 comments on commit 1b81c84

Please sign in to comment.