Skip to content

Commit

Permalink
[NFC] Improve Tensor.assertEqual error messages, gardening.
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-zheng committed Oct 4, 2018
1 parent 3ab4472 commit 53fa930
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/main/scala/lantern/ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1556,14 +1556,15 @@ trait TensorExp extends Dsl with Diff {

@virtualize
def assertEqual(a: Tensor, b: Tensor, mark: String = "", tal: Float = 0.000001f) = {
assert(a.shape == b.shape, s"ERROR: $mark not equal in dimensionsi ${a.shape.seq} != ${b.shape.seq}\\n")
val errorPrefix = if (mark.nonEmpty) s"ERROR ($mark)" else "ERROR"
assert(a.shape == b.shape, s"$errorPrefix: tensor shapes are not equal, ${a.shape.seq} != ${b.shape.seq}\\n")

val i = var_new(0)

This comment has been minimized.

Copy link
@dan-zheng

dan-zheng Oct 4, 2018

Author Collaborator

I believe the "assert data are equal" diagnostic would be more clear and useful if it printed the entire contents of both tensors, rather than the first unequal scalar.

I suppose this requires a nice way to print staged arrays (and hence tensors).
Is "calling printf on all scalars in a loop" the best way to implement array printing? @feiwang3311 @GSAir

One concern is that, for large tensors, printing the entire contents is not desirable because it floods the screen. I'd argue this is a non-issue because assertEqual is mostly used for testing purposes on small tensors.

This comment has been minimized.

Copy link
@GSAir

GSAir Oct 4, 2018

Collaborator

There is a print function print function in Tensor that does some pretty printing automatically. If it is a 3D or a 4D tensor. Otherwise it is using a printRaw method.

Like we are generating C, there isn't much better than printf in a loop :)

This comment has been minimized.

Copy link
@dan-zheng

dan-zheng Oct 4, 2018

Author Collaborator

Nice!

One idea to generalize printing for n-d tensors is implementing "element tensor indexing":

// Rewrite `def apply(i: Rep[Int]) = data(i)` to return element tensors,
// rather than indexing into data.

val matrix = Tensor.fromData(Seq(3, 4), 0, 1, 2, ..., 11)
// matrix: [[ 0,  1,  2,  3],
//          [ 4,  5,  6,  7],
//          [ 8,  9, 10, 11]]

val first = matrix(0) // The first element tensor in `matrix`: [0, 1, 2, 3].
matrix(1) // The second element tensor: [4, 5, 6, 7].

matrix.data(0) // It's still possible to directly index into scalars.

Then, print can be defined as:

def print(separator: String = ", ") = {
  if (isScalar) printf("%f", this.data(0))
  else {
    // Recurse over element tensors.
    printf("[")
    for (i <- 0 until this.shape(0): Rep[Range]) {
      this(i).print(separator)
      printf(separator)
    }
    printf("]")
  }
}

This comment has been minimized.

Copy link
@GSAir

GSAir Oct 4, 2018

Collaborator

Yes, the apply function shouldn't be used that way anyway. But if you change we have to make sure that it is in fact never used that way.

I was actually talking about that with Fei at lunch. We should create an object Dimensions that is handle all the arithmetic on the strides/dimensions. It will make the extraction of sub-tensors more easily and make the code cleaner.

This comment has been minimized.

Copy link
@dan-zheng

dan-zheng Oct 4, 2018

Author Collaborator

Nice! Swift for TensorFlow follows a similar design: "dimensions" are represented by a TensorShape struct.

I think we should support two types of (basic, non-strided) indexing:

  • tensor(i): returns an element tensor, whose shape is tensor.shape.tail.
  • tensor(range): returns a tensor slice, whose shape is [range.length] + tensor.shape.tail.
    • This is useful for minibatching.

Element tensors and slices should probably be cost-free views into the data of their parent tensor. Otherwise, minibatching will be highly inefficient. We can implement copy-on-write semantics.

This comment has been minimized.

Copy link
@TiarkRompf

TiarkRompf Oct 5, 2018

Collaborator

Makes sense. Should this be a feature ticket?

This comment has been minimized.

Copy link
@dan-zheng

dan-zheng Oct 5, 2018

Author Collaborator

@TiarkRompf Indexing issue: #15.

while (i < a.scalarCount && { val diff = a.data(i) - b.data(i); diff > -tal && diff < tal }) {
i += 1
}
if (i < a.scalarCount) {
printf("ERROR: %s not equal in some data - %.4f != %.4f (%d)\\n", mark, a.data(i), b.data(i), i)
printf("%s: tensor data are not equal at index %d, %.4f != %.4f\\n", errorPrefix, i, a.data(i), b.data(i))
error("")
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/test/scala/lantern/test_ad_lms_vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ class AdLMSVectorTest extends FunSuite {
}

val gene_dir = "/tmp/"
def testByRun(snippet: DslDriverC[String, Unit]) = {
def runTest(snippet: DslDriverC[String, Unit]) = {
val test = new PrintWriter(new File("/tmp/snippet.cpp"))
test.println(snippet.code)
test.flush()
Expand Down Expand Up @@ -1398,7 +1398,7 @@ class AdLMSVectorTest extends FunSuite {
debug_file.flush()

// test runtime assertion of the generated file
testByRun(deb)
runTest(deb)
}

test("op_conv_pad") {
Expand All @@ -1425,7 +1425,7 @@ class AdLMSVectorTest extends FunSuite {
debug_file.println(deb.code)
debug_file.flush()

testByRun(deb)
runTest(deb)
}

test("backprop_op_conv") {
Expand Down Expand Up @@ -1460,7 +1460,7 @@ class AdLMSVectorTest extends FunSuite {
debug_file.println(deb.code)
debug_file.flush()

testByRun(deb)
runTest(deb)
}

test("backprop_op_conv_pad") {
Expand Down

0 comments on commit 53fa930

Please sign in to comment.