diff --git a/.travis.yml b/.travis.yml index 9ffc70a0..808e2a61 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,7 @@ cache: before_script: - scripts/download_squeezenet.sh + - scripts/download_resnet.sh script: - sbt test diff --git a/scripts/download_squeezenet.sh b/scripts/download_squeezenet.sh index a6314c13..9e5c840c 100755 --- a/scripts/download_squeezenet.sh +++ b/scripts/download_squeezenet.sh @@ -10,4 +10,3 @@ then tar xzf $HOME/tmp/squeezenet.tar.gz -C $HOME/onnx_models rm -f $HOME/tmp/squeezenet.tar.gz fi - diff --git a/src/test/scala/lantern/test_onnx.scala b/src/test/scala/lantern/test_onnx.scala index ee885536..83f297a1 100644 --- a/src/test/scala/lantern/test_onnx.scala +++ b/src/test/scala/lantern/test_onnx.scala @@ -23,12 +23,24 @@ import java.nio.FloatBuffer; class ONNXTest extends FunSuite { - val model_file = s"""${sys.env("HOME")}/onnx_models/squeezenet/model.onnx""" - val model_dir = s"""${sys.env("HOME")}/onnx_models/squeezenet/""" + val model_file_all = (name: String) => s"""${sys.env("HOME")}/onnx_models/$name/model.onnx""" + val model_dir_all = (name: String) => s"""${sys.env("HOME")}/onnx_models/$name/""" val gene_dir = "/tmp/" + def runTest(snippet: DslDriverC[String, Unit]) = { + val test = new PrintWriter(new File("/tmp/snippet.cpp")) + test.println(snippet.code) + test.flush() + new java.io.File("/tmp/snippet").delete + import scala.sys.process._ + System.out.println("Compile C++ code") + (s"g++ -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet": ProcessBuilder).lines.foreach(System.out.println) + System.out.println("Run C++ code") + (s"/tmp/snippet a": ProcessBuilder).lines.foreach(System.out.println) + } test("onnx_reading_basic") { + val model_file = model_file_all("squeezenet") System.out.println(s"testing reading onnx models from $model_file") val model = onnx_ml.ModelProto.parseFrom(new FileInputStream(model_file)) @@ -215,6 +227,7 @@ class ONNXTest extends FunSuite { @virtualize def snippet(a: Rep[String]): Rep[Unit] = { + val model_file = model_file_all("squeezenet") val model = onnx_ml.ModelProto.parseFrom(new FileInputStream(model_file)) val graph = model.getGraph @@ -484,16 +497,16 @@ class ONNXTest extends FunSuite { } } - println(s"testing reading ONNX models from $model_file") - val squeezenet_file = new PrintWriter(new File(gene_dir + "squeezenet.cpp")) squeezenet_file.println(squeezenet.code) squeezenet_file.flush() } - test("inference") { + test("squeezenet_inference") { + val model_file = model_file_all("squeezenet") + val model_dir = model_dir_all("squeezenet") System.out.println(s"testing reading ONNX model using library from $model_file") val inference_func = new DslDriverC[String, Unit] with ONNXLib { @@ -510,17 +523,20 @@ class ONNXTest extends FunSuite { val input = readTensor(input_file).tensor val output = readTensor(output_file).tensor val output1 = func(input) + Tensor.assertEqual(output, output1.resize(1, 1000, 1, 1)) } } val squeezenet_file = new PrintWriter(new File(gene_dir + "squeezenet.cpp")) squeezenet_file.println(inference_func.code) squeezenet_file.flush() - + runTest(inference_func) } - test("training") { + test("squeezenet_training") { + val model_file = model_file_all("squeezenet") + val model_dir = model_dir_all("squeezenet") System.out.println(s"testing reading ONNX model using library from $model_file for training") val training_func = new DslDriverC[String, Unit] with ONNXLib { @@ -551,4 +567,32 @@ class ONNXTest extends FunSuite { squeezenet_file.println(training_func.code) squeezenet_file.flush() } + + test("resnet_inference") { + + val model_file = model_file_all("resnet") + val model_dir = model_dir_all("resnet") + System.out.println(s"testing reading ONNX model using library from $model_file") + + val inference_func = new DslDriverC[String, Unit] with ONNXLib { + + @virtualize + def snippet(a: Rep[String]): Rep[Unit] = { + val model = readONNX(model_file) + val (func, x_dims) = (model.inference_func, model.x_dims) + + // get test data as TensorProto + val input_file = model_dir + "test_data_set_0/input_0.pb" + val output_file = model_dir + "test_data_set_0/output_0.pb" + val input = readTensor(input_file).tensor + val output = readTensor(output_file).tensor + val output1 = func(input) + Tensor.assertEqual(output, output1) + } + } + val resnet_file = new PrintWriter(new File(gene_dir + "resnet.cpp")) + resnet_file.println(inference_func.code) + resnet_file.flush() + runTest(inference_func) + } } \ No newline at end of file