Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-Commons train api functionality #310

Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c182e5c
add ml-commons train support
rawwar Oct 7, 2023
edafa9b
update __all__
rawwar Oct 7, 2023
443b9b0
fix test cases
rawwar Oct 8, 2023
e3c4e64
sleep after bulk insert
rawwar Oct 8, 2023
f0ce236
fix formatting
rawwar Oct 8, 2023
7e18884
remove unused imports
rawwar Oct 8, 2023
e7ead98
remove duplicate conftest
rawwar Oct 11, 2023
ac14c98
delete duplicate conftest
rawwar Oct 11, 2023
2a52fac
include pytest plugins
rawwar Oct 11, 2023
e9365a5
revert pandas version
rawwar Oct 13, 2023
cf5074a
include license
rawwar Oct 13, 2023
4112a2f
fix formatting
rawwar Oct 13, 2023
2b14a14
fix imports order
rawwar Oct 13, 2023
e09c697
fix imports order
rawwar Oct 13, 2023
7c57d2f
lint fix
rawwar Oct 13, 2023
2610dda
update changelog
rawwar Oct 13, 2023
31fac86
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 18, 2023
eab5a29
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 27, 2023
4e159c2
revert testcases
rawwar Oct 31, 2023
55f35c5
remove fixtures
rawwar Oct 31, 2023
7a58dc0
updated test cases
rawwar Oct 31, 2023
43e5a7e
lint fixes
rawwar Oct 31, 2023
70f0a75
update fixture
rawwar Oct 31, 2023
404c0b3
revert
rawwar Oct 31, 2023
f19e36e
Merge branch 'main' of https://github.com/opensearch-project/opensear…
rawwar Nov 3, 2023
3cb7fb6
include train in MLCommons class as a func
rawwar Nov 3, 2023
a8306fe
remove model train
rawwar Nov 3, 2023
45d7aeb
fix tests
rawwar Nov 3, 2023
9431d8f
revert nox
rawwar Nov 3, 2023
6faf462
add tests to model_train
rawwar Nov 3, 2023
eaf4bdc
fix lint
rawwar Nov 3, 2023
a6f0969
fix lint
rawwar Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add workflow and scripts for automating model listing updating process by @thanawan-atc in ([#210](https://github.com/opensearch-project/opensearch-py-ml/pull/210))
- Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
3 changes: 2 additions & 1 deletion opensearch_py_ml/ml_commons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from opensearch_py_ml.ml_commons.ml_commons_client import MLCommonClient
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_train import ModelTrain
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader

__all__ = ["MLCommonClient", "ModelExecute", "ModelUploader"]
__all__ = ["MLCommonClient", "ModelExecute", "ModelUploader", "ModelTrain"]
13 changes: 12 additions & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import time
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from deprecated.sphinx import deprecated
from opensearchpy import OpenSearch
Expand All @@ -22,6 +22,7 @@
TIMEOUT,
)
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_train import ModelTrain
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader


Expand All @@ -35,6 +36,7 @@ def __init__(self, os_client: OpenSearch):
self._client = os_client
self._model_uploader = ModelUploader(os_client)
self._model_execute = ModelExecute(os_client)
self._model_train = ModelTrain(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
"""
Expand Down Expand Up @@ -580,3 +582,12 @@ def delete_task(self, task_id: str) -> object:
method="DELETE",
url=API_URL,
)

def train_model(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = False
) -> dict:
"""
This method trains an ML model
"""

return self._model_train._train(algorithm_name, input_json, is_async)
44 changes: 44 additions & 0 deletions opensearch_py_ml/ml_commons/model_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import json
from typing import Optional

from opensearchpy import OpenSearch

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI


class ModelTrain:
rawwar marked this conversation as resolved.
Show resolved Hide resolved
"""
Class for training models using ML Commons train API.
"""

API_ENDPOINT = "_train"

def __init__(self, os_client: OpenSearch):
self._client = os_client

def _train(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = True
) -> dict:
"""
This method trains an ML model
"""

params = {}
if not isinstance(input_json, dict):
input_json = json.loads(input_json)
if is_async:
params["async"] = "true"

return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/{ModelTrain.API_ENDPOINT}/{algorithm_name}",
body=input_json,
params=params,
)
3 changes: 0 additions & 3 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,3 @@ def test_integration_model_train_register_full_cycle():
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in deleting model"


test_integration_model_train_register_full_cycle()
103 changes: 103 additions & 0 deletions tests/ml_commons/test_model_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

import time

import pytest
from opensearchpy import OpenSearch, helpers
from sklearn.datasets import load_iris

from opensearch_py_ml.ml_commons import MLCommonClient, ModelTrain
from tests import OPENSEARCH_TEST_CLIENT

ml_client = MLCommonClient(OPENSEARCH_TEST_CLIENT)


@pytest.fixture
def iris_index():
index_name = "test__index__iris_data"
index_mapping = {
"mappings": {
"properties": {
"sepal_length": {"type": "float"},
"sepal_width": {"type": "float"},
"petal_length": {"type": "float"},
"petal_width": {"type": "float"},
"species": {"type": "keyword"},
}
}
}

if ml_client._client.indices.exists(index=index_name):
ml_client._client.indices.delete(index=index_name)
ml_client._client.indices.create(index=index_name, body=index_mapping)

iris = load_iris()
iris_data = iris.data
iris_target = iris.target
iris_species = [iris.target_names[i] for i in iris_target]

actions = [
{
"_index": index_name,
"_source": {
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
"species": species,
},
}
for (sepal_length, sepal_width, petal_length, petal_width), species in zip(
iris_data, iris_species
)
]

helpers.bulk(ml_client._client, actions)
# without the sleep, test is failing.
time.sleep(2)

yield index_name

ml_client._client.indices.delete(index=index_name)


def test_init():
assert isinstance(ml_client._client, OpenSearch)
assert isinstance(ml_client._model_train, ModelTrain)


def test_train(iris_index):
algorithm_name = "kmeans"
input_json_sync = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_sync)
assert isinstance(response, dict)
assert "model_id" in response
assert "status" in response
assert response["status"] == "COMPLETED"

input_json_async = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)

assert isinstance(response, dict)
assert "task_id" in response
assert "status" in response
assert response["status"] == "CREATED"
Loading