Skip to content

Commit

Permalink
untest resnet (too expensive for regular test)
Browse files Browse the repository at this point in the history
  • Loading branch information
feiwang3311 committed Oct 10, 2018
1 parent 1d2394c commit 536e48d
Showing 1 changed file with 61 additions and 61 deletions.
122 changes: 61 additions & 61 deletions src/test/scala/lantern/test_onnx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -508,65 +508,65 @@ class ONNXTest extends LanternFunSuite {
runTest(training_func)
}

test("resnet_inference") {

val model_file = model_file_all("resnet50")
val model_dir = model_dir_all("resnet50")
System.out.println(s"testing reading ONNX model using library from $model_file")

val inference_func = new LanternDriverC[String, Unit] with ONNXLib {

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val model = readONNX(model_file)
val (func, x_dims) = (model.inference_func, model.x_dims)

// get test data as TensorProto
val input_file = model_dir + "test_data_set_0/input_0.pb"
val output_file = model_dir + "test_data_set_0/output_0.pb"
val input = readTensor(input_file).tensor
val output = readTensor(output_file).tensor
val output1 = func(input)
Tensor.assertEqual(output, output1)
}
}
val resnet_file = new PrintWriter(new File(gene_dir + "resnet.cpp"))
resnet_file.println(inference_func.code)
resnet_file.flush()
runTest(inference_func)
}

test("resnet_training") {

val model_file = model_file_all("resnet50")
val model_dir = model_dir_all("resnet50")
System.out.println(s"testing reading ONNX model using library from $model_file for training")

val training_func = new LanternDriverC[String, Unit] with ONNXLib {

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val model = readONNX(model_file)
val (func, x_dims, y_dims) = (model.training_func, model.x_dims, model.y_dims)

// fake input and target
val input_file = model_dir + "test_data_set_0/input_0.pb"
val output_file = model_dir + "test_data_set_0/output_0.pb"
val input = readTensor(input_file).tensor
val output = readTensor(output_file).tensor

val target = NewArray[Int](x_dims(0))
for (i <- DataLoop(x_dims(0))) target(i) = 1
def lossFun(dummy: TensorR) = func(TensorR(input)).nllLossB(target).sum()

val loss = gradR_loss(lossFun)(Tensor.zeros(1))
println(loss.data(0))
}
}

val resnet_file = new PrintWriter(new File(gene_dir + "resnetTraining.cpp"))
resnet_file.println(training_func.code)
resnet_file.flush()
// runTest(training_func)
}
// test("resnet_inference") {

// val model_file = model_file_all("resnet50")
// val model_dir = model_dir_all("resnet50")
// System.out.println(s"testing reading ONNX model using library from $model_file")

// val inference_func = new LanternDriverC[String, Unit] with ONNXLib {

// @virtualize
// def snippet(a: Rep[String]): Rep[Unit] = {
// val model = readONNX(model_file)
// val (func, x_dims) = (model.inference_func, model.x_dims)

// // get test data as TensorProto
// val input_file = model_dir + "test_data_set_0/input_0.pb"
// val output_file = model_dir + "test_data_set_0/output_0.pb"
// val input = readTensor(input_file).tensor
// val output = readTensor(output_file).tensor
// val output1 = func(input)
// Tensor.assertEqual(output, output1)
// }
// }
// val resnet_file = new PrintWriter(new File(gene_dir + "resnet.cpp"))
// resnet_file.println(inference_func.code)
// resnet_file.flush()
// runTest(inference_func)
// }

// test("resnet_training") {

// val model_file = model_file_all("resnet50")
// val model_dir = model_dir_all("resnet50")
// System.out.println(s"testing reading ONNX model using library from $model_file for training")

// val training_func = new LanternDriverC[String, Unit] with ONNXLib {

// @virtualize
// def snippet(a: Rep[String]): Rep[Unit] = {
// val model = readONNX(model_file)
// val (func, x_dims, y_dims) = (model.training_func, model.x_dims, model.y_dims)

// // fake input and target
// val input_file = model_dir + "test_data_set_0/input_0.pb"
// val output_file = model_dir + "test_data_set_0/output_0.pb"
// val input = readTensor(input_file).tensor
// val output = readTensor(output_file).tensor

// val target = NewArray[Int](x_dims(0))
// for (i <- DataLoop(x_dims(0))) target(i) = 1
// def lossFun(dummy: TensorR) = func(TensorR(input)).nllLossB(target).sum()

// val loss = gradR_loss(lossFun)(Tensor.zeros(1))
// println(loss.data(0))
// }
// }

// val resnet_file = new PrintWriter(new File(gene_dir + "resnetTraining.cpp"))
// resnet_file.println(training_func.code)
// resnet_file.flush()
// // runTest(training_func)
// }
}

0 comments on commit 536e48d

Please sign in to comment.