diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java
index ce9b29ae5ba..b99e5a64826 100644
--- a/api/src/main/java/ai/djl/Device.java
+++ b/api/src/main/java/ai/djl/Device.java
@@ -14,11 +14,16 @@
import ai.djl.engine.Engine;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
/**
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
@@ -30,7 +35,7 @@
* @see The D2L chapter
* on GPU devices
*/
-public final class Device {
+public class Device {
private static final Map CACHE = new ConcurrentHashMap<>();
@@ -39,8 +44,8 @@ public final class Device {
private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");
- private String deviceType;
- private int deviceId;
+ protected String deviceType;
+ protected int deviceId;
/**
* Creates a {@code Device} with basic information.
@@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) {
return engine.defaultDevice();
}
+ if (deviceName.contains("+")) {
+ String[] split = deviceName.split("\\+");
+ List subDevices =
+ Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList());
+ return new MultiDevice(subDevices);
+ }
+
Matcher matcher = DEVICE_NAME.matcher(deviceName);
if (matcher.matches()) {
String deviceType = matcher.group(1);
@@ -214,4 +226,91 @@ public interface Type {
String CPU = "cpu";
String GPU = "gpu";
}
+
+ /** A combined {@link Device} representing the composition of multiple other devices. */
+ public static class MultiDevice extends Device {
+
+ List devices;
+
+ /**
+ * Constructs a {@link MultiDevice} with a range of new devices.
+ *
+ * @param deviceType the type of the sub-devices
+ * @param startInclusive the start (inclusive) of the devices range
+ * @param endExclusive the end (exclusive) of the devices range
+ */
+ public MultiDevice(String deviceType, int startInclusive, int endExclusive) {
+ this(
+ IntStream.range(startInclusive, endExclusive)
+ .mapToObj(i -> Device.of(deviceType, i))
+ .collect(Collectors.toList()));
+ }
+
+ /**
+ * Constructs a {@link MultiDevice} from sub devices.
+ *
+ * @param devices the sub devices
+ */
+ public MultiDevice(Device... devices) {
+ this(Arrays.asList(devices));
+ }
+
+ /**
+ * Constructs a {@link MultiDevice} from sub devices.
+ *
+ * @param devices the sub devices
+ */
+ public MultiDevice(List devices) {
+ super(null, -1);
+ devices.sort(
+ Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER)
+ .thenComparingInt(Device::getDeviceId));
+ this.deviceType =
+ String.join(
+ "+",
+ (Iterable)
+ () ->
+ devices.stream()
+ .map(d -> d.getDeviceType() + d.getDeviceId())
+ .iterator());
+ this.devices = devices;
+ }
+
+ /**
+ * Returns the sub devices.
+ *
+ * @return the sub devices
+ */
+ public List getDevices() {
+ return devices;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ MultiDevice that = (MultiDevice) o;
+ return Objects.equals(devices, that.devices);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), devices);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return deviceType + "()";
+ }
+ }
}
diff --git a/api/src/main/java/ai/djl/training/ParameterStore.java b/api/src/main/java/ai/djl/training/ParameterStore.java
index 7029282c46e..15c83bde8ca 100644
--- a/api/src/main/java/ai/djl/training/ParameterStore.java
+++ b/api/src/main/java/ai/djl/training/ParameterStore.java
@@ -14,6 +14,7 @@
package ai.djl.training;
import ai.djl.Device;
+import ai.djl.Device.MultiDevice;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
@@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices
this.parameterServer = parameterServer;
deviceMap.clear();
for (int i = 0; i < devices.length; ++i) {
+ if (devices[i] instanceof MultiDevice) {
+ throw new IllegalArgumentException(
+ "The parameter store does not support MultiDevices");
+ }
if (deviceMap.put(devices[i], i) != null) {
throw new IllegalArgumentException("Duplicated devices are not allowed.");
}
diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java
index 92a0474c6e7..63572810875 100644
--- a/api/src/test/java/ai/djl/DeviceTest.java
+++ b/api/src/test/java/ai/djl/DeviceTest.java
@@ -13,6 +13,7 @@
package ai.djl;
+import ai.djl.Device.MultiDevice;
import ai.djl.engine.Engine;
import org.testng.Assert;
@@ -37,6 +38,8 @@ public void testDevice() {
System.setProperty("test_key", "test");
Engine.debugEnvironment();
+
+ Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size());
}
@Test
@@ -54,5 +57,9 @@ public void testDeviceName() {
Device defaultDevice = Engine.getInstance().defaultDevice();
Assert.assertEquals(Device.fromName(""), defaultDevice);
Assert.assertEquals(Device.fromName(null), defaultDevice);
+
+ Assert.assertEquals(
+ Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1)));
+ Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3));
}
}