Skip to content

Commit

Permalink
add payload to stats api
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Mar 6, 2024
1 parent 4d28f24 commit c0cf990
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
13 changes: 11 additions & 2 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input
from opensearch_py_ml.ml_commons.validators.stats import validate_stats_input


class MLCommonClient:
Expand Down Expand Up @@ -608,7 +609,13 @@ def delete_task(self, task_id: str) -> object:
url=API_URL,
)

def get_stats(self, node_id: Optional[str] = "", stat_id: Optional[str] = ""):
def get_stats(
self,
node_id: Optional[str] = "",
stat_id: Optional[str] = "",
payload: Optional[dict] = None,
):
validate_stats_input(node_id, stat_id, payload)
if node_id and stat_id:
url = f"{ML_BASE_URI}/{node_id}/stats/{stat_id}"
elif node_id:
Expand All @@ -618,7 +625,9 @@ def get_stats(self, node_id: Optional[str] = "", stat_id: Optional[str] = ""):
else:
url = f"{ML_BASE_URI}/stats"

return self._client.transport.perform_request(method="GET", url=url)
return self._client.transport.perform_request(
method="GET", url=url, body=payload
)

def _get_profile(self, payload: Optional[dict] = None):
"""
Expand Down
22 changes: 22 additions & 0 deletions opensearch_py_ml/ml_commons/validators/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

"""Module for validating Stats API parameters """


def validate_stats_input(node_id, stat_id, payload):
if payload:
if node_id or stat_id:
raise ValueError(
"Stats API does not accept node_id or stat_id with payload"
)
if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")
if node_id is not None and not isinstance(node_id, str):
raise ValueError("node_id needs to be a string or None")
if stat_id is not None and not isinstance(stat_id, str):
raise ValueError("stat_id needs to be a string or None")

0 comments on commit c0cf990

Please sign in to comment.