From 078443f597466a253edaddadf636422da1a67000 Mon Sep 17 00:00:00 2001 From: kalyan Date: Mon, 6 Nov 2023 20:36:48 +0530 Subject: [PATCH] pr fixes Signed-off-by: kalyan --- .../ml_commons/model_access_control.py | 34 ++--- .../ml_commons/validators/__init__.py | 6 + .../validators/model_access_control.py | 57 +++----- tests/ml_commons/test_model_access_control.py | 25 +++- .../test_model_access_control_validators.py | 132 ++++++++++++++++++ 5 files changed, 192 insertions(+), 62 deletions(-) create mode 100644 opensearch_py_ml/ml_commons/validators/__init__.py create mode 100644 tests/ml_commons/test_validators/test_model_access_control_validators.py diff --git a/opensearch_py_ml/ml_commons/model_access_control.py b/opensearch_py_ml/ml_commons/model_access_control.py index 07432dc59..506e9aef5 100644 --- a/opensearch_py_ml/ml_commons/model_access_control.py +++ b/opensearch_py_ml/ml_commons/model_access_control.py @@ -52,20 +52,8 @@ def update_model_group( self, update_query: dict, model_group_id: Optional[str] = None, - model_group_name: Optional[str] = None, ): - validate_update_model_group_parameters( - update_query, model_group_id, model_group_name - ) - if model_group_name: - model_group = self.search_model_group_by_name(model_group_name) - try: - if len(model_group["hits"]["hits"]) > 0: - model_group_id = model_group["hits"]["hits"][0]["_id"] - else: - raise Exception - except Exception: - raise Exception(f"Model group with name: {model_group_name} not found") + validate_update_model_group_parameters(update_query, model_group_id) return self.client.transport.perform_request( method="PUT", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}", @@ -83,22 +71,22 @@ def search_model_group_by_name(self, model_group_name, _source=None, size=1): if _source: query["_source"] = _source return self.search_model_group(query) + + def get_model_group_id_by_name(self, model_group_name): + try: + res = self.search_model_group_by_name(model_group_name) + if res["hits"]["hits"]: + return res["hits"]["hits"][0]["_id"] + else: + return None + except NotFoundError: + return None def delete_model_group( self, model_group_id: str = None, - model_group_name: str = None, - ignore_if_not_exists=True, ): validate_delete_model_group_parameters(model_group_id, model_group_name) - if model_group_name: - model_group = self.search_model_group_by_name(model_group_name) - try: - model_group_id = model_group["hits"]["hits"][0]["_id"] - except (KeyError, IndexError): - if ignore_if_not_exists: - return None - raise Exception(f"Model group with name: {model_group_name} not found") return self.client.transport.perform_request( method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}" ) diff --git a/opensearch_py_ml/ml_commons/validators/__init__.py b/opensearch_py_ml/ml_commons/validators/__init__.py new file mode 100644 index 000000000..43d9d3f4f --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. \ No newline at end of file diff --git a/opensearch_py_ml/ml_commons/validators/model_access_control.py b/opensearch_py_ml/ml_commons/validators/model_access_control.py index cf419cbc8..75f3be49f 100644 --- a/opensearch_py_ml/ml_commons/validators/model_access_control.py +++ b/opensearch_py_ml/ml_commons/validators/model_access_control.py @@ -7,38 +7,41 @@ """ Module for validating model access control parameters """ +from typing import List, Optional ACCESS_MODES = ["public", "private", "restricted"] NoneType = type(None) -def _validate_model_group_name(name): +def _validate_model_group_name(name: str): if not name or not isinstance(name, str): raise ValueError("name is required and needs to be a string") -def _validate_model_group_description(description): +def _validate_model_group_description(description: Optional[str]): if not isinstance(description, (NoneType, str)): raise ValueError("description needs to be a string") -def _validate_model_group_access_mode(access_mode): +def _validate_model_group_access_mode(access_mode: Optional[str]): + if access_mode is None: + return if access_mode not in ACCESS_MODES: - raise ValueError(f"access_mode must be in {ACCESS_MODES}") + raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None") -def _validate_model_group_backend_roles(backend_roles): +def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]): if not isinstance(backend_roles, (NoneType, list)): raise ValueError("backend_roles should either be None or a list of roles names") -def _validate_model_group_add_all_backend_roles(add_all_backend_roles): - if not isinstance(add_all_backend_roles, bool): +def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]): + if not isinstance(add_all_backend_roles, (NoneType, bool)): raise ValueError("add_all_backend_roles should be a boolean") -def _validate_model_group_query(query, operation=None): +def _validate_model_group_query(query: dict, operation: Optional[str]=None): if not isinstance(query, dict): raise ValueError("query needs to be a dictionary") @@ -47,11 +50,11 @@ def _validate_model_group_query(query, operation=None): def validate_create_model_group_parameters( - name, - description, - access_mode, - backend_roles, - add_all_backend_roles, + name: str, + description: Optional[str] = None, + access_mode: Optional[str] = "private", + backend_roles: Optional[List[str]] = None, + add_all_backend_roles: Optional[bool] = False, ): _validate_model_group_name(name) _validate_model_group_description(description) @@ -78,39 +81,21 @@ def validate_create_model_group_parameters( def validate_update_model_group_parameters( - update_query, model_group_id, model_group_name + update_query: dict, model_group_id: str ): - if model_group_id and model_group_name: - raise ValueError( - "You cannot specify both model_group_id and model_group_name at the same time" - ) - if not isinstance(model_group_id, (NoneType, str)): + if not isinstance(model_group_id, str): raise ValueError("Invalid model_group_id. model_group_id needs to be a string") - if not isinstance(model_group_name, (NoneType, str)): - raise ValueError( - "Invalid model_group_name. model_group_name needs to be a string" - ) - if not isinstance(update_query, dict): raise ValueError("Invalid update_query. update_query needs to be a dictionary") -def validate_delete_model_group_parameters(model_group_id, model_group_name): - if model_group_id and model_group_name: - raise ValueError( - "You cannot specify both model_group_id and model_group_name at the same time" - ) +def validate_delete_model_group_parameters(model_group_id: str): - if not isinstance(model_group_id, (NoneType, str)): + if not isinstance(model_group_id, str): raise ValueError("Invalid model_group_id. model_group_id needs to be a string") - if not isinstance(model_group_name, (NoneType, str)): - raise ValueError( - "Invalid model_group_name. model_group_name needs to be a string" - ) - -def validate_search_model_group_parameters(query): +def validate_search_model_group_parameters(query: dict): _validate_model_group_query(query) diff --git a/tests/ml_commons/test_model_access_control.py b/tests/ml_commons/test_model_access_control.py index 119fd73d3..6d6bdea77 100644 --- a/tests/ml_commons/test_model_access_control.py +++ b/tests/ml_commons/test_model_access_control.py @@ -13,9 +13,10 @@ from packaging.version import parse as parse_version from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl +from opensearch_py_ml.ml_commons.validators.model_access_control import * from tests import OPENSEARCH_TEST_CLIENT -OPENSEARCH_VERSION = parse_version(os.environ["OPENSEARCH_VERSION"]) +OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.7.0")) MAC_MIN_VERSION = parse_version("2.8.0") MAC_UPDATE_MIN_VERSION = parse_version("2.11.0") @@ -95,6 +96,18 @@ def test_register_model_group(client): ) +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_get_model_group_id_by_name(client, test_model_group): + model_group_id = client.get_model_group_id_by_name(test_model_group) + assert model_group_id is not None + + model_group_id = client.get_model_group_id_by_name("test-unknown") + assert model_group_id is None + + @pytest.mark.skipif( OPENSEARCH_VERSION < MAC_UPDATE_MIN_VERSION, reason="Model groups updates are supported in OpenSearch 2.11.0 and above", @@ -105,7 +118,10 @@ def test_update_model_group(client, test_model_group): "description": "updated description", } try: - res = client.update_model_group(update_query, model_group_name=test_model_group) + model_group_id = client.get_model_group_id_by_name(test_model_group) + if model_group_id is None: + raise Exception(f"No model group found with the name: {test_model_group}") + res = client.update_model_group(update_query, model_group_id=model_group_id) assert isinstance(res, dict) assert "status" in res assert res["status"] == "Updated" @@ -188,7 +204,10 @@ def test_delete_model_group(client, test_model_group): for each in "AB": model_group_name = f"__test__model_group_{each}" - res = client.delete_model_group(model_group_name=model_group_name) + model_group_id = client.get_model_group_id_by_name(model_group_name) + if model_group_id is None: + continue + res = client.delete_model_group(model_group_id=model_group_id) assert res is None or isinstance(res, dict) if isinstance(res, dict): assert "result" in res diff --git a/tests/ml_commons/test_validators/test_model_access_control_validators.py b/tests/ml_commons/test_validators/test_model_access_control_validators.py new file mode 100644 index 000000000..1effba522 --- /dev/null +++ b/tests/ml_commons/test_validators/test_model_access_control_validators.py @@ -0,0 +1,132 @@ +import pytest + +from opensearch_py_ml.ml_commons.validators.model_access_control import ( + _validate_model_group_name, + _validate_model_group_description, + _validate_model_group_access_mode, + _validate_model_group_backend_roles, + _validate_model_group_add_all_backend_roles, + _validate_model_group_query, + validate_create_model_group_parameters, + validate_update_model_group_parameters, + validate_delete_model_group_parameters, + validate_search_model_group_parameters +) + + +def test_validate_model_group_name(): + with pytest.raises(ValueError): + _validate_model_group_name(None) + + with pytest.raises(ValueError): + _validate_model_group_name("") + + with pytest.raises(ValueError): + _validate_model_group_name(123) + + res = _validate_model_group_name("ValidName") + assert res is None + + +def test_validate_model_group_description(): + with pytest.raises(ValueError): + _validate_model_group_description(123) + + res = _validate_model_group_description("") + assert res is None + + res = _validate_model_group_description(None) + assert res is None + + res = _validate_model_group_description("ValidName") + assert res is None + + +def test_validate_model_group_access_mode(): + with pytest.raises(ValueError): + _validate_model_group_access_mode(123) + + res = _validate_model_group_access_mode("private") + assert res is None + + res = _validate_model_group_access_mode("restricted") + assert res is None + + res = _validate_model_group_access_mode(None) + assert res is None + +def test_validate_model_group_backend_roles(): + with pytest.raises(ValueError): + _validate_model_group_backend_roles(123) + + res = _validate_model_group_backend_roles(["admin"]) + assert res is None + + res = _validate_model_group_backend_roles(None) + assert res is None + +def test_validate_model_group_add_all_backend_roles(): + with pytest.raises(ValueError): + _validate_model_group_add_all_backend_roles(123) + + res = _validate_model_group_add_all_backend_roles(False) + assert res is None + + res = _validate_model_group_add_all_backend_roles(True) + assert res is None + + res = _validate_model_group_add_all_backend_roles(None) + assert res is None + + +def test_validate_model_group_query(): + with pytest.raises(ValueError): + _validate_model_group_query(123) + + res = _validate_model_group_query({}) + assert res is None + + with pytest.raises(ValueError): + _validate_model_group_query(None) + + res = _validate_model_group_query({"query": {"match": {"name": "test"}}}) + assert res is None + + +def test_validate_create_model_group_parameters(): + with pytest.raises(ValueError): + validate_create_model_group_parameters(123) + + res = validate_create_model_group_parameters("test") + assert res is None + + with pytest.raises(ValueError): + validate_create_model_group_parameters("test", access_mode="restricted") + + with pytest.raises(ValueError): + validate_create_model_group_parameters("test", access_mode="private",add_all_backend_roles=True) + + +def test_validate_update_model_group_parameters(): + with pytest.raises(ValueError): + validate_update_model_group_parameters(123, 123) + + res = validate_update_model_group_parameters({"query": {}}, "test") + assert res is None + +def test_validate_delete_model_group_parameters(): + with pytest.raises(ValueError): + validate_delete_model_group_parameters(123) + + res = validate_delete_model_group_parameters("test") + assert res is None + +def test_validate_search_model_group_parameters(): + with pytest.raises(ValueError): + validate_search_model_group_parameters(123) + + res = validate_search_model_group_parameters({"query": {}}) + assert res is None + + + \ No newline at end of file