Skip to content

Commit

Permalink
[ML] Do not block queue while waiting for PyTorch results (#80074)
Browse files Browse the repository at this point in the history
Since we introduced queueing of the inference requests, we should
not be waiting for the result in the operation that runs in the queue
as that blocks the queue. In particular, this means we cannot benefit
from parallel forwarding when `model_threads` is set to more than one.

In addition, since we now handle timeouts by having a scheduled run,
we do not ever need to wait on a latch for pending results. We just
need to store the listener and notify it when the result is processed.
  • Loading branch information
dimitris-athanasiou authored Oct 29, 2021
1 parent 4f2ff7f commit 5a41fa4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ static class InferenceAction extends AbstractRunnable {

void onTimeout() {
if (notified.compareAndSet(false, true)) {
processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
processContext.getResultProcessor().ignoreResposeWithoutNotifying(String.valueOf(requestId));
listener.onFailure(
new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
);
return;
}
logger.debug("request [{}] received timeout after [{}] but listener already alerted", requestId, timeout);
logger.debug("[{}] request [{}] received timeout after [{}] but listener already alerted", modelId, requestId, timeout);
}

void onSuccess(InferenceResults inferenceResults) {
Expand All @@ -307,17 +307,21 @@ void onSuccess(InferenceResults inferenceResults) {
listener.onResponse(inferenceResults);
return;
}
logger.debug("request [{}] received inference response but listener already notified", requestId);
logger.debug("[{}] request [{}] received inference response but listener already notified", modelId, requestId);
}

@Override
public void onFailure(Exception e) {
timeoutHandler.cancel();
if (notified.compareAndSet(false, true)) {
processContext.getResultProcessor().ignoreResposeWithoutNotifying(String.valueOf(requestId));
listener.onFailure(e);
return;
}
logger.debug(() -> new ParameterizedMessage("request [{}] received failure but listener already notified", requestId), e);
logger.debug(
() -> new ParameterizedMessage("[{}] request [{}] received failure but listener already notified", modelId, requestId),
e
);
}

@Override
Expand All @@ -332,66 +336,53 @@ protected void doRun() throws Exception {
processor.validateInputs(text);
assert config instanceof NlpConfig;
NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
logger.trace(() -> "Inference Request " + request.processInput.utf8ToString());
logger.debug(() -> "Inference Request " + request.processInput.utf8ToString());
if (request.tokenization.anyTruncated()) {
logger.debug("[{}] [{}] input truncated", modelId, requestId);
}
PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr);
processContext.getResultProcessor()
.registerRequest(
requestIdStr,
ActionListener.wrap(
pyTorchResult -> processResult(
pyTorchResult,
processContext,
request.tokenization,
processor.getResultProcessor((NlpConfig) config),
ActionListener.wrap(this::onSuccess, this::onFailure)
),
this::onFailure
)
);
processContext.process.get().writeInferenceRequest(request.processInput);
waitForResult(
processContext,
pendingResult,
request.tokenization,
requestIdStr,
timeout,
processor.getResultProcessor((NlpConfig) config),
ActionListener.wrap(this::onSuccess, this::onFailure)
);
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
onFailure(ExceptionsHelper.serverError("error writing to process", e));
} catch (Exception e) {
onFailure(e);
} finally {
processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
}
}

private void waitForResult(
private void processResult(
PyTorchResult pyTorchResult,
ProcessContext processContext,
PyTorchResultProcessor.PendingResult pendingResult,
TokenizationResult tokenization,
String requestId,
TimeValue timeout,
NlpTask.ResultProcessor inferenceResultsProcessor,
ActionListener<InferenceResults> listener
) {
try {
PyTorchResult pyTorchResult = processContext.getResultProcessor()
.waitForResult(processContext.process.get(), requestId, pendingResult, timeout);
if (pyTorchResult == null) {
listener.onFailure(
new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
);
return;
}

if (pyTorchResult.isError()) {
listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(), RestStatus.INTERNAL_SERVER_ERROR));
return;
}

logger.debug(
() -> new ParameterizedMessage("[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId)
);
InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
logger.debug(
() -> new ParameterizedMessage("[{}] processed result for request [{}]", processContext.task.getModelId(), requestId)
);
listener.onResponse(results);
} catch (InterruptedException e) {
listener.onFailure(e);
if (pyTorchResult.isError()) {
listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(), RestStatus.INTERNAL_SERVER_ERROR));
return;
}

