Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 12, 2024
1 parent 0135cb9 commit 796d7d3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static software.amazon.awssdk.http.SdkHttpMethod.GET;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

Expand Down Expand Up @@ -86,7 +88,18 @@ public void invokeRemoteService(
ActionListener<Tuple<Integer, ModelTensors>> actionListener
) {
try {
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
SdkHttpFullRequest request;
switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST);
break;
case "GET":
request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET);
break;
default:
throw new IllegalArgumentException("unsupported http method");
}
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ private void runPredict(
.getMlModelTensors()
.get(0)
.getDataAsMap();
if (dataAsMap != null
&& (dataAsMap.containsKey("TransformJobArn") || dataAsMap.containsKey("id"))) {
Integer statusCode = tensorOutput.getMlModelOutputs().get(0).getStatusCode();
if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
remoteJob.putAll(dataAsMap);
mlTask.setRemoteJob(remoteJob);
mlTask.setTaskId(null);
Expand Down

0 comments on commit 796d7d3

Please sign in to comment.