Skip to content

Commit

Permalink
Let BackendCudnn extend BackendCublas.
Browse files Browse the repository at this point in the history
Rationale here: feiwang3311#8 (comment)

- Move GPU test utilities to `LanternFunSuite`.
- Improve CUDA/cuBLAS/cuDNN error messages.
  - Example: "cuBLAS error occurred: 7 (lantern-snippet.cpp:150)"
- Add cuDNN test.
  • Loading branch information
dan-zheng committed Oct 7, 2018
1 parent 001cc0a commit 02d773d
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 51 deletions.
26 changes: 15 additions & 11 deletions src/main/scala/lantern/ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ trait TensorExp extends Dsl with Diff {
* cuBLAS tensor operation backend. WIP.
*/
class BackendCublas extends Backend {
override def setup(): Unit = generateRawCode("cublasHandle_t handle;\nCUBLAS_CALL(cublasCreate(&handle));")
override def cleanup(): Unit = generateRawCode("CUBLAS_CALL(cublasDestroy(handle));")
override def setup(): Unit = generateRawCode("cublasHandle_t cublasHandle;\nCUBLAS_CALL(cublasCreate(&cublasHandle));")
override def cleanup(): Unit = generateRawCode("CUBLAS_CALL(cublasDestroy(cublasHandle));")

// Reference:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-dot
def sdot(a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) =
unchecked[Unit]("CUBLAS_CALL(cublasSdot(handle, ", a.length, ",", a, ",1,", b, ",1,", result, "))")
unchecked[Unit]("CUBLAS_CALL(cublasSdot(cublasHandle, ", a.length, ",", a, ",1,", b, ",1,", result, "))")

override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = {
val res = NewArray[Float](1)
Expand All @@ -316,7 +316,7 @@ trait TensorExp extends Dsl with Diff {
val zero = NewArray[Float](1); zero(0) = 0
val one = NewArray[Float](1); one(0) = 1
unchecked[Unit](
"CUBLAS_CALL(cublasSgemv(handle, CUBLAS_OP_N, ",
"CUBLAS_CALL(cublasSgemv(cublasHandle, CUBLAS_OP_N, ",
m, ",", n, ",", one, ",",
a, ",", m, ",", b, ",", zero, ",", result, ",", one, "))")
}
Expand All @@ -335,7 +335,7 @@ trait TensorExp extends Dsl with Diff {
val zero = NewArray[Float](1); zero(0) = 0
val one = NewArray[Float](1); one(0) = 1
unchecked[Unit](
"CUBLAS_CALL(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, ",
"CUBLAS_CALL(cublasSgemm(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, ",
m, ",", n, ",", k, ",", one, ",",
a, ",", m, ",", b, ",", k, ",", zero, ",", result, ",", m, "))")
}
Expand All @@ -352,14 +352,18 @@ trait TensorExp extends Dsl with Diff {

/**
* cuDNN tensor operation backend. WIP.
* Extends `BackendCublas` to leverage cuBLAS primitives.
*/
class BackendCudnn extends Backend {
override def setup(): Unit = generateRawCode("cudnnHandle_t handle;\nCUDNN_CALL(cudnnCreate(&handle));")
override def cleanup(): Unit = generateRawCode("CUDNN_CALL(cudnnDestroy(handle));")
class BackendCudnn extends BackendCublas {
override def setup(): Unit = {
super.setup()
generateRawCode("cudnnHandle_t cudnnHandle;\nCUDNN_CALL(cudnnCreate(&cudnnHandle));")
}

override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = ???
override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = ???
override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = ???
override def cleanup(): Unit = {
super.cleanup()
generateRawCode("CUDNN_CALL(cudnnDestroy(cudnnHandle));")
}
}

// The current backend for code generation.
Expand Down
37 changes: 17 additions & 20 deletions src/main/scala/lantern/dslapi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,46 +486,43 @@ trait DslGenCublas extends DslGenBase {
case _ => super.emitNode(sym,rhs)
}

override def templateHeaders: NSeq[String] = super.templateHeaders ++ NSeq("<cuda_runtime.h>", "\"cublas_v2.h\"")
override def templateHeaders: NSeq[String] =
super.templateHeaders ++ NSeq("<cuda.h>", "<cuda_runtime.h>", "<cublas_v2.h>")

override def templateRawCode: String = super.templateRawCode + """
#define CUDA_CALL(f) { \
cudaError_t err = (f); \
if (err != cudaSuccess) { \
std::cerr << "Error occurred: " << err << std::endl; \
std::exit(1); \
fprintf(stderr, "CUDA error occurred: %s (%s:%d)\n", \
cudaGetErrorString(err), __FILE__, __LINE__); \
exit(err); \
} \
}
#define CUBLAS_CALL(f) { \
cublasStatus_t stat = (f); \
if (stat != CUBLAS_STATUS_SUCCESS) { \
std::cerr << "Error occurred: " << err << std::endl; \
exit(1); \
fprintf(stderr, "cuBLAS error occurred: %d (%s:%d)\n", \
stat, __FILE__, __LINE__); \
exit(stat); \
} \
}
"""
}

@virtualize
trait DslGenCudnn extends DslGenBase {
trait DslGenCudnn extends DslGenCublas {
val IR: DslExp
import IR._

override def templateHeaders: NSeq[String] = super.templateHeaders ++ NSeq("<cuda.h>", "<cudnn.h>")
override def templateRawCode: String = super.templateRawCode + """
#define CUDA_CALL(f) { \
cudaError_t err = (f); \
if (err != cudaSuccess) { \
std::cerr << "Error occurred: " << err << std::endl; \
std::exit(1); \
} \
}
#define CUDNN_CALL(f) { \
cudnnStatus_t err = (f); \
if (err != CUDNN_STATUS_SUCCESS) { \
std::cerr << "Error occurred: " << err << std::endl; \
std::exit(1); \
cudnnStatus_t stat = (f); \
if (stat != CUDNN_STATUS_SUCCESS) { \
fprintf(stderr, "cuDNN error occurred: %d (%s:%d)\n", \
stat, __FILE__, __LINE__); \
exit(stat); \
} \
}
"""
Expand Down Expand Up @@ -627,8 +624,8 @@ abstract class DslDriverCudnn[A: Manifest, B: Manifest] extends DslDriverBase[A,

new java.io.File(binaryFileName).delete
import scala.sys.process._
System.out.println("Compile C++ (cuDNN) code")
(s"nvcc -std=c++11 -O1 $cppFileName -o $binaryFileName -lcudnn": ProcessBuilder).lines.foreach(System.out.println) //-std=c99
System.out.println("Compile C++ (cuBLAS & cuDNN) code")
(s"nvcc -std=c++11 -O1 $cppFileName -o $binaryFileName -lcublas -lcudnn": ProcessBuilder).lines.foreach(System.out.println) //-std=c99
System.out.println("Run C++ code")
(s"$binaryFileName $a": ProcessBuilder).lines.foreach(System.out.println)
}
Expand Down
17 changes: 16 additions & 1 deletion src/test/scala/lantern/LanternFunSuite.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
package lantern

import org.scalatest.FunSuite
import org.scalactic.source
import org.scalatest.{FunSuite, Tag}

class LanternFunSuite extends FunSuite {
def runTest(driver: LanternDriver[String, Unit]) {
driver.eval("dummy")
}

// TODO: Edit this function to actually detect whether GPU codegen is possible.
// One idea: check for:
// - The existence of cuBLAS header files (<cuda_runtime.h>, <cublas_v2.h>).
// - The existence of a GPU (perhaps run `nvidia-smi`).
def isGPUAvailable = false

// Utility function wrapping `test` that checks whether GPU is available.
def testGPU(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */)(implicit pos: source.Position) {
if (isGPUAvailable)
test(testName, testTags: _*)(testFun)(pos)
else
ignore(testName, testTags: _*)(testFun)(pos)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,9 @@ package lantern
import org.scala_lang.virtualized.virtualize
import org.scala_lang.virtualized.SourceContext

import org.scalactic.source
import org.scalatest.{FunSuite, Tag}

import scala.collection.mutable.ArrayBuffer
import scala.collection.{Seq => NSeq}
import java.io.{File, PrintWriter}

class CublasTest extends LanternFunSuite {
// TODO: Edit this function to actually detect whether GPU codegen is possible.
// One idea: check for:
// - The existence of cuBLAS header files (<cuda_runtime.h>, "cublas_v2.h").
// - The existence of a GPU (perhaps run `nvidia-smi`).
def isGPUAvailable = false

class TestCublas extends LanternFunSuite {
testGPU("vector-vector-dot") {
val vvdot = new LanternDriverCublas[String, Unit] {
backend = new BackendCublas
Expand Down Expand Up @@ -64,11 +53,4 @@ class CublasTest extends LanternFunSuite {
runTest(mmdot)
}

// Utility function wrapping `test` that checks whether GPU is available.
def testGPU(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */)(implicit pos: source.Position) {
if (isGPUAvailable)
test(testName, testTags: _*)(testFun)(pos)
else
ignore(testName, testTags: _*)(testFun)(pos)
}
}
55 changes: 55 additions & 0 deletions src/test/scala/lantern/TestCudnn.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package lantern

import org.scala_lang.virtualized.virtualize
import org.scala_lang.virtualized.SourceContext

import scala.collection.{Seq => NSeq}

class TestCudnn extends LanternFunSuite {
testGPU("vector-vector-dot") {
val vvdot = new LanternDriverCudnn[String, Unit] {
backend = new BackendCudnn

@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
val length = 2
val v1 = Tensor.fromData(NSeq(4), 1, 2, 3, 4)
val v2 = Tensor.fromData(NSeq(4), -1, -2, -3, -4)
val expected = Tensor.fromData(NSeq(1), -30)
Tensor.assertEqual(v1.dot(v2), expected)
}
}
runTest(vvdot)
}

testGPU("matrix-vector-dot") {
val mvdot = new LanternDriverCudnn[String, Unit] {
backend = new BackendCudnn

@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
val m = Tensor.fromData(NSeq(2, 4), 1, 2, 3, 4, 5, 6, 7, 8)
val v = Tensor.fromData(NSeq(4), -1, -2, -3, -4)
val expected = Tensor.fromData(NSeq(2), -30, -70)
Tensor.assertEqual(m.dot(v), expected)
}
}
runTest(mvdot)
}

testGPU("matrix-matrix-dot") {
val mmdot = new LanternDriverCudnn[String, Unit] {
backend = new BackendCudnn

@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
// Note: it's better to test with non-square matrices.
val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6)
val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7)
val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79)
Tensor.assertEqual(m1.dot(m2), expected)
}
}
runTest(mmdot)
}
}

0 comments on commit 02d773d

Please sign in to comment.