forked from explosion/thinc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change adds `AppleOps` to Thinc, to ensure that the AMX unit is always used on Apple Silicon Macs. Before this change, a user would get much worse performance if they forgot to install `thinc-apple-ops`. The `apple_ops` and `_accelerate` modules are built conditionally. When detecting the best CPU implementation, we rely on a `try...except` import to determine whether Apple ops are available. Even though x86_64 Macs do not have an AMX unit, Accelerate is competitive with BLIS, so it does not hurt to enable Apple ops on all Macs.
- Loading branch information
Showing
13 changed files
with
281 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
cdef extern from "Accelerate/Accelerate.h": | ||
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor | ||
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans | ||
enum CBLAS_UPLO: CblasUpper, CblasLower | ||
enum CBLAS_DIAG: CblasNonUnit, CblasUnit | ||
enum CBLAS_SIDE: CblasLeft, CblasRight | ||
|
||
# BLAS level 1 routines | ||
|
||
void cblas_sswap(int M, float *x, int incX, float *y, int incY) nogil | ||
void cblas_sscal(int N, float alpha, float *x, int incX) nogil | ||
void cblas_scopy(int N, float *x, int incX, float *y, int incY) nogil | ||
void cblas_saxpy(int N, float alpha, float *x, int incX, float *y, int incY ) nogil | ||
float cblas_sdot(int N, float *x, int incX, float *y, int incY ) nogil | ||
float cblas_snrm2(int N, float *x, int incX) nogil | ||
float cblas_sasum(int N, float *x, int incX) nogil | ||
int cblas_isamax(int N, float *x, int incX) nogil | ||
|
||
# BLAS level 2 routines | ||
void cblas_sgemv(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, int M, int N, | ||
float alpha, float *A, int lda, float *x, int incX, | ||
float beta, float *y, int incY) nogil | ||
|
||
void cblas_sger(CBLAS_ORDER Order, int M, int N, float alpha, float *x, | ||
int incX, float *y, int incY, float *A, int lda) nogil | ||
|
||
# BLAS level 3 routines | ||
void cblas_sgemm(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, | ||
CBLAS_TRANSPOSE TransB, int M, int N, int K, | ||
float alpha, float *A, int lda, float *B, int ldb, | ||
float beta, float *C, int ldc) nogil | ||
|
||
|
||
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, | ||
float alpha, const float* A, int lda, const float *B, | ||
int ldb, float beta, float* C, int ldc) nogil | ||
|
||
|
||
cdef void saxpy(int N, float alpha, const float* X, int incX, | ||
float *Y, int incY) nogil |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
cimport numpy as np | ||
from libc.stdint cimport uintptr_t | ||
|
||
import numpy | ||
|
||
|
||
cpdef np.ndarray gemm(float[:, ::1] A, float[:, ::1] B, | ||
bint trans1=False, bint trans2=False, | ||
np.ndarray out=None): | ||
cdef int nM = A.shape[0] if not trans1 else A.shape[1] | ||
cdef int nK = A.shape[1] if not trans1 else A.shape[0] | ||
cdef int nK_b = B.shape[0] if not trans2 else B.shape[1] | ||
cdef int nN = B.shape[1] if not trans2 else B.shape[0] | ||
|
||
cdef float[:, ::1] C = out | ||
|
||
if out is None: | ||
out = numpy.empty((nM, nN), dtype="f") | ||
C = out | ||
else: | ||
if C.shape[0] != nM or C.shape[1] != nN: | ||
msg = "Shape mismatch for output matrix, was: (%d, %d), expected (%d, %d)" | ||
raise ValueError(msg % (C.shape[0], C.shape[1], nM, nN)) | ||
|
||
|
||
if nK != nK_b: | ||
msg = "Shape mismatch for gemm: (%d, %d), (%d, %d)" | ||
raise ValueError(msg % (nM, nK, nK_b, nN)) | ||
|
||
if nM == 0 or nK == 0 or nN == 0: | ||
return out | ||
|
||
cblas_sgemm( | ||
CblasRowMajor, | ||
CblasTrans if trans1 else CblasNoTrans, | ||
CblasTrans if trans2 else CblasNoTrans, | ||
nM, | ||
nN, | ||
nK, | ||
1.0, | ||
&A[0, 0], | ||
A.shape[1], | ||
&B[0, 0], | ||
B.shape[1], | ||
0.0, | ||
&C[0, 0], | ||
C.shape[1] | ||
) | ||
return out | ||
|
||
|
||
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, | ||
float alpha, const float* A, int lda, const float *B, | ||
int ldb, float beta, float* C, int ldc) nogil: | ||
cblas_sgemm( | ||
CblasRowMajor, | ||
CblasTrans if TransA else CblasNoTrans, | ||
CblasTrans if TransB else CblasNoTrans, | ||
M, | ||
N, | ||
K, | ||
alpha, | ||
A, | ||
lda, | ||
B, | ||
ldb, | ||
beta, | ||
C, | ||
ldc | ||
) | ||
|
||
|
||
cdef void saxpy(int N, float alpha, const float* X, int incX, | ||
float *Y, int incY) nogil: | ||
cblas_saxpy(N, alpha, X, incX, Y, incY) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Optional | ||
|
||
import numpy | ||
|
||
from ._accelerate import gemm | ||
|
||
from ._accelerate cimport saxpy, sgemm | ||
from .cblas cimport CBlas, set_saxpy, set_sgemm | ||
|
||
from .. import registry | ||
from ..types import Floats2d | ||
from .numpy_ops import NumpyOps | ||
|
||
|
||
@registry.ops("AppleOps") | ||
class AppleOps(NumpyOps): | ||
"""Thinc Ops class that calls into Apple's native libraries for some | ||
operations. Other operations fall back to numpy.""" | ||
name = "apple" | ||
xp = numpy | ||
|
||
def cblas(self) -> CBlas: | ||
cdef CBlas cblas = CBlas() | ||
set_saxpy(cblas, saxpy) | ||
set_sgemm(cblas, sgemm) | ||
return cblas | ||
|
||
def gemm( | ||
self, | ||
x: Floats2d, | ||
y: Floats2d, | ||
out: Optional[Floats2d] = None, | ||
trans1: bool = False, | ||
trans2: bool = False, | ||
) -> Floats2d: | ||
"""Perform General Matrix Multiplication (GeMM) and optionally store | ||
the result in the specified output variable. | ||
""" | ||
return gemm(x, y, out=out, trans1=trans1, trans2=trans2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy | ||
import pytest | ||
|
||
from thinc.compat import has_apple_ops | ||
|
||
try: | ||
import thinc.backends._accelerate as accelerate | ||
except: | ||
pass | ||
|
||
|
||
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") | ||
def test_basic_sgemm(): | ||
A = numpy.random.randn(5, 4).astype("f") | ||
B = numpy.random.randn(4, 7).astype("f") | ||
C = accelerate.gemm(A, B) | ||
assert C.shape == (A.shape[0], B.shape[1]) | ||
|
||
C_out = numpy.empty((5, 7), dtype="f") | ||
accelerate.gemm(A, B, out=C_out) | ||
|
||
numpy.testing.assert_allclose(C, C_out) | ||
|
||
|
||
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") | ||
def test_incorrect_output_size(): | ||
A = numpy.ndarray((5, 4), dtype="f") | ||
B = numpy.ndarray((4, 7), dtype="f") | ||
|
||
with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"): | ||
accelerate.gemm(A, B, out=numpy.ndarray((3, 7), dtype="f")) | ||
|
||
with pytest.raises(ValueError, match=r"Shape mismatch for output matrix"): | ||
accelerate.gemm(A, B, out=numpy.ndarray((5, 3), dtype="f")) | ||
|
||
|
||
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") | ||
@pytest.mark.parametrize( | ||
"A_shape,B_shape,transA,transB", | ||
[ | ||
[(0, 0), (0, 0), False, False], | ||
[(0, 0), (0, 0), True, False], | ||
[(0, 0), (0, 0), False, True], | ||
[(0, 0), (0, 0), True, True], | ||
[(0, 5), (5, 0), False, False], | ||
[(5, 0), (5, 0), False, True], | ||
[(5, 0), (5, 0), True, False], | ||
], | ||
) | ||
def test_zero_size(A_shape, B_shape, transA, transB): | ||
A = numpy.ndarray(A_shape, dtype="f") | ||
B = numpy.ndarray(B_shape, dtype="f") | ||
if not transA and not transB: | ||
C = numpy.dot(A, B) | ||
elif transA: | ||
C = numpy.dot(A.T, B) | ||
elif transB: | ||
C = numpy.dot(A, B.T) | ||
else: | ||
C = numpy.dot(A.T, B.T) | ||
C_ = accelerate.gemm(A, B, trans1=transA, trans2=transB) | ||
assert C.shape == C_.shape | ||
|
||
|
||
@pytest.mark.skipif(not has_apple_ops, reason="Apple ops not available") | ||
@pytest.mark.parametrize( | ||
"A_shape,B_shape,transA,transB", | ||
[ | ||
[(4, 5), (4, 5), False, False], | ||
[(5, 4), (4, 5), True, False], | ||
[(4, 5), (5, 4), False, True], | ||
[(5, 4), (5, 4), True, True], | ||
], | ||
) | ||
def test_incorrect_shapes(A_shape, B_shape, transA, transB): | ||
A = numpy.ndarray(A_shape, dtype="f") | ||
B = numpy.ndarray(B_shape, dtype="f") | ||
with pytest.raises(ValueError, match=r"Shape mismatch"): | ||
accelerate.gemm(A, B, trans1=transA, trans2=transB) |
Oops, something went wrong.