Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-Commons train api functionality #310

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c182e5c
add ml-commons train support
rawwar Oct 7, 2023
edafa9b
update __all__
rawwar Oct 7, 2023
443b9b0
fix test cases
rawwar Oct 8, 2023
e3c4e64
sleep after bulk insert
rawwar Oct 8, 2023
f0ce236
fix formatting
rawwar Oct 8, 2023
7e18884
remove unused imports
rawwar Oct 8, 2023
e7ead98
remove duplicate conftest
rawwar Oct 11, 2023
ac14c98
delete duplicate conftest
rawwar Oct 11, 2023
2a52fac
include pytest plugins
rawwar Oct 11, 2023
e9365a5
revert pandas version
rawwar Oct 13, 2023
cf5074a
include license
rawwar Oct 13, 2023
4112a2f
fix formatting
rawwar Oct 13, 2023
2b14a14
fix imports order
rawwar Oct 13, 2023
e09c697
fix imports order
rawwar Oct 13, 2023
7c57d2f
lint fix
rawwar Oct 13, 2023
2610dda
update changelog
rawwar Oct 13, 2023
31fac86
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 18, 2023
eab5a29
Merge branch 'opensearch-project:main' into kalyan/286-ml-commons-add…
rawwar Oct 27, 2023
4e159c2
revert testcases
rawwar Oct 31, 2023
55f35c5
remove fixtures
rawwar Oct 31, 2023
7a58dc0
updated test cases
rawwar Oct 31, 2023
43e5a7e
lint fixes
rawwar Oct 31, 2023
70f0a75
update fixture
rawwar Oct 31, 2023
404c0b3
revert
rawwar Oct 31, 2023
f19e36e
Merge branch 'main' of https://github.com/opensearch-project/opensear…
rawwar Nov 3, 2023
3cb7fb6
include train in MLCommons class as a func
rawwar Nov 3, 2023
a8306fe
remove model train
rawwar Nov 3, 2023
45d7aeb
fix tests
rawwar Nov 3, 2023
9431d8f
revert nox
rawwar Nov 3, 2023
6faf462
add tests to model_train
rawwar Nov 3, 2023
eaf4bdc
fix lint
rawwar Nov 3, 2023
a6f0969
fix lint
rawwar Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add workflow and scripts for automating model listing updating process by @thanawan-atc in ([#210](https://github.com/opensearch-project/opensearch-py-ml/pull/210))
- Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
22 changes: 21 additions & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import time
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from deprecated.sphinx import deprecated
from opensearchpy import OpenSearch
Expand Down Expand Up @@ -105,6 +105,26 @@ def upload_model(

return model_id

def train_model(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = False
) -> dict:
"""
This method trains an ML Model
"""

params = {}
if not isinstance(input_json, dict):
input_json = json.loads(input_json)
if is_async:
params["async"] = "true"

return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/_train/{algorithm_name}",
body=input_json,
params=params,
)

def register_model(
self,
model_path: str,
Expand Down
94 changes: 93 additions & 1 deletion tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

import os
import shutil
import time
from json import JSONDecodeError
from os.path import exists

from opensearchpy import OpenSearch
import pytest
from opensearchpy import OpenSearch, helpers
from opensearchpy.exceptions import RequestError
from sklearn.datasets import load_iris

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
Expand Down Expand Up @@ -44,6 +49,55 @@
PRETRAINED_MODEL_FORMAT = "TORCH_SCRIPT"


@pytest.fixture
def iris_index():
index_name = "test__index__iris_data"
index_mapping = {
"mappings": {
"properties": {
"sepal_length": {"type": "float"},
"sepal_width": {"type": "float"},
"petal_length": {"type": "float"},
"petal_width": {"type": "float"},
"species": {"type": "keyword"},
}
}
}

if ml_client._client.indices.exists(index=index_name):
ml_client._client.indices.delete(index=index_name)
ml_client._client.indices.create(index=index_name, body=index_mapping)

iris = load_iris()
iris_data = iris.data
iris_target = iris.target
iris_species = [iris.target_names[i] for i in iris_target]

actions = [
{
"_index": index_name,
"_source": {
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
"species": species,
},
}
for (sepal_length, sepal_width, petal_length, petal_width), species in zip(
iris_data, iris_species
)
]

helpers.bulk(ml_client._client, actions)
# without the sleep, test is failing.
time.sleep(2)

yield index_name

ml_client._client.indices.delete(index=index_name)


def clean_test_folder(TEST_FOLDER):
if os.path.exists(TEST_FOLDER):
for files in os.listdir(TEST_FOLDER):
Expand Down Expand Up @@ -72,6 +126,44 @@ def test_init():
assert isinstance(ml_client._model_uploader, ModelUploader)


def test_train(iris_index):
algorithm_name = "kmeans"
input_json_sync = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_sync)
assert isinstance(response, dict)
assert "model_id" in response
assert "status" in response
assert response["status"] == "COMPLETED"

input_json_async = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)

assert isinstance(response, dict)
assert "task_id" in response
assert "status" in response
assert response["status"] == "CREATED"

with pytest.raises(JSONDecodeError):
ml_client.train_model(algorithm_name, "", is_async=True)

with pytest.raises(RequestError):
ml_client.train_model(algorithm_name, {}, is_async=True)


def test_execute():
raised = False
try:
Expand Down
Loading