diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index a4ebfcb9df1..c31353766d3 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -160,7 +160,7 @@ protected double overlap(double x1, double w1, double x2, double w2) { return right - left; } - private DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(NDList list) { float[] flattened = list.get(0).toFloatArray(); ArrayList intermediateResults = new ArrayList<>(); int sizeClasses = classes.size(); @@ -280,7 +280,7 @@ public YoloV5Translator build() { } } - private static final class IntermediateResult { + protected static final class IntermediateResult { /** * A sortable score for how good the recognition is relative to others. Higher should be diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java new file mode 100644 index 00000000000..dc160ba754b --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; + +import java.util.ArrayList; +import java.util.Map; + +/** + * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check + * here: https://github.com/ultralytics/ultralytics + */ +public class YoloV8Translator extends YoloV5Translator { + + /** + * Constructs an ImageTranslator with the provided builder. + * + * @param builder the data to build with + */ + protected YoloV8Translator(Builder builder) { + super(builder); + } + + /** + * Creates a builder to build a {@code YoloV8Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static YoloV8Translator.Builder builder(Map arguments) { + YoloV8Translator.Builder builder = new YoloV8Translator.Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + @Override + protected DetectedObjects processFromBoxOutput(NDList list) { + NDArray features4OneImg = list.get(0); + int sizeClasses = classes.size(); + long sizeBoxes = features4OneImg.size(1); + ArrayList intermediateResults = new ArrayList<>(); + + for (long b = 0; b < sizeBoxes; b++) { + float maxClass = 0; + int maxIndex = 0; + for (int c = 4; c < sizeClasses; c++) { + float classProb = features4OneImg.getFloat(c, b); + if (classProb > maxClass) { + maxClass = classProb; + maxIndex = c; + } + } + + if (maxClass > threshold) { + float xPos = features4OneImg.getFloat(0, b); // center x + float yPos = features4OneImg.getFloat(1, b); // center y + float w = features4OneImg.getFloat(2, b); + float h = features4OneImg.getFloat(3, b); + Rectangle rect = + new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); + intermediateResults.add( + new IntermediateResult(classes.get(maxIndex), maxClass, maxIndex, rect)); + } + } + + return nms(intermediateResults); + } + + /** The builder for {@link YoloV8Translator}. */ + public static class Builder extends YoloV5Translator.Builder { + /** + * Builds the translator. + * + * @return the new translator + */ + @Override + public YoloV8Translator build() { + if (pipeline == null) { + addTransform( + array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255)); + } + validate(); + return new YoloV8Translator(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java new file mode 100644 index 00000000000..b5a4db00d28 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.Translator; + +import java.io.Serializable; +import java.util.Map; + +/** A translatorFactory that creates a {@link YoloV8Translator} instance. */ +public class YoloV8TranslatorFactory extends ObjectDetectionTranslatorFactory + implements Serializable { + + private static final long serialVersionUID = 1L; + + /** {@inheritDoc} */ + @Override + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return YoloV8Translator.builder(arguments).build(); + } +} diff --git a/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java new file mode 100644 index 00000000000..8fbbae7301b --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/YoloV8TranslatorFactoryTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.BasicTranslator; +import ai.djl.translate.Translator; + +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +public class YoloV8TranslatorFactoryTest { + + private YoloV8TranslatorFactory factory; + + @BeforeClass + public void setUp() { + factory = new YoloV8TranslatorFactory(); + } + + @Test + public void testGetSupportedTypes() { + Assert.assertEquals(factory.getSupportedTypes().size(), 5); + } + + @Test + public void testNewInstance() { + Map arguments = new HashMap<>(); + try (Model model = Model.newInstance("test")) { + Translator translator1 = + factory.newInstance(Image.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator1 instanceof YoloV8Translator); + + Translator translator2 = + factory.newInstance(Path.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator2 instanceof BasicTranslator); + + Translator translator3 = + factory.newInstance(URL.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator3 instanceof BasicTranslator); + + Translator translator4 = + factory.newInstance(InputStream.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator4 instanceof BasicTranslator); + + Translator translator5 = + factory.newInstance(Input.class, Output.class, model, arguments); + Assert.assertTrue(translator5 instanceof ImageServingTranslator); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(Image.class, Output.class, model, arguments)); + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java new file mode 100644 index 00000000000..5a474c1da31 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -0,0 +1,132 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.DetectedObjects.DetectedObject; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.translator.YoloV8TranslatorFactory; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** An example of inference using an yolov8 model. */ +public final class Yolov8Detection { + + private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class); + + private Yolov8Detection() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + DetectedObjects detection = Yolov8Detection.predict(); + logger.info("{}", detection); + } + + public static DetectedObjects predict() throws IOException, ModelException, TranslateException { + String classPath = System.getProperty("java.class.path"); + String pathSeparator = System.getProperty("path.separator"); + classPath = classPath.split(pathSeparator)[0]; + Path modelPath = Paths.get(classPath + "/yolov8n.onnx"); + Path imgPath = Paths.get(classPath + "/yolov8_test.jpg"); + Image img = ImageFactory.getInstance().fromFile(imgPath); + + Map arguments = new ConcurrentHashMap<>(); + arguments.put("width", 640); + arguments.put("height", 640); + arguments.put("resize", "true"); + arguments.put("toTensor", true); + arguments.put("applyRatio", true); + arguments.put("threshold", 0.6f); + arguments.put("synsetFileName", "yolov8_synset.txt"); + + YoloV8TranslatorFactory yoloV8TranslatorFactory = new YoloV8TranslatorFactory(); + Translator translator = + yoloV8TranslatorFactory.newInstance( + Image.class, DetectedObjects.class, null, arguments); + + Criteria criteria = + Criteria.builder() + .setTypes(Image.class, DetectedObjects.class) + .optModelPath(modelPath) + .optEngine("OnnxRuntime") + .optTranslator(translator) + .optProgress(new ProgressBar()) + .build(); + + DetectedObjects detectedObjects; + DetectedObject detectedObject; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Path outputPath = Paths.get(classPath + "/output"); + Files.createDirectories(outputPath); + + detectedObjects = predictor.predict(img); + + if (detectedObjects.getNumberOfObjects() > 0) { + List detectedObjectList = detectedObjects.items(); + for (DetectedObject object : detectedObjectList) { + detectedObject = object; + BoundingBox boundingBox = detectedObject.getBoundingBox(); + Rectangle tectangle = boundingBox.getBounds(); + logger.info( + detectedObject.getClassName() + + " " + + detectedObject.getProbability() + + " " + + tectangle.getX() + + " " + + tectangle.getY() + + " " + + tectangle.getWidth() + + " " + + tectangle.getHeight()); + } + + saveBoundingBoxImage( + img.resize(640, 640, false), + detectedObjects, + outputPath, + imgPath.toFile().getName()); + } + + return detectedObjects; + } + } + + private static void saveBoundingBoxImage( + Image img, DetectedObjects detectedObjects, Path outputPath, String outputFileName) + throws IOException { + img.drawBoundingBoxes(detectedObjects); + + Path imagePath = outputPath.resolve(outputFileName); + img.save(Files.newOutputStream(imagePath), "png"); + } +} diff --git a/examples/src/test/resources/yolov8_synset.txt b/examples/src/test/resources/yolov8_synset.txt new file mode 100644 index 00000000000..7139f0cc628 --- /dev/null +++ b/examples/src/test/resources/yolov8_synset.txt @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush \ No newline at end of file diff --git a/examples/src/test/resources/yolov8_test.jpg b/examples/src/test/resources/yolov8_test.jpg new file mode 100644 index 00000000000..01e43374348 Binary files /dev/null and b/examples/src/test/resources/yolov8_test.jpg differ diff --git a/examples/src/test/resources/yolov8n.onnx b/examples/src/test/resources/yolov8n.onnx new file mode 100644 index 00000000000..430f7f2beb0 Binary files /dev/null and b/examples/src/test/resources/yolov8n.onnx differ