Skip to content

Commit

Permalink
pr fixes
Browse files Browse the repository at this point in the history
Signed-off-by: kalyan <[email protected]>
  • Loading branch information
rawwar committed Nov 6, 2023
1 parent 24cb85c commit 078443f
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 62 deletions.
34 changes: 11 additions & 23 deletions opensearch_py_ml/ml_commons/model_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,8 @@ def update_model_group(
self,
update_query: dict,
model_group_id: Optional[str] = None,
model_group_name: Optional[str] = None,
):
validate_update_model_group_parameters(
update_query, model_group_id, model_group_name
)
if model_group_name:
model_group = self.search_model_group_by_name(model_group_name)
try:
if len(model_group["hits"]["hits"]) > 0:
model_group_id = model_group["hits"]["hits"][0]["_id"]
else:
raise Exception
except Exception:
raise Exception(f"Model group with name: {model_group_name} not found")
validate_update_model_group_parameters(update_query, model_group_id)
return self.client.transport.perform_request(
method="PUT",
url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}",
Expand All @@ -83,22 +71,22 @@ def search_model_group_by_name(self, model_group_name, _source=None, size=1):
if _source:
query["_source"] = _source
return self.search_model_group(query)

def get_model_group_id_by_name(self, model_group_name):
try:
res = self.search_model_group_by_name(model_group_name)
if res["hits"]["hits"]:
return res["hits"]["hits"][0]["_id"]
else:
return None
except NotFoundError:
return None

def delete_model_group(
self,
model_group_id: str = None,
model_group_name: str = None,
ignore_if_not_exists=True,
):
validate_delete_model_group_parameters(model_group_id, model_group_name)
if model_group_name:
model_group = self.search_model_group_by_name(model_group_name)
try:
model_group_id = model_group["hits"]["hits"][0]["_id"]
except (KeyError, IndexError):
if ignore_if_not_exists:
return None
raise Exception(f"Model group with name: {model_group_name} not found")
return self.client.transport.perform_request(
method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}"
)
6 changes: 6 additions & 0 deletions opensearch_py_ml/ml_commons/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
57 changes: 21 additions & 36 deletions opensearch_py_ml/ml_commons/validators/model_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,41 @@

""" Module for validating model access control parameters """

from typing import List, Optional

ACCESS_MODES = ["public", "private", "restricted"]

NoneType = type(None)


def _validate_model_group_name(name):
def _validate_model_group_name(name: str):
if not name or not isinstance(name, str):
raise ValueError("name is required and needs to be a string")


def _validate_model_group_description(description):
def _validate_model_group_description(description: Optional[str]):
if not isinstance(description, (NoneType, str)):
raise ValueError("description needs to be a string")


def _validate_model_group_access_mode(access_mode):
def _validate_model_group_access_mode(access_mode: Optional[str]):
if access_mode is None:
return
if access_mode not in ACCESS_MODES:
raise ValueError(f"access_mode must be in {ACCESS_MODES}")
raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None")


def _validate_model_group_backend_roles(backend_roles):
def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]):
if not isinstance(backend_roles, (NoneType, list)):
raise ValueError("backend_roles should either be None or a list of roles names")


def _validate_model_group_add_all_backend_roles(add_all_backend_roles):
if not isinstance(add_all_backend_roles, bool):
def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]):
if not isinstance(add_all_backend_roles, (NoneType, bool)):
raise ValueError("add_all_backend_roles should be a boolean")


def _validate_model_group_query(query, operation=None):
def _validate_model_group_query(query: dict, operation: Optional[str]=None):
if not isinstance(query, dict):
raise ValueError("query needs to be a dictionary")

Expand All @@ -47,11 +50,11 @@ def _validate_model_group_query(query, operation=None):


def validate_create_model_group_parameters(
name,
description,
access_mode,
backend_roles,
add_all_backend_roles,
name: str,
description: Optional[str] = None,
access_mode: Optional[str] = "private",
backend_roles: Optional[List[str]] = None,
add_all_backend_roles: Optional[bool] = False,
):
_validate_model_group_name(name)
_validate_model_group_description(description)
Expand All @@ -78,39 +81,21 @@ def validate_create_model_group_parameters(


def validate_update_model_group_parameters(
update_query, model_group_id, model_group_name
update_query: dict, model_group_id: str
):
if model_group_id and model_group_name:
raise ValueError(
"You cannot specify both model_group_id and model_group_name at the same time"
)

if not isinstance(model_group_id, (NoneType, str)):
if not isinstance(model_group_id, str):
raise ValueError("Invalid model_group_id. model_group_id needs to be a string")

if not isinstance(model_group_name, (NoneType, str)):
raise ValueError(
"Invalid model_group_name. model_group_name needs to be a string"
)

if not isinstance(update_query, dict):
raise ValueError("Invalid update_query. update_query needs to be a dictionary")


def validate_delete_model_group_parameters(model_group_id, model_group_name):
if model_group_id and model_group_name:
raise ValueError(
"You cannot specify both model_group_id and model_group_name at the same time"
)
def validate_delete_model_group_parameters(model_group_id: str):

if not isinstance(model_group_id, (NoneType, str)):
if not isinstance(model_group_id, str):
raise ValueError("Invalid model_group_id. model_group_id needs to be a string")

if not isinstance(model_group_name, (NoneType, str)):
raise ValueError(
"Invalid model_group_name. model_group_name needs to be a string"
)


def validate_search_model_group_parameters(query):
def validate_search_model_group_parameters(query: dict):
_validate_model_group_query(query)
25 changes: 22 additions & 3 deletions tests/ml_commons/test_model_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from packaging.version import parse as parse_version

from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.validators.model_access_control import *
from tests import OPENSEARCH_TEST_CLIENT

OPENSEARCH_VERSION = parse_version(os.environ["OPENSEARCH_VERSION"])
OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.7.0"))
MAC_MIN_VERSION = parse_version("2.8.0")
MAC_UPDATE_MIN_VERSION = parse_version("2.11.0")

Expand Down Expand Up @@ -95,6 +96,18 @@ def test_register_model_group(client):
)


