Skip to content

Commit

Permalink
[api] Optimized text embedding post processing performance (#3459)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 9, 2024
1 parent 0c68d70 commit f91a696
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
Expand All @@ -34,7 +33,7 @@
import java.util.List;

/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {
public class CrossEncoderServingTranslator implements Translator<Input, Output> {

private Translator<StringPair, float[]> translator;

Expand All @@ -56,74 +55,13 @@ public void prepare(TranslatorContext ctx) throws Exception {
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
ReRankingInput in = ReRankingInput.parseInput(input);
if (in.batch != null) {
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, in.batch);
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
ctx.setAttachment("batch", Boolean.TRUE);
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
inputs.add(parseStringPair(obj));
}
return translator.batchProcessInput(ctx, inputs);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
inputs.add(new StringPair(key, value));
}
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, inputs);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}

NDList ret = translator.processInput(ctx, pair);
NDList ret = translator.processInput(ctx, in.pair);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
Expand All @@ -132,6 +70,27 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
return ret;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) throws Exception {
int[] mapping = new int[inputs.size()];
List<StringPair> prompts = new ArrayList<>(mapping.length);
for (int i = 0; i < mapping.length; ++i) {
ReRankingInput in = ReRankingInput.parseInput(inputs.get(i));
if (in.batch != null) {
List<StringPair> batch = in.batch;
mapping[i] = batch.size();
prompts.addAll(batch);
} else {
mapping[i] = -1;
prompts.add(in.pair);
}
}
ctx.setAttachment("mapping", mapping);
return translator.batchProcessInput(ctx, prompts);
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Expand All @@ -149,17 +108,126 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
return output;
}

private StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
List<Output> ret = new ArrayList<>(mapping.length);
int index = 0;
for (int size : mapping) {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (size == -1) {
// non-batching
output.add(BytesSupplier.wrapAsJson(outputs.get(index++)));
} else {
// client side batching
float[][] embeddings = new float[size][];
for (int j = 0; j < size; ++j) {
embeddings[j] = outputs.get(index++);
}
output.add(BytesSupplier.wrapAsJson(embeddings));
}
ret.add(output);
}
return ret;
}

private static final class ReRankingInput {

private StringPair pair;
private List<StringPair> batch;

ReRankingInput(StringPair pair) {
this.pair = pair;
}

ReRankingInput(List<StringPair> batch) {
this.batch = batch;
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());

static ReRankingInput parseInput(Input input) throws TranslateException {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
batch.add(parseStringPair(obj));
}
return new ReRankingInput(batch);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
batch.add(new StringPair(key, value));
}
return new ReRankingInput(batch);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}
return new ReRankingInput(pair);
}

private static StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());
}
throw new TranslateException("Missing text or text_pair in json.");
}
throw new TranslateException("Missing text or text_pair in json.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception

/** {@inheritDoc} */
@Override
@SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"})
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
Expand Down
2 changes: 1 addition & 1 deletion extensions/tokenizers/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ tasks {
downloadPath gzipInto file
}

if ("text_embedding" != task)
if (task !in arrayOf("text_embedding", "text_classification"))
continue

file = prefix / task / "ai.djl.huggingface.rust.json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -88,14 +89,29 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
/** {@inheritDoc} */
@Override
public List<float[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
NDList[] batches = batchifier.unbatchify(list);
List<float[]> ret = new ArrayList<>(batches.length);
for (NDList batch : batches) {
NDArray result = batch.get(0);
if (sigmoid) {
if (sigmoid) {
NDList[] batches = batchifier.unbatchify(list);
List<float[]> ret = new ArrayList<>(batches.length);
for (NDList batch : batches) {
NDArray result = batch.get(0);
result = result.getNDArrayInternal().sigmoid();
ret.add(result.toFloatArray());
}
ret.add(result.toFloatArray());
return ret;
}
NDArray array = list.get(0);
int batchSize = Math.toIntExact(array.size(0));
float[] buf = list.get(0).toFloatArray();
if (batchSize == 1) {
return Collections.singletonList(buf);
}

int length = buf.length / batchSize;
List<float[]> ret = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize; ++i) {
float[] f = new float[length];
System.arraycopy(buf, i * length, f, 0, length);
ret.add(f);
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -155,14 +156,20 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
/** {@inheritDoc} */
@Override
public List<float[]> batchProcessOutput(TranslatorContext ctx, NDList list) {
int batchSize = Math.toIntExact(list.head().size(0));
NDArray attentionMask = (NDArray) ctx.getAttachment("attentionMask");
NDArray output = processEmbedding(list, attentionMask);
int batchSize = Math.toIntExact(output.size(0));
float[] buf = output.toFloatArray();
if (batchSize == 1) {
return Collections.singletonList(buf);
}

int length = buf.length / batchSize;
List<float[]> ret = new ArrayList<>(batchSize);
NDList splitList = output.split(batchSize);
for (int i = 0; i < batchSize; i++) {
NDArray array = splitList.get(i);
ret.add(array.toFloatArray());
for (int i = 0; i < batchSize; ++i) {
float[] f = new float[length];
System.arraycopy(buf, i * length, f, 0, length);
ret.add(f);
}
return ret;
}
Expand Down
Loading

0 comments on commit f91a696

Please sign in to comment.