Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start cuBLAS backend support. #18

Merged
merged 11 commits into from
Oct 7, 2018
195 changes: 152 additions & 43 deletions src/main/scala/lantern/ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,22 +190,47 @@ trait TensorExp extends Dsl with Diff {
} */

/**
* Defines tensor-specific operations.
* Eventually, a tensor operation IR may be introduced to enable analyses/transformations.
* A code generation backend for tensor operations.
*
* Note: Eventually, a tensor operation IR may be introduced to enable analyses and
* transformations such as operator fusion and matrix chain multiplication optimization.
*/
trait Backend {
def dot(x: Tensor, y: Tensor): Tensor
// TODO: Add more ops.
// Compute vector-vector dot product, i.e. inner product.
// [V] dot [V] => [1] (scalar)
def vectorVectorDot(x: Tensor, y: Tensor): Tensor

// Compute matrix-vector dot product.
// [M1 x M2] dot [M2] => [M1]
def matrixVectorDot(x: Tensor, y: Tensor): Tensor

// Compute matrix-matrix dot product.
// [M1 x M2] dot [M2 x M3] => [M1 x M3]
def matrixMatrixDot(x: Tensor, y: Tensor): Tensor

def dot(x: Tensor, y: Tensor): Tensor =
(x.rank, y.rank) match {
case (1, 1) => vectorVectorDot(x, y)
case (2, 1) => matrixVectorDot(x, y)
case (2, 2) => matrixMatrixDot(x, y)
case _ => throw new IllegalArgumentException(s"Incompatible shapes: ${x.shape}, ${y.shape}")
}

// TODO: Add more ops:
// - Elementwise binary ops (+, -, *, /).
// - GPU backends need to address broadcasting.
// - `BackendCublas` can define addition using `cublasSaxpy`.
// - Conv2d.
// - Activation functions (e.g. relu).
// - Fused multiply add operations?
}

