Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 16, 2023
1 parent 76db648 commit af4de7b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
39 changes: 39 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MODEL_VERSION_FIELD,
TIMEOUT,
)
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
Expand Down Expand Up @@ -608,3 +609,41 @@ def delete_task(self, task_id: str) -> object:
method="DELETE",
url=API_URL,
)

def _get_profile(self, payload: Optional[dict] = None):
validate_profile_input(None, payload)
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def _get_models_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)

def _get_tasks_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)

def get_profile(self, profile_type="all", id=None, request_body=None):
if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
return self._get_models_profile(id, request_body)
elif profile_type == "task":
pass
else:
raise ValueError(
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)
16 changes: 16 additions & 0 deletions opensearch_py_ml/ml_commons/validators/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

"""Module for validating Profile API parameters """


def validate_profile_input(path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("path_parameter needs to be a string or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")

0 comments on commit af4de7b

Please sign in to comment.