Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Put the output of the model into existing NDArrays when provided #618

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,24 @@ public Predictor(Model model, Translator<I, O> translator, boolean copy) {
*/
@SuppressWarnings("PMD.AvoidRethrowingException")
public O predict(I input) throws TranslateException {
return batchPredict(Collections.singletonList(input)).get(0);
return batchPredict(Collections.singletonList(input), null).get(0);
}

private NDList predict(NDList ndList) {
/**
* Predicts an item for inference.
*
* @param input the input
* @return the output object defined by the user
* @throws TranslateException if an error occurs during prediction
*/
@SuppressWarnings("PMD.AvoidRethrowingException")
public O predict(I input, NDList output) throws TranslateException {
return batchPredict(Collections.singletonList(input), output).get(0);
}

private NDList predict(NDList ndList, NDList output) {
logger.trace("Predictor input data: {}", ndList);
return block.forward(parameterStore, ndList, false);
return block.forward(parameterStore, ndList, output, false);
}

/**
Expand All @@ -142,6 +154,18 @@ private NDList predict(NDList ndList) {
*/
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"})
public List<O> batchPredict(List<I> inputs) throws TranslateException {
return batchPredict(inputs, null);
}

/**
* Predicts a batch for inference.
*
* @param inputs a list of inputs
* @return a list of output objects defined by the user
* @throws TranslateException if an error occurs during prediction
*/
@SuppressWarnings("PMD.AvoidRethrowingException")
public List<O> batchPredict(List<I> inputs, NDList output) throws TranslateException {
long begin = System.nanoTime();
try (PredictorContext context = new PredictorContext()) {
if (!prepared) {
Expand All @@ -157,7 +181,7 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
NDList ndList = translator.processInput(context, input);
preprocessEnd(ndList);

NDList result = predict(ndList);
NDList result = predict(ndList, output);
predictEnd(result);

ret.add(translator.processOutput(context, result));
Expand All @@ -170,7 +194,7 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
NDList inputBatch = processInputs(context, inputs);
preprocessEnd(inputBatch);

NDList result = predict(inputBatch);
NDList result = predict(inputBatch, output);
predictEnd(result);

List<O> ret = processOutputs(context, result);
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/modality/nlp/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public Encoder(byte version, Block block) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
return block.forward(parameterStore, inputs, training, params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public NDList forward(ParameterStore parameterStore, NDList inputs, boolean trai
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
if (training) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public List<String> unembedText(NDArray textEmbedding) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
return trainableWordEmbedding.forward(parameterStore, inputs, training, params);
Expand Down
29 changes: 28 additions & 1 deletion api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,33 @@ public interface Block {
* @return the output of the forward pass
*/
default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training) {
return forward(parameterStore, inputs, training, null);
return forward(parameterStore, inputs, null, training, null);
}

/**
* Applies the operating function of the block once. This method should be called only on blocks
* that are initialized.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass
* @return the output of the forward pass
*/
default NDList forward(ParameterStore parameterStore, NDList inputs, NDList output, boolean training) {
return forward(parameterStore, inputs, output, training, null);
}

/**
* Applies the operating function of the block once. This method should be called only on blocks
* that are initialized.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass
* @return the output of the forward pass
*/
default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
return forward(parameterStore, inputs, null, training, params);
}

/**
Expand All @@ -130,6 +156,7 @@ default NDList forward(ParameterStore parameterStore, NDList inputs, boolean tra
NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params);

Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/LambdaBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public static LambdaBlock singleton(Function<NDArray, NDArray> lambda) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
return lambda.apply(inputs);
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/ParallelBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public final ParallelBlock add(Function<NDList, NDList> f) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
return function.apply(
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/SequentialBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public void replaceLastBlock(Block block) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDList current = inputs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ public Convolution(ConvolutionBuilder<?> builder) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public Deconvolution(DeconvolutionBuilder<?> builder) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public ConstantEmbedding(NDArray embedding) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDManager manager = inputs.get(0).getManager();
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/core/Embedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDList opInputs = opInputs(parameterStore, inputs, training);
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/core/Linear.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public class Linear extends AbstractBlock {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/core/Prelu.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public Prelu() {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/norm/BatchNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public class BatchNorm extends AbstractBlock {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/norm/Dropout.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public class Dropout extends AbstractBlock {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
return dropout(inputs.singletonOrThrow(), rate, training);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ public final void setStateOutputs(boolean stateOutputs) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList o,
boolean training,
PairList<String, Object> params) {
inputs = opInputs(parameterStore, inputs, training);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ private NDArray createAttentionHeadsFromEmbeddings(
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
// E=embedding size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private SingleShotDetection(Builder builder) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
NDList networkOutput = inputs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public void initState(NDList encoderStates) {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList o,
boolean training,
PairList<String, Object> params) {
if (training) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public void removeLastBlock() {
public NDList forward(
ParameterStore parameterStore,
NDList inputs,
NDList output,
boolean training,
PairList<String, Object> params) {
// TODO refactor the forward to not take ParameterStore
Expand All @@ -106,7 +107,7 @@ public NDList forward(
for (NDArray array : inputs) {
inputDescriptions.add(array.getName(), array.getShape());
}
NDList outputs = IValueUtils.forward(this, inputs, training);
NDList outputs = IValueUtils.forward(this, inputs, output, training);
for (NDArray array : outputs) {
outputDescriptions.add(array.getName(), array.getShape());
}
Expand All @@ -115,7 +116,7 @@ public NDList forward(
}
}
}
return IValueUtils.forward(this, inputs, training);
return IValueUtils.forward(this, inputs, output, training);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ public static PtNDArray toNDArray(long iValueHandle, PtNDManager manager) {
return new PtNDArray(manager, ndHandle);
}

/**
* Extract IValue with a {@link PtNDArray} value.
*
* @param arrayHandle array handle
* @param iValueHandle IValue {@link Pointer}
*/
public static void toNDArrayCopy(Pointer arrayHandle, Pointer iValueHandle) {
PyTorchLibrary.LIB.iValueToTensorCopy(arrayHandle, iValueHandle);
}

/**
* Extract IValue to {@link NDList}.
*
Expand Down Expand Up @@ -190,15 +200,20 @@ public static Map<Long, Long> toIValueMap(long iValueHandle) {
return map;
}

private static NDList forwardHelper(long iValueHandle, PtNDManager manager) {
private static NDList forwardHelper(long iValueHandle, NDList output, PtNDManager manager) {
NDList list = new NDList();
if (isNDArray(iValueHandle)) {
if (output != null) {
toNDArrayCopy(((PtNDArray)output.get(0)).getHandle(), iValueHandle);
PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle);
return output;
}
list.add(toNDArray(iValueHandle, manager));
} else if (isNDList(iValueHandle)) {
list.addAll(toNDList(iValueHandle, manager));
} else if (isList(iValueHandle) || isTuple(iValueHandle)) {
for (long handle : toIValueArray(iValueHandle)) {
list.addAll(forwardHelper(handle, manager));
list.addAll(forwardHelper(handle, output, manager));
}
} else if (isMap(iValueHandle)) {
// Only allows <String, NDArray> type of map
Expand Down Expand Up @@ -231,14 +246,14 @@ private static NDList forwardHelper(long iValueHandle, PtNDManager manager) {
* @param isTrain is running on training mode
* @return result {@link NDList}
*/
public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
public static NDList forward(PtSymbolBlock block, NDList inputs, NDList output, boolean isTrain) {
long[] arrayHandles =
inputs.stream().mapToLong(input -> ((PtNDArray) input).getHandle()).toArray();
String[] names = inputs.stream().map(NDArray::getName).toArray(String[]::new);
long[] iValueInputs = getInputs(arrayHandles, names);
long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueInputs, isTrain);
PtNDManager manager = (PtNDManager) inputs.get(0).getManager();
return forwardHelper(result, manager);
return forwardHelper(result, output, manager);
}

private static boolean isNameList(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ native long moduleLoad(

native long iValueToTensor(long iValueHandle);

native void iValueToTensorCopy(long iValueHandle, long tensorHandle);

native long[] iValueToTensorList(long iValueHandle);

native long[] iValueToList(long iValueHandle);
Expand Down