/**
* Native tensor op backend.
* Native tensor operation backend. WIP.
* Tensor ops are defined in terms of primitive operations.
*/
trait BackendNative extends Backend {
// Compute vector-vector dot product, i.e. inner product.
// [V] dot [V] => [1] (scalar)
private def vvdot(x: Tensor, y: Tensor): Tensor = {
class BackendNative extends Backend {
override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = {
assert(x.shape(0) == y.shape(0))
val value = var_new(0.0f)
for (i <- DataLoop(x.shape.last)) {
Expand All @@ -216,9 +241,7 @@ trait TensorExp extends Dsl with Diff {
Tensor(res, 1)
}

// Compute matrix-vector dot product.
// [M1 x M2] dot [M2] => [M1]
private def mvdot(x: Tensor, y: Tensor): Tensor = {
override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = {
assert(x.shape(1) == y.shape(0))
val dim1 = x.shape(0)
val dim2 = x.shape(1)
Expand All @@ -233,9 +256,7 @@ trait TensorExp extends Dsl with Diff {
Tensor(res, dim1)
}

// Compute matrix-matrix dot product.
// [M1 x M2] dot [M2 x M3] => [M1 x M3]
private def mmdot(x: Tensor, y: Tensor): Tensor = {
override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = {
assert(x.shape(1) == y.shape(0))
val dim1 = x.shape(0)
val dim2 = x.shape(1)
Expand All @@ -252,41 +273,129 @@ trait TensorExp extends Dsl with Diff {
}
Tensor(res, dim1, dim3)
}

override def dot(x: Tensor, y: Tensor): Tensor =
(x.rank, y.rank) match {
case (1, 1) => vvdot(x, y)
case (2, 1) => mvdot(x, y)
case (2, 2) => mmdot(x, y)
case _ => throw new IllegalArgumentException(s"Incompatible shapes: ${x.shape}, ${y.shape}")
}
}

/**
* cuBLAS tensor op backend. WIP.
* cuBLAS tensor operation backend. WIP.
*/
trait BackendCUBLAS extends Backend {
// GEMM reference:
class BackendCublas extends Backend {
// Reference:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-dot
def sdot(a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) =
unchecked[Unit]("CUBLAS_CALL(cublasSdot(handle, ", a.length, ",", a, ",1,", b, ",1,", result, "))")

override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = {
val res = NewArray[Float](1)
sdot(x.data, y.data, res)
Tensor(res, 1)
}

// Reference:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemv
def sgemv(m: Int, n: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = {
val zero = NewArray[Float](1); zero(0) = 0
val one = NewArray[Float](1); one(0) = 1
unchecked[Unit](
"CUBLAS_CALL(cublasSgemv(handle, CUBLAS_OP_N, ",
m, ",", n, ",", one, ",",
a, ",", m, ",", b, ",", zero, ",", result, ",", one, "))")
}

override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = {
val m = x.shape(0)
val n = x.shape(1)
val res = NewArray[Float](m)
sgemv(m, n, x.data, y.data, res)
Tensor(res, m)
}

// Reference:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm
//
// cublasStatus_t cublasSgemm(cublasHandle_t handle,
// cublasOperation_t transa, cublasOperation_t transb,
// int m, int n, int k,
// const float *alpha,
// const float *A, int lda,
// const float *B, int ldb,
// const float *beta,
// float *C, int ldc)
def sgemm(a: Array[Float], b: Array[Float], c: Array[Float]) = unchecked[Array[Float]]("cublasSgemm(...)")

override def dot(x: Tensor, y: Tensor): Tensor = ???
def sgemm(m: Int, n: Int, k: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = {
val zero = NewArray[Float](1); zero(0) = 0
val one = NewArray[Float](1); one(0) = 1
unchecked[Unit](
"CUBLAS_CALL(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, ",
m, ",", n, ",", k, ",", one, ",",
a, ",", m, ",", b, ",", k, ",", zero, ",", result, ",", m, "))")
}

override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = {
val m = x.shape(0)
val n = y.shape(1)
val k = y.shape(0)
val res = NewArray[Float](m * n)
sgemm(m, n, k, x.data, y.data, res)
Tensor(res, m, n)
}
}

/**
* Default tensor op backend, extending `BackendNative`.
* cuDNN tensor operation backend. WIP.
*/
class BackendDefault extends BackendNative
val backend: Backend = new BackendDefault
class BackendCudnn extends Backend {
override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = ???
override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = ???
override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = ???
}

// The current backend for code generation.
// To switch code generation to a different backend, simply change this value
// in your DSL program.
var backend: Backend = new BackendNative

/**
* Transfer data between backends.
* @param from The current backend.
* @param to The new backend.
* @param data The data to transfer.
* @tparam T Type of the data.
*/
def transfer[T](from: Backend, to: Backend)(data: T) {
// TODO: Implement logic. `cudaMemcpy` will be involved.
// Consider what to do when unified memory is used (i.e. `cudaMallocMananged`).
(from, to) match {
case (cpu: BackendNative, gpu: BackendCudnn) => ???
case (gpu: BackendCudnn, cpu: BackendNative) => ???
case _ => ???
}
}

/**
* Call a function with given inputs, generating code for the specified backend.
* The inputs and result will be transferred between backends automatically.
* @param b The new backend.
* @param input The input to the function.
* @param f The function to call, whose code will be generated on the new backend.
* @tparam T The function input type.
* @tparam U The function output type.
*/
def withBackend[T, U](b: Backend, input: T)(f: T => U) = {
val originalBackend = backend

// Transfer input to the new backend.
transfer(originalBackend, b)(input)

// Change the backend (i.e. codegen target), then call `f`.
backend = b
val result = f(input)

// Transfer `result` to the old backend, then reset the backend.
transfer(originalBackend, b)(result)
backend = originalBackend
}

/**
* Call a function with given inputs, generating code for CPU.
* The inputs and result will be transferred between backends automatically.
*/
def withCPU[T, U](input: T)(f: T => U) = withBackend(new BackendNative, input)(f)

/**
* Call a function with given inputs, generating code for GPU (via `BackendCudnn`).
* The inputs and result will be transferred between backends automatically.
*/
def withGPU[T, U](input: T)(f: T => U) = withBackend(new BackendCudnn, input)(f)

class Tensor(val data: Rep[Array[Float]], val dimensions: NSeq[Int]) extends Serializable {

Expand Down Expand Up @@ -500,8 +609,8 @@ trait TensorExp extends Dsl with Diff {

@virtualize
def sum2D(dim: Int) = {
assert (this.rank == 2, "Only deal with 2D tensor")
assert (dim == 0 || dim == 1, "dim must be in range of this.nbDims")
assert(this.rank == 2, "Only deal with 2D tensor")
assert(dim == 0 || dim == 1, "dim must be in range of this.nbDims")

if (dim == 0) ???
else {
Expand Down
Loading