Skip to content

Commit

Permalink
add connector tool
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent f6260c2 commit a196399
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class AbstractConnector implements Connector {
@Setter
protected ConnectorClientConfig connectorClientConfig;

protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
protected Map<String, String> createDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
return null;
}
Expand Down Expand Up @@ -116,9 +116,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
}

@Override
public Optional<ConnectorAction> findPredictAction() {
public Optional<ConnectorAction> findAction(String action) {
if (actions != null) {
return actions.stream().filter(a -> a.getActionType() == ConnectorAction.ActionType.PREDICT).findFirst();
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
}
return Optional.empty();
}
Expand All @@ -131,17 +131,17 @@ public void removeCredential() {
}

@Override
public String getPredictEndpoint(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (!predictAction.isPresent()) {
public String getActionEndpoint(String action, Map<String, String> parameters) {
Optional<ConnectorAction> connectorAction = findAction(action);
if (!connectorAction.isPresent()) {
return null;
}
String predictEndpoint = predictAction.get().getUrl();
String actionEndpoint = connectorAction.get().getUrl();
if (parameters != null && parameters.size() > 0) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
actionEndpoint = substitutor.replace(actionEndpoint);
}
return predictEndpoint;
return actionEndpoint;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,18 @@ public interface Connector extends ToXContentObject, Writeable {

ConnectorClientConfig getConnectorClientConfig();

String getPredictEndpoint(Map<String, String> parameters);
String getActionEndpoint(String action, Map<String, String> parameters);

String getPredictHttpMethod();
String getActionHttpMethod(String action);

<T> T createPredictPayload(Map<String, String> parameters);
<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(Function<String, String> function);
void decrypt(String action, Function<String, String> function);
void encrypt(Function<String, String> function);

Connector cloneConnector();

Optional<ConnectorAction> findPredictAction();
Optional<ConnectorAction> findAction(String action);

void removeCredential();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
}

public enum ActionType {
PREDICT
PREDICT,
EXECUTE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
public <T> T createPayload(String action, Map<String, String> parameters) {
Optional<ConnectorAction> connectorAction = findAction(action);
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);
Expand All @@ -323,6 +323,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
return (T) parameters.get("http_body");
}


protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
Expand All @@ -348,15 +349,15 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(Function<String, String> function) {
public void decrypt(String action, Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> predictAction = findPredictAction();
Map<String, String> headers = predictAction.isPresent() ? predictAction.get().getHeaders() : null;
this.decryptedHeaders = createPredictDecryptedHeaders(headers);
Optional<ConnectorAction> connectorAction = findAction(action);
Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
this.decryptedHeaders = createDecryptedHeaders(headers);
}

@Override
Expand All @@ -378,8 +379,8 @@ public void encrypt(Function<String, String> function) {
}
}

public String getPredictHttpMethod() {
return findPredictAction().get().getMethod();
public String getActionHttpMethod(String action) {
return findAction(action).get().getMethod();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,32 @@
public class MLExecuteConnectorRequest extends MLTaskRequest {

String connectorId;
String connectorAction;
MLInput mlInput;
@Setter
User user;

@Builder
public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask, User user) {
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, boolean dispatchTask, User user) {
super(dispatchTask);
this.mlInput = mlInput;
this.connectorAction = connectorAction == null ? "predict" : connectorAction;
this.connectorId = connectorId;
this.user = user;
}

public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) {
this(connectorId, mlInput, true, null);
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput) {
this(connectorId, connectorAction, mlInput, true, null);
}

public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, User user) {
this(connectorId, mlInput, true, user);
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, User user) {
this(connectorId, connectorAction, mlInput, true, user);
}

public MLExecuteConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readOptionalString();
this.connectorId = in.readString();
this.connectorAction = in.readString();
this.mlInput = new MLInput(in);
if (in.readBoolean()) {
this.user = new User(in);
Expand All @@ -66,7 +69,8 @@ public MLExecuteConnectorRequest(StreamInput in) throws IOException {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(this.connectorId);
out.writeString(this.connectorId);
out.writeString(this.connectorAction);
this.mlInput.writeTo(out);
if (user != null) {
out.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public AwsConnectorExecutor(Connector connector) {

@SuppressWarnings("removal")
@Override
public void invokeRemoteModel(
public void invokeRemoteService(
String action,
MLInput mlInput,
Map<String, String> parameters,
String payload,
Expand All @@ -80,7 +81,7 @@ public void invokeRemoteModel(
ActionListener<List<ModelTensors>> actionListener
) {
try {
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
Expand All @@ -93,17 +94,18 @@ public void invokeRemoteModel(
tensorOutputs,
connector,
scriptService,
mlGuard
mlGuard,
action
)
)
.build();
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception);
actionListener.onFailure(exception);
} catch (Throwable e) {
log.error("Failed to execute predict in aws connector", e);
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
log.error("Failed to execute {} in aws connector", action, e);
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class ConnectorUtils {
}

public static RemoteInferenceInputDataSet processInput(
String action,
MLInput mlInput,
Connector connector,
Map<String, String> parameters,
Expand All @@ -71,22 +72,23 @@ public static RemoteInferenceInputDataSet processInput(
if (mlInput == null) {
throw new IllegalArgumentException("Input is null");
}
Optional<ConnectorAction> predictAction = connector.findPredictAction();
Optional<ConnectorAction> predictAction = connector.findAction(action);
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
}
RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService);
RemoteInferenceInputDataSet inputData = processMLInput(action, mlInput, connector, parameters, scriptService);
escapeRemoteInferenceInputData(inputData);
return inputData;
}

private static RemoteInferenceInputDataSet processMLInput(
String action,
MLInput mlInput,
Connector connector,
Map<String, String> parameters,
ScriptService scriptService
) {
String preProcessFunction = getPreprocessFunction(mlInput, connector);
String preProcessFunction = getPreprocessFunction(action, mlInput, connector);
if (preProcessFunction == null) {
if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
Expand Down Expand Up @@ -168,8 +170,8 @@ public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet in
}
}

private static String getPreprocessFunction(MLInput mlInput, Connector connector) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) {
Optional<ConnectorAction> predictAction = connector.findAction(action);
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction != null) {
return preProcessFunction;
Expand All @@ -181,6 +183,7 @@ private static String getPreprocessFunction(MLInput mlInput, Connector connector
}

public static ModelTensors processOutput(
String action,
String modelResponse,
Connector connector,
ScriptService scriptService,
Expand All @@ -194,7 +197,7 @@ public static ModelTensors processOutput(
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
List<ModelTensor> modelTensors = new ArrayList<>();
Optional<ConnectorAction> predictAction = connector.findPredictAction();
Optional<ConnectorAction> predictAction = connector.findAction(action);
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
}
Expand Down Expand Up @@ -263,6 +266,7 @@ public static SdkHttpFullRequest signRequest(
}

public static SdkHttpFullRequest buildSdkRequest(
String action,
Connector connector,
Map<String, String> parameters,
String payload,
Expand All @@ -279,7 +283,7 @@ public static SdkHttpFullRequest buildSdkRequest(
log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
String endpoint = connector.getPredictEndpoint(parameters);
String endpoint = connector.getActionEndpoint(action, parameters);
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public HttpJsonConnectorExecutor(Connector connector) {

@SuppressWarnings("removal")
@Override
public void invokeRemoteModel(
public void invokeRemoteService(
String action,
MLInput mlInput,
Map<String, String> parameters,
String payload,
Expand All @@ -85,15 +86,15 @@ public void invokeRemoteModel(
) {
try {
SdkHttpFullRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
validateHttpClientParameters(parameters);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
validateHttpClientParameters(action, parameters);
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
break;
case "GET":
validateHttpClientParameters(parameters);
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET);
validateHttpClientParameters(action, parameters);
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET);
break;
default:
throw new IllegalArgumentException("unsupported http method");
Expand All @@ -110,7 +111,8 @@ public void invokeRemoteModel(
tensorOutputs,
connector,
scriptService,
mlGuard
mlGuard,
action
)
)
.build();
Expand All @@ -124,8 +126,8 @@ public void invokeRemoteModel(
}
}

private void validateHttpClientParameters(Map<String, String> parameters) throws Exception {
String endpoint = connector.getPredictEndpoint(parameters);
private void validateHttpClientParameters(String action, Map<String, String> parameters) throws Exception {
String endpoint = connector.getActionEndpoint(action, parameters);
URL url = new URL(endpoint);
String protocol = url.getProtocol();
String host = url.getHost();
Expand Down
Loading

0 comments on commit a196399

Please sign in to comment.