-
Notifications
You must be signed in to change notification settings - Fork 668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Apples' Metal Performance Shaders (MPS) in pytorch #2037
Conversation
api/src/main/java/ai/djl/Device.java
Outdated
@@ -36,6 +36,7 @@ public final class Device { | |||
|
|||
private static final Device CPU = new Device(Type.CPU, -1); | |||
private static final Device GPU = Device.of(Type.GPU, 0); | |||
private static final Device MPS = Device.of(Type.MPS, -1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS is PyTorch specific, for the time being, it's better keep it PyTorch only. We can add to Device
class when it become standard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not too sure how to avoid this part "cleanly" without having to rewrite a bunch of code in pytorche's JniUtils.java, as most of the functions there use a Device
object to infer the PtDeviceType
:
djl/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java
Lines 191 to 202 in 8fce4eb
public static PtNDArray createZerosNdArray( | |
PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { | |
int layoutVal = layoutMapper(fmt, device); | |
return new PtNDArray( | |
manager, | |
PyTorchLibrary.LIB.torchZeros( | |
shape.getShape(), | |
dType.ordinal(), | |
layoutVal, | |
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, | |
false)); | |
} |
One way is to change the PtDeviceType:: toDeviceType(Device device)
to somehow return the code for MPS device on osx-aarm64 systems, but I don't know how to do that:
djl/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtDeviceType.java
Lines 28 to 38 in 8fce4eb
public static int toDeviceType(Device device) { | |
String deviceType = device.getDeviceType(); | |
if (Device.Type.CPU.equals(deviceType)) { | |
return 0; | |
} else if (Device.Type.GPU.equals(deviceType)) { | |
return 1; | |
} else { | |
throw new IllegalArgumentException("Unsupported device: " + device.toString()); | |
} | |
} |
api/src/main/java/ai/djl/Device.java
Outdated
/** Contains device type string constants. */ | ||
public interface Type { | ||
String CPU = "cpu"; | ||
String GPU = "gpu"; | ||
String MPS = "mps"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need make changes in Device
class. Let's keep it private in PyTorch for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this might create some false expectations form users that Apple MPS is supported for all the engines, but the current solution means that apple silicone users don't need to rewrite any of their existing code. Perhaps adding a documentation that MPS will only work for pytorch?
The users can simply specify "mps" as their device type in their model Criteria when running on Mac M1/M2 and get the acceleration working. Otherwise the developers would need to add pytorch and OSX-specific code to support "mps", making it less likely to be used. The device type, on the other hand, is usually just a system setting / cli argument to ensure let's say the GPU acceleration can be switched on/off.
Codecov ReportBase: 72.08% // Head: 69.53% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #2037 +/- ##
============================================
- Coverage 72.08% 69.53% -2.56%
- Complexity 5126 5953 +827
============================================
Files 473 597 +124
Lines 21970 26498 +4528
Branches 2351 2880 +529
============================================
+ Hits 15838 18426 +2588
- Misses 4925 6687 +1762
- Partials 1207 1385 +178
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Description
Pytorch 1.12 supports Apple’s Metal Performance Shaders (MPS) for accelerated training/inference. A simple test on HuggingFace models shows 2x-3x improvements on the inference latency over the CPU cores on an Apple M1Pro with 8 CPU and 14 GPU cores:
This improvement has been requested by me in the DJL help slack channel, as well as more recently in #2018 .
Brief description of what this PR is about
Some caveats:
torch::jit::load(path, device, map);
is called withdevice= torch::nullopt
as in torch the model gets deserialized in 'legacy' mode, which only support CPU and GPU devices. The model get desiralized on the CPU and then can be converted to MPS. For this, in traced torch model's directory in the file "serving.properties" theoption.mapLocation=false
should be set, which triggers this implementation already present in DJL native module.Alternatively, the the following check can be added the pytorch-native module:
djl/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
Lines 49 to 52 in f097502
PYTORCH_ENABLE_MPS_FALLBACK=1
as a work around a pytorch implementation deficiency of 'aten::index.Tensor' for MPS, see General MPS op coverage tracking issue pytorch/pytorch#77764