From 963332db6b9736a69d212165116fe992674d7083 Mon Sep 17 00:00:00 2001 From: Julian Gamble Date: Fri, 29 Sep 2023 02:27:23 +1000 Subject: [PATCH] PtndArrayEx.multiboxDetection() implementation (#2769) * Implement PtNDArraryEx.multiboxDetection * MultiboxDetection - code cleanup * MultiboxDetection - code cleanup * MultiboxDetection - code cleanup * MultiboxDetection - code cleanup * format code * Fix, add tests, and pass CI --------- Co-authored-by: Zach Kimberg --- api/src/main/java/ai/djl/nn/Block.java | 15 ++ .../java/ai/djl/pytorch/engine/PtModel.java | 5 +- .../ai/djl/pytorch/engine/PtNDArrayEx.java | 150 +++++++++++++++++- .../examples/training/TrainPikachuTest.java | 1 - .../SingleShotDetectionTest.java | 7 +- .../ai/djl/basicmodelzoo/BasicModelZoo.java | 2 +- 6 files changed, 172 insertions(+), 8 deletions(-) diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 3d58d501293..82beca520d1 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -313,6 +313,21 @@ default void freezeParameters(boolean freeze) { } } + /** + * Freezes or unfreezes all parameters inside the block that pass the predicate. + * + * @param freeze true to mark as frozen rather than unfrozen + * @param pred true tests if the parameter should be updated + * @see Parameter#freeze(boolean) + */ + default void freezeParameters(boolean freeze, Predicate pred) { + for (Parameter parameter : getParameters().values()) { + if (pred.test(parameter)) { + parameter.freeze(freeze); + } + } + } + /** * Validates that actual layout matches the expected layout. * diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index e72e98c9495..35e95f7de86 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -18,6 +18,7 @@ import ai.djl.Model; import ai.djl.ndarray.types.DataType; import ai.djl.nn.Parameter; +import ai.djl.nn.Parameter.Type; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { } if (wasLoaded) { // Unfreeze parameters if training directly - block.freezeParameters(false); + block.freezeParameters( + false, + p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR); } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index fa4ee81f26c..b7f92cbd1c3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.NDUtils; @@ -24,6 +25,8 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.jni.JniUtils; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ @@ -760,7 +763,152 @@ public NDList multiBoxDetection( float nmsThreshold, boolean forceSuppress, int nmsTopK) { - throw new UnsupportedOperationException("Not implemented"); + assert (inputs.size() == 3); + + NDArray clsProb = inputs.get(0); + NDArray locPred = inputs.get(1); + NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4)); + + NDManager ndManager = array.getManager(); + + NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f}); + + assert (variances.size() == 4); // << "Variance size must be 4"; + final int numClasses = (int) clsProb.size(1); + final int numAnchors = (int) clsProb.size(2); + final int numBatches = (int) clsProb.size(0); + + final float[] pAnchor = anchors.toFloatArray(); + + // [id, prob, xmin, ymin, xmax, ymax] + // TODO Move to NDArray-based implementation + NDList batchOutputs = new NDList(); + for (int nbatch = 0; nbatch < numBatches; ++nbatch) { + float[][] outputs = new float[numAnchors][6]; + final float[] pClsProb = clsProb.get(nbatch).toFloatArray(); + final float[] pLocPred = locPred.get(nbatch).toFloatArray(); + + for (int i = 0; i < numAnchors; ++i) { + // find the predicted class id and probability + float score = -1; + int id = 0; + for (int j = 1; j < numClasses; ++j) { + float temp = pClsProb[j * numAnchors + i]; + if (temp > score) { + score = temp; + id = j; + } + } + + if (id > 0 && score < threshold) { + id = 0; + } + + // [id, prob, xmin, ymin, xmax, ymax] + outputs[i][0] = id - 1; + outputs[i][1] = score; + int offset = i * 4; + float[] pAnchorRow4 = new float[4]; + pAnchorRow4[0] = pAnchor[offset]; + pAnchorRow4[1] = pAnchor[offset + 1]; + pAnchorRow4[2] = pAnchor[offset + 2]; + pAnchorRow4[3] = pAnchor[offset + 3]; + float[] pLocPredRow4 = new float[4]; + pLocPredRow4[0] = pLocPred[offset]; + pLocPredRow4[1] = pLocPred[offset + 1]; + pLocPredRow4[2] = pLocPred[offset + 2]; + pLocPredRow4[3] = pLocPred[offset + 3]; + float[] outRowLast4 = + transformLocations( + pAnchorRow4, + pLocPredRow4, + clip, + variances.toFloatArray()[0], + variances.toFloatArray()[1], + variances.toFloatArray()[2], + variances.toFloatArray()[3]); + outputs[i][2] = outRowLast4[0]; + outputs[i][3] = outRowLast4[1]; + outputs[i][4] = outRowLast4[2]; + outputs[i][5] = outRowLast4[3]; + } + + outputs = + Arrays.stream(outputs) + .filter(o -> o[0] >= 0) + .sorted(Comparator.comparing(o -> -o[1])) + .toArray(float[][]::new); + + // apply nms + for (int i = 0; i < outputs.length; ++i) { + for (int j = i + 1; j < outputs.length; ++j) { + if (outputs[i][0] == outputs[j][0]) { + float[] outputsIRow4 = new float[4]; + float[] outputsJRow4 = new float[4]; + outputsIRow4[0] = outputs[i][2]; + outputsIRow4[1] = outputs[i][3]; + outputsIRow4[2] = outputs[i][4]; + outputsIRow4[3] = outputs[i][5]; + outputsJRow4[0] = outputs[j][2]; + outputsJRow4[1] = outputs[j][3]; + outputsJRow4[2] = outputs[j][4]; + outputsJRow4[3] = outputs[j][5]; + float iou = calculateOverlap(outputsIRow4, outputsJRow4); + if (iou >= nmsThreshold) { + outputs[j][0] = -1; + } + } + } + } + batchOutputs.add(ndManager.create(outputs)); + } // end iter batch + + NDArray pOutNDArray = NDArrays.stack(batchOutputs); + NDList resultNDList = new NDList(); + resultNDList.add(pOutNDArray); + assert (resultNDList.size() == 1); + return resultNDList; + } + + private float[] transformLocations( + final float[] anchors, + final float[] locPred, + final boolean clip, + final float vx, + final float vy, + final float vw, + final float vh) { + float[] outRowLast4 = new float[4]; + // transform predictions to detection results + float al = anchors[0]; + float at = anchors[1]; + float ar = anchors[2]; + float ab = anchors[3]; + float aw = ar - al; + float ah = ab - at; + float ax = (al + ar) / 2.f; + float ay = (at + ab) / 2.f; + float px = locPred[0]; + float py = locPred[1]; + float pw = locPred[2]; + float ph = locPred[3]; + float ox = px * vx * aw + ax; + float oy = py * vy * ah + ay; + float ow = (float) (Math.exp(pw * vw) * aw / 2); + float oh = (float) (Math.exp(ph * vh) * ah / 2); + outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow); + outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh); + outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow); + outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh); + return outRowLast4; + } + + private float calculateOverlap(final float[] a, final float[] b) { + float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0])); + float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1])); + float i = w * h; + float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; + return u <= 0.f ? 0f : (i / u); } /** {@inheritDoc} */ diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java index 2a61e25862e..1a5699836c8 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java @@ -27,7 +27,6 @@ public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException { - TestRequirements.engine("MXNet"); TestRequirements.nightly(); String[] args; diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java index b5907925ee4..008d652dc82 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java @@ -31,6 +31,7 @@ import ai.djl.nn.LambdaBlock; import ai.djl.nn.SequentialBlock; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() { } private ZooModel getModel() throws IOException, ModelException { - // SSD-pikachu model only available in MXNet - // TODO: Add PyTorch model to model zoo - TestUtils.requiresEngine("MXNet"); - + TestUtils.requiresEngine( + ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new)); Criteria criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 7cdfc040c12..543ab5f1f21 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -43,8 +43,8 @@ public String getGroupId() { public Set getSupportedEngines() { Set set = new HashSet<>(); set.add("MXNet"); + set.add("PyTorch"); // TODO Currently WIP in supporting these two engines in the basic model zoo - // set.add("PyTorch"); // set.add("TensorFlow"); return set; }