Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ch-py-ml into kalyan/model-profile

Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 6, 2023
2 parents 6692fe2 + 1666baa commit d9bab10
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
2 changes: 2 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TIMEOUT,
)
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.model_profile import ModelProfile
Expand All @@ -38,6 +39,7 @@ def __init__(self, os_client: OpenSearch):
self._model_uploader = ModelUploader(os_client)
self._model_execute = ModelExecute(os_client)
self.model_access_control = ModelAccessControl(os_client)
self.connector = Connector(os_client)
self.profile = ModelProfile(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
Expand Down
51 changes: 51 additions & 0 deletions opensearch_py_ml/ml_commons/model_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI


class Connector:
def __init__(self, os_client: OpenSearch):
self.client = os_client

def create_standalone_connector(self, payload: dict):
if not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary")

return self.client.transport.perform_request(
method="POST", url=f"{ML_BASE_URI}/connectors/_create", body=payload
)

def list_connectors(self):
search_query = {"query": {"match_all": {}}}
return self.search_connectors(search_query)

def search_connectors(self, search_query: dict):
if not isinstance(search_query, dict):
raise ValueError("search_query needs to be a dictionary")

return self.client.transport.perform_request(
method="POST", url=f"{ML_BASE_URI}/connectors/_search", body=search_query
)

def get_connector(self, connector_id: str):
if not isinstance(connector_id, str):
raise ValueError("connector_id needs to be a string")

return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/connectors/{connector_id}"
)

def delete_connector(self, connector_id: str):
if not isinstance(connector_id, str):
raise ValueError("connector_id needs to be a string")

return self.client.transport.perform_request(
method="DELETE", url=f"{ML_BASE_URI}/connectors/{connector_id}"
)
158 changes: 158 additions & 0 deletions tests/ml_commons/test_model_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import os

import pytest
from opensearchpy.exceptions import NotFoundError, RequestError
from packaging.version import parse as parse_version

from opensearch_py_ml.ml_commons.model_connector import Connector
from tests import OPENSEARCH_TEST_CLIENT

OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0"))
CONNECTOR_MIN_VERSION = parse_version("2.9.0")


@pytest.fixture
def client():
return Connector(OPENSEARCH_TEST_CLIENT)


def _safe_delete_connector(client, connector_id):
try:
client.delete_connector(connector_id=connector_id)
except NotFoundError:
pass


@pytest.fixture
def connector_payload():
return {
"name": "Test Connector",
"description": "Connector for testing",
"version": 1,
"protocol": "http",
"parameters": {"endpoint": "api.openai.com", "model": "gpt-3.5-turbo"},
"credential": {"openAI_key": "..."},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://${parameters.endpoint}/v1/chat/completions",
"headers": {"Authorization": "Bearer ${credential.openAI_key}"},
"request_body": '{ "model": "${parameters.model}", "messages": ${parameters.messages} }',
}
],
}


@pytest.fixture
def test_connector(client: Connector, connector_payload: dict):
res = client.create_standalone_connector(connector_payload)
connector_id = res["connector_id"]
yield connector_id

_safe_delete_connector(client, connector_id)


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_create_standalone_connector(client: Connector, connector_payload: dict):
res = client.create_standalone_connector(connector_payload)
assert "connector_id" in res

_safe_delete_connector(client, res["connector_id"])

with pytest.raises(ValueError):
client.create_standalone_connector("")


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_list_connectors(client, test_connector):
try:
res = client.list_connectors()
assert len(res["hits"]["hits"]) > 0

# check if test_connector id is in the response
found = False
for each in res["hits"]["hits"]:
if each["_id"] == test_connector:
found = True
break
assert found, "Test connector not found in list connectors response"
except Exception as ex:
assert False, f"Failed to list connectors due to {ex}"


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_search_connectors(client, test_connector):
try:
query = {"query": {"match": {"name": "Test Connector"}}}
res = client.search_connectors(query)
assert len(res["hits"]["hits"]) > 0

# check if test_connector id is in the response
found = False
for each in res["hits"]["hits"]:
if each["_id"] == test_connector:
found = True
break
assert found, "Test connector not found in search connectors response"
except Exception as ex:
assert False, f"Failed to search connectors due to {ex}"

with pytest.raises(ValueError):
client.search_connectors("test")


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_get_connector(client, test_connector):
try:
res = client.get_connector(connector_id=test_connector)
assert res["name"] == "Test Connector"
except Exception as ex:
assert False, f"Failed to get connector due to {ex}"

with pytest.raises(ValueError):
client.get_connector(connector_id=None)

with pytest.raises(RequestError) as exec_info:
client.get_connector(connector_id="test-unknown")
assert exec_info.value.status_code == 400


@pytest.mark.skipif(
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
reason="Connectors are supported in OpenSearch 2.9.0 and above",
)
def test_delete_connector(client, test_connector):
try:
res = client.delete_connector(connector_id=test_connector)
assert res["result"] == "deleted"
except Exception as ex:
assert False, f"Failed to delete connector due to {ex}"

try:
res = client.delete_connector(connector_id="unknown")
assert res["result"] == "not_found"
except Exception as ex:
assert False, f"Failed to delete connector due to {ex}"

with pytest.raises(ValueError):
client.delete_connector(connector_id={"test": "fail"})

0 comments on commit d9bab10

Please sign in to comment.