diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index a4ebfcb9df1..c31353766d3 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -160,7 +160,7 @@ protected double overlap(double x1, double w1, double x2, double w2) { return right - left; } - private DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(NDList list) { float[] flattened = list.get(0).toFloatArray(); ArrayList intermediateResults = new ArrayList<>(); int sizeClasses = classes.size(); @@ -280,7 +280,7 @@ public YoloV5Translator build() { } } - private static final class IntermediateResult { + protected static final class IntermediateResult { /** * A sortable score for how good the recognition is relative to others. Higher should be diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java new file mode 100644 index 00000000000..dc160ba754b --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; + +import java.util.ArrayList; +import java.util.Map; + +/** + * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check + * here: https://github.com/ultralytics/ultralytics + */ +public class YoloV8Translator extends YoloV5Translator { + + /** + * Constructs an ImageTranslator with the provided builder. + * + * @param builder the data to build with + */ + protected YoloV8Translator(Builder builder) { + super(builder); + } + + /** + * Creates a builder to build a {@code YoloV8Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static YoloV8Translator.Builder builder(Map arguments) { + YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + @Override + protected DetectedObjects processFromBoxOutput(NDList list) { + NDArray features4OneImg = list.get(0); + int sizeClasses = classes.size(); + long sizeBoxes = features4OneImg.size(1); + ArrayList intermediateResults = new ArrayList<>(); + + for (long b = 0; b < sizeBoxes; b++) { + float maxClass = 0; + int maxIndex = 0; + for (int c = 4; c < sizeClasses; c++) { + float classProb = features4OneImg.getFloat(c, b); + if (classProb > maxClass) { + maxClass = classProb; + maxIndex = c; + } + } + + if (maxClass > threshold) { + float xPos = features4OneImg.getFloat(0, b); // center x + float yPos = features4OneImg.getFloat(1, b); // center y + float w = features4OneImg.getFloat(2, b); + float h = features4OneImg.getFloat(3, b); + Rectangle rect = + new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); + intermediateResults.add( + new IntermediateResult(classes.get(maxIndex), maxClass, maxIndex, rect)); + } + } + + return nms(intermediateResults); + } + + /** The builder for {@link YoloV8Translator}. */ + public static class Builder extends YoloV5Translator.Builder { + /** + * Builds the translator. + * + * @return the new translator + */ + @Override + public YoloV8Translator build() { + if (pipeline == null) { + addTransform( + array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); + } + validate(); + return new YoloV8Translator(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java new file mode 100644 index 00000000000..b5a4db00d28 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.Translator; + +import java.io.Serializable; +import java.util.Map; + +/** A translatorFactory that creates a {@link YoloV8Translator} instance. */ +public class YoloV8TranslatorFactory extends ObjectDetectionTranslatorFactory + implements Serializable { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return YoloV8Translator.builder(arguments).build(); + } +} diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 3d58d501293..82beca520d1 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -313,6 +313,21 @@ default void freezeParameters(boolean freeze) { } } + /** + * Freezes or unfreezes all parameters inside the block that pass the predicate. + * + * @param freeze true to mark as frozen rather than unfrozen + * @param pred true tests if the parameter should be updated + * @see Parameter#freeze(boolean) + */ + default void freezeParameters(boolean freeze, Predicate pred) { + for (Parameter parameter : getParameters().values()) { + if (pred.test(parameter)) { + parameter.freeze(freeze); + } + } + } + /** * Validates that actual layout matches the expected layout. * diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java new file mode 100644 index 00000000000..8fbbae7301b --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.BasicTranslator; +import ai.djl.translate.Translator; + +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +public class YoloV8TranslatorFactoryTest { + + private YoloV8TranslatorFactory factory; + + @BeforeClass + public void setUp() { + factory = new YoloV8TranslatorFactory(); + } + + @Test + public void testGetSupportedTypes() { + Assert.assertEquals(factory.getSupportedTypes().size(), 5); + } + + @Test + public void testNewInstance() { + Map arguments = new HashMap<>(); + try (Model model = Model.newInstance("test")) { + Translator translator1 = + factory.newInstance(Image.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator1 instanceof YoloV8Translator); + + Translator translator2 = + factory.newInstance(Path.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator2 instanceof BasicTranslator); + + Translator translator3 = + factory.newInstance(URL.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator3 instanceof BasicTranslator); + + Translator translator4 = + factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator4 instanceof BasicTranslator); + + Translator translator5 = + factory.newInstance(Input.class, Output.class, model, arguments); + Assert.assertTrue(translator5 instanceof ImageServingTranslator); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(Image.class, Output.class, model, arguments)); + } + } +} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java index a8a183de317..451fc9676e7 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java @@ -123,7 +123,10 @@ public void zeroGradients() { NDManager systemManager = MxNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); + } } } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 89599722435..8efab7a027b 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -97,7 +97,7 @@ public int getRank() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.15.1"; + return "1.16.0"; } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index d090e08decb..c46671597b3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -76,7 +76,10 @@ public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + // To prevent memory leak we must close gradient after use. + try (NDArray gradient = array.getGradient()) { + gradient.subi(gradient); + } } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index e72e98c9495..35e95f7de86 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -18,6 +18,7 @@ import ai.djl.Model; import ai.djl.ndarray.types.DataType; import ai.djl.nn.Parameter; +import ai.djl.nn.Parameter.Type; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -189,7 +190,9 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { } if (wasLoaded) { // Unfreeze parameters if training directly - block.freezeParameters(false); + block.freezeParameters( + false, + p -> p.getType() != Type.RUNNING_MEAN && p.getType() != Type.RUNNING_VAR); } for (Pair> pair : initializer) { if (pair.getKey() != null && pair.getValue() != null) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index fa4ee81f26c..b7f92cbd1c3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.NDUtils; @@ -24,6 +25,8 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.jni.JniUtils; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ @@ -760,7 +763,152 @@ public NDList multiBoxDetection( float nmsThreshold, boolean forceSuppress, int nmsTopK) { - throw new UnsupportedOperationException("Not implemented"); + assert (inputs.size() == 3); + + NDArray clsProb = inputs.get(0); + NDArray locPred = inputs.get(1); + NDArray anchors = inputs.get(2).reshape(new Shape(-1, 4)); + + NDManager ndManager = array.getManager(); + + NDArray variances = ndManager.create(new float[] {0.1f, 0.1f, 0.2f, 0.2f}); + + assert (variances.size() == 4); // << "Variance size must be 4"; + final int numClasses = (int) clsProb.size(1); + final int numAnchors = (int) clsProb.size(2); + final int numBatches = (int) clsProb.size(0); + + final float[] pAnchor = anchors.toFloatArray(); + + // [id, prob, xmin, ymin, xmax, ymax] + // TODO Move to NDArray-based implementation + NDList batchOutputs = new NDList(); + for (int nbatch = 0; nbatch < numBatches; ++nbatch) { + float[][] outputs = new float[numAnchors][6]; + final float[] pClsProb = clsProb.get(nbatch).toFloatArray(); + final float[] pLocPred = locPred.get(nbatch).toFloatArray(); + + for (int i = 0; i < numAnchors; ++i) { + // find the predicted class id and probability + float score = -1; + int id = 0; + for (int j = 1; j < numClasses; ++j) { + float temp = pClsProb[j * numAnchors + i]; + if (temp > score) { + score = temp; + id = j; + } + } + + if (id > 0 && score < threshold) { + id = 0; + } + + // [id, prob, xmin, ymin, xmax, ymax] + outputs[i][0] = id - 1; + outputs[i][1] = score; + int offset = i * 4; + float[] pAnchorRow4 = new float[4]; + pAnchorRow4[0] = pAnchor[offset]; + pAnchorRow4[1] = pAnchor[offset + 1]; + pAnchorRow4[2] = pAnchor[offset + 2]; + pAnchorRow4[3] = pAnchor[offset + 3]; + float[] pLocPredRow4 = new float[4]; + pLocPredRow4[0] = pLocPred[offset]; + pLocPredRow4[1] = pLocPred[offset + 1]; + pLocPredRow4[2] = pLocPred[offset + 2]; + pLocPredRow4[3] = pLocPred[offset + 3]; + float[] outRowLast4 = + transformLocations( + pAnchorRow4, + pLocPredRow4, + clip, + variances.toFloatArray()[0], + variances.toFloatArray()[1], + variances.toFloatArray()[2], + variances.toFloatArray()[3]); + outputs[i][2] = outRowLast4[0]; + outputs[i][3] = outRowLast4[1]; + outputs[i][4] = outRowLast4[2]; + outputs[i][5] = outRowLast4[3]; + } + + outputs = + Arrays.stream(outputs) + .filter(o -> o[0] >= 0) + .sorted(Comparator.comparing(o -> -o[1])) + .toArray(float[][]::new); + + // apply nms + for (int i = 0; i < outputs.length; ++i) { + for (int j = i + 1; j < outputs.length; ++j) { + if (outputs[i][0] == outputs[j][0]) { + float[] outputsIRow4 = new float[4]; + float[] outputsJRow4 = new float[4]; + outputsIRow4[0] = outputs[i][2]; + outputsIRow4[1] = outputs[i][3]; + outputsIRow4[2] = outputs[i][4]; + outputsIRow4[3] = outputs[i][5]; + outputsJRow4[0] = outputs[j][2]; + outputsJRow4[1] = outputs[j][3]; + outputsJRow4[2] = outputs[j][4]; + outputsJRow4[3] = outputs[j][5]; + float iou = calculateOverlap(outputsIRow4, outputsJRow4); + if (iou >= nmsThreshold) { + outputs[j][0] = -1; + } + } + } + } + batchOutputs.add(ndManager.create(outputs)); + } // end iter batch + + NDArray pOutNDArray = NDArrays.stack(batchOutputs); + NDList resultNDList = new NDList(); + resultNDList.add(pOutNDArray); + assert (resultNDList.size() == 1); + return resultNDList; + } + + private float[] transformLocations( + final float[] anchors, + final float[] locPred, + final boolean clip, + final float vx, + final float vy, + final float vw, + final float vh) { + float[] outRowLast4 = new float[4]; + // transform predictions to detection results + float al = anchors[0]; + float at = anchors[1]; + float ar = anchors[2]; + float ab = anchors[3]; + float aw = ar - al; + float ah = ab - at; + float ax = (al + ar) / 2.f; + float ay = (at + ab) / 2.f; + float px = locPred[0]; + float py = locPred[1]; + float pw = locPred[2]; + float ph = locPred[3]; + float ox = px * vx * aw + ax; + float oy = py * vy * ah + ay; + float ow = (float) (Math.exp(pw * vw) * aw / 2); + float oh = (float) (Math.exp(ph * vh) * ah / 2); + outRowLast4[0] = clip ? Math.max(0f, Math.min(1f, ox - ow)) : (ox - ow); + outRowLast4[1] = clip ? Math.max(0f, Math.min(1f, oy - oh)) : (oy - oh); + outRowLast4[2] = clip ? Math.max(0f, Math.min(1f, ox + ow)) : (ox + ow); + outRowLast4[3] = clip ? Math.max(0f, Math.min(1f, oy + oh)) : (oy + oh); + return outRowLast4; + } + + private float calculateOverlap(final float[] a, final float[] b) { + float w = Math.max(0f, Math.min(a[2], b[2]) - Math.max(a[0], b[0])); + float h = Math.max(0f, Math.min(a[3], b[3]) - Math.max(a[1], b[1])); + float i = w * h; + float u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; + return u <= 0.f ? 0f : (i / u); } /** {@inheritDoc} */ diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java new file mode 100644 index 00000000000..5a474c1da31 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -0,0 +1,132 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.DetectedObjects.DetectedObject; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** An example of inference using an yolov8 model. */ +public final class Yolov8Detection { + + private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class); + + private Yolov8Detection() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + DetectedObjects detection = Yolov8Detection.predict(); + logger.info("{}", detection); + } + + public static DetectedObjects predict() throws IOException, ModelException, TranslateException { + String classPath = System.getProperty("java.class.path"); + String pathSeparator = System.getProperty("path.separator"); + classPath = classPath.split(pathSeparator)[0]; + Path modelPath = Paths.get(classPath + "/yolov8n.onnx"); + Path imgPath = Paths.get(classPath + "/yolov8_test.jpg"); + Image img = ImageFactory.getInstance().fromFile(imgPath); + + Map arguments = new ConcurrentHashMap<>(); + arguments.put("width", 640); + arguments.put("height", 640); + arguments.put("resize", "true"); + arguments.put("toTensor", true); + arguments.put("applyRatio", true); + arguments.put("threshold", 0.6f); + arguments.put("synsetFileName", "yolov8_synset.txt"); + + YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory(); + Translator translator = + yoloV8TranslatorFactory.newInstance( + Image.class, DetectedObjects.class, null, arguments); + + Criteria criteria = + Criteria.builder() + .setTypes(Image.class, DetectedObjects.class) + .optModelPath(modelPath) + .optEngine("OnnxRuntime") + .optTranslator(translator) + .optProgress(new ProgressBar()) + .build(); + + DetectedObjects detectedObjects; + DetectedObject detectedObject; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Path outputPath = Paths.get(classPath + "/output"); + Files.createDirectories(outputPath); + + detectedObjects = predictor.predict(img); + + if (detectedObjects.getNumberOfObjects() > 0) { + List detectedObjectList = detectedObjects.items(); + for (DetectedObject object : detectedObjectList) { + detectedObject = object; + BoundingBox boundingBox = detectedObject.getBoundingBox(); + Rectangle tectangle = boundingBox.getBounds(); + logger.info( + detectedObject.getClassName() + + " " + + detectedObject.getProbability() + + " " + + tectangle.getX() + + " " + + tectangle.getY() + + " " + + tectangle.getWidth() + + " " + + tectangle.getHeight()); + } + + saveBoundingBoxImage( + img.resize(640, 640, false), + detectedObjects, + outputPath, + imgPath.toFile().getName()); + } + + return detectedObjects; + } + } + + private static void saveBoundingBoxImage( + Image img, DetectedObjects detectedObjects, Path outputPath, String outputFileName) + throws IOException { + img.drawBoundingBoxes(detectedObjects); + + Path imagePath = outputPath.resolve(outputFileName); + img.save(Files.newOutputStream(imagePath), "png"); + } +} diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java index 2a61e25862e..1a5699836c8 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java @@ -27,7 +27,6 @@ public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException { - TestRequirements.engine("MXNet"); TestRequirements.nightly(); String[] args; diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt new file mode 100644 index 00000000000..7139f0cc628 --- /dev/null +++ b/examples/src/test/resources/yolov8_synset.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/examples/src/test/resources/yolov8_test.jpg b/examples/src/test/resources/yolov8_test.jpg new file mode 100644 index 00000000000..01e43374348 Binary files /dev/null and b/examples/src/test/resources/yolov8_test.jpg differ diff --git a/examples/src/test/resources/yolov8n.onnx b/examples/src/test/resources/yolov8n.onnx new file mode 100644 index 00000000000..430f7f2beb0 Binary files /dev/null and b/examples/src/test/resources/yolov8n.onnx differ diff --git a/gradle.properties b/gradle.properties index 23a6019761a..87dc4fe5a15 100644 --- a/gradle.properties +++ b/gradle.properties @@ -17,7 +17,7 @@ pytorch_version=2.0.1 tensorflow_version=2.10.1 tflite_version=2.6.2 trt_version=8.4.1 -onnxruntime_version=1.15.1 +onnxruntime_version=1.16.0 paddlepaddle_version=2.3.2 sentencepiece_version=0.1.97 tokenizers_version=0.13.3 diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java index b5907925ee4..008d652dc82 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/object_detection/SingleShotDetectionTest.java @@ -31,6 +31,7 @@ import ai.djl.nn.LambdaBlock; import ai.djl.nn.SequentialBlock; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -123,10 +124,8 @@ private TrainingConfig setupTrainingConfig() { } private ZooModel getModel() throws IOException, ModelException { - // SSD-pikachu model only available in MXNet - // TODO: Add PyTorch model to model zoo - TestUtils.requiresEngine("MXNet"); - + TestUtils.requiresEngine( + ModelZoo.getModelZoo("ai.djl.zoo").getSupportedEngines().toArray(String[]::new)); Criteria criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 7cdfc040c12..543ab5f1f21 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -43,8 +43,8 @@ public String getGroupId() { public Set getSupportedEngines() { Set set = new HashSet<>(); set.add("MXNet"); + set.add("PyTorch"); // TODO Currently WIP in supporting these two engines in the basic model zoo - // set.add("PyTorch"); // set.add("TensorFlow"); return set; } diff --git a/tools/scripts/build_ft_deps.sh b/tools/scripts/build_ft_deps.sh index 4d3cb94a103..110bedc2d92 100755 --- a/tools/scripts/build_ft_deps.sh +++ b/tools/scripts/build_ft_deps.sh @@ -41,7 +41,7 @@ cd ../../ mkdir -p FasterTransformer/build cd FasterTransformer/build git submodule init && git submodule update -cmake -DCMAKE_BUILD_TYPE=Release -DSM=70,75,80,86 -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. +cmake -DCMAKE_BUILD_TYPE=Release -DSM=70,75,80,86,90 -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. make -j$(nproc) cp lib/libth_transformer.so /tmp/binaries/ cd ../../