From 4da1af0bc445a1055f972992138047592eb455a6 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 4 Oct 2018 20:47:43 -0400 Subject: [PATCH 01/10] Implement matrix-matrix multiplication, clean up dot. --- src/main/scala/lantern/ad_lms_vector.scala | 93 +++++++++++++++---- .../scala/lantern/test_ad_lms_vector.scala | 59 ++++++------ 2 files changed, 104 insertions(+), 48 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index 326c6306..a6a2ab63 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -203,22 +203,63 @@ trait TensorExp extends Dsl with Diff { * Tensor ops are defined in terms of primitive operations. */ trait BackendNative extends Backend { - override def dot(x: Tensor, y: Tensor): Tensor = { - // TODO: (Fei Wang): only support 2D dot 1D and 1D dot 1D - val off = var_new(0) - val up = if (x.rank > 1) x.shape(0) else 1 - val res = NewArray[Float](up) - for (j <- DataLoop(up)) { + // Compute vector-vector dot product, i.e. inner product. + // [V1] dot [V2] => [1] (scalar) + private def vvdot(x: Tensor, y: Tensor): Tensor = { + assert(x.shape(0) == y.shape(0)) + val value = var_new(0.0f) + for (i <- DataLoop(x.shape.last)) { + value += x.data(i) * y.data(i) + } + val res = NewArray[Float](1) + res(0) = readVar(value) + Tensor(res, 1) + } + + // Compute matrix-vector dot product. + // [M1 x M2] dot [M2] => [M1] + private def mvdot(x: Tensor, y: Tensor): Tensor = { + assert(x.shape(1) == y.shape(0)) + val dim1 = x.shape(0) + val dim2 = x.shape(1) + val res = NewArray[Float](dim1) + for (i <- DataLoop(dim1)) { val value = var_new(0.0f) - for (i <- DataLoop(x.shape.last)) { - value += x.data(off) * y.data(i) - off += 1 + for (j <- DataLoop(dim2)) { + value += x.data(i * dim2 + j) * y.data(j) + } + res(i) = readVar(value) + } + Tensor(res, dim1) + } + + // Compute matrix-matrix dot product. + // [M1 x M2] dot [M2 x M3] => [M1 x M3] + private def mmdot(x: Tensor, y: Tensor): Tensor = { + assert(x.shape(1) == y.shape(0)) + val dim1 = x.shape(0) + val dim2 = x.shape(1) + val dim3 = y.shape(1) + val res = NewArray[Float](dim1 * dim3) + for (i <- DataLoop(dim1)) { + for (j <- DataLoop(dim3)) { + val value = var_new(0.0f) + for (k <- DataLoop(dim2)) { + value += x.data(i * dim2 + k) * y.data(k * dim3 + j) + } + res(i * dim3 + j) = readVar(value) } - res(j) = readVar(value) } - val dim = if (x.rank == 1) 1 else x.shape(0) - Tensor(res, dim) + Tensor(res, dim1, dim3) } + + override def dot(x: Tensor, y: Tensor): Tensor = + (x.rank, y.rank) match { + case (1, 1) => vvdot(x, y) + case (2, 1) => mvdot(x, y) + case (2, 2) => mmdot(x, y) + case _ => throw new IllegalArgumentException(s"Incompatible shapes: ${x.shape}, ${y.shape}") + } } /** @@ -396,12 +437,22 @@ trait TensorExp extends Dsl with Diff { for (i <- DataLoop(scalarCount)) this.data(i) = that.data(i) } - // NOTE: only handles (Matrix dot Vector) and (Vector dot Vector) + // `dot` represents the following: + // - vector-vector dot product. + // [V1] dot [V2] => [1] (scalar) + // - matrix-vector multiplication. + // [M1 x M2] dot [M2] => [M1] + // - matrix-matrix multiplication. + // [M1 x M2] dot [M2 x M3] => [M1 x M3] def dot(that: Tensor) = { - // assert that and this have the same dimension - generate_comment(s"dot ${this.shape.seq} - ${that.shape.seq}") - assert(this.rank <= 2 && that.rank == 1, s"Only M x V or V x V allowed ${this.shape} - ${that.shape}") - assert(this.shape.last == that.shape(0), s"dimensions of vector do not match dot! ${this.shape.seq} - ${that.shape.seq}") + generate_comment(s"dot: ${this.shape.seq}, ${that.shape.seq}") + (this.rank, that.rank) match { + case (1, 1) => assert(this.shape(0) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}") + case (2, 1) => assert(this.shape(1) == that.shape(0), s"Incompatible shapes: (${this.shape}, ${that.shape}") + case (2, 2) => assert(this.shape(0) == that.shape(1), s"Incompatible shapes: (${this.shape}, ${that.shape}") + case _ => throw new IllegalArgumentException( + s"Only vector-vector, matrix-vector, and matrix-matrix multiplication are allowed (actual shapes: ${this.shape}, ${that.shape})") + } backend.dot(this, that) } @@ -1625,7 +1676,13 @@ trait TensorExp extends Dsl with Diff { that.d.minus_mult_div_square(this.x, y.d, that.x) } - // vector dot product or Matrix vector dot (viewed as multiple vector dot product) (not the common view) + // `dot` represents the following: + // - vector-vector dot product. + // [V1] dot [V2] => [1] (scalar) + // - matrix-vector multiplication. + // [M1 x M2] dot [M2] => [M1] + // - matrix-matrix multiplication. + // [M1 x M2] dot [M2 x M3] => [M1 x M3] def dot(that: TensorR): TensorR @diff = shift { (k: TensorR => Unit) => val res = x dot that.x val y = TensorR(res); k(y) diff --git a/src/test/scala/lantern/test_ad_lms_vector.scala b/src/test/scala/lantern/test_ad_lms_vector.scala index 1459da4b..384d0801 100644 --- a/src/test/scala/lantern/test_ad_lms_vector.scala +++ b/src/test/scala/lantern/test_ad_lms_vector.scala @@ -42,46 +42,45 @@ class AdLMSVectorTest extends FunSuite { array0.eval("abc") } - test("array1") { - val array1 = new DslDriverC[String, Unit] with TensorExp { - + test("vector-vector-dot") { + val vvdot = new DslDriverC[String, Unit] with TensorExp { @virtualize - def snippet(a: Rep[String]): Rep[Unit] = { + def snippet(x: Rep[String]): Rep[Unit] = { val length = 2 - val res = Tensor.randinit(length) - val res2 = Tensor.randinit(length, seed = Some(5)) - - val result = res dot res2 - - // assertions - if (res.data(0) * res2.data(0) + res.data(1) * res2.data(1) != result.data(0)) - println("ERROR: the dot product of two vectors is not correct") + val v1 = Tensor.fromData(NSeq(4), 1, 2, 3, 4) + val v2 = Tensor.fromData(NSeq(4), -1, -2, -3, -4) + val expected = Tensor.fromData(NSeq(1), -30) + Tensor.assertEqual(v1.dot(v2), expected) } } - array1.eval("abc") + runTest(vvdot) } - test("array1_1") { - val array1_1 = new DslDriverC[String, Unit] with TensorExp { - + test("matrix-vector-dot") { + val mvdot = new DslDriverC[String, Unit] with TensorExp { @virtualize - def snippet(a: Rep[String]): Rep[Unit] = { - val dim0 = 2 - val dim1 = 3 - val matrix = Tensor.rand(dim0, dim1) - val vector = Tensor.randinit(dim1, seed = Some(4)) - - //println("the result is:") - val result = matrix dot vector - //result.print() + def snippet(x: Rep[String]): Rep[Unit] = { + val m = Tensor.fromData(NSeq(2, 4), 1, 2, 3, 4, 5, 6, 7, 8) + val v = Tensor.fromData(NSeq(4), -1, -2, -3, -4) + val expected = Tensor.fromData(NSeq(2), -30, -70) + Tensor.assertEqual(m.dot(v), expected) + } + } + runTest(mvdot) + } - if (matrix(0, 0) * vector(0) + matrix(0, 1) * vector(1) + matrix(0, 2) * vector(2) != result(0)) - printf("ERROR: the matrix vector dot product is wrong on the first element of result, %.3f != %.3f\\n", matrix(0, 0) * vector(0) + matrix(0, 1) * vector(1) + matrix(0, 2) * vector(2), result(0)) - if (matrix(1, 0) * vector(0) + matrix(1, 1) * vector(1) + matrix(1, 2) * vector(2) != result(1)) - printf("ERROR: the matrix vector dot product is wrong on the second element of result, %.3f != %.3f\\n", matrix(1, 0) * vector(0) + matrix(1, 1) * vector(1) + matrix(1, 2) * vector(2), result(1)) + test("matrix-matrix-dot") { + val mmdot = new DslDriverC[String, Unit] with TensorExp { + @virtualize + def snippet(x: Rep[String]): Rep[Unit] = { + // Note: it's better to test with non-square matrices. + val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6) + val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7) + val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79) + Tensor.assertEqual(m1.dot(m2), expected) } } - array1_1.eval("abc") + runTest(mmdot) } test("array2") { From 8c0e4e8c98179ef99f352055d40fc4c145aaad43 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 4 Oct 2018 21:23:50 -0400 Subject: [PATCH 02/10] Fix vector-vector dot doc comment. --- src/main/scala/lantern/ad_lms_vector.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index a6a2ab63..62329bfa 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -204,7 +204,7 @@ trait TensorExp extends Dsl with Diff { */ trait BackendNative extends Backend { // Compute vector-vector dot product, i.e. inner product. - // [V1] dot [V2] => [1] (scalar) + // [V] dot [V] => [1] (scalar) private def vvdot(x: Tensor, y: Tensor): Tensor = { assert(x.shape(0) == y.shape(0)) val value = var_new(0.0f) @@ -439,7 +439,7 @@ trait TensorExp extends Dsl with Diff { // `dot` represents the following: // - vector-vector dot product. - // [V1] dot [V2] => [1] (scalar) + // [V] dot [V] => [1] (scalar) // - matrix-vector multiplication. // [M1 x M2] dot [M2] => [M1] // - matrix-matrix multiplication. @@ -1678,7 +1678,7 @@ trait TensorExp extends Dsl with Diff { // `dot` represents the following: // - vector-vector dot product. - // [V1] dot [V2] => [1] (scalar) + // [V] dot [V] => [1] (scalar) // - matrix-vector multiplication. // [M1 x M2] dot [M2] => [M1] // - matrix-matrix multiplication. From a78ffc33160cec8e98f15e9e28a79ebc872a4c60 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 4 Oct 2018 23:45:33 -0400 Subject: [PATCH 03/10] Fix typos. --- src/main/scala/lantern/ad_lms_vector.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index 62329bfa..46c00db8 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -448,8 +448,8 @@ trait TensorExp extends Dsl with Diff { generate_comment(s"dot: ${this.shape.seq}, ${that.shape.seq}") (this.rank, that.rank) match { case (1, 1) => assert(this.shape(0) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}") - case (2, 1) => assert(this.shape(1) == that.shape(0), s"Incompatible shapes: (${this.shape}, ${that.shape}") - case (2, 2) => assert(this.shape(0) == that.shape(1), s"Incompatible shapes: (${this.shape}, ${that.shape}") + case (2, 1) => assert(this.shape(1) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}") + case (2, 2) => assert(this.shape(0) == that.shape(1), s"Incompatible shapes: ${this.shape}, ${that.shape}") case _ => throw new IllegalArgumentException( s"Only vector-vector, matrix-vector, and matrix-matrix multiplication are allowed (actual shapes: ${this.shape}, ${that.shape})") } From f9d5be917a7b32b37ac9fc425a55ebd96c3d8a88 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 12:57:35 -0400 Subject: [PATCH 04/10] Add cuBLAS/cuDNN code generators and drivers. - Refactored common codegen logic into a new `DslGenBase` trait. TODO: reduce code duplication between code generators and drivers. Defining a base trait for code generators makes sense. --- src/main/scala/lantern/dslapi.scala | 337 ++++++++++++++++++++++++++-- 1 file changed, 323 insertions(+), 14 deletions(-) diff --git a/src/main/scala/lantern/dslapi.scala b/src/main/scala/lantern/dslapi.scala index 702d2864..427b2789 100644 --- a/src/main/scala/lantern/dslapi.scala +++ b/src/main/scala/lantern/dslapi.scala @@ -221,7 +221,7 @@ trait DslImpl extends DslExp { q => // TODO: currently part of this is specific to the query tests. generalize? move? @virtualize -trait DslGenC extends CGenNumericOpsExtra +trait DslGenBase extends CGenNumericOpsExtra with CGenPrimitiveOps with CGenBooleanOps with CGenIfThenElse with CGenEqual with CGenRangeOps with CGenOrderingOps with CGenMiscOps with CGenArrayOps with CGenStringOps @@ -336,6 +336,7 @@ trait DslGenC extends CGenNumericOpsExtra case Const(0) if x.tp == typ[Char] => "'\\0'" case _ => super.quote(x) } + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case Error(s) => stream.println("assert(false && " + quote(s) + ");") case afs@ArrayFromSeq(xs) => stream.println(remap(afs.m) + " " + quote(sym) + "[" + xs.length + "] = {" + (xs map quote mkString ",") + "}; // ;)") @@ -345,6 +346,7 @@ trait DslGenC extends CGenNumericOpsExtra val arrType = remap(a.m) //stream.println(arrType + "* " + quote(sym) + " = " + getMemoryAllocString(quote(n), arrType)) stream.println(arrType + "* " + quote(sym) + " = " + getMemoryAllocStringArena(quote(n), arrType)) + //stream.println("unique_ptr<" + arrType + "[]> " + quote(sym) + "(new " + arrType + "[" + quote(n) + "]);") //stream.println("shared_ptr<" + arrType + "[]> " + quote(sym) + "(new " + arrType + "[" + quote(n) + "]);") case ArrayApply(x,n) => emitValDef(sym, quote(x) + "[" + quote(n) + "]") @@ -364,6 +366,13 @@ trait DslGenC extends CGenNumericOpsExtra case MathTanh(x) => emitValDef(sym, src"tanh($x)") case _ => super.emitNode(sym,rhs) } +} + +trait DslGenC extends DslGenBase { + val IR: DslExp + import IR._ + + // TODO: Reduce code duplication in `emitSource` functions. override def emitSource[A:Typ](args: List[Sym[_]], body: Block[A], functionName: String, out: java.io.PrintWriter) = { withStream(out) { stream.println("""#include @@ -442,13 +451,261 @@ int main(int argc, char *argv[]) { } Snippet(argv[1]); return 0; -}""") +} +""") } super.emitSource[A](args, body, functionName, out) } } +@virtualize +trait DslGenCublas extends DslGenBase { + val IR: DslExp + import IR._ + + // Allocate GPU memory. + def getCudaMallocString(buffer: String, count: String, memType: String): String = { + "CUDA_CALL(cudaMalloc(&" + buffer + ", " + count + " * sizeof(" + memType + ")));" + } + // Allocate unified memory, accessible by CPU and GPU. + // FIXME: I encountered "bus error" when performing CPU ops on managed memory: + // Thread 1 "snippet" received signal SIGBUS, Bus error. + // Snippet (x0=) at snippet.cpp:144 + // 144 float x32 = x30 - x31; + // I wonder if others can replicate this issue. + def getCudaMallocManagedString(buffer: String, count: String, memType: String): String = { + "CUDA_CALL(cudaMallocManaged(&" + buffer + ", " + count + " * sizeof(" + memType + ")));" + } + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case a@ArrayNew(n) => + // Unified CPU/GPU memory via `cudaMallocManaged` is more convenient than `cudaMalloc`, but less performant. + // We can use a similar memory pool technique with `cudaMallocManaged`. + val arrType = remap(a.m) + stream.println(arrType + "* " + quote(sym) + "; " + getCudaMallocManagedString(quote(sym), quote(n), arrType)) + // stream.println(arrType + "* " + quote(sym) + "; " + getCudaMallocString(quote(sym), quote(n), arrType)) + case _ => super.emitNode(sym,rhs) + } + + override def emitSource[A:Typ](args: List[Sym[_]], body: Block[A], functionName: String, out: java.io.PrintWriter) = { + withStream(out) { + stream.println("""#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "cublas_v2.h" + +using namespace std; + +#ifndef MAP_FILE +#define MAP_FILE MAP_SHARED +#endif + +#define CUDA_CALL(f) { \ + cudaError_t err = (f); \ + if (err != cudaSuccess) { \ + std::cerr << "Error occurred: " << err << std::endl; \ + std::exit(1); \ + } \ +} + +#define CUBLAS_CALL(f) { \ + cublasStatus_t stat = (f); \ + if (stat != CUBLAS_STATUS_SUCCESS) { \ + std::cerr << "Error occurred: " << err << std::endl; \ + exit(1); \ + } \ +} + +int fsize(int fd) { + struct stat stat; + int res = fstat(fd, &stat); + return stat.st_size; +} + +int printll(char *s) { + while (*s != '\n' && *s != ',' && *s != '\t') { + putchar(*s++); + } + return 0; +} + +long hash(char *str0, int len) { + unsigned char *str = (unsigned char *)str0; + unsigned long hash = 5381; + int c; + + while ((c = *str++) && len--) + hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ + + return hash; +} + +int HEAP_SIZE = 1073741826; // 1048576; // 2147483652; // 536870912; // 268435456; // 2097152; +void *mallocBase = malloc(HEAP_SIZE); +void *mallocAddr = mallocBase; +void *waterMark = mallocBase; +void *myMalloc(size_t bytes) { + void *res = mallocAddr; + mallocAddr = (void *)((char *)mallocAddr + bytes); + return res; +} + +int timeval_subtract(struct timeval *result, struct timeval *t2, struct timeval *t1) { + long int diff = (t2->tv_usec + 1000000 * t2->tv_sec) - (t1->tv_usec + 1000000 * t1->tv_sec); + result->tv_sec = diff / 1000000; + result->tv_usec = diff % 1000000; + return (diff < 0); +} + +void Snippet(char *); + +std::random_device rd{}; +std::mt19937 gen{rd()}; +std::normal_distribution<> d{0, 1}; +cublasHandle_t handle; + +int main(int argc, char *argv[]) { + CUBLAS_CALL(cublasCreate(&handle)); + if (argc != 2) { + printf("usage: query \n"); + return 0; + } + Snippet(argv[1]); + CUBLAS_CALL(cublasDestroy(handle)); + return 0; +} +""") + } + super.emitSource[A](args, body, functionName, out) + } +} + +@virtualize +trait DslGenCudnn extends DslGenBase { + val IR: DslExp + import IR._ + + override def emitSource[A:Typ](args: List[Sym[_]], body: Block[A], functionName: String, out: java.io.PrintWriter) = { + withStream(out) { + stream.println("""#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace std; + +#ifndef MAP_FILE +#define MAP_FILE MAP_SHARED +#endif + +#define CUDA_CALL(f) { \ + cudaError_t err = (f); \ + if (err != cudaSuccess) { \ + std::cerr << "Error occurred: " << err << std::endl; \ + std::exit(1); \ + } \ +} + +#define CUDNN_CALL(f) { \ + cudnnStatus_t err = (f); \ + if (err != CUDNN_STATUS_SUCCESS) { \ + std::cerr << "Error occurred: " << err << std::endl; \ + std::exit(1); \ + } \ +} + +int fsize(int fd) { + struct stat stat; + int res = fstat(fd, &stat); + return stat.st_size; +} + +int printll(char *s) { + while (*s != '\n' && *s != ',' && *s != '\t') { + putchar(*s++); + } + return 0; +} + +long hash(char *str0, int len) { + unsigned char *str = (unsigned char *)str0; + unsigned long hash = 5381; + int c; + + while ((c = *str++) && len--) + hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ + + return hash; +} + +int HEAP_SIZE = 1073741826; // 1048576; // 2147483652; // 536870912; // 268435456; // 2097152; +void *mallocBase = malloc(HEAP_SIZE); +void *mallocAddr = mallocBase; +void *waterMark = mallocBase; +void *myMalloc(size_t bytes) { + void *res = mallocAddr; + mallocAddr = (void *)((char *)mallocAddr + bytes); + return res; +} + +int timeval_subtract(struct timeval *result, struct timeval *t2, struct timeval *t1) { + long int diff = (t2->tv_usec + 1000000 * t2->tv_sec) - (t1->tv_usec + 1000000 * t1->tv_sec); + result->tv_sec = diff / 1000000; + result->tv_usec = diff % 1000000; + return (diff < 0); +} + +void Snippet(char *); + +std::random_device rd{}; +std::mt19937 gen{rd()}; +std::normal_distribution<> d{0, 1}; +void Snippet(char*); + +int main(int argc, char *argv[]) { + if (argc != 2) { + printf("usage: query \n"); + return 0; + } + Snippet(argv[1]); + return 0; +} +""") + } + super.emitSource[A](args, body, functionName, out) + } +} @virtualize abstract class DslSnippet[A:Manifest, B:Manifest] extends Dsl { @@ -460,7 +717,7 @@ abstract class DslDriver[A:Manifest,B:Manifest] extends DslSnippet[A,B] with Dsl lazy val f = compile(snippet)(manifestTyp[A],manifestTyp[B]) def precompile: Unit = f - //def precompileSilently: Unit = utils.devnull(f) + // def precompileSilently: Unit = utils.devnull(f) def eval(x: A): B = f(x) @@ -472,30 +729,82 @@ abstract class DslDriver[A:Manifest,B:Manifest] extends DslSnippet[A,B] with Dsl } @virtualize -abstract class DslDriverC[A: Manifest, B: Manifest] extends DslSnippet[A, B] with DslExp { - q => +abstract class DslDriverC[A: Manifest, B: Manifest] extends DslSnippet[A, B] with DslExp { self => val codegen = new DslGenC { - val IR: q.type = q + val IR: self.type = self } + lazy val code: String = { - //implicit val mA = manifestTyp[A] - //implicit val mB = manifestTyp[B] val source = new java.io.StringWriter() codegen.emitSource(snippet, "Snippet", new java.io.PrintWriter(source)) source.toString } def eval(a: A): Unit = { - // TBD: should read result of type B? val out = new java.io.PrintWriter("/tmp/snippet.cpp") out.println(code) - out.close - //TODO: use precompile - (new java.io.File("/tmp/snippet")).delete + out.close() + + // TODO: Use precompile + new java.io.File("/tmp/snippet").delete import scala.sys.process._ System.out.println("Compile C++ code") - (s"g++ -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet": ProcessBuilder).lines.foreach(System.out.println _) //-std=c99 + (s"g++ -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet": ProcessBuilder).lines.foreach(System.out.println) //-std=c99 + System.out.println("Run C++ code") + (s"/tmp/snippet $a": ProcessBuilder).lines.foreach(System.out.println) + } +} + +@virtualize +abstract class DslDriverCublas[A: Manifest, B: Manifest] extends DslSnippet[A, B] with DslExp { self => + val codegen = new DslGenCublas { + val IR: self.type = self + } + + lazy val code: String = { + val source = new java.io.StringWriter() + codegen.emitSource(snippet, "Snippet", new java.io.PrintWriter(source)) + source.toString + } + + def eval(a: A): Unit = { + val out = new java.io.PrintWriter("/tmp/snippet.cpp") + out.println(code) + out.close() + + // TODO: Use precompile + new java.io.File("/tmp/snippet").delete + import scala.sys.process._ + System.out.println("Compile C++ (cuBLAS) code") + (s"nvcc -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet -lcublas": ProcessBuilder).lines.foreach(System.out.println) //-std=c99 + System.out.println("Run C++ code") + (s"/tmp/snippet $a": ProcessBuilder).lines.foreach(System.out.println) + } +} + +@virtualize +abstract class DslDriverCudnn[A: Manifest, B: Manifest] extends DslSnippet[A, B] with DslExp { self => + val codegen = new DslGenCudnn { + val IR: self.type = self + } + + lazy val code: String = { + val source = new java.io.StringWriter() + codegen.emitSource(snippet, "Snippet", new java.io.PrintWriter(source)) + source.toString + } + + def eval(a: A): Unit = { + val out = new java.io.PrintWriter("/tmp/snippet.cpp") + out.println(code) + out.close() + + // TODO: Use precompile + new java.io.File("/tmp/snippet").delete + import scala.sys.process._ + System.out.println("Compile C++ (cuDNN) code") + (s"nvcc -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet -lcudnn": ProcessBuilder).lines.foreach(System.out.println) //-std=c99 System.out.println("Run C++ code") - (s"/tmp/snippet $a": ProcessBuilder).lines.foreach(System.out.println _) + (s"/tmp/snippet $a": ProcessBuilder).lines.foreach(System.out.println) } } From 5763cf20962f4d50a9abe1b00a7815271f3b4901 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 13:02:24 -0400 Subject: [PATCH 05/10] Implement `dot` for `BackendCublas`, refactor `Backend` trait. - `Backend` now defines a default `dot` method that dispatches to separate v*v, m*v, m*m methods. - Implement `dot` methods for cuBLAS. - v*v: cublasSdot - m*v: cublasSgemv - m*m: cublasSgemm --- src/main/scala/lantern/ad_lms_vector.scala | 140 ++++++++++++++++----- 1 file changed, 109 insertions(+), 31 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index 849a937b..a43bcf81 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -190,22 +190,47 @@ trait TensorExp extends Dsl with Diff { } */ /** - * Defines tensor-specific operations. - * Eventually, a tensor operation IR may be introduced to enable analyses/transformations. + * A code generation backend for tensor operations. + * + * Note: Eventually, a tensor operation IR may be introduced to enable analyses and + * transformations such as operator fusion and matrix chain multiplication optimization. */ trait Backend { - def dot(x: Tensor, y: Tensor): Tensor - // TODO: Add more ops. + // Compute vector-vector dot product, i.e. inner product. + // [V] dot [V] => [1] (scalar) + def vectorVectorDot(x: Tensor, y: Tensor): Tensor + + // Compute matrix-vector dot product. + // [M1 x M2] dot [M2] => [M1] + def matrixVectorDot(x: Tensor, y: Tensor): Tensor + + // Compute matrix-matrix dot product. + // [M1 x M2] dot [M2 x M3] => [M1 x M3] + def matrixMatrixDot(x: Tensor, y: Tensor): Tensor + + def dot(x: Tensor, y: Tensor): Tensor = + (x.rank, y.rank) match { + case (1, 1) => vectorVectorDot(x, y) + case (2, 1) => matrixVectorDot(x, y) + case (2, 2) => matrixMatrixDot(x, y) + case _ => throw new IllegalArgumentException(s"Incompatible shapes: ${x.shape}, ${y.shape}") + } + + // TODO: Add more ops: + // - Elementwise binary ops (+, -, *, /). + // - GPU backends need to address broadcasting. + // - `BackendCublas` can define addition using `cublasSaxpy`. + // - Conv2d. + // - Activation functions (e.g. relu). + // - Fused multiply add operations? } /** - * Native tensor op backend. + * Native tensor operation backend. WIP. * Tensor ops are defined in terms of primitive operations. */ - trait BackendNative extends Backend { - // Compute vector-vector dot product, i.e. inner product. - // [V] dot [V] => [1] (scalar) - private def vvdot(x: Tensor, y: Tensor): Tensor = { + class BackendNative extends Backend { + override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = { assert(x.shape(0) == y.shape(0)) val value = var_new(0.0f) for (i <- DataLoop(x.shape.last)) { @@ -216,9 +241,7 @@ trait TensorExp extends Dsl with Diff { Tensor(res, 1) } - // Compute matrix-vector dot product. - // [M1 x M2] dot [M2] => [M1] - private def mvdot(x: Tensor, y: Tensor): Tensor = { + override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = { assert(x.shape(1) == y.shape(0)) val dim1 = x.shape(0) val dim2 = x.shape(1) @@ -233,9 +256,7 @@ trait TensorExp extends Dsl with Diff { Tensor(res, dim1) } - // Compute matrix-matrix dot product. - // [M1 x M2] dot [M2 x M3] => [M1 x M3] - private def mmdot(x: Tensor, y: Tensor): Tensor = { + override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = { assert(x.shape(1) == y.shape(0)) val dim1 = x.shape(0) val dim2 = x.shape(1) @@ -252,21 +273,56 @@ trait TensorExp extends Dsl with Diff { } Tensor(res, dim1, dim3) } - - override def dot(x: Tensor, y: Tensor): Tensor = - (x.rank, y.rank) match { - case (1, 1) => vvdot(x, y) - case (2, 1) => mvdot(x, y) - case (2, 2) => mmdot(x, y) - case _ => throw new IllegalArgumentException(s"Incompatible shapes: ${x.shape}, ${y.shape}") - } } /** - * cuBLAS tensor op backend. WIP. + * cuBLAS tensor operation backend. WIP. */ - trait BackendCUBLAS extends Backend { - // GEMM reference: + class BackendCublas extends Backend { + // Reference: + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-dot + // + // cublasStatus_t cublasSdot(cublasHandle_t handle, int n, + // const float *x, int incx, + // const float *y, int incy, + // float *result) + def sdot(a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = + unchecked[Unit]("CUBLAS_CALL(cublasSdot(handle, ", a.length, ",", a, ",1,", b, ",1,", result, "))") + + override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = { + val res = NewArray[Float](1) + sdot(x.data, y.data, res) + Tensor(res, 1) + } + + // Reference: + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemv + // + // cublasStatus_t cublasSgemv(cublasHandle_t handle, cublasOperation_t trans, + // int m, int n, + // const float *alpha, + // const float *A, int lda, + // const float *x, int incx, + // const float *beta, + // float *y, int incy) + def sgemv(m: Int, n: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = { + val zero = NewArray[Float](1); zero(0) = 0 + val one = NewArray[Float](1); one(0) = 1 + unchecked[Unit]( + "CUBLAS_CALL(cublasSgemv(handle, CUBLAS_OP_N, ", + m, ",", n, ",", one, ",", + a, ",", m, ",", b, ",", zero, ",", result, ",", one, "))") + } + + override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = { + val m = x.shape(0) + val n = x.shape(1) + val res = NewArray[Float](m) + sgemv(m, n, x.data, y.data, res) + Tensor(res, m) + } + + // Reference: // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm // // cublasStatus_t cublasSgemm(cublasHandle_t handle, @@ -277,16 +333,38 @@ trait TensorExp extends Dsl with Diff { // const float *B, int ldb, // const float *beta, // float *C, int ldc) - def sgemm(a: Array[Float], b: Array[Float], c: Array[Float]) = unchecked[Array[Float]]("cublasSgemm(...)") + def sgemm(m: Int, n: Int, k: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = { + val zero = NewArray[Float](1); zero(0) = 0 + val one = NewArray[Float](1); one(0) = 1 + unchecked[Unit]( + "CUBLAS_CALL(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, ", + m, ",", n, ",", k, ",", one, ",", + a, ",", m, ",", b, ",", k, ",", zero, ",", result, ",", m, "))") + } - override def dot(x: Tensor, y: Tensor): Tensor = ??? + override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = { + val m = x.shape(0) + val n = y.shape(1) + val k = y.shape(0) + val res = NewArray[Float](m * n) + sgemm(m, n, k, x.data, y.data, res) + Tensor(res, m, n) + } } /** - * Default tensor op backend, extending `BackendNative`. + * cuDNN tensor operation backend. WIP. */ - class BackendDefault extends BackendNative - val backend: Backend = new BackendDefault + class BackendCudnn extends Backend { + override def vectorVectorDot(x: Tensor, y: Tensor): Tensor = ??? + override def matrixVectorDot(x: Tensor, y: Tensor): Tensor = ??? + override def matrixMatrixDot(x: Tensor, y: Tensor): Tensor = ??? + } + + // The current backend for code generation. + // To switch code generation to a different backend, simply change this value + // in your DSL program. + var backend: Backend = new BackendNative class Tensor(val data: Rep[Array[Float]], val dimensions: NSeq[Int]) extends Serializable { From 1b1c2a816b1015ea7e0fee3ae27dc17670df260b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 13:16:18 -0400 Subject: [PATCH 06/10] Add cuBLAS code generation test. The cuBLAS test suite is currently disabled (`isGPUAvailable` is set to false). Otherwise, Travis CI will fail. TODO: - Implement `isGPUAvailable` to actually detect whether GPU codegen is possible. - Factor test utility methods into a common base class. - Set up GPU CI. --- src/test/scala/lantern/test_cublas.scala | 90 ++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 src/test/scala/lantern/test_cublas.scala diff --git a/src/test/scala/lantern/test_cublas.scala b/src/test/scala/lantern/test_cublas.scala new file mode 100644 index 00000000..eb029ea8 --- /dev/null +++ b/src/test/scala/lantern/test_cublas.scala @@ -0,0 +1,90 @@ +package lantern + +import org.scala_lang.virtualized.virtualize +import org.scala_lang.virtualized.SourceContext + +import org.scalactic.source +import org.scalatest.{FunSuite, Tag} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Seq => NSeq} +import java.io.{File, PrintWriter} + +class CublasTest extends FunSuite { + // TODO: Edit this function to actually detect whether GPU codegen is possible. + // One idea: check for: + // - The existence of cuBLAS header files (, "cublas_v2.h"). + // - The existence of a GPU (perhaps run `nvidia-smi`). + def isGPUAvailable = false + + testGPU("vector-vector-dot") { + val vvdot = new DslDriverCublas[String, Unit] with TensorExp { + backend = new BackendCublas + + @virtualize + def snippet(x: Rep[String]): Rep[Unit] = { + val length = 2 + val v1 = Tensor.fromData(NSeq(4), 1, 2, 3, 4) + val v2 = Tensor.fromData(NSeq(4), -1, -2, -3, -4) + val expected = Tensor.fromData(NSeq(1), -30) + Tensor.assertEqual(v1.dot(v2), expected) + } + } + runTest(vvdot) + } + + testGPU("matrix-vector-dot") { + val mvdot = new DslDriverCublas[String, Unit] with TensorExp { + backend = new BackendCublas + + @virtualize + def snippet(x: Rep[String]): Rep[Unit] = { + val m = Tensor.fromData(NSeq(2, 4), 1, 2, 3, 4, 5, 6, 7, 8) + val v = Tensor.fromData(NSeq(4), -1, -2, -3, -4) + val expected = Tensor.fromData(NSeq(2), -30, -70) + Tensor.assertEqual(m.dot(v), expected) + } + } + runTest(mvdot, debug = true) + } + + testGPU("matrix-matrix-dot") { + val mmdot = new DslDriverCublas[String, Unit] with TensorExp { + backend = new BackendCublas + + @virtualize + def snippet(x: Rep[String]): Rep[Unit] = { + // Note: it's better to test with non-square matrices. + val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6) + val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7) + val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79) + Tensor.assertEqual(m1.dot(m2), expected) + } + } + runTest(mmdot, debug = true) + } + + // Utility function wrapping `test` that checks whether GPU is available. + def testGPU(testName: String, testTags: Tag*)(testFun: => Any /* Assertion */)(implicit pos: source.Position) { + if (isGPUAvailable) + test(testName, testTags: _*)(testFun)(pos) + else + ignore(testName, testTags: _*)(testFun)(pos) + } + + // TODO: Refactor `runTest` into a "LanternTestSuite" base class. + def runTest(snippet: DslDriverCublas[String, Unit], debug: Boolean = false) { + val test = new PrintWriter(new File("/tmp/snippet.cpp")) + if (debug) { + System.err.println(snippet.code) + } + test.println(snippet.code) + test.flush() + new java.io.File("/tmp/snippet").delete + import scala.sys.process._ + System.out.println("Compile C++ code") + (s"g++ -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet": ProcessBuilder).lines.foreach(System.out.println) + System.out.println("Run C++ code") + (s"/tmp/snippet a": ProcessBuilder).lines.foreach(System.out.println) + } +} \ No newline at end of file From cee7c9b32abfa1127a43945e191ed7287a42645b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 13:24:44 -0400 Subject: [PATCH 07/10] Implement `withBackend` device placement function. `withBackend` explicitly demarcates code that should be run on a different backend. It transfers inputs/results between backends automatically. Design info: https://github.com/feiwang3311/Lantern/issues/8#issuecomment-426563742 --- src/main/scala/lantern/ad_lms_vector.scala | 57 +++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index a43bcf81..a57d7705 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -366,6 +366,59 @@ trait TensorExp extends Dsl with Diff { // in your DSL program. var backend: Backend = new BackendNative + /** + * Transfer data between backends. + * @param from The current backend. + * @param to The new backend. + * @param data The data to transfer. + * @tparam T Type of the data. + */ + def transfer[T](from: Backend, to: Backend)(data: T) { + // TODO: Implement logic. It will involve `cudaMemcpy`. + (from, to) match { + case (cpu: BackendNative, gpu: BackendCudnn) => ??? + case (gpu: BackendCudnn, cpu: BackendNative) => ??? + case _ => ??? + } + } + + /** + * Call a function with given inputs, generating code for the specified backend. + * The inputs and result will be transferred between backends automatically. + * @param b The new backend. + * @param input The input to the function. + * @param f The function to call, whose code will be generated on the new backend. + * @tparam T The function input type. + * @tparam U The function output type. + */ + def withBackend[T, U](b: Backend, input: T)(f: T => U) = { + val originalBackend = backend + + // Transfer input to the new backend. + // TODO: Consider using CPU-GPU shared memory? + transfer(originalBackend, b)(input) + + // Change the backend (i.e. codegen target), then call `f`. + backend = b + val result = f(input) + + // Transfer `result` to the old backend, then reset the backend. + transfer(originalBackend, b)(result) + backend = originalBackend + } + + /** + * Call a function with given inputs, generating code for CPU. + * The inputs and result will be transferred between backends automatically. + */ + def withCPU[T, U](input: T)(f: T => U) = withBackend(new BackendNative, input)(f) + + /** + * Call a function with given inputs, generating code for GPU (via `BackendCudnn`). + * The inputs and result will be transferred between backends automatically. + */ + def withGPU[T, U](input: T)(f: T => U) = withBackend(new BackendCudnn, input)(f) + class Tensor(val data: Rep[Array[Float]], val dimensions: NSeq[Int]) extends Serializable { val MAX_DOUBLE = 1e10f // FIXME @@ -578,8 +631,8 @@ trait TensorExp extends Dsl with Diff { @virtualize def sum2D(dim: Int) = { - assert (this.rank == 2, "Only deal with 2D tensor") - assert (dim == 0 || dim == 1, "dim must be in range of this.nbDims") + assert(this.rank == 2, "Only deal with 2D tensor") + assert(dim == 0 || dim == 1, "dim must be in range of this.nbDims") if (dim == 0) ??? else { From b4a5187b4490d61decd55b73eeb552b2ad802abf Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 13:59:52 -0400 Subject: [PATCH 08/10] [NFC] Remove cuBLAS function signatures from comments. These comments bloat the lines of code. Links to the cuBLAS API reference still exist, for each method. --- src/main/scala/lantern/ad_lms_vector.scala | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index a57d7705..f80cf133 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -281,11 +281,6 @@ trait TensorExp extends Dsl with Diff { class BackendCublas extends Backend { // Reference: // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-dot - // - // cublasStatus_t cublasSdot(cublasHandle_t handle, int n, - // const float *x, int incx, - // const float *y, int incy, - // float *result) def sdot(a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = unchecked[Unit]("CUBLAS_CALL(cublasSdot(handle, ", a.length, ",", a, ",1,", b, ",1,", result, "))") @@ -297,14 +292,6 @@ trait TensorExp extends Dsl with Diff { // Reference: // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemv - // - // cublasStatus_t cublasSgemv(cublasHandle_t handle, cublasOperation_t trans, - // int m, int n, - // const float *alpha, - // const float *A, int lda, - // const float *x, int incx, - // const float *beta, - // float *y, int incy) def sgemv(m: Int, n: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = { val zero = NewArray[Float](1); zero(0) = 0 val one = NewArray[Float](1); one(0) = 1 @@ -324,15 +311,6 @@ trait TensorExp extends Dsl with Diff { // Reference: // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm - // - // cublasStatus_t cublasSgemm(cublasHandle_t handle, - // cublasOperation_t transa, cublasOperation_t transb, - // int m, int n, int k, - // const float *alpha, - // const float *A, int lda, - // const float *B, int ldb, - // const float *beta, - // float *C, int ldc) def sgemm(m: Int, n: Int, k: Int, a: Rep[Array[Float]], b: Rep[Array[Float]], result: Rep[Array[Float]]) = { val zero = NewArray[Float](1); zero(0) = 0 val one = NewArray[Float](1); one(0) = 1 From d1a52c6dd6e5f5489a3f5f4b83a0fbbe8d0e2013 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 14:36:54 -0400 Subject: [PATCH 09/10] [NFC] Update comment. --- src/main/scala/lantern/ad_lms_vector.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index f80cf133..fce84f5a 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -352,7 +352,8 @@ trait TensorExp extends Dsl with Diff { * @tparam T Type of the data. */ def transfer[T](from: Backend, to: Backend)(data: T) { - // TODO: Implement logic. It will involve `cudaMemcpy`. + // TODO: Implement logic. `cudaMemcpy` will be involved. + // Consider what to do when unified memory is used (i.e. `cudaMallocMananged`). (from, to) match { case (cpu: BackendNative, gpu: BackendCudnn) => ??? case (gpu: BackendCudnn, cpu: BackendNative) => ??? @@ -373,7 +374,6 @@ trait TensorExp extends Dsl with Diff { val originalBackend = backend // Transfer input to the new backend. - // TODO: Consider using CPU-GPU shared memory? transfer(originalBackend, b)(input) // Change the backend (i.e. codegen target), then call `f`. From a6d82fb63c4e52fe11375f39d9d4faac82778cc2 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 6 Oct 2018 16:22:48 -0400 Subject: [PATCH 10/10] Remove `withBackend`. Add `withBackend` in a separate PR for separation of concerns. --- src/main/scala/lantern/ad_lms_vector.scala | 53 ---------------------- 1 file changed, 53 deletions(-) diff --git a/src/main/scala/lantern/ad_lms_vector.scala b/src/main/scala/lantern/ad_lms_vector.scala index fce84f5a..62d8937a 100644 --- a/src/main/scala/lantern/ad_lms_vector.scala +++ b/src/main/scala/lantern/ad_lms_vector.scala @@ -344,59 +344,6 @@ trait TensorExp extends Dsl with Diff { // in your DSL program. var backend: Backend = new BackendNative - /** - * Transfer data between backends. - * @param from The current backend. - * @param to The new backend. - * @param data The data to transfer. - * @tparam T Type of the data. - */ - def transfer[T](from: Backend, to: Backend)(data: T) { - // TODO: Implement logic. `cudaMemcpy` will be involved. - // Consider what to do when unified memory is used (i.e. `cudaMallocMananged`). - (from, to) match { - case (cpu: BackendNative, gpu: BackendCudnn) => ??? - case (gpu: BackendCudnn, cpu: BackendNative) => ??? - case _ => ??? - } - } - - /** - * Call a function with given inputs, generating code for the specified backend. - * The inputs and result will be transferred between backends automatically. - * @param b The new backend. - * @param input The input to the function. - * @param f The function to call, whose code will be generated on the new backend. - * @tparam T The function input type. - * @tparam U The function output type. - */ - def withBackend[T, U](b: Backend, input: T)(f: T => U) = { - val originalBackend = backend - - // Transfer input to the new backend. - transfer(originalBackend, b)(input) - - // Change the backend (i.e. codegen target), then call `f`. - backend = b - val result = f(input) - - // Transfer `result` to the old backend, then reset the backend. - transfer(originalBackend, b)(result) - backend = originalBackend - } - - /** - * Call a function with given inputs, generating code for CPU. - * The inputs and result will be transferred between backends automatically. - */ - def withCPU[T, U](input: T)(f: T => U) = withBackend(new BackendNative, input)(f) - - /** - * Call a function with given inputs, generating code for GPU (via `BackendCudnn`). - * The inputs and result will be transferred between backends automatically. - */ - def withGPU[T, U](input: T)(f: T => U) = withBackend(new BackendCudnn, input)(f) - class Tensor(val data: Rep[Array[Float]], val dimensions: NSeq[Int]) extends Serializable { val MAX_DOUBLE = 1e10f // FIXME