@pytest.mark.skipif(
OPENSEARCH_VERSION < MAC_MIN_VERSION,
reason="Model groups are supported in OpenSearch 2.8.0 and above",
)
def test_get_model_group_id_by_name(client, test_model_group):
model_group_id = client.get_model_group_id_by_name(test_model_group)
assert model_group_id is not None

model_group_id = client.get_model_group_id_by_name("test-unknown")
assert model_group_id is None


@pytest.mark.skipif(
OPENSEARCH_VERSION < MAC_UPDATE_MIN_VERSION,
reason="Model groups updates are supported in OpenSearch 2.11.0 and above",
Expand All @@ -105,7 +118,10 @@ def test_update_model_group(client, test_model_group):
"description": "updated description",
}
try:
res = client.update_model_group(update_query, model_group_name=test_model_group)
model_group_id = client.get_model_group_id_by_name(test_model_group)
if model_group_id is None:
raise Exception(f"No model group found with the name: {test_model_group}")
res = client.update_model_group(update_query, model_group_id=model_group_id)
assert isinstance(res, dict)
assert "status" in res
assert res["status"] == "Updated"
Expand Down Expand Up @@ -188,7 +204,10 @@ def test_delete_model_group(client, test_model_group):

for each in "AB":
model_group_name = f"__test__model_group_{each}"
res = client.delete_model_group(model_group_name=model_group_name)
model_group_id = client.get_model_group_id_by_name(model_group_name)
if model_group_id is None:
continue
res = client.delete_model_group(model_group_id=model_group_id)
assert res is None or isinstance(res, dict)
if isinstance(res, dict):
assert "result" in res
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest

from opensearch_py_ml.ml_commons.validators.model_access_control import (
_validate_model_group_name,
_validate_model_group_description,
_validate_model_group_access_mode,
_validate_model_group_backend_roles,
_validate_model_group_add_all_backend_roles,
_validate_model_group_query,
validate_create_model_group_parameters,
validate_update_model_group_parameters,
validate_delete_model_group_parameters,
validate_search_model_group_parameters
)


def test_validate_model_group_name():
with pytest.raises(ValueError):
_validate_model_group_name(None)

with pytest.raises(ValueError):
_validate_model_group_name("")

with pytest.raises(ValueError):
_validate_model_group_name(123)

res = _validate_model_group_name("ValidName")
assert res is None


def test_validate_model_group_description():
with pytest.raises(ValueError):
_validate_model_group_description(123)

res = _validate_model_group_description("")
assert res is None

res = _validate_model_group_description(None)
assert res is None

res = _validate_model_group_description("ValidName")
assert res is None


def test_validate_model_group_access_mode():
with pytest.raises(ValueError):
_validate_model_group_access_mode(123)

res = _validate_model_group_access_mode("private")
assert res is None

res = _validate_model_group_access_mode("restricted")
assert res is None

res = _validate_model_group_access_mode(None)
assert res is None

def test_validate_model_group_backend_roles():
with pytest.raises(ValueError):
_validate_model_group_backend_roles(123)

res = _validate_model_group_backend_roles(["admin"])
assert res is None

res = _validate_model_group_backend_roles(None)
assert res is None

def test_validate_model_group_add_all_backend_roles():
with pytest.raises(ValueError):
_validate_model_group_add_all_backend_roles(123)

res = _validate_model_group_add_all_backend_roles(False)
assert res is None

res = _validate_model_group_add_all_backend_roles(True)
assert res is None

res = _validate_model_group_add_all_backend_roles(None)
assert res is None


def test_validate_model_group_query():
with pytest.raises(ValueError):
_validate_model_group_query(123)

res = _validate_model_group_query({})
assert res is None

with pytest.raises(ValueError):
_validate_model_group_query(None)

res = _validate_model_group_query({"query": {"match": {"name": "test"}}})
assert res is None


def test_validate_create_model_group_parameters():
with pytest.raises(ValueError):
validate_create_model_group_parameters(123)

res = validate_create_model_group_parameters("test")
assert res is None

with pytest.raises(ValueError):
validate_create_model_group_parameters("test", access_mode="restricted")

with pytest.raises(ValueError):
validate_create_model_group_parameters("test", access_mode="private",add_all_backend_roles=True)


def test_validate_update_model_group_parameters():
with pytest.raises(ValueError):
validate_update_model_group_parameters(123, 123)

res = validate_update_model_group_parameters({"query": {}}, "test")
assert res is None

def test_validate_delete_model_group_parameters():
with pytest.raises(ValueError):
validate_delete_model_group_parameters(123)

res = validate_delete_model_group_parameters("test")
assert res is None

def test_validate_search_model_group_parameters():
with pytest.raises(ValueError):
validate_search_model_group_parameters(123)

res = validate_search_model_group_parameters({"query": {}})
assert res is None



0 comments on commit 078443f

Please sign in to comment.