Skip to content

Commit

Permalink
Fix matrix-matrix dot assertion, update tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-zheng committed Oct 8, 2018
1 parent 02d773d commit db0a80f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
3 changes: 1 addition & 2 deletions src/main/scala/lantern/ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,7 @@ trait TensorExp extends Dsl with Diff {
generate_comment(s"dot: ${this.shape.seq}, ${that.shape.seq}")
(this.rank, that.rank) match {
case (1, 1) => assert(this.shape(0) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}")
case (2, 1) => assert(this.shape(1) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}")
case (2, 2) => assert(this.shape(0) == that.shape(1), s"Incompatible shapes: ${this.shape}, ${that.shape}")
case (2, 1) | (2, 2) => assert(this.shape(1) == that.shape(0), s"Incompatible shapes: ${this.shape}, ${that.shape}")
case _ => throw new IllegalArgumentException(
s"Only vector-vector, matrix-vector, and matrix-matrix multiplication are allowed (actual shapes: ${this.shape}, ${that.shape})")
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/scala/lantern/TestCublas.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class TestCublas extends LanternFunSuite {

@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
// Note: it's better to test with non-square matrices.
// Note: it's better to test with matrices [M1 x M2] and [M2 x M3] where M1 != M3.
val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6)
val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7)
val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79)
val m2 = Tensor.fromData(NSeq(3, 1), 2, 3, 4)
val expected = Tensor.fromData(NSeq(2, 1), 20, 47)
Tensor.assertEqual(m1.dot(m2), expected)
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/scala/lantern/TestCudnn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class TestCudnn extends LanternFunSuite {

@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
// Note: it's better to test with non-square matrices.
// Note: it's better to test with matrices [M1 x M2] and [M2 x M3] where M1 != M3.
val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6)
val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7)
val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79)
val m2 = Tensor.fromData(NSeq(3, 1), 2, 3, 4)
val expected = Tensor.fromData(NSeq(2, 1), 20, 47)
Tensor.assertEqual(m1.dot(m2), expected)
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/test/scala/lantern/test_ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ class AdLMSVectorTest extends LanternFunSuite {
val mmdot = new LanternDriverC[String, Unit] {
@virtualize
def snippet(x: Rep[String]): Rep[Unit] = {
// Note: it's better to test with non-square matrices.
// Note: it's better to test with matrices [M1 x M2] and [M2 x M3] where M1 != M3.
val m1 = Tensor.fromData(NSeq(2, 3), 1, 2, 3, 4, 5, 6)
val m2 = Tensor.fromData(NSeq(3, 2), 2, 3, 4, 5, 6, 7)
val expected = Tensor.fromData(NSeq(2, 2), 28, 34, 64, 79)
val m2 = Tensor.fromData(NSeq(3, 1), 2, 3, 4)
val expected = Tensor.fromData(NSeq(2, 1), 20, 47)
Tensor.assertEqual(m1.dot(m2), expected)
}
}
Expand Down

0 comments on commit db0a80f

Please sign in to comment.