Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: kalyan <[email protected]>
  • Loading branch information
rawwar committed Dec 1, 2023
1 parent 1e3ea65 commit 5f2bfb1
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 112 deletions.
7 changes: 0 additions & 7 deletions opensearch_py_ml/ml_commons/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ def create_standalone_connector(self, payload: dict):
method="POST", url=f"{ML_BASE_URI}/connectors/_create", body=payload
)

def create_internal_connector(self, payload: dict):
if not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary")

return self.client.transport.perform_request(
method="POST", url=f"{ML_BASE_URI}/models/_register", body=payload
)

def list_connectors(self):
search_query = {"query": {"match_all": {}}}
Expand Down
105 changes: 0 additions & 105 deletions tests/ml_commons/test_connectors.py

This file was deleted.

152 changes: 152 additions & 0 deletions tests/ml_commons/test_model_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import os

import pytest
from opensearchpy.exceptions import NotFoundError, RequestError
from packaging.version import parse as parse_version

from opensearch_py_ml.ml_commons.model_connector import Connector
from tests import OPENSEARCH_TEST_CLIENT

OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0"))
CONNECTOR_MIN_VERSION = parse_version("2.11.0")

print("!@#", OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION)
@pytest.fixture
def client():
return Connector(OPENSEARCH_TEST_CLIENT)


def _safe_delete_connector(client, connector_id):
try:
client.delete_connector(connector_id=connector_id)
except NotFoundError:
pass


@pytest.fixture
def connector_payload():
return {
"name": "Test Connector",
"description": "Connector for testing",
"version": 1,
"protocol": "http",
"parameters": {"endpoint": "api.openai.com", "model": "gpt-3.5-turbo"},
"credential": {"openAI_key": "..."},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://${parameters.endpoint}/v1/chat/completions",
"headers": {"Authorization": "Bearer ${credential.openAI_key}"},
"request_body": '{ "model": "${parameters.model}", "messages": ${parameters.messages} }',
}
],
}


@pytest.fixture
def test_connector(client: Connector, connector_payload: dict):
res = client.create_standalone_connector(connector_payload)
connector_id = res["connector_id"]
yield connector_id

_safe_delete_connector(client, connector_id)


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_create_standalone_connector(client: Connector, connector_payload: dict):
res = client.create_standalone_connector(connector_payload)
assert "connector_id" in res

_safe_delete_connector(client, res["connector_id"])

with pytest.raises(ValueError):
client.create_standalone_connector("")


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_list_connector(client, test_connector):
try:
res = client.list_connectors()
assert len(res["hits"]["hits"]) > 0

# check if test_connector id is in the response
found = False
for each in res["hits"]["hits"]:
if each["_id"] == test_connector:
found = True
break
assert found, "Test connector not found in list connectors response"
except Exception as ex:
assert False, "Failed to list connectors due to {ex}"


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_search_connector(client, test_connector):
try:
query = {"query": {"match": {"name": "Test Connector"}}}
res = client.search_connector(query)
assert len(res["hits"]["hits"]) > 0

# check if test_connector id is in the response
found = False
for each in res["hits"]["hits"]:
if each["_id"] == test_connector:
found = True
break
assert found, "Test connector not found in search connectors response"
except Exception as ex:
assert False, "Failed to search connectors due to {ex}"


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_get_connector(client, test_connector):
try:
res = client.get_connector(connector_id=test_connector)
assert res["connector_id"] == test_connector
except Exception as ex:
assert False, "Failed to get connector due to {ex}"

with pytest.raises(ValueError):
client.get_connector(connector_id=None)

with pytest.raises(RequestError) as exec_info:
client.get_connector(connector_id="test-unknown")
assert exec_info.value.status_code == 400


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_delete_connector(client, test_connector):
try:
res = client.delete_connector(connector_id=test_connector)
assert res["result"] == "deleted"
except Exception as ex:
assert False, "Failed to delete connector due to {ex}"

try:
res = client.delete_connector(connector_id="unknown")
assert res["result"] == "not_found"
except Exception as ex:
assert False, "Failed to delete connector due to {ex}"

0 comments on commit 5f2bfb1

Please sign in to comment.