Skip to content

Commit

Permalink
add tests for register model group
Browse files Browse the repository at this point in the history
Signed-off-by: kalyan <[email protected]>
  • Loading branch information
rawwar committed Nov 4, 2023
1 parent 61c3090 commit a89bb39
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 8 deletions.
16 changes: 9 additions & 7 deletions opensearch_py_ml/ml_commons/model_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# GitHub history for details.

from opensearchpy import OpenSearch
from opensearchpy.exceptions import RequestError
import json
from typing import List, Optional
from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI
Expand All @@ -28,12 +29,11 @@ def register_model_group(
description: Optional[str] = None,
access_mode: Optional[str] = "private",
backend_roles: Optional[List[str]] = None,
add_all_backend_roles: Optional[bool] = False,
add_all_backend_roles: Optional[bool] = False
):
validate_create_model_group_parameters(
name, description, access_mode, backend_roles, add_all_backend_roles
)
# import pdb;pdb.set_trace()

body = {"name": name, "add_all_backend_roles": add_all_backend_roles}
if description:
Expand Down Expand Up @@ -79,20 +79,22 @@ def search_model_group(self, query):
)

def search_model_group_by_name(self, model_group_name, _source=None, size=1):
query = {"query": {"term": {"name": model_group_name}}, "size": size}
query = {"query": {"match": {"name": model_group_name}}, "size": size}
if _source:
query["_source"] = _source
return self.search_model_group(query)

def delete_model_group(
self, model_group_id: str = None, model_group_name: str = None
self, model_group_id: str = None, model_group_name: str = None, ignore_if_not_exists=True
):
validate_delete_model_group_parameters(model_group_name, model_group_id)
validate_delete_model_group_parameters(model_group_id, model_group_name)
if model_group_name:
model_group = self.search_model_group_by_name(model_group_name)
try:
model_group_id = model_group["hits"]["hits"][0]["_source"]["name"]
except KeyError:
model_group_id = model_group["hits"]["hits"][0]["_id"]
except (KeyError, IndexError):
if ignore_if_not_exists:
return None
raise Exception(f"Model group with name: {model_group_name} not found")
return self.client.transport.perform_request(
method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}"
Expand Down
3 changes: 3 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
import opensearch_py_ml as oml

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_PREFIX = "oml_test__"
INDEX_PREFIX = f"{TEST_PREFIX}_index_"
MODEL_GROUP_PREFIX = f"{TEST_PREFIX}_model_group_"

# Create pandas and opensearch_py_ml data frames
from tests import (
Expand Down
95 changes: 94 additions & 1 deletion tests/ml_commons/test_model_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,97 @@
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import pytest
import pytest
import time
from unittest import TestCase
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from tests import OPENSEARCH_TEST_CLIENT
from opensearchpy.exceptions import RequestError


@pytest.fixture
def client():
return ModelAccessControl(OPENSEARCH_TEST_CLIENT)

@pytest.fixture
def test_model_group(client):
model_group_name = "__test__model_group_1"
client.delete_model_group(model_group_name=model_group_name)
# time.sleep(0.5)
client.register_model_group(
name=model_group_name,
description="test model group for opensearch-py-ml test cases",
)
yield model_group_name

client.delete_model_group(model_group_name=model_group_name)


def test_register_model_group(client):

model_group_name1 = "__test__model_group_A"
# import pdb;pdb.set_trace()
try:
_ = client.delete_model_group(model_group_name=model_group_name1)
time.sleep(2)
res = client.register_model_group(name=model_group_name1)
assert isinstance(res, dict)
assert "model_group_id" in res
assert "status" in res
assert res['status'] == "CREATED"
except Exception as ex:
assert False,f"Failed to register model group due to {ex}"

model_group_name2 = "__test__model_group_B"

try:
_ = client.delete_model_group(model_group_name=model_group_name2)
time.sleep(2)
res = client.register_model_group(
name=model_group_name2,
description="test",
access_mode="restricted",
backend_roles=["admin"],
)
assert "model_group_id" in res
assert "status" in res
assert res['status'] == "CREATED"
except Exception as ex:
assert False,f"Failed to register restricted model group due to {ex}"

model_group_name3 = "__test__model_group_C"
with pytest.raises(RequestError) as exec_info:
_ = client.delete_model_group(model_group_name=model_group_name3)
time.sleep(2)
res = client.register_model_group(
name=model_group_name3,
description="test",
access_mode="restricted",
add_all_backend_roles=True
)
assert exec_info.value.status_code == 400
assert exec_info.match("Admin users cannot add all backend roles to a model group")

with pytest.raises(RequestError) as exec_info:
client.register_model_group(name=model_group_name2)
assert exec_info.value.status_code == 400
assert exec_info.match("The name you provided is already being used by a model group")

for each in "ABC":
client.delete_model_group(model_group_name=f"__test__model_group_{each}")


def test_update_model_group(client, test_model_group):
pass


def test_delete_model_group(client):
pass


def test_search_model_group(client):
pass


def test_search_model_group_by_name(client):
pass
28 changes: 28 additions & 0 deletions tests/setup_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from opensearchpy import OpenSearch
from opensearchpy.exceptions import RequestError
from .common import OPENSEARCH_TEST_CLIENT, MODEL_GROUP_PREFIX
import time




def delete_test_model_groups(os: OpenSearch):
model_group_query = {"query": {"match_phrase_prefix": {"name": MODEL_GROUP_PREFIX}}}

try:
result = os.transport.perform_request(
method="GET", url="/_plugins/_ml/model_groups/_search", body=model_group_query
)

for each in result["hits"]["hits"]:
try:
os.transport.perform_request(
method="DELETE", url=f"/_plugins/_ml/model_groups/{each['_id']}"
)
time.sleep(0.2)
except Exception as ex:
print(f"Failed to delete model group id: {each['_id']}")
except RequestError as ex:
print(f"Failed due to request error: {ex}")
except Exception as ex:
print(f"Unexpected Model group deletion failure due to: {ex}")

0 comments on commit a89bb39

Please sign in to comment.