Skip to content

Commit

Permalink
Support Half/BFloat16 in arange
Browse files Browse the repository at this point in the history
Partial fix for #7748.

ghstack-source-id: c7d2a59a32ebda2eb38b57a0558c4d7a2673fbd0
ghstack-comment-id: 2605368953
Pull Request resolved: #7791
  • Loading branch information
swolchok committed Jan 21, 2025
1 parent 948fba6 commit e14a3f3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions kernels/portable/cpu/op_arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Tensor& arange_out(KernelRuntimeContext& ctx, const Scalar& end, Tensor& out) {
InvalidArgument,
out);

ET_SWITCH_REAL_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() {
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() {
auto out_data = out.mutable_data_ptr<CTYPE>();
for (size_t i = 0; i < size; i++) {
out_data[i] = static_cast<CTYPE>(i);
Expand Down Expand Up @@ -88,7 +88,7 @@ Tensor& arange_start_out(
InvalidArgument,
out);

ET_SWITCH_REAL_TYPES(
ET_SWITCH_REALHBF16_TYPES(
out.scalar_type(), ctx, "arange.start_out", CTYPE, [&]() {
auto out_data = out.mutable_data_ptr<CTYPE>();
for (size_t i = 0; i < size; i++) {
Expand Down
8 changes: 4 additions & 4 deletions kernels/test/op_arange_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ class OpArangeStartOutTest : public OperatorTest {
};

/// A generic smoke test that works for any dtype that supports zeros().
TEST_F(OpArangeOutTest, AllRealDtypesSupported) {
TEST_F(OpArangeOutTest, AllRealHBF16DtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_arange_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

Expand Down Expand Up @@ -164,10 +164,10 @@ TEST_F(OpArangeOutTest, DynamicShapeUnbound) {
}

/// A generic smoke test that works for any dtype that supports zeros().
TEST_F(OpArangeStartOutTest, AllRealDtypesSupported) {
TEST_F(OpArangeStartOutTest, AllRealHBF16DtypesSupported) {
#define TEST_ENTRY(ctype, dtype) \
test_arange_start_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

Expand Down

0 comments on commit e14a3f3

Please sign in to comment.