Skip to content

Commit

Permalink
Move MPS tests to same file
Browse files Browse the repository at this point in the history
  • Loading branch information
petebankhead committed Nov 25, 2023
1 parent 4d8295c commit d4d6e48
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 62 deletions.
31 changes: 0 additions & 31 deletions api/src/test/java/ai/djl/modality/ClassificationsTest.java

This file was deleted.

31 changes: 0 additions & 31 deletions api/src/test/java/ai/djl/ndarray/internal/NDArrayExTest.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.pytorch.integration;

import ai.djl.Device;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
Expand All @@ -21,6 +22,9 @@
import org.testng.SkipException;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.List;

public class MpsTest {

@Test
Expand All @@ -36,4 +40,42 @@ public void testMps() {
Assert.assertEquals(array.getDevice().getDeviceType(), "mps");
}
}

private static boolean checkMpsCompatible() {
return "aarch64".equals(System.getProperty("os.arch")) &&
System.getProperty("os.name").startsWith("Mac");
}

@Test
public void testToTensorMPS() {
if (!checkMpsCompatible()) {
throw new SkipException("MPS toTensor test requires Apple Silicon macOS.");
}

// Test that toTensor does not fail on MPS (e.g. due to use of float64 for division)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);;
NDArray tensor = array.getNDArrayInternal().toTensor();
Assert.assertEquals(tensor.toFloatArray(), new float[]{127f/255f});
}
}

@Test
public void testClassificationsMPS() {
if (!checkMpsCompatible()) {
throw new SkipException("MPS classification test requires Apple Silicon macOS.");
}

// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to float64)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
List<String> names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth");
NDArray tensor = manager.create(new float[]{0f, 0.125f, 1f, 0.5f, 0.25f});
Classifications classifications = new Classifications(
names,
tensor
);
Assert.assertNotNull(classifications.topK(1).equals(Arrays.asList("Third")));
}
}

}

0 comments on commit d4d6e48

Please sign in to comment.