Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remote inference: add connector; fine tune ML model and tensor class #1051

Merged
merged 1 commit into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ dependencies {
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compileOnly "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation group: 'org.json', name: 'json', version: '20230227'
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,27 @@
import java.util.HashMap;
import java.util.Map;

public enum ModelAccessMode {
public enum AccessMode {
PUBLIC("public"),
PRIVATE("private"),
RESTRICTED("restricted");

@Getter
private String value;

ModelAccessMode(String value) {
AccessMode(String value) {
this.value = value;
}

private static final Map<String, ModelAccessMode> cache = new HashMap<>();
private static final Map<String, AccessMode> cache = new HashMap<>();

static {
for (ModelAccessMode modelAccessMode : values()) {
for (AccessMode modelAccessMode : values()) {
cache.put(modelAccessMode.value, modelAccessMode);
}
}

public static ModelAccessMode from(String value) {
public static AccessMode from(String value) {
try {
return cache.get(value);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
METRICS_CORRELATION;
METRICS_CORRELATION,
REMOTE;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
Expand All @@ -32,6 +33,7 @@ public class MLCommonsClassLoader {
private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap<>();
private static Map<String, Class<?>> connectorClassMap = new HashMap<>();

static {
try {
Expand All @@ -54,11 +56,26 @@ public static void loadClassMapping() {
loadExecuteInputClassMapping();
loadExecuteOutputClassMapping();
loadMLInputClassMapping();
loadConnectorClassMapping();
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
}

private static void loadConnectorClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.connector");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(Connector.class);
for (Class<?> clazz : classes) {
Connector connector = clazz.getAnnotation(Connector.class);
if (connector != null) {
String name = connector.value();
if (name != null && name.length() > 0) {
connectorClassMap.put(name, clazz);
}
}
}
}

/**
* Load ML algorithm parameter and ML output class.
*/
Expand Down Expand Up @@ -195,7 +212,7 @@ public static <T extends Enum<T>, S, I extends Object> S initExecuteOutputInstan
}

@SuppressWarnings("unchecked")
private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Class<?>> map, T type, I in, Class<?> constructorParamClass) {
private static <T, S, I extends Object> S init(Map<T, Class<?>> map, T type, I in, Class<?> constructorParamClass) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
Expand All @@ -205,8 +222,8 @@ private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Clas
return (S) constructor.newInstance(in);
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
if (cause instanceof MLException || cause instanceof IllegalArgumentException) {
throw (RuntimeException)cause;
} else {
log.error("Failed to init instance for type " + type, e);
return null;
Expand All @@ -218,14 +235,19 @@ public static boolean canInitMLInput(FunctionName functionName) {
return mlInputClassMap.containsKey(functionName);
}

public static <S> S initConnector(String name, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(connectorClassMap, name, initArgs, constructorParameterTypes);
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S> S initMLInput(T type, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(mlInputClassMap, type, initArgs, constructorParameterTypes);
}

private static <T extends Enum<T>, S> S init(Map<Enum<?>, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
private static <T, S> S init(Map<T, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
Expand Down
50 changes: 48 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand All @@ -25,13 +26,17 @@
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
import static org.opensearch.ml.common.connector.Connector.createConnector;

@Getter
public class MLModel implements ToXContentObject {
@Deprecated
public static final String ALGORITHM_FIELD = "algorithm";
public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
Expand Down Expand Up @@ -70,6 +75,8 @@ public class MLModel implements ToXContentObject {
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -102,6 +109,11 @@ public class MLModel implements ToXContentObject {

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

@Setter
private Connector connector;
private String connectorId;

@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
Expand All @@ -126,7 +138,9 @@ public MLModel(String name,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes) {
boolean deployToAllNodes,
Connector connector,
String connectorId) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand All @@ -152,6 +166,8 @@ public MLModel(String name,
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
this.connector = connector;
this.connectorId = connectorId;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -191,6 +207,11 @@ public MLModel(StreamInput input) throws IOException{
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
if (input.readBoolean()) {
String connectorProtocol = input.readString();
connector = MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{input}, String.class, StreamInput.class);
}
connectorId = input.readOptionalString();
}
}

Expand Down Expand Up @@ -240,6 +261,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalString(modelGroupId);
if (connector != null) {
out.writeBoolean(true);
out.writeString(connector.getProtocol());
connector.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
}

@Override
Expand Down Expand Up @@ -320,6 +349,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
if (connector != null) {
builder.field(CONNECTOR_FIELD, connector);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -356,6 +391,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;
Connector connector = null;
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -391,7 +428,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
user = User.parse(parser);
break;
case ALGORITHM_FIELD:
algorithm = FunctionName.from(parser.text());
case FUNCTION_NAME_FIELD:
algorithm = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
case MODEL_ID_FIELD:
modelId = parser.text();
Expand Down Expand Up @@ -436,6 +474,12 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case CONNECTOR_FIELD:
connector = createConnector(parser);
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -491,6 +535,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.connector(connector)
.connectorId(connectorId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Connector {
String value();
}
Loading