Skip to content

Commit

Permalink
partial work of resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
feiwang3311 committed Oct 5, 2018
1 parent 1bd2130 commit bf5709b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cache:

before_script:
- scripts/download_squeezenet.sh
- scripts/download_resnet.sh

script:
- sbt test
Expand Down
1 change: 0 additions & 1 deletion scripts/download_squeezenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ then
tar xzf $HOME/tmp/squeezenet.tar.gz -C $HOME/onnx_models
rm -f $HOME/tmp/squeezenet.tar.gz
fi

58 changes: 51 additions & 7 deletions src/test/scala/lantern/test_onnx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,24 @@ import java.nio.FloatBuffer;

class ONNXTest extends FunSuite {

val model_file = s"""${sys.env("HOME")}/onnx_models/squeezenet/model.onnx"""
val model_dir = s"""${sys.env("HOME")}/onnx_models/squeezenet/"""
val model_file_all = (name: String) => s"""${sys.env("HOME")}/onnx_models/$name/model.onnx"""
val model_dir_all = (name: String) => s"""${sys.env("HOME")}/onnx_models/$name/"""
val gene_dir = "/tmp/"
def runTest(snippet: DslDriverC[String, Unit]) = {
val test = new PrintWriter(new File("/tmp/snippet.cpp"))
test.println(snippet.code)
test.flush()
new java.io.File("/tmp/snippet").delete
import scala.sys.process._
System.out.println("Compile C++ code")
(s"g++ -std=c++11 -O1 /tmp/snippet.cpp -o /tmp/snippet": ProcessBuilder).lines.foreach(System.out.println)
System.out.println("Run C++ code")
(s"/tmp/snippet a": ProcessBuilder).lines.foreach(System.out.println)
}

test("onnx_reading_basic") {

val model_file = model_file_all("squeezenet")
System.out.println(s"testing reading onnx models from $model_file")
val model = onnx_ml.ModelProto.parseFrom(new FileInputStream(model_file))

Expand Down Expand Up @@ -215,6 +227,7 @@ class ONNXTest extends FunSuite {
@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {

val model_file = model_file_all("squeezenet")
val model = onnx_ml.ModelProto.parseFrom(new FileInputStream(model_file))
val graph = model.getGraph

Expand Down Expand Up @@ -484,16 +497,16 @@ class ONNXTest extends FunSuite {
}
}

println(s"testing reading ONNX models from $model_file")

val squeezenet_file = new PrintWriter(new File(gene_dir + "squeezenet.cpp"))
squeezenet_file.println(squeezenet.code)
squeezenet_file.flush()

}

test("inference") {
test("squeezenet_inference") {

val model_file = model_file_all("squeezenet")
val model_dir = model_dir_all("squeezenet")
System.out.println(s"testing reading ONNX model using library from $model_file")

val inference_func = new DslDriverC[String, Unit] with ONNXLib {
Expand All @@ -510,17 +523,20 @@ class ONNXTest extends FunSuite {
val input = readTensor(input_file).tensor
val output = readTensor(output_file).tensor
val output1 = func(input)
Tensor.assertEqual(output, output1.resize(1, 1000, 1, 1))
}
}

val squeezenet_file = new PrintWriter(new File(gene_dir + "squeezenet.cpp"))
squeezenet_file.println(inference_func.code)
squeezenet_file.flush()

runTest(inference_func)
}

test("training") {
test("squeezenet_training") {

val model_file = model_file_all("squeezenet")
val model_dir = model_dir_all("squeezenet")
System.out.println(s"testing reading ONNX model using library from $model_file for training")

val training_func = new DslDriverC[String, Unit] with ONNXLib {
Expand Down Expand Up @@ -551,4 +567,32 @@ class ONNXTest extends FunSuite {
squeezenet_file.println(training_func.code)
squeezenet_file.flush()
}

test("resnet_inference") {

val model_file = model_file_all("resnet")
val model_dir = model_dir_all("resnet")
System.out.println(s"testing reading ONNX model using library from $model_file")

val inference_func = new DslDriverC[String, Unit] with ONNXLib {

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val model = readONNX(model_file)
val (func, x_dims) = (model.inference_func, model.x_dims)

// get test data as TensorProto
val input_file = model_dir + "test_data_set_0/input_0.pb"
val output_file = model_dir + "test_data_set_0/output_0.pb"
val input = readTensor(input_file).tensor
val output = readTensor(output_file).tensor
val output1 = func(input)
Tensor.assertEqual(output, output1)
}
}
val resnet_file = new PrintWriter(new File(gene_dir + "resnet.cpp"))
resnet_file.println(inference_func.code)
resnet_file.flush()
runTest(inference_func)
}
}

0 comments on commit bf5709b

Please sign in to comment.