diff --git a/opensearch_py_ml/ml_commons/connector.py b/opensearch_py_ml/ml_commons/connector.py index 2eb425dd7..bd8c72749 100644 --- a/opensearch_py_ml/ml_commons/connector.py +++ b/opensearch_py_ml/ml_commons/connector.py @@ -6,6 +6,7 @@ # GitHub history for details. from opensearchpy import OpenSearch + from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI @@ -40,12 +41,11 @@ def search_connectors(self, search_query: dict): return self.client.transport.perform_request( method="POST", url=f"{ML_BASE_URI}/connectors/_search", body=search_query ) - - + def get_connector(self, connector_id: str): if not isinstance(connector_id, str): raise ValueError("connector_id needs to be a string") - + return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/connectors/{connector_id}" ) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 9a6929376..e9bd09606 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -13,6 +13,7 @@ from deprecated.sphinx import deprecated from opensearchpy import OpenSearch +from opensearch_py_ml.ml_commons.connector import Connector from opensearch_py_ml.ml_commons.ml_common_utils import ( ML_BASE_URI, MODEL_FORMAT_FIELD, @@ -24,7 +25,7 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader -from opensearch_py_ml.ml_commons.connector import Connector + class MLCommonClient: """ diff --git a/tests/ml_commons/test_connectors.py b/tests/ml_commons/test_connectors.py index f93eb981c..f3d977db5 100644 --- a/tests/ml_commons/test_connectors.py +++ b/tests/ml_commons/test_connectors.py @@ -1,10 +1,18 @@ -from opensearch_py_ml.ml_commons.connectors import Connector -from opensearchpy.exceptions import NotFoundError, RequestError -from tests import OPENSEARCH_TEST_CLIENT -from packaging.version import parse as parse_version +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + import os + import pytest +from opensearchpy.exceptions import NotFoundError, RequestError +from packaging.version import parse as parse_version +from opensearch_py_ml.ml_commons.connector import Connector +from tests import OPENSEARCH_TEST_CLIENT OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0")) CONNECTOR_MIN_VERSION = parse_version("2.11.0") @@ -41,13 +49,12 @@ def test_connector(client: Connector): } ], } - + res = client.create_standalone_connector(payload) - connector_id = res['connector_id'] + connector_id = res["connector_id"] yield connector_id - - _safe_delete_connector(client, connector_id) + _safe_delete_connector(client, connector_id) @pytest.mark.skipif( @@ -95,4 +102,4 @@ def test_get_connector(): reason="Connectors are supported in OpenSearch 2.9.0 and above", ) def test_delete_connector(): - pass \ No newline at end of file + pass