Skip to content

Commit

Permalink
Support Half/BFloat16 in op_allclose
Browse files Browse the repository at this point in the history
Differential Revision: D68366831

Pull Request resolved: #7766
  • Loading branch information
swolchok authored Jan 24, 2025
1 parent 43580f5 commit 99489fe
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 220 deletions.
14 changes: 14 additions & 0 deletions kernels/portable/cpu/op_allclose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ bool tensors_are_close(
a.numel(),
rtol,
atol);
} else if (a.scalar_type() == ScalarType::Half) {
return data_is_close<Half>(
a.const_data_ptr<Half>(),
b.const_data_ptr<Half>(),
a.numel(),
rtol,
atol);
} else if (a.scalar_type() == ScalarType::BFloat16) {
return data_is_close<BFloat16>(
a.const_data_ptr<BFloat16>(),
b.const_data_ptr<BFloat16>(),
a.numel(),
rtol,
atol);
} else {
// Non-floating-point types can be compared bitwise.
return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;
Expand Down
Loading

0 comments on commit 99489fe

Please sign in to comment.