forked from explosion/thinc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
168 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
cdef extern from "Accelerate/Accelerate.h": | ||
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor | ||
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans | ||
enum CBLAS_UPLO: CblasUpper, CblasLower | ||
enum CBLAS_DIAG: CblasNonUnit, CblasUnit | ||
enum CBLAS_SIDE: CblasLeft, CblasRight | ||
|
||
# BLAS level 1 routines | ||
|
||
void cblas_sswap(int M, float *x, int incX, float *y, int incY) nogil | ||
void cblas_sscal(int N, float alpha, float *x, int incX) nogil | ||
void cblas_scopy(int N, float *x, int incX, float *y, int incY) nogil | ||
void cblas_saxpy(int N, float alpha, float *x, int incX, float *y, int incY ) nogil | ||
float cblas_sdot(int N, float *x, int incX, float *y, int incY ) nogil | ||
float cblas_snrm2(int N, float *x, int incX) nogil | ||
float cblas_sasum(int N, float *x, int incX) nogil | ||
int cblas_isamax(int N, float *x, int incX) nogil | ||
|
||
# BLAS level 2 routines | ||
void cblas_sgemv(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, int M, int N, | ||
float alpha, float *A, int lda, float *x, int incX, | ||
float beta, float *y, int incY) nogil | ||
|
||
void cblas_sger(CBLAS_ORDER Order, int M, int N, float alpha, float *x, | ||
int incX, float *y, int incY, float *A, int lda) nogil | ||
|
||
# BLAS level 3 routines | ||
void cblas_sgemm(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, | ||
CBLAS_TRANSPOSE TransB, int M, int N, int K, | ||
float alpha, float *A, int lda, float *B, int ldb, | ||
float beta, float *C, int ldc) nogil | ||
|
||
|
||
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, | ||
float alpha, const float* A, int lda, const float *B, | ||
int ldb, float beta, float* C, int ldc) nogil | ||
|
||
|
||
cdef void saxpy(int N, float alpha, const float* X, int incX, | ||
float *Y, int incY) nogil |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from libc.stdint cimport uintptr_t | ||
cimport numpy as np | ||
import numpy | ||
|
||
|
||
cpdef np.ndarray gemm(float[:, ::1] A, float[:, ::1] B, | ||
bint trans1=False, bint trans2=False, | ||
np.ndarray out=None): | ||
cdef int nM = A.shape[0] if not trans1 else A.shape[1] | ||
cdef int nK = A.shape[1] if not trans1 else A.shape[0] | ||
cdef int nK_b = B.shape[0] if not trans2 else B.shape[1] | ||
cdef int nN = B.shape[1] if not trans2 else B.shape[0] | ||
|
||
cdef float[:, ::1] C = out | ||
|
||
if out is None: | ||
out = numpy.empty((nM, nN), dtype="f") | ||
C = out | ||
else: | ||
if C.shape[0] != nM or C.shape[1] != nN: | ||
msg = "Shape mismatch for output matrix, was: (%d, %d), expected (%d, %d)" | ||
raise ValueError(msg % (C.shape[0], C.shape[1], nM, nN)) | ||
|
||
|
||
if nK != nK_b: | ||
msg = "Shape mismatch for gemm: (%d, %d), (%d, %d)" | ||
raise ValueError(msg % (nM, nK, nK_b, nN)) | ||
|
||
if nM == 0 or nK == 0 or nN == 0: | ||
return out | ||
|
||
cblas_sgemm( | ||
CblasRowMajor, | ||
CblasTrans if trans1 else CblasNoTrans, | ||
CblasTrans if trans2 else CblasNoTrans, | ||
nM, | ||
nN, | ||
nK, | ||
1.0, | ||
&A[0, 0], | ||
A.shape[1], | ||
&B[0, 0], | ||
B.shape[1], | ||
0.0, | ||
&C[0, 0], | ||
C.shape[1] | ||
) | ||
return out | ||
|
||
|
||
cdef void sgemm(bint TransA, bint TransB, int M, int N, int K, | ||
float alpha, const float* A, int lda, const float *B, | ||
int ldb, float beta, float* C, int ldc) nogil: | ||
cblas_sgemm( | ||
CblasRowMajor, | ||
CblasTrans if TransA else CblasNoTrans, | ||
CblasTrans if TransB else CblasNoTrans, | ||
M, | ||
N, | ||
K, | ||
alpha, | ||
A, | ||
lda, | ||
B, | ||
ldb, | ||
beta, | ||
C, | ||
ldc | ||
) | ||
|
||
|
||
cdef void saxpy(int N, float alpha, const float* X, int incX, | ||
float *Y, int incY) nogil: | ||
cblas_saxpy(N, alpha, X, incX, Y, incY) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
import numpy | ||
|
||
from . import apple_blas | ||
from .apple_blas cimport saxpy, sgemm | ||
from .cblas cimport CBlas, set_saxpy, set_sgemm | ||
from .numpy_ops import NumpyOps | ||
from ..types import Floats2d | ||
from .. import registry | ||
|
||
|
||
@registry.ops("AppleOps") | ||
class AppleOps(NumpyOps): | ||
"""Thinc Ops class that calls into Apple's native libraries for some | ||
operations. Other operations fall back to numpy.""" | ||
name = "apple" | ||
xp = numpy | ||
|
||
def cblas(self) -> CBlas: | ||
cdef CBlas cblas = CBlas() | ||
set_saxpy(cblas, saxpy) | ||
set_sgemm(cblas, sgemm) | ||
return cblas | ||
|
||
def gemm( | ||
self, | ||
x: Floats2d, | ||
y: Floats2d, | ||
out: Optional[Floats2d] = None, | ||
trans1: bool = False, | ||
trans2: bool = False, | ||
) -> Floats2d: | ||
"""Perform General Matrix Multiplication (GeMM) and optionally store | ||
the result in the specified output variable. | ||
""" | ||
return apple_blas.gemm(x, y, out=out, trans1=trans1, trans2=trans2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters