diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 1eb63885..5fbc3355 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -26,6 +26,7 @@ from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input +from opensearch_py_ml.ml_commons.validators.stats import validate_stats_input class MLCommonClient: @@ -608,7 +609,13 @@ def delete_task(self, task_id: str) -> object: url=API_URL, ) - def get_stats(self, node_id: Optional[str] = "", stat_id: Optional[str] = ""): + def get_stats( + self, + node_id: Optional[str] = "", + stat_id: Optional[str] = "", + payload: Optional[dict] = None, + ): + validate_stats_input(node_id, stat_id, payload) if node_id and stat_id: url = f"{ML_BASE_URI}/{node_id}/stats/{stat_id}" elif node_id: @@ -618,7 +625,9 @@ def get_stats(self, node_id: Optional[str] = "", stat_id: Optional[str] = ""): else: url = f"{ML_BASE_URI}/stats" - return self._client.transport.perform_request(method="GET", url=url) + return self._client.transport.perform_request( + method="GET", url=url, body=payload + ) def _get_profile(self, payload: Optional[dict] = None): """ diff --git a/opensearch_py_ml/ml_commons/validators/stats.py b/opensearch_py_ml/ml_commons/validators/stats.py new file mode 100644 index 00000000..a2c4c8f7 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/stats.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +"""Module for validating Stats API parameters """ + + +def validate_stats_input(node_id, stat_id, payload): + if payload: + if node_id or stat_id: + raise ValueError( + "Stats API does not accept node_id or stat_id with payload" + ) + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None") + if node_id is not None and not isinstance(node_id, str): + raise ValueError("node_id needs to be a string or None") + if stat_id is not None and not isinstance(stat_id, str): + raise ValueError("stat_id needs to be a string or None")