Skip to content

Commit

Permalink
Add ErrorSpec::skip_comparison for exhaustive tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657777584
  • Loading branch information
Gregory Pataky authored and copybara-github committed Jul 31, 2024
1 parent a4f6b06 commit 2cea900
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
28 changes: 28 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,22 @@ std::string StringifyNum(const std::array<NativeT, N>& inputs) {
return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")");
}

template <typename ErrorGenerator>
void PrintSkipped(int64_t* skipped, const ErrorGenerator& err_generator) {
// We send some fixed amount of skipped messages to the log. The remainder we
// squelch unless we're at vlog level 2.
constexpr int64_t kMaxMismatchesLoggedToErr = 1000;

(*skipped)++;
if (*skipped < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
LOG(WARNING) << err_generator();
} else if (*skipped == kMaxMismatchesLoggedToErr) {
LOG(WARNING) << "Not printing any more skipped messages; pass "
"--vmodule=exhaustive_op_test=2 to see "
"all of them.";
}
}

template <typename ErrorGenerator>
void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) {
// We send a few mismatches to gunit so they show up nicely in test logs.
Expand All @@ -366,6 +382,7 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) {
"all of them.";
}
}

} // namespace

template <PrimitiveType T, size_t N>
Expand Down Expand Up @@ -400,6 +417,7 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(

absl::Span<const NativeT> result_arr = result_literal.data<NativeT>();

int64_t skipped = 0;
int64_t mismatches = 0;

for (int64_t i = 0; i < result_arr.size(); ++i) {
Expand All @@ -416,6 +434,16 @@ void ExhaustiveOpTestBase<T, N>::ExpectNear(
static_cast<NativeT>(CallOperation(evaluate_op, inputs_ref_ty));
ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs);

if (error_spec.skip_comparison) {
PrintSkipped(&skipped, [&] {
return absl::StrFormat(
"skipping tolerance check for input %s due to "
"ErrorSpec::skip_comparison",
StringifyNum<NativeT, ComponentIntegralNativeT, N>(inputs));
});
continue;
}

if (check_valid_range != nullptr && !check_valid_range(inputs, actual)) {
PrintMismatch(&mismatches, [&] {
return absl::StrFormat(
Expand Down
7 changes: 7 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ struct ErrorSpec {
// spec; this only covers the case when both `expected` and `actual` are
// equal to 0.
bool strict_signed_zeros = false;
// If true, this will skip comparing the output of the test to the expected
// value. This should be used only as a last resort, since it is effectively
// turning off the test for a specific input value set.
bool skip_comparison = false;
};

// Representations of the reference function passed in by the user.
Expand Down Expand Up @@ -617,6 +621,9 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
// Testing will ignore inputs for which known_incorrect_fn_ returns true.
// The argument to the function is the raw bits for the data being test,
// zero extended to 64 bits if the data type is less than 64 bits.
//
// DEPRECATED: Please see ErrorSpec::skip_comparison for an easier framework
// to skip nearness checks for certain unary or binary inputs.
std::function<bool(int64_t)> known_incorrect_fn_;

// If true, allows denormals to be flushed to non-sign-preserving 0.
Expand Down

0 comments on commit 2cea900

Please sign in to comment.