-
Notifications
You must be signed in to change notification settings - Fork 515
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MetalPerformanceShadersGraph Bindings (#14303)
I'm very pleased to present full bindings to the MetalPerformanceShadersGraph framework! I'm happy with how everything turned out with the exception of a few notes and questions below. I re-implemented Apple's MNIST sample (from https://developer.apple.com/documentation/metalperformanceshadersgraph/training_a_neural_network_using_mps_graph) here: https://gist.github.com/praeclarum/b8077771fb341a1f9c28240113e00425 It's also added as a unit test. Fixes #14286 ### Notes * Although the API says it works on macOS 11, it has bugs and crashes with errors even with Apple’s Swift examples. It’s better on macOS 12. iOS 14 and on is fine. * `MPSGraphSparseStorageType` has terrible names. They match Apple's but I wish they were better. * I added convenience methods to `MPSNDArray` and `MPSGrapTensorData` and the `Variable` and `Constant` operations to decrease the amount of unsafe code users have to write. I currently do this for 32-bit floats, the most common data type. Co-authored-by: Alex Soto <[email protected]> Co-authored-by: Rolf Bjarne Kvinge <[email protected]> Co-authored-by: Manuel de la Pena <[email protected]>
- Loading branch information
1 parent
8cf0231
commit bd4fee0
Showing
26 changed files
with
2,463 additions
and
3,215 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
using System; | ||
using System.Runtime.InteropServices; | ||
|
||
using Foundation; | ||
using ObjCRuntime; | ||
using Metal; | ||
|
||
namespace MetalPerformanceShadersGraph | ||
{ | ||
[Flags] | ||
public enum MPSGraphOptions : ulong | ||
{ | ||
None = 0, | ||
SynchronizeResults = 1, | ||
Verbose = 2, | ||
Default = SynchronizeResults, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphTensorNamedDataLayout : ulong | ||
{ | ||
Nchw = 0, | ||
Nhwc = 1, | ||
Oihw = 2, | ||
Hwio = 3, | ||
Chw = 4, | ||
Hwc = 5, | ||
Hw = 6, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphPaddingStyle : ulong | ||
{ | ||
Explicit = 0, | ||
Valid = 1, | ||
Same = 2, | ||
ExplicitOffset = 3, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphPaddingMode : long | ||
{ | ||
Constant = 0, | ||
Reflect = 1, | ||
Symmetric = 2, | ||
ClampToEdge = 3, | ||
Zero = 4, | ||
Periodic = 5, | ||
AntiPeriodic = 6, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphReductionMode : ulong | ||
{ | ||
Min = 0, | ||
Max = 1, | ||
Sum = 2, | ||
Product = 3, | ||
ArgumentMin = 4, | ||
ArgumentMax = 5, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphResizeMode : ulong | ||
{ | ||
Nearest = 0, | ||
Bilinear = 1, | ||
} | ||
|
||
[Native] | ||
public enum MPSGraphScatterMode : long | ||
{ | ||
Add = 0, | ||
Sub = 1, | ||
Mul = 2, | ||
Div = 3, | ||
Min = 4, | ||
Max = 5, | ||
Set = 6, | ||
} | ||
|
||
public enum MPSGraphDeviceType : uint | ||
{ | ||
Metal = 0, | ||
} | ||
|
||
public enum MPSGraphLossReductionType : ulong | ||
{ | ||
Axis = 0, | ||
Sum = 1, | ||
Mean = 2, | ||
} | ||
|
||
// For COO, indexTensor0 is x index and indexTensor1 is y index | ||
// For CSC, indexTensor0 and indexTensor1 correspond to rowIndex and colStarts respectively. | ||
// For CSR, indexTensor0 and indexTensor1 correspond to colIndex and rowStarts respectively. | ||
public enum MPSGraphSparseStorageType : ulong | ||
{ | ||
Coo = 0, | ||
Csc = 1, | ||
Csr = 2, | ||
} | ||
|
||
public enum MPSGraphRandomDistribution : ulong | ||
{ | ||
Uniform = 0, | ||
Normal = 1, | ||
TruncatedNormal = 2, | ||
} | ||
|
||
public enum MPSGraphRandomNormalSamplingMethod : ulong | ||
{ | ||
InvCdf = 0, | ||
BoxMuller = 1, | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#nullable enable | ||
|
||
using System; | ||
using System.Buffers; | ||
using System.Runtime.InteropServices; | ||
|
||
using Foundation; | ||
using ObjCRuntime; | ||
using Metal; | ||
using MetalPerformanceShaders; | ||
|
||
namespace MetalPerformanceShadersGraph | ||
{ | ||
public static partial class MPSGraphMemoryOps_Extensions | ||
{ | ||
public static unsafe MPSGraphTensor Constant (this MPSGraph graph, float scalar) | ||
{ | ||
return graph.Constant ((double) scalar, new [] { 1 }, MPSDataType.Float32); | ||
} | ||
|
||
public static unsafe MPSGraphTensor Constant (this MPSGraph graph, ReadOnlySpan<float> values, int[] shape) | ||
{ | ||
var length = 1; | ||
for (var i = 0; i < shape.Length; i++) | ||
length *= shape [i]; | ||
if (length != values.Length) | ||
throw new ArgumentException ($"The number of values ({values.Length}) does not match the shape length ({length})."); | ||
fixed (float* p = values) { | ||
using var data = NSData.FromBytesNoCopy ((IntPtr) p, (nuint) (values.Length * 4), freeWhenDone: false); | ||
return graph.Constant (data, shape, MPSDataType.Float32); | ||
} | ||
} | ||
|
||
public static MPSGraphTensor Variable (this MPSGraph graph, float initialValue, int[] shape, string? name = null) | ||
{ | ||
var length = 1; | ||
for (var i = 0; i < shape.Length; i++) | ||
length *= shape [i]; | ||
var pool = ArrayPool<float>.Shared; | ||
var a = pool.Rent (length); | ||
Array.Fill (a, initialValue); | ||
var v = Variable (graph, a, shape, name); | ||
pool.Return (a); | ||
return v; | ||
} | ||
|
||
public static unsafe MPSGraphTensor Variable (this MPSGraph graph, ReadOnlySpan<float> initialValues, int[] shape, string? name = null) | ||
{ | ||
var length = 1; | ||
for (var i = 0; i < shape.Length; i++) | ||
length *= shape [i]; | ||
if (length != initialValues.Length) | ||
throw new ArgumentException ($"The number of initial values ({initialValues.Length}) does not match the shape length ({length})."); | ||
fixed (float* p = initialValues) { | ||
using var data = NSData.FromBytesNoCopy ((IntPtr) p, (nuint) (initialValues.Length * 4), freeWhenDone: false); | ||
return graph.Variable (data, shape, MPSDataType.Float32, name); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#nullable enable | ||
|
||
using System; | ||
using System.Buffers; | ||
using System.Runtime.InteropServices; | ||
|
||
using Foundation; | ||
using ObjCRuntime; | ||
using Metal; | ||
using MetalPerformanceShaders; | ||
|
||
namespace MetalPerformanceShadersGraph | ||
{ | ||
public partial class MPSGraphTensorData | ||
{ | ||
public static MPSGraphTensorData Create (IMTLDevice device, ReadOnlySpan<float> values, params int[] shape) | ||
{ | ||
var ndarray = MPSNDArray.Create (device, values, shape); | ||
return new MPSGraphTensorData (ndarray); | ||
} | ||
|
||
public static MPSGraphTensorData Create (params MPSImage[] imageBatch) | ||
{ | ||
return new MPSGraphTensorData (NSArray<MPSImage>.FromNSObjects (imageBatch)); | ||
} | ||
|
||
public void Read (Span<float> values) | ||
{ | ||
using var ndarray = this.MPSNDArray; | ||
ndarray.Read (values); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
bd4fee0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❌ [CI Build] Tests failed on VSTS: simulator tests iOS ❌
Tests failed on VSTS: simulator tests iOS.
Test results
32 tests failed, 202 tests passed.
Failed tests
Tests run: 117 Passed: 107 Inconclusive: 0 Failed: 2 Ignored: 8)
Tests run: 117 Passed: 107 Inconclusive: 0 Failed: 1 Ignored: 9)
Tests run: 129 Passed: 119 Inconclusive: 0 Failed: 2 Ignored: 8)
Tests run: 129 Passed: 118 Inconclusive: 0 Failed: 2 Ignored: 9)
Tests run: 9 Passed: 8 Inconclusive: 0 Failed: 1 Ignored: 0)
Tests run: 9 Passed: 8 Inconclusive: 0 Failed: 1 Ignored: 0)
Tests run: 117 Passed: 107 Inconclusive: 0 Failed: 2 Ignored: 8)
Tests run: 117 Passed: 106 Inconclusive: 0 Failed: 2 Ignored: 9)
Tests run: 129 Passed: 120 Inconclusive: 0 Failed: 1 Ignored: 8)
Tests run: 129 Passed: 118 Inconclusive: 0 Failed: 2 Ignored: 9)
Tests run: 20 Passed: 18 Inconclusive: 0 Failed: 1 Ignored: 1)
Tests run: 20 Passed: 18 Inconclusive: 0 Failed: 1 Ignored: 1)
Pipeline on Agent XAMBOT-1098.Monterey
Add MetalPerformanceShadersGraph Bindings (#14303)