From 1666baa7f0959c974fe397de2b7e8bff1cba84c2 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 6 Dec 2023 08:07:24 +0530 Subject: [PATCH] #284: Add support for connectors (#345) * add connectors Signed-off-by: kalyan * update Signed-off-by: kalyan * fix Signed-off-by: kalyan * rename Signed-off-by: kalyanr * add tests Signed-off-by: kalyan * fix Signed-off-by: kalyan * fix Signed-off-by: kalyan * lint fix Signed-off-by: kalyan * update changelog Signed-off-by: kalyan * increase test coverage Signed-off-by: kalyan --------- Signed-off-by: kalyan Signed-off-by: kalyanr --- CHANGELOG.md | 1 + .../ml_commons/ml_commons_client.py | 2 + .../ml_commons/model_connector.py | 51 ++++++ tests/ml_commons/test_model_connector.py | 158 ++++++++++++++++++ 4 files changed, 212 insertions(+) create mode 100644 opensearch_py_ml/ml_commons/model_connector.py create mode 100644 tests/ml_commons/test_model_connector.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d68dbe63..4eb78bf78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283)) - Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) - Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) +- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 72e2e158b..2fde29294 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -22,6 +22,7 @@ TIMEOUT, ) from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl +from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader @@ -37,6 +38,7 @@ def __init__(self, os_client: OpenSearch): self._model_uploader = ModelUploader(os_client) self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) + self.connector = Connector(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_connector.py b/opensearch_py_ml/ml_commons/model_connector.py new file mode 100644 index 000000000..d06ce196f --- /dev/null +++ b/opensearch_py_ml/ml_commons/model_connector.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +from opensearchpy import OpenSearch + +from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI + + +class Connector: + def __init__(self, os_client: OpenSearch): + self.client = os_client + + def create_standalone_connector(self, payload: dict): + if not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary") + + return self.client.transport.perform_request( + method="POST", url=f"{ML_BASE_URI}/connectors/_create", body=payload + ) + + def list_connectors(self): + search_query = {"query": {"match_all": {}}} + return self.search_connectors(search_query) + + def search_connectors(self, search_query: dict): + if not isinstance(search_query, dict): + raise ValueError("search_query needs to be a dictionary") + + return self.client.transport.perform_request( + method="POST", url=f"{ML_BASE_URI}/connectors/_search", body=search_query + ) + + def get_connector(self, connector_id: str): + if not isinstance(connector_id, str): + raise ValueError("connector_id needs to be a string") + + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/connectors/{connector_id}" + ) + + def delete_connector(self, connector_id: str): + if not isinstance(connector_id, str): + raise ValueError("connector_id needs to be a string") + + return self.client.transport.perform_request( + method="DELETE", url=f"{ML_BASE_URI}/connectors/{connector_id}" + ) diff --git a/tests/ml_commons/test_model_connector.py b/tests/ml_commons/test_model_connector.py new file mode 100644 index 000000000..a8eee71bd --- /dev/null +++ b/tests/ml_commons/test_model_connector.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +import os + +import pytest +from opensearchpy.exceptions import NotFoundError, RequestError +from packaging.version import parse as parse_version + +from opensearch_py_ml.ml_commons.model_connector import Connector +from tests import OPENSEARCH_TEST_CLIENT + +OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0")) +CONNECTOR_MIN_VERSION = parse_version("2.9.0") + + +@pytest.fixture +def client(): + return Connector(OPENSEARCH_TEST_CLIENT) + + +def _safe_delete_connector(client, connector_id): + try: + client.delete_connector(connector_id=connector_id) + except NotFoundError: + pass + + +@pytest.fixture +def connector_payload(): + return { + "name": "Test Connector", + "description": "Connector for testing", + "version": 1, + "protocol": "http", + "parameters": {"endpoint": "api.openai.com", "model": "gpt-3.5-turbo"}, + "credential": {"openAI_key": "..."}, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions", + "headers": {"Authorization": "Bearer ${credential.openAI_key}"}, + "request_body": '{ "model": "${parameters.model}", "messages": ${parameters.messages} }', + } + ], + } + + +@pytest.fixture +def test_connector(client: Connector, connector_payload: dict): + res = client.create_standalone_connector(connector_payload) + connector_id = res["connector_id"] + yield connector_id + + _safe_delete_connector(client, connector_id) + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_create_standalone_connector(client: Connector, connector_payload: dict): + res = client.create_standalone_connector(connector_payload) + assert "connector_id" in res + + _safe_delete_connector(client, res["connector_id"]) + + with pytest.raises(ValueError): + client.create_standalone_connector("") + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_list_connectors(client, test_connector): + try: + res = client.list_connectors() + assert len(res["hits"]["hits"]) > 0 + + # check if test_connector id is in the response + found = False + for each in res["hits"]["hits"]: + if each["_id"] == test_connector: + found = True + break + assert found, "Test connector not found in list connectors response" + except Exception as ex: + assert False, f"Failed to list connectors due to {ex}" + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_search_connectors(client, test_connector): + try: + query = {"query": {"match": {"name": "Test Connector"}}} + res = client.search_connectors(query) + assert len(res["hits"]["hits"]) > 0 + + # check if test_connector id is in the response + found = False + for each in res["hits"]["hits"]: + if each["_id"] == test_connector: + found = True + break + assert found, "Test connector not found in search connectors response" + except Exception as ex: + assert False, f"Failed to search connectors due to {ex}" + + with pytest.raises(ValueError): + client.search_connectors("test") + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_get_connector(client, test_connector): + try: + res = client.get_connector(connector_id=test_connector) + assert res["name"] == "Test Connector" + except Exception as ex: + assert False, f"Failed to get connector due to {ex}" + + with pytest.raises(ValueError): + client.get_connector(connector_id=None) + + with pytest.raises(RequestError) as exec_info: + client.get_connector(connector_id="test-unknown") + assert exec_info.value.status_code == 400 + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_delete_connector(client, test_connector): + try: + res = client.delete_connector(connector_id=test_connector) + assert res["result"] == "deleted" + except Exception as ex: + assert False, f"Failed to delete connector due to {ex}" + + try: + res = client.delete_connector(connector_id="unknown") + assert res["result"] == "not_found" + except Exception as ex: + assert False, f"Failed to delete connector due to {ex}" + + with pytest.raises(ValueError): + client.delete_connector(connector_id={"test": "fail"})