Skip to content

Commit

Permalink
remote inference: add connector; fine tune ML model and tensor class (#…
Browse files Browse the repository at this point in the history
…1051)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jul 9, 2023
1 parent 8139f8f commit 4e1fb9b
Show file tree
Hide file tree
Showing 74 changed files with 1,953 additions and 227 deletions.
4 changes: 4 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ dependencies {
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compileOnly "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation group: 'org.json', name: 'json', version: '20230227'
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,27 @@
import java.util.HashMap;
import java.util.Map;

public enum ModelAccessMode {
public enum AccessMode {
PUBLIC("public"),
PRIVATE("private"),
RESTRICTED("restricted");

@Getter
private String value;

ModelAccessMode(String value) {
AccessMode(String value) {
this.value = value;
}

private static final Map<String, ModelAccessMode> cache = new HashMap<>();
private static final Map<String, AccessMode> cache = new HashMap<>();

static {
for (ModelAccessMode modelAccessMode : values()) {
for (AccessMode modelAccessMode : values()) {
cache.put(modelAccessMode.value, modelAccessMode);
}
}

public static ModelAccessMode from(String value) {
public static AccessMode from(String value) {
try {
return cache.get(value);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
METRICS_CORRELATION;
METRICS_CORRELATION,
REMOTE;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
Expand All @@ -32,6 +33,7 @@ public class MLCommonsClassLoader {
private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap<>();
private static Map<String, Class<?>> connectorClassMap = new HashMap<>();

static {
try {
Expand All @@ -54,11 +56,26 @@ public static void loadClassMapping() {
loadExecuteInputClassMapping();
loadExecuteOutputClassMapping();
loadMLInputClassMapping();
loadConnectorClassMapping();
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
}

private static void loadConnectorClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.connector");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(Connector.class);
for (Class<?> clazz : classes) {
Connector connector = clazz.getAnnotation(Connector.class);
if (connector != null) {
String name = connector.value();
if (name != null && name.length() > 0) {
connectorClassMap.put(name, clazz);
}
}
}
}

/**
* Load ML algorithm parameter and ML output class.
*/
Expand Down Expand Up @@ -195,7 +212,7 @@ public static <T extends Enum<T>, S, I extends Object> S initExecuteOutputInstan
}

@SuppressWarnings("unchecked")
private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Class<?>> map, T type, I in, Class<?> constructorParamClass) {
private static <T, S, I extends Object> S init(Map<T, Class<?>> map, T type, I in, Class<?> constructorParamClass) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
Expand All @@ -205,8 +222,8 @@ private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Clas
return (S) constructor.newInstance(in);
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
if (cause instanceof MLException || cause instanceof IllegalArgumentException) {
throw (RuntimeException)cause;
} else {
log.error("Failed to init instance for type " + type, e);
return null;
Expand All @@ -218,14 +235,19 @@ public static boolean canInitMLInput(FunctionName functionName) {
return mlInputClassMap.containsKey(functionName);
}

public static <S> S initConnector(String name, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(connectorClassMap, name, initArgs, constructorParameterTypes);
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S> S initMLInput(T type, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(mlInputClassMap, type, initArgs, constructorParameterTypes);
}

private static <T extends Enum<T>, S> S init(Map<Enum<?>, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
private static <T, S> S init(Map<T, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
Expand Down
50 changes: 48 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand All @@ -25,13 +26,17 @@
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
import static org.opensearch.ml.common.connector.Connector.createConnector;

@Getter
public class MLModel implements ToXContentObject {
@Deprecated
public static final String ALGORITHM_FIELD = "algorithm";
public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name";
public static final String MODEL_GROUP_ID_FIELD = "model_group_id";
// We use int type for version in first release 1.3. In 2.4, we changed to
Expand Down Expand Up @@ -70,6 +75,8 @@ public class MLModel implements ToXContentObject {
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -102,6 +109,11 @@ public class MLModel implements ToXContentObject {

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

@Setter
private Connector connector;
private String connectorId;

@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
Expand All @@ -126,7 +138,9 @@ public MLModel(String name,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes) {
boolean deployToAllNodes,
Connector connector,
String connectorId) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand All @@ -152,6 +166,8 @@ public MLModel(String name,
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
this.connector = connector;
this.connectorId = connectorId;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -191,6 +207,11 @@ public MLModel(StreamInput input) throws IOException{
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
if (input.readBoolean()) {
String connectorProtocol = input.readString();
connector = MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{input}, String.class, StreamInput.class);
}
connectorId = input.readOptionalString();
}
}

Expand Down Expand Up @@ -240,6 +261,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalString(modelGroupId);
if (connector != null) {
out.writeBoolean(true);
out.writeString(connector.getProtocol());
connector.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
}

@Override
Expand Down Expand Up @@ -320,6 +349,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
if (connector != null) {
builder.field(CONNECTOR_FIELD, connector);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -356,6 +391,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;
Connector connector = null;
String connectorId = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -391,7 +428,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
user = User.parse(parser);
break;
case ALGORITHM_FIELD:
algorithm = FunctionName.from(parser.text());
case FUNCTION_NAME_FIELD:
algorithm = FunctionName.from(parser.text().toUpperCase(Locale.ROOT));
break;
case MODEL_ID_FIELD:
modelId = parser.text();
Expand Down Expand Up @@ -436,6 +474,12 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case CONNECTOR_FIELD:
connector = createConnector(parser);
break;
case CONNECTOR_ID_FIELD:
connectorId = parser.text();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -491,6 +535,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.connector(connector)
.connectorId(connectorId)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Connector {
String value();
}
Loading

0 comments on commit 4e1fb9b

Please sign in to comment.