diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 9cb25089423e..13f3325b12ed 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -88,11 +88,12 @@ public Classifications(List classNames, NDArray probabilities) { */ public Classifications(List classNames, NDArray probabilities, int topK) { this.classNames = classNames; - if (probabilities.getDataType().equals(DataType.FLOAT32)) { + if (probabilities.getDataType() == DataType.FLOAT32) { // Avoid converting float32 to float64 as this is not supported on MPS device this.probabilities = new ArrayList<>(); - for (float prob : probabilities.toFloatArray()) + for (float prob : probabilities.toFloatArray()) { this.probabilities.add((double) prob); + } } else { NDArray array = probabilities.toType(DataType.FLOAT64, false); this.probabilities = diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index b484114fd6b7..07e56a5ca046 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -435,10 +435,11 @@ default NDArray toTensor() { result = result.expandDims(0); } // For Apple Silicon MPS it is important not to switch to 64-bit float here - if (result.getDataType().equals(DataType.FLOAT32)) + if (result.getDataType() == DataType.FLOAT32) { result = result.div(255.0f).transpose(0, 3, 1, 2); - else + } else { result = result.div(255.0).transpose(0, 3, 1, 2); + } if (dim == 3) { result = result.squeeze(0); } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index 0c162806084a..deef04907bed 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -117,8 +117,11 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(ByteBuffer.wrap(buf), - new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) { + try (NDArray array = + manager.create( + ByteBuffer.wrap(buf), + new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), + DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -131,7 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index aa97ac346b4b..5503e721caae 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -112,7 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) { + try (NDArray array = + manager.create( + ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -124,7 +126,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException { throw new AssertionError("Failed skip data."); } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java index 5bf2aa1fee6b..95d091e8c6e9 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -42,8 +42,8 @@ public void testMps() { } private static boolean checkMpsCompatible() { - return "aarch64".equals(System.getProperty("os.arch")) && - System.getProperty("os.name").startsWith("Mac"); + return "aarch64".equals(System.getProperty("os.arch")) + && System.getProperty("os.name").startsWith("Mac"); } @Test @@ -54,9 +54,10 @@ public void testToTensorMPS() { // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division) try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { - NDArray array = manager.create(127f).reshape(1, 1, 1, 1);; + NDArray array = manager.create(127f).reshape(1, 1, 1, 1); + ; NDArray tensor = array.getNDArrayInternal().toTensor(); - Assert.assertEquals(tensor.toFloatArray(), new float[]{127f/255f}); + Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f}); } } @@ -66,16 +67,13 @@ public void testClassificationsMPS() { throw new SkipException("MPS classification test requires Apple Silicon macOS."); } - // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to float64) - try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to + // float64) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { List names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth"); - NDArray tensor = manager.create(new float[]{0f, 0.125f, 1f, 0.5f, 0.25f}); - Classifications classifications = new Classifications( - names, - tensor - ); + NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f}); + Classifications classifications = new Classifications(names, tensor); Assert.assertNotNull(classifications.topK(1).equals(Arrays.asList("Third"))); } } - }