diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index ce9b29ae5ba..b99e5a64826 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -14,11 +14,16 @@ import ai.djl.engine.Engine; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code @@ -30,7 +35,7 @@ * @see The D2L chapter * on GPU devices */ -public final class Device { +public class Device { private static final Map CACHE = new ConcurrentHashMap<>(); @@ -39,8 +44,8 @@ public final class Device { private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); - private String deviceType; - private int deviceId; + protected String deviceType; + protected int deviceId; /** * Creates a {@code Device} with basic information. @@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) { return engine.defaultDevice(); } + if (deviceName.contains("+")) { + String[] split = deviceName.split("\\+"); + List subDevices = + Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList()); + return new MultiDevice(subDevices); + } + Matcher matcher = DEVICE_NAME.matcher(deviceName); if (matcher.matches()) { String deviceType = matcher.group(1); @@ -214,4 +226,91 @@ public interface Type { String CPU = "cpu"; String GPU = "gpu"; } + + /** A combined {@link Device} representing the composition of multiple other devices. */ + public static class MultiDevice extends Device { + + List devices; + + /** + * Constructs a {@link MultiDevice} with a range of new devices. + * + * @param deviceType the type of the sub-devices + * @param startInclusive the start (inclusive) of the devices range + * @param endExclusive the end (exclusive) of the devices range + */ + public MultiDevice(String deviceType, int startInclusive, int endExclusive) { + this( + IntStream.range(startInclusive, endExclusive) + .mapToObj(i -> Device.of(deviceType, i)) + .collect(Collectors.toList())); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(Device... devices) { + this(Arrays.asList(devices)); + } + + /** + * Constructs a {@link MultiDevice} from sub devices. + * + * @param devices the sub devices + */ + public MultiDevice(List devices) { + super(null, -1); + devices.sort( + Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER) + .thenComparingInt(Device::getDeviceId)); + this.deviceType = + String.join( + "+", + (Iterable) + () -> + devices.stream() + .map(d -> d.getDeviceType() + d.getDeviceId()) + .iterator()); + this.devices = devices; + } + + /** + * Returns the sub devices. + * + * @return the sub devices + */ + public List getDevices() { + return devices; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + MultiDevice that = (MultiDevice) o; + return Objects.equals(devices, that.devices); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), devices); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return deviceType + "()"; + } + } } diff --git a/api/src/main/java/ai/djl/training/ParameterStore.java b/api/src/main/java/ai/djl/training/ParameterStore.java index 7029282c46e..15c83bde8ca 100644 --- a/api/src/main/java/ai/djl/training/ParameterStore.java +++ b/api/src/main/java/ai/djl/training/ParameterStore.java @@ -14,6 +14,7 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.Device.MultiDevice; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.nn.Parameter; @@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices this.parameterServer = parameterServer; deviceMap.clear(); for (int i = 0; i < devices.length; ++i) { + if (devices[i] instanceof MultiDevice) { + throw new IllegalArgumentException( + "The parameter store does not support MultiDevices"); + } if (deviceMap.put(devices[i], i) != null) { throw new IllegalArgumentException("Duplicated devices are not allowed."); } diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java index 92a0474c6e7..63572810875 100644 --- a/api/src/test/java/ai/djl/DeviceTest.java +++ b/api/src/test/java/ai/djl/DeviceTest.java @@ -13,6 +13,7 @@ package ai.djl; +import ai.djl.Device.MultiDevice; import ai.djl.engine.Engine; import org.testng.Assert; @@ -37,6 +38,8 @@ public void testDevice() { System.setProperty("test_key", "test"); Engine.debugEnvironment(); + + Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size()); } @Test @@ -54,5 +57,9 @@ public void testDeviceName() { Device defaultDevice = Engine.getInstance().defaultDevice(); Assert.assertEquals(Device.fromName(""), defaultDevice); Assert.assertEquals(Device.fromName(null), defaultDevice); + + Assert.assertEquals( + Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1))); + Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3)); } }