logger.debug(
() -> new ParameterizedMessage("[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId)
);
InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
logger.debug(
() -> new ParameterizedMessage("[{}] processed result for request [{}]", processContext.task.getModelId(), requestId)
);
listener.onResponse(results);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;

import java.time.Instant;
Expand All @@ -19,9 +19,6 @@
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

public class PyTorchResultProcessor {

Expand All @@ -31,7 +28,6 @@ public class PyTorchResultProcessor {

private final String deploymentId;
private volatile boolean isStopping;
private volatile boolean stoppedProcessing;
private final LongSummaryStatistics timingStats;
private Instant lastUsed;

Expand All @@ -40,20 +36,18 @@ public PyTorchResultProcessor(String deploymentId) {
this.timingStats = new LongSummaryStatistics();
}

public PendingResult registerRequest(String requestId) {
return pendingResults.computeIfAbsent(requestId, k -> new PendingResult());
public void registerRequest(String requestId, ActionListener<PyTorchResult> listener) {
pendingResults.computeIfAbsent(requestId, k -> new PendingResult(listener));
}

/**
* Call this method when the caller is no longer waiting on the request response.
* Note that the pending result listener will not be notified.
*
* @param requestId The request ID that is no longer being waited on
*/
public void requestIgnored(String requestId) {
PendingResult pendingResult = pendingResults.remove(requestId);
if (pendingResult != null) {
pendingResult.latch.countDown();
}
public void ignoreResposeWithoutNotifying(String requestId) {
pendingResults.remove(requestId);
}

public void process(NativePyTorchProcess process) {
Expand All @@ -67,18 +61,16 @@ public void process(NativePyTorchProcess process) {
if (pendingResult == null) {
logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId()));
} else {
pendingResult.result.set(result);
pendingResult.latch.countDown();
pendingResult.listener.onResponse(result);
}
}
} catch (Exception e) {
// No need to report error as we're stopping
if (isStopping == false) {
logger.error(new ParameterizedMessage("[{}] Error processing results", deploymentId), e);
}
pendingResults.forEach((id, pendingResults) -> {
if (pendingResults.result.compareAndSet(
null,
pendingResults.forEach(
(id, pendingResults) -> pendingResults.listener.onResponse(
new PyTorchResult(
id,
null,
Expand All @@ -87,24 +79,17 @@ public void process(NativePyTorchProcess process) {
? "inference canceled as process is stopping"
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
)
)) {
pendingResults.latch.countDown();
}
});
)
);
pendingResults.clear();
} finally {
pendingResults.forEach((id, pendingResults) -> {
// Only set the result if it has not already been set
if (pendingResults.result.compareAndSet(
null,
pendingResults.forEach(
(id, pendingResults) -> pendingResults.listener.onResponse(
new PyTorchResult(id, null, null, "inference canceled as process is stopping")
)) {
pendingResults.latch.countDown();
}
});
)
);
pendingResults.clear();
}
stoppedProcessing = true;
logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId));
}

Expand All @@ -119,18 +104,6 @@ private synchronized void processResult(PyTorchResult result) {
}
}

public PyTorchResult waitForResult(NativePyTorchProcess process, String requestId, PendingResult pendingResult, TimeValue timeout)
throws InterruptedException {
if (process == null || stoppedProcessing || process.isProcessAlive() == false) {
PyTorchResult storedResult = pendingResult.result.get();
return storedResult == null ? new PyTorchResult(requestId, null, null, "native process no longer started") : storedResult;
}
if (pendingResult.latch.await(timeout.millis(), TimeUnit.MILLISECONDS)) {
return pendingResult.result.get();
}
return null;
}

public synchronized Instant getLastUsed() {
return lastUsed;
}
Expand All @@ -140,7 +113,10 @@ public void stop() {
}

public static class PendingResult {
private final AtomicReference<PyTorchResult> result = new AtomicReference<>();
private final CountDownLatch latch = new CountDownLatch(1);
public final ActionListener<PyTorchResult> listener;

public PendingResult(ActionListener<PyTorchResult> listener) {
this.listener = Objects.requireNonNull(listener);
}
}
}

0 comments on commit 5a41fa4

Please sign in to comment.