forked from opensearch-project/opensearch-py-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
opensearch-project#289 - Add register, update and delete model group …
…functionality to support Model Access Control (opensearch-project#332) * init Signed-off-by: kalyan <[email protected]> * add search, delete and update Signed-off-by: kalyan <[email protected]> * add tests for register model group Signed-off-by: kalyan <[email protected]> * update cluster to 2.11 Signed-off-by: kalyan <[email protected]> * test skipif Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * add tests Signed-off-by: kalyan <[email protected]> * update matrix Signed-off-by: kalyan <[email protected]> * cancel in progress Signed-off-by: kalyan <[email protected]> * update concurrency Signed-off-by: kalyan <[email protected]> * job level concurrency Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * fix tests Signed-off-by: kalyan <[email protected]> * tests passing Signed-off-by: kalyan <[email protected]> * isort fix Signed-off-by: kalyan <[email protected]> * fix action Signed-off-by: kalyan <[email protected]> * fix action Signed-off-by: kalyan <[email protected]> * fix action Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * update changelog Signed-off-by: kalyan <[email protected]> * fix os dockerfile Signed-off-by: kalyan <[email protected]> * test Signed-off-by: kalyanr <[email protected]> * pass opensearch version Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * update OS dockerfile Signed-off-by: kalyan <[email protected]> * fix failing tests Signed-off-by: kalyan <[email protected]> * update dockerfile for 2.11.0 Signed-off-by: kalyan <[email protected]> * remove disable warning Signed-off-by: kalyan <[email protected]> * fix upload model Signed-off-by: kalyan <[email protected]> * fix lint Signed-off-by: kalyan <[email protected]> * fix lint Signed-off-by: kalyan <[email protected]> * include reference Signed-off-by: kalyan <[email protected]> * pr fixes Signed-off-by: kalyan <[email protected]> * lint fix Signed-off-by: kalyan <[email protected]> * fix lint Signed-off-by: kalyan <[email protected]> * fix tests Signed-off-by: kalyan <[email protected]> * skip Signed-off-by: kalyan <[email protected]> * fix lint Signed-off-by: kalyan <[email protected]> * fix lint and increase coverage Signed-off-by: kalyan <[email protected]> * fix lint Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * feedback fixes Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * lint fix Signed-off-by: kalyan <[email protected]> * fix test cases Signed-off-by: kalyan <[email protected]> * pr feedback fixes Signed-off-by: kalyanr <[email protected]> * revert Signed-off-by: kalyanr <[email protected]> --------- Signed-off-by: kalyan <[email protected]> Signed-off-by: kalyanr <[email protected]>
- Loading branch information
Showing
11 changed files
with
663 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,18 @@ | ||
ARG OPENSEARCH_VERSION | ||
ARG OPENSEARCH_VERSION=latest | ||
FROM opensearchproject/opensearch:$OPENSEARCH_VERSION | ||
|
||
# OPENSEARCH_VERSION needs to be redefined as any arg before FROM is outside build scope. | ||
# Reference: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact | ||
ARG OPENSEARCH_VERSION=latest | ||
ARG opensearch_path=/usr/share/opensearch | ||
ARG opensearch_yml=$opensearch_path/config/opensearch.yml | ||
|
||
ARG SECURE_INTEGRATION | ||
RUN echo "plugins.ml_commons.only_run_on_ml_node: false" >> $opensearch_yml; | ||
RUN echo "plugins.ml_commons.native_memory_threshold: 100" >> $opensearch_yml; | ||
RUN if [ "$OPENSEARCH_VERSION" == "2.11.0" ] ; then \ | ||
echo "plugins.ml_commons.model_access_control_enabled: true" >> $opensearch_yml; \ | ||
echo "plugins.ml_commons.allow_registering_model_via_local_file: true" >> $opensearch_yml; \ | ||
echo "plugins.ml_commons.allow_registering_model_via_url: true" >> $opensearch_yml; \ | ||
fi | ||
RUN if [ "$SECURE_INTEGRATION" != "true" ] ; then echo "plugins.security.disabled: true" >> $opensearch_yml; fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. | ||
|
||
from typing import List, Optional | ||
|
||
from opensearchpy import OpenSearch | ||
from opensearchpy.exceptions import NotFoundError | ||
|
||
from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI | ||
from opensearch_py_ml.ml_commons.validators.model_access_control import ( | ||
validate_create_model_group_parameters, | ||
validate_delete_model_group_parameters, | ||
validate_search_model_group_parameters, | ||
validate_update_model_group_parameters, | ||
) | ||
|
||
|
||
class ModelAccessControl: | ||
API_ENDPOINT = "model_groups" | ||
|
||
def __init__(self, os_client: OpenSearch): | ||
self.client = os_client | ||
|
||
def register_model_group( | ||
self, | ||
name: str, | ||
description: Optional[str] = None, | ||
access_mode: Optional[str] = "private", | ||
backend_roles: Optional[List[str]] = None, | ||
add_all_backend_roles: Optional[bool] = False, | ||
): | ||
validate_create_model_group_parameters( | ||
name, description, access_mode, backend_roles, add_all_backend_roles | ||
) | ||
|
||
body = {"name": name, "add_all_backend_roles": add_all_backend_roles} | ||
if description: | ||
body["description"] = description | ||
if access_mode: | ||
body["access_mode"] = access_mode | ||
if backend_roles: | ||
body["backend_roles"] = backend_roles | ||
|
||
return self.client.transport.perform_request( | ||
method="POST", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_register", body=body | ||
) | ||
|
||
def update_model_group( | ||
self, | ||
update_query: dict, | ||
model_group_id: Optional[str] = None, | ||
): | ||
validate_update_model_group_parameters(update_query, model_group_id) | ||
return self.client.transport.perform_request( | ||
method="PUT", | ||
url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}", | ||
body=update_query, | ||
) | ||
|
||
def search_model_group(self, query: dict): | ||
validate_search_model_group_parameters(query) | ||
return self.client.transport.perform_request( | ||
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_search", body=query | ||
) | ||
|
||
def search_model_group_by_name( | ||
self, | ||
model_group_name: str, | ||
_source: Optional[List] = None, | ||
size: Optional[int] = 1, | ||
): | ||
query = {"query": {"match": {"name": model_group_name}}, "size": size} | ||
if _source: | ||
query["_source"] = _source | ||
return self.search_model_group(query) | ||
|
||
def get_model_group_id_by_name(self, model_group_name: str): | ||
try: | ||
res = self.search_model_group_by_name(model_group_name) | ||
if res["hits"]["hits"]: | ||
return res["hits"]["hits"][0]["_id"] | ||
else: | ||
raise NotFoundError | ||
except NotFoundError: | ||
print(f"No model group found with name:{model_group_name}") | ||
return None | ||
except Exception as ex: | ||
print(f"Error in get_model_group_id_by_name: {ex}") | ||
return None | ||
|
||
def delete_model_group(self, model_group_id: str): | ||
validate_delete_model_group_parameters(model_group_id) | ||
return self.client.transport.perform_request( | ||
method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}" | ||
) | ||
|
||
def delete_model_group_by_name(self, model_group_name: str): | ||
model_group_id = self.get_model_group_id_by_name(model_group_name) | ||
if model_group_id is None: | ||
raise NotFoundError(f"Model group {model_group_name} not found") | ||
return self.delete_model_group(model_group_id=model_group_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. |
97 changes: 97 additions & 0 deletions
97
opensearch_py_ml/ml_commons/validators/model_access_control.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. | ||
|
||
""" Module for validating model access control parameters """ | ||
|
||
from typing import List, Optional | ||
|
||
ACCESS_MODES = ["public", "private", "restricted"] | ||
|
||
NoneType = type(None) | ||
|
||
|
||
def _validate_model_group_name(name: str): | ||
if not name or not isinstance(name, str): | ||
raise ValueError("name is required and needs to be a string") | ||
|
||
|
||
def _validate_model_group_description(description: Optional[str]): | ||
if not isinstance(description, (NoneType, str)): | ||
raise ValueError("description needs to be a string") | ||
|
||
|
||
def _validate_model_group_access_mode(access_mode: Optional[str]): | ||
if access_mode is None: | ||
return | ||
if access_mode not in ACCESS_MODES: | ||
raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None") | ||
|
||
|
||
def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]): | ||
if not isinstance(backend_roles, (NoneType, list)): | ||
raise ValueError("backend_roles should either be None or a list of roles names") | ||
|
||
|
||
def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]): | ||
if not isinstance(add_all_backend_roles, (NoneType, bool)): | ||
raise ValueError("add_all_backend_roles should be a boolean") | ||
|
||
|
||
def _validate_model_group_query(query: dict, operation: Optional[str] = None): | ||
if not isinstance(query, dict): | ||
raise ValueError("query needs to be a dictionary") | ||
|
||
if operation and not isinstance(operation, str): | ||
raise ValueError("operation needs to be a string") | ||
|
||
|
||
def validate_create_model_group_parameters( | ||
name: str, | ||
description: Optional[str] = None, | ||
access_mode: Optional[str] = "private", | ||
backend_roles: Optional[List[str]] = None, | ||
add_all_backend_roles: Optional[bool] = False, | ||
): | ||
_validate_model_group_name(name) | ||
_validate_model_group_description(description) | ||
_validate_model_group_access_mode(access_mode) | ||
_validate_model_group_backend_roles(backend_roles) | ||
_validate_model_group_add_all_backend_roles(add_all_backend_roles) | ||
|
||
if access_mode == "restricted": | ||
if not backend_roles and not add_all_backend_roles: | ||
raise ValueError( | ||
"You must specify either backend_roles or add_all_backend_roles=True for restricted access_mode" | ||
) | ||
|
||
if backend_roles and add_all_backend_roles: | ||
raise ValueError( | ||
"You cannot specify both backend_roles and add_all_backend_roles=True at the same time" | ||
) | ||
|
||
elif access_mode == "private": | ||
if backend_roles or add_all_backend_roles: | ||
raise ValueError( | ||
"You must not specify backend_roles or add_all_backend_roles=True for a private model group" | ||
) | ||
|
||
|
||
def validate_update_model_group_parameters(update_query: dict, model_group_id: str): | ||
if not isinstance(model_group_id, str): | ||
raise ValueError("Invalid model_group_id. model_group_id needs to be a string") | ||
|
||
if not isinstance(update_query, dict): | ||
raise ValueError("Invalid update_query. update_query needs to be a dictionary") | ||
|
||
|
||
def validate_delete_model_group_parameters(model_group_id: str): | ||
if not isinstance(model_group_id, str): | ||
raise ValueError("Invalid model_group_id. model_group_id needs to be a string") | ||
|
||
|
||
def validate_search_model_group_parameters(query: dict): | ||
_validate_model_group_query(query) |
Oops, something went wrong.