-
Notifications
You must be signed in to change notification settings - Fork 668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Apples' Metal Performance Shaders (MPS) in pytorch #2037
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,10 +21,10 @@ | |
import java.util.regex.Pattern; | ||
|
||
/** | ||
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code | ||
* NDArray}. | ||
* The {@code Device} class provides the specified assignment for CPU/GPU/MPS processing on the | ||
* {@code NDArray}. | ||
* | ||
* <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with | ||
* <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU/MPS with | ||
* deviceType and deviceId provided. | ||
* | ||
* @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter | ||
|
@@ -36,6 +36,7 @@ public final class Device { | |
|
||
private static final Device CPU = new Device(Type.CPU, -1); | ||
private static final Device GPU = Device.of(Type.GPU, 0); | ||
private static final Device MPS = Device.of(Type.MPS, -1); | ||
|
||
private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)"); | ||
|
||
|
@@ -45,7 +46,7 @@ public final class Device { | |
/** | ||
* Creates a {@code Device} with basic information. | ||
* | ||
* @param deviceType the device type, typically CPU or GPU | ||
* @param deviceType the device type, typically CPU, GPU, or MPS | ||
* @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can | ||
* choose which GPU to process the NDArray | ||
*/ | ||
|
@@ -57,7 +58,7 @@ private Device(String deviceType, int deviceId) { | |
/** | ||
* Returns a {@code Device} with device type and device id. | ||
* | ||
* @param deviceType the device type, typically CPU or GPU | ||
* @param deviceType the device type, typically CPU, GPU, or MPS | ||
* @param deviceId the deviceId on the hardware. | ||
* @return a {@code Device} instance | ||
*/ | ||
|
@@ -83,7 +84,7 @@ public static Device fromName(String deviceName) { | |
/** | ||
* Parses a deviceName string into a device. | ||
* | ||
* <p>The main format of a device name string is "cpu", "gpu0", or "nc1". This is simply | ||
* <p>The main format of a device name string is "cpu", "gpu0","mps", or "nc1". This is simply | ||
* deviceType concatenated with the deviceId. If no deviceId is used, -1 will be assumed. | ||
* | ||
* <p>There are also several simplified formats. The "-1", deviceNames corresponds to cpu. | ||
|
@@ -150,6 +151,15 @@ public boolean isGpu() { | |
return Type.GPU.equals(deviceType); | ||
} | ||
|
||
/** | ||
* Returns if the {@code Device} is MPS. | ||
* | ||
* @return if the {@code Device} is MPS. | ||
*/ | ||
public boolean isMps() { | ||
return Type.MPS.equals(deviceType); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public String toString() { | ||
|
@@ -209,9 +219,19 @@ public static Device gpu(int deviceId) { | |
return of(Type.GPU, deviceId); | ||
} | ||
|
||
/** | ||
* Returns the default Metal Performance Shaders (MPS) Device. | ||
* | ||
* @return the default MPS Device | ||
*/ | ||
public static Device mps() { | ||
return MPS; | ||
} | ||
|
||
/** Contains device type string constants. */ | ||
public interface Type { | ||
String CPU = "cpu"; | ||
String GPU = "gpu"; | ||
String MPS = "mps"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need make changes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that this might create some false expectations form users that Apple MPS is supported for all the engines, but the current solution means that apple silicone users don't need to rewrite any of their existing code. Perhaps adding a documentation that MPS will only work for pytorch? The users can simply specify "mps" as their device type in their model Criteria when running on Mac M1/M2 and get the acceleration working. Otherwise the developers would need to add pytorch and OSX-specific code to support "mps", making it less likely to be used. The device type, on the other hand, is usually just a system setting / cli argument to ensure let's say the GPU acceleration can be switched on/off. |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS is PyTorch specific, for the time being, it's better keep it PyTorch only. We can add to
Device
class when it become standard.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not too sure how to avoid this part "cleanly" without having to rewrite a bunch of code in pytorche's JniUtils.java, as most of the functions there use a
Device
object to infer thePtDeviceType
:djl/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
Lines 191 to 202 in 8fce4eb
One way is to change the
PtDeviceType:: toDeviceType(Device device)
to somehow return the code for MPS device on osx-aarm64 systems, but I don't know how to do that:djl/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java
Lines 28 to 38 in 8fce4eb