Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Apples' Metal Performance Shaders (MPS) in pytorch #2037

Merged
merged 3 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import java.util.regex.Pattern;

/**
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
* NDArray}.
* The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the
* {@code NDArray}.
*
* <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with
* <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU/MPS with
* deviceType and deviceId provided.
*
* @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter
Expand All @@ -36,6 +36,7 @@ public final class Device {

private static final Device CPU = new Device(Type.CPU, -1);
private static final Device GPU = Device.of(Type.GPU, 0);
private static final Device MPS = Device.of(Type.MPS, -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS is PyTorch specific, for the time being, it's better keep it PyTorch only. We can add to Device class when it become standard.

Copy link
Contributor Author

@demq demq Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure how to avoid this part "cleanly" without having to rewrite a bunch of code in pytorche's JniUtils.java, as most of the functions there use a Device object to infer the PtDeviceType:

public static PtNDArray createZerosNdArray(
PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) {
int layoutVal = layoutMapper(fmt, device);
return new PtNDArray(
manager,
PyTorchLibrary.LIB.torchZeros(
shape.getShape(),
dType.ordinal(),
layoutVal,
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
false));
}

One way is to change the PtDeviceType:: toDeviceType(Device device) to somehow return the code for MPS device on osx-aarm64 systems, but I don't know how to do that:

public static int toDeviceType(Device device) {
String deviceType = device.getDeviceType();
if (Device.Type.CPU.equals(deviceType)) {
return 0;
} else if (Device.Type.GPU.equals(deviceType)) {
return 1;
} else {
throw new IllegalArgumentException("Unsupported device: " + device.toString());
}
}


private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");

Expand All @@ -45,7 +46,7 @@ public final class Device {
/**
* Creates a {@code Device} with basic information.
*
* @param deviceType the device type, typically CPU or GPU
* @param deviceType the device type, typically CPU, GPU, or MPS
* @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can
* choose which GPU to process the NDArray
*/
Expand All @@ -57,7 +58,7 @@ private Device(String deviceType, int deviceId) {
/**
* Returns a {@code Device} with device type and device id.
*
* @param deviceType the device type, typically CPU or GPU
* @param deviceType the device type, typically CPU, GPU, or MPS
* @param deviceId the deviceId on the hardware.
* @return a {@code Device} instance
*/
Expand All @@ -83,7 +84,7 @@ public static Device fromName(String deviceName) {
/**
* Parses a deviceName string into a device.
*
* <p>The main format of a device name string is "cpu", "gpu0", or "nc1". This is simply
* <p>The main format of a device name string is "cpu", "gpu0","mps", or "nc1". This is simply
* deviceType concatenated with the deviceId. If no deviceId is used, -1 will be assumed.
*
* <p>There are also several simplified formats. The "-1", deviceNames corresponds to cpu.
Expand Down Expand Up @@ -150,6 +151,15 @@ public boolean isGpu() {
return Type.GPU.equals(deviceType);
}

/**
* Returns if the {@code Device} is MPS.
*
* @return if the {@code Device} is MPS.
*/
public boolean isMps() {
return Type.MPS.equals(deviceType);
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand Down Expand Up @@ -209,9 +219,19 @@ public static Device gpu(int deviceId) {
return of(Type.GPU, deviceId);
}

/**
* Returns the default Metal Performance Shaders (MPS) Device.
*
* @return the default MPS Device
*/
public static Device mps() {
return MPS;
}

/** Contains device type string constants. */
public interface Type {
String CPU = "cpu";
String GPU = "gpu";
String MPS = "mps";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need make changes in Device class. Let's keep it private in PyTorch for now

Copy link
Contributor Author

@demq demq Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this might create some false expectations form users that Apple MPS is supported for all the engines, but the current solution means that apple silicone users don't need to rewrite any of their existing code. Perhaps adding a documentation that MPS will only work for pytorch?

The users can simply specify "mps" as their device type in their model Criteria when running on Mac M1/M2 and get the acceleration working. Otherwise the developers would need to add pytorch and OSX-specific code to support "mps", making it less likely to be used. The device type, on the other hand, is usually just a system setting / cli argument to ensure let's say the GPU acceleration can be switched on/off.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public static int toDeviceType(Device device) {
return 0;
} else if (Device.Type.GPU.equals(deviceType)) {
return 1;
} else if (Device.Type.MPS.equals(deviceType)) {
return 13;
} else {
throw new IllegalArgumentException("Unsupported device: " + device.toString());
}
Expand All @@ -49,6 +51,8 @@ public static String fromDeviceType(int deviceType) {
return Device.Type.CPU;
case 1:
return Device.Type.GPU;
case 13:
return Device.Type.MPS;
default:
throw new IllegalArgumentException("Unsupported deviceType: " + deviceType);
}
Expand Down