diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 84025ce07e1..13f3325b12e 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -88,10 +88,18 @@ public Classifications(List classNames, NDArray probabilities) { */ public Classifications(List classNames, NDArray probabilities, int topK) { this.classNames = classNames; - NDArray array = probabilities.toType(DataType.FLOAT64, false); - this.probabilities = - Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); - array.close(); + if (probabilities.getDataType() == DataType.FLOAT32) { + // Avoid converting float32 to float64 as this is not supported on MPS device + this.probabilities = new ArrayList<>(); + for (float prob : probabilities.toFloatArray()) { + this.probabilities.add((double) prob); + } + } else { + NDArray array = probabilities.toType(DataType.FLOAT64, false); + this.probabilities = + Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList()); + array.close(); + } this.topK = topK; } diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index b12ac5dd07d..07e56a5ca04 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -434,7 +434,12 @@ default NDArray toTensor() { if (dim == 3) { result = result.expandDims(0); } - result = result.div(255.0).transpose(0, 3, 1, 2); + // For Apple Silicon MPS it is important not to switch to 64-bit float here + if (result.getDataType() == DataType.FLOAT32) { + result = result.div(255.0f).transpose(0, 3, 1, 2); + } else { + result = result.div(255.0).transpose(0, 3, 1, 2); + } if (dim == 3) { result = result.squeeze(0); } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index a92a9b6a3d4..deef04907be 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -118,8 +119,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { byte[] buf = Utils.toByteArray(is); try (NDArray array = manager.create( - new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) { - array.set(buf); + ByteBuffer.wrap(buf), + new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), + DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -132,8 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index 164ba9876cb..5503e721caa 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.Map; /** @@ -111,8 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(length, 28, 28, 1), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create( + ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } @@ -123,10 +125,9 @@ private NDArray readLabel(Artifact.Item item) throws IOException { if (is.skip(8) != 8) { throw new AssertionError("Failed skip data."); } - byte[] buf = Utils.toByteArray(is); - try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) { - array.set(buf); + try (NDArray array = + manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) { return array.toType(DataType.FLOAT32, false); } } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java index 8b4e2326f26..5b6ed349e10 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.integration; import ai.djl.Device; +import ai.djl.modality.Classifications; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; @@ -21,6 +22,10 @@ import org.testng.SkipException; import org.testng.annotations.Test; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + public class MpsTest { @Test @@ -36,4 +41,39 @@ public void testMps() { Assert.assertEquals(array.getDevice().getDeviceType(), "mps"); } } + + private static boolean checkMpsCompatible() { + return "aarch64".equals(System.getProperty("os.arch")) + && System.getProperty("os.name").startsWith("Mac"); + } + + @Test + public void testToTensorMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS toTensor test requires Apple Silicon macOS."); + } + + // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + NDArray array = manager.create(127f).reshape(1, 1, 1, 1); + NDArray tensor = array.getNDArrayInternal().toTensor(); + Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f}); + } + } + + @Test + public void testClassificationsMPS() { + if (!checkMpsCompatible()) { + throw new SkipException("MPS classification test requires Apple Silicon macOS."); + } + + // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to + // float64) + try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { + List names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth"); + NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f}); + Classifications classifications = new Classifications(names, tensor); + Assert.assertEquals(classifications.topK(1), Collections.singletonList("Third")); + } + } }