Skip to content

Commit

Permalink
Clean up Seq usages.
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-zheng authored and feiwang3311 committed Oct 13, 2018
1 parent 15400e5 commit 59acc20
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions src/test/scala/lantern/TestTensorDifferentiation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1344,19 +1344,18 @@ class AdLMSVectorTest extends LanternFunSuite {
test("op_conv") {

val deb = new LanternDriverC[String, Unit] {
import scala.collection.Seq;

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val input = Tensor.ones(1, 3, 8, 8)
val kernel = Tensor.ones(1, 3, 3, 3)
val bias = Tensor.ones(1)
val strides: Seq[Int] = List(2, 2).toSeq
val pads: Seq[Int] = List(0,0,0,0).toSeq
val strides = Seq(2, 2)
val pads = Seq(0,0,0,0)
val output = input.conv2D_batch(kernel, Some(bias), strides, pads)

// assert equal
val expect = Tensor.fromData(scala.collection.Seq(1,1,3,3), 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f)
val expect = Tensor.fromData(Seq(1,1,3,3), 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f)
Tensor.assertEqual(expect, output, "expect and output are")
}
}
Expand All @@ -1372,19 +1371,18 @@ class AdLMSVectorTest extends LanternFunSuite {
test("op_conv_pad") {

val deb = new LanternDriverC[String, Unit] {
import scala.collection.Seq;

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val input = Tensor.ones(1, 1, 4, 4)
val kernel = Tensor.ones(1, 1, 3, 3)
val bias = Tensor.zeros(1)
val strides: Seq[Int] = List(3, 3).toSeq
val pads: Seq[Int] = List(1, 1, 1, 1).toSeq
val strides = Seq(3, 3)
val pads = Seq(1, 1, 1, 1)
val output = input.conv2D_batch(kernel, Some(bias), strides, pads)

// assert equal
val expect = Tensor.fromData(scala.collection.Seq(1,1,2,2), 4.0f, 4.0f, 4.0f, 4.0f)
val expect = Tensor.fromData(Seq(1,1,2,2), 4.0f, 4.0f, 4.0f, 4.0f)
Tensor.assertEqual(expect, output, "expect and output are")
}
}
Expand All @@ -1398,17 +1396,16 @@ class AdLMSVectorTest extends LanternFunSuite {

test("op_conv_pad_nobias") {
val deb = new LanternDriverC[String, Unit] {
import scala.collection.Seq;

@virtualize
def snippet(a: Rep[String]): Rep[Unit] = {
val input = Tensor.ones(1, 1, 4, 4)
val kernel = Tensor.ones(1, 1, 3, 3)
val strides: Seq[Int] = List(3, 3).toSeq
val pads: Seq[Int] = List(1, 1, 1, 1).toSeq
val strides = Seq(3, 3)
val pads = Seq(1, 1, 1, 1)
val output = input.conv2D_batch(kernel, None, strides, pads)

val expect = Tensor.fromData(scala.collection.Seq(1,1,2,2), 4.0f, 4.0f, 4.0f, 4.0f)
val expect = Tensor.fromData(Seq(1,1,2,2), 4.0f, 4.0f, 4.0f, 4.0f)
Tensor.assertEqual(expect, output, "expect and output are")
}
}
Expand All @@ -1428,8 +1425,8 @@ class AdLMSVectorTest extends LanternFunSuite {
val input = TensorR(Tensor.ones(1,1,4,4))
val kernel = TensorR(Tensor.ones(1,1,3,3))
val bias = TensorR(Tensor.zeros(1))
val strides: scala.collection.Seq[Int] = List(1,1).toSeq
val pads: scala.collection.Seq[Int] = List(0,0,0,0).toSeq
val strides = Seq(1,1)
val pads = Seq(0,0,0,0)

def lossFun(x: TensorR) = {
val output = input.convBBP(kernel, Some(bias), strides, pads)
Expand All @@ -1439,10 +1436,10 @@ class AdLMSVectorTest extends LanternFunSuite {
gradR_loss(lossFun)(Tensor.zeros(1))

// assert equal
val expect_input_grad = Tensor.fromData(scala.collection.Seq(1,1,4,4),
val expect_input_grad = Tensor.fromData(Seq(1,1,4,4),
1.0f, 2.0f, 2.0f, 1.0f, 2.0f, 4.0f, 4.0f, 2.0f, 2.0f, 4.0f, 4.0f, 2.0f, 1.0f, 2.0f, 2.0f, 1.0f)
val expect_kernel_grad = Tensor.fill(Seq(1, 1, 3, 3), 4.0f)
val expect_bias_grad = Tensor.fromData(scala.collection.Seq(1), 4.0f)
val expect_bias_grad = Tensor.fromData(Seq(1), 4.0f)
Tensor.assertEqual(expect_input_grad * 2.0f, input.d, "expect and input.gradient are")
Tensor.assertEqual(expect_kernel_grad * 2.0f, kernel.d, "expect and kernel.gradient are")
Tensor.assertEqual(expect_bias_grad * 2.0f, bias.d, "expect and bias.gradient are")
Expand All @@ -1459,8 +1456,8 @@ class AdLMSVectorTest extends LanternFunSuite {
val input = TensorR(Tensor.ones(1,1,4,4))
val kernel = TensorR(Tensor.ones(1,1,3,3))
val bias = TensorR(Tensor.zeros(1))
val strides: scala.collection.Seq[Int] = List(3,3).toSeq
val pads: scala.collection.Seq[Int] = List(1,1,1,1).toSeq
val strides = Seq(3,3)
val pads = Seq(1,1,1,1)

def lossFun(x: TensorR) = {
val output = input.convBBP(kernel, Some(bias), strides, pads)
Expand All @@ -1469,11 +1466,11 @@ class AdLMSVectorTest extends LanternFunSuite {
val loss = gradR_loss(lossFun)(Tensor.zeros(1))

// assert equal
val expect_input_grad = Tensor.fromData(scala.collection.Seq(1,1,4,4),
val expect_input_grad = Tensor.fromData(Seq(1,1,4,4),
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f)
val expect_kernel_grad = Tensor.fromData(scala.collection.Seq(1,1,3,3),
val expect_kernel_grad = Tensor.fromData(Seq(1,1,3,3),
1.0f, 2.0f, 1.0f, 2.0f, 4.0f, 2.0f, 1.0f, 2.0f, 1.0f)
val expect_bias_grad = Tensor.fromData(scala.collection.Seq(1), 4.0f)
val expect_bias_grad = Tensor.fromData(Seq(1), 4.0f)
Tensor.assertEqual(expect_input_grad, input.d, "expect and input.gradient are")
Tensor.assertEqual(expect_kernel_grad, kernel.d, "expect and kernel.gradient are")
Tensor.assertEqual(expect_bias_grad, bias.d, "expect and bias.gradient are")
Expand All @@ -1492,7 +1489,7 @@ class AdLMSVectorTest extends LanternFunSuite {
def snippet(a: Rep[String]): Rep[Unit] = {
val input = TensorR(Tensor.ones(1,1,4,4))
def lossFun(x: TensorR) = {
input.averagePoolBK(List(2, 2).toSeq, List(2, 2).toSeq, None).sum()
input.averagePoolBK(Seq(2, 2), Seq(2, 2), None).sum()
}
gradR_loss(lossFun)(Tensor.zeros(1))
// assert equal
Expand All @@ -1501,7 +1498,7 @@ class AdLMSVectorTest extends LanternFunSuite {

input.clear_grad()
def lossFun2(x: TensorR) = {
input.averagePoolBK(List(2, 2).toSeq, List(1, 1).toSeq, None).sum()
input.averagePoolBK(Seq(2, 2), Seq(1, 1), None).sum()
}
gradR_loss(lossFun2)(Tensor.zeros(1))
// assert equal
Expand Down

0 comments on commit 59acc20

Please sign in to comment.