From df6c0699815792786b6a29b1267042164974f38e Mon Sep 17 00:00:00 2001 From: kalyan Date: Mon, 6 Nov 2023 22:56:48 +0530 Subject: [PATCH] fix lint and increase coverage Signed-off-by: kalyan --- tests/ml_commons/test_model_access_control.py | 8 +++++++- .../test_model_access_control_validators.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/ml_commons/test_model_access_control.py b/tests/ml_commons/test_model_access_control.py index 782129454..b65a06fad 100644 --- a/tests/ml_commons/test_model_access_control.py +++ b/tests/ml_commons/test_model_access_control.py @@ -7,9 +7,10 @@ import os import time +from unittest.mock import patch import pytest -from opensearchpy.exceptions import RequestError +from opensearchpy.exceptions import NotFoundError, RequestError from packaging.version import parse as parse_version from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl @@ -120,6 +121,11 @@ def test_get_model_group_id_by_name(client, test_model_group): model_group_id = client.get_model_group_id_by_name("test-unknown") assert model_group_id is None + # Mock NotFoundError as it only happens when index isn't created + with patch.object(client, "search_model_group_by_name", side_effect=NotFoundError): + model_group_id = client.get_model_group_id_by_name(test_model_group) + assert model_group_id is None + @pytest.mark.skipif( OPENSEARCH_VERSION < MAC_UPDATE_MIN_VERSION, diff --git a/tests/ml_commons/test_validators/test_model_access_control_validators.py b/tests/ml_commons/test_validators/test_model_access_control_validators.py index a53f42c94..b424e4b2e 100644 --- a/tests/ml_commons/test_validators/test_model_access_control_validators.py +++ b/tests/ml_commons/test_validators/test_model_access_control_validators.py @@ -101,6 +101,9 @@ def test_validate_model_group_query(): res = _validate_model_group_query({"query": {"match": {"name": "test"}}}) assert res is None + with pytest.raises(ValueError): + _validate_model_group_query({}, 123) + def test_validate_create_model_group_parameters(): with pytest.raises(ValueError):