-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
Tensor.assertEqual
error messages, gardening.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1556,14 +1556,15 @@ trait TensorExp extends Dsl with Diff { | |
|
||
@virtualize | ||
def assertEqual(a: Tensor, b: Tensor, mark: String = "", tal: Float = 0.000001f) = { | ||
assert(a.shape == b.shape, s"ERROR: $mark not equal in dimensionsi ${a.shape.seq} != ${b.shape.seq}\\n") | ||
val errorPrefix = if (mark.nonEmpty) s"ERROR ($mark)" else "ERROR" | ||
assert(a.shape == b.shape, s"$errorPrefix: tensor shapes are not equal, ${a.shape.seq} != ${b.shape.seq}\\n") | ||
|
||
val i = var_new(0) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
GSAir
Collaborator
|
||
while (i < a.scalarCount && { val diff = a.data(i) - b.data(i); diff > -tal && diff < tal }) { | ||
i += 1 | ||
} | ||
if (i < a.scalarCount) { | ||
printf("ERROR: %s not equal in some data - %.4f != %.4f (%d)\\n", mark, a.data(i), b.data(i), i) | ||
printf("%s: tensor data are not equal at index %d, %.4f != %.4f\\n", errorPrefix, i, a.data(i), b.data(i)) | ||
error("") | ||
} | ||
} | ||
|
I believe the "assert data are equal" diagnostic would be more clear and useful if it printed the entire contents of both tensors, rather than the first unequal scalar.
I suppose this requires a nice way to print staged arrays (and hence tensors).
Is "calling
printf
on all scalars in a loop" the best way to implement array printing? @feiwang3311 @GSAirOne concern is that, for large tensors, printing the entire contents is not desirable because it floods the screen. I'd argue this is a non-issue because
assertEqual
is mostly used for testing purposes on small tensors.