Skip to content

Commit

Permalink
ML-Commons train api functionality (opensearch-project#310)
Browse files Browse the repository at this point in the history
* add ml-commons train support

Signed-off-by: kalyanr <[email protected]>

* update __all__

Signed-off-by: kalyanr <[email protected]>

* fix test cases

Signed-off-by: kalyanr <[email protected]>

* sleep after bulk insert

Signed-off-by: kalyanr <[email protected]>

* fix formatting

Signed-off-by: kalyanr <[email protected]>

* remove unused imports

Signed-off-by: kalyanr <[email protected]>

* remove duplicate conftest

Signed-off-by: kalyan <[email protected]>

* delete duplicate conftest

Signed-off-by: kalyan <[email protected]>

* include pytest plugins

Signed-off-by: kalyan <[email protected]>

* revert pandas version

Signed-off-by: Kalyan <[email protected]>

* include license

Signed-off-by: kalyan <[email protected]>

* fix formatting

Signed-off-by: kalyan <[email protected]>

* fix imports order

Signed-off-by: kalyan <[email protected]>

* fix imports order

Signed-off-by: kalyan <[email protected]>

* lint fix

Signed-off-by: kalyan <[email protected]>

* update changelog

Signed-off-by: kalyan <[email protected]>

* revert testcases

Signed-off-by: kalyan <[email protected]>

* remove fixtures

Signed-off-by: kalyan <[email protected]>

* updated test cases

Signed-off-by: kalyan <[email protected]>

* lint fixes

Signed-off-by: kalyan <[email protected]>

* update fixture

Signed-off-by: kalyan <[email protected]>

* revert

Signed-off-by: kalyan <[email protected]>

* include train in MLCommons class as  a func

Signed-off-by: kalyan <[email protected]>

* remove model train

Signed-off-by: kalyan <[email protected]>

* fix tests

Signed-off-by: kalyan <[email protected]>

* revert nox

Signed-off-by: kalyan <[email protected]>

* add tests to model_train

Signed-off-by: kalyan <[email protected]>

* fix lint

Signed-off-by: kalyan <[email protected]>

* fix lint

Signed-off-by: kalyan <[email protected]>

---------

Signed-off-by: kalyanr <[email protected]>
Signed-off-by: kalyan <[email protected]>
Signed-off-by: Kalyan <[email protected]>
  • Loading branch information
rawwar authored Nov 3, 2023
1 parent da415f0 commit 4a8a3e7
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add workflow and scripts for automating model listing updating process by @thanawan-atc in ([#210](https://github.com/opensearch-project/opensearch-py-ml/pull/210))
- Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
22 changes: 21 additions & 1 deletion opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import time
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from deprecated.sphinx import deprecated
from opensearchpy import OpenSearch
Expand Down Expand Up @@ -105,6 +105,26 @@ def upload_model(

return model_id

def train_model(
self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = False
) -> dict:
"""
This method trains an ML Model
"""

params = {}
if not isinstance(input_json, dict):
input_json = json.loads(input_json)
if is_async:
params["async"] = "true"

return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/_train/{algorithm_name}",
body=input_json,
params=params,
)

def register_model(
self,
model_path: str,
Expand Down
94 changes: 93 additions & 1 deletion tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

import os
import shutil
import time
from json import JSONDecodeError
from os.path import exists

from opensearchpy import OpenSearch
import pytest
from opensearchpy import OpenSearch, helpers
from opensearchpy.exceptions import RequestError
from sklearn.datasets import load_iris

from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
Expand Down Expand Up @@ -44,6 +49,55 @@
PRETRAINED_MODEL_FORMAT = "TORCH_SCRIPT"


@pytest.fixture
def iris_index():
index_name = "test__index__iris_data"
index_mapping = {
"mappings": {
"properties": {
"sepal_length": {"type": "float"},
"sepal_width": {"type": "float"},
"petal_length": {"type": "float"},
"petal_width": {"type": "float"},
"species": {"type": "keyword"},
}
}
}

if ml_client._client.indices.exists(index=index_name):
ml_client._client.indices.delete(index=index_name)
ml_client._client.indices.create(index=index_name, body=index_mapping)

iris = load_iris()
iris_data = iris.data
iris_target = iris.target
iris_species = [iris.target_names[i] for i in iris_target]

actions = [
{
"_index": index_name,
"_source": {
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
"species": species,
},
}
for (sepal_length, sepal_width, petal_length, petal_width), species in zip(
iris_data, iris_species
)
]

helpers.bulk(ml_client._client, actions)
# without the sleep, test is failing.
time.sleep(2)

yield index_name

ml_client._client.indices.delete(index=index_name)


def clean_test_folder(TEST_FOLDER):
if os.path.exists(TEST_FOLDER):
for files in os.listdir(TEST_FOLDER):
Expand Down Expand Up @@ -72,6 +126,44 @@ def test_init():
assert isinstance(ml_client._model_uploader, ModelUploader)


def test_train(iris_index):
algorithm_name = "kmeans"
input_json_sync = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_sync)
assert isinstance(response, dict)
assert "model_id" in response
assert "status" in response
assert response["status"] == "COMPLETED"

input_json_async = {
"parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"},
"input_query": {
"_source": ["petal_length", "petal_width"],
"size": 10000,
},
"input_index": [iris_index],
}
response = ml_client.train_model(algorithm_name, input_json_async, is_async=True)

assert isinstance(response, dict)
assert "task_id" in response
assert "status" in response
assert response["status"] == "CREATED"

with pytest.raises(JSONDecodeError):
ml_client.train_model(algorithm_name, "", is_async=True)

with pytest.raises(RequestError):
ml_client.train_model(algorithm_name, {}, is_async=True)


def test_execute():
raised = False
try:
Expand Down

0 comments on commit 4a8a3e7

Please sign in to comment.