From af4de7b40f88c2c45043a919c2a9edf9685858ec Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 11:41:23 +0530 Subject: [PATCH] change Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 39 +++++++++++++++++++ .../ml_commons/validators/profile.py | 16 ++++++++ 2 files changed, 55 insertions(+) create mode 100644 opensearch_py_ml/ml_commons/validators/profile.py diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 34529d059..138bd9c59 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -21,6 +21,7 @@ MODEL_VERSION_FIELD, TIMEOUT, ) +from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute @@ -608,3 +609,41 @@ def delete_task(self, task_id: str) -> object: method="DELETE", url=API_URL, ) + + def _get_profile(self, payload: Optional[dict] = None): + validate_profile_input(None, payload) + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload + ) + + def _get_models_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" + return self.client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def _get_tasks_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" + return self.client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def get_profile(self, profile_type="all", id=None, request_body=None): + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + return self._get_models_profile(id, request_body) + elif profile_type == "task": + pass + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + ) diff --git a/opensearch_py_ml/ml_commons/validators/profile.py b/opensearch_py_ml/ml_commons/validators/profile.py new file mode 100644 index 000000000..602edb0f9 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/profile.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +"""Module for validating Profile API parameters """ + + +def validate_profile_input(path_parameter, payload): + if path_parameter is not None and not isinstance(path_parameter, str): + raise ValueError("path_parameter needs to be a string or None") + + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None")