diff --git a/opensearch_py_ml/ml_commons/connectors.py b/opensearch_py_ml/ml_commons/connector.py similarity index 99% rename from opensearch_py_ml/ml_commons/connectors.py rename to opensearch_py_ml/ml_commons/connector.py index 18535341a..2eb425dd7 100644 --- a/opensearch_py_ml/ml_commons/connectors.py +++ b/opensearch_py_ml/ml_commons/connector.py @@ -9,7 +9,7 @@ from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI -class Connectors: +class Connector: def __init__(self, os_client: OpenSearch): self.client = os_client diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 7548933cd..9a6929376 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -24,7 +24,7 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader -from opensearch_py_ml.ml_commons.connectors import Connectors +from opensearch_py_ml.ml_commons.connector import Connector class MLCommonClient: """ @@ -37,7 +37,7 @@ def __init__(self, os_client: OpenSearch): self._model_uploader = ModelUploader(os_client) self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) - self.connectors = Connectors(os_client) + self.connector = Connector(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/tests/ml_commons/test_connectors.py b/tests/ml_commons/test_connectors.py new file mode 100644 index 000000000..f93eb981c --- /dev/null +++ b/tests/ml_commons/test_connectors.py @@ -0,0 +1,98 @@ +from opensearch_py_ml.ml_commons.connectors import Connector +from opensearchpy.exceptions import NotFoundError, RequestError +from tests import OPENSEARCH_TEST_CLIENT +from packaging.version import parse as parse_version +import os +import pytest + + +OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0")) +CONNECTOR_MIN_VERSION = parse_version("2.11.0") + + +@pytest.fixture +def client(): + return Connector(OPENSEARCH_TEST_CLIENT) + + +def _safe_delete_connector(client, connector_id): + try: + client.delete_connector(connector_id=connector_id) + except NotFoundError: + pass + + +@pytest.fixture +def test_connector(client: Connector): + payload = { + "name": "Test Connector", + "description": "Connector for testing", + "version": 1, + "protocol": "http", + "parameters": {"endpoint": "api.openai.com", "model": "gpt-3.5-turbo"}, + "credential": {"openAI_key": "..."}, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions", + "headers": {"Authorization": "Bearer ${credential.openAI_key}"}, + "request_body": '{ "model": "${parameters.model}", "messages": ${parameters.messages} }', + } + ], + } + + res = client.create_standalone_connector(payload) + connector_id = res['connector_id'] + yield connector_id + + _safe_delete_connector(client, connector_id) + + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_create_standalone_connector(): + pass + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_create_internal_connector(): + pass + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_list_connector(): + pass + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_search_connector(): + pass + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_get_connector(): + pass + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION, + reason="Connectors are supported in OpenSearch 2.9.0 and above", +) +def test_delete_connector(): + pass \ No newline at end of file