diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e7b7596..32adb424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add workflow and scripts for automating model listing updating process by @thanawan-atc in ([#210](https://github.com/opensearch-project/opensearch-py-ml/pull/210)) - Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) - Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283)) +- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index b4198d10..57f9fd10 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -8,7 +8,7 @@ import json import time -from typing import Any, List, Union +from typing import Any, List, Optional, Union from deprecated.sphinx import deprecated from opensearchpy import OpenSearch @@ -105,6 +105,26 @@ def upload_model( return model_id + def train_model( + self, algorithm_name: str, input_json: dict, is_async: Optional[bool] = False + ) -> dict: + """ + This method trains an ML Model + """ + + params = {} + if not isinstance(input_json, dict): + input_json = json.loads(input_json) + if is_async: + params["async"] = "true" + + return self._client.transport.perform_request( + method="POST", + url=f"{ML_BASE_URI}/_train/{algorithm_name}", + body=input_json, + params=params, + ) + def register_model( self, model_path: str, diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 9452d609..86c5af24 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -7,9 +7,14 @@ import os import shutil +import time +from json import JSONDecodeError from os.path import exists -from opensearchpy import OpenSearch +import pytest +from opensearchpy import OpenSearch, helpers +from opensearchpy.exceptions import RequestError +from sklearn.datasets import load_iris from opensearch_py_ml.ml_commons import MLCommonClient from opensearch_py_ml.ml_commons.model_uploader import ModelUploader @@ -44,6 +49,55 @@ PRETRAINED_MODEL_FORMAT = "TORCH_SCRIPT" +@pytest.fixture +def iris_index(): + index_name = "test__index__iris_data" + index_mapping = { + "mappings": { + "properties": { + "sepal_length": {"type": "float"}, + "sepal_width": {"type": "float"}, + "petal_length": {"type": "float"}, + "petal_width": {"type": "float"}, + "species": {"type": "keyword"}, + } + } + } + + if ml_client._client.indices.exists(index=index_name): + ml_client._client.indices.delete(index=index_name) + ml_client._client.indices.create(index=index_name, body=index_mapping) + + iris = load_iris() + iris_data = iris.data + iris_target = iris.target + iris_species = [iris.target_names[i] for i in iris_target] + + actions = [ + { + "_index": index_name, + "_source": { + "sepal_length": sepal_length, + "sepal_width": sepal_width, + "petal_length": petal_length, + "petal_width": petal_width, + "species": species, + }, + } + for (sepal_length, sepal_width, petal_length, petal_width), species in zip( + iris_data, iris_species + ) + ] + + helpers.bulk(ml_client._client, actions) + # without the sleep, test is failing. + time.sleep(2) + + yield index_name + + ml_client._client.indices.delete(index=index_name) + + def clean_test_folder(TEST_FOLDER): if os.path.exists(TEST_FOLDER): for files in os.listdir(TEST_FOLDER): @@ -72,6 +126,44 @@ def test_init(): assert isinstance(ml_client._model_uploader, ModelUploader) +def test_train(iris_index): + algorithm_name = "kmeans" + input_json_sync = { + "parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"}, + "input_query": { + "_source": ["petal_length", "petal_width"], + "size": 10000, + }, + "input_index": [iris_index], + } + response = ml_client.train_model(algorithm_name, input_json_sync) + assert isinstance(response, dict) + assert "model_id" in response + assert "status" in response + assert response["status"] == "COMPLETED" + + input_json_async = { + "parameters": {"centroids": 3, "iterations": 10, "distance_type": "COSINE"}, + "input_query": { + "_source": ["petal_length", "petal_width"], + "size": 10000, + }, + "input_index": [iris_index], + } + response = ml_client.train_model(algorithm_name, input_json_async, is_async=True) + + assert isinstance(response, dict) + assert "task_id" in response + assert "status" in response + assert response["status"] == "CREATED" + + with pytest.raises(JSONDecodeError): + ml_client.train_model(algorithm_name, "", is_async=True) + + with pytest.raises(RequestError): + ml_client.train_model(algorithm_name, {}, is_async=True) + + def test_execute(): raised = False try: