Skip to content

Commit

Permalink
Support Half/BFloat16 in softmax
Browse files Browse the repository at this point in the history
Partial fix for #7748.

ghstack-source-id: bf24332b014e771a1df08c3a739c628b3140013a
ghstack-comment-id: 2608615985
Pull Request resolved: #7867
  • Loading branch information
swolchok committed Jan 23, 2025
1 parent c28ae4a commit e21d862
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 42 deletions.
77 changes: 39 additions & 38 deletions kernels/portable/cpu/op_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,48 @@ Tensor& softmax_out(
// Adjust for negative dim
dim = dim < 0 ? dim + nonzero_dim(in) : dim;

ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
ET_SWITCH_FLOATHBF16_TYPES(
in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

apply_over_dim(
[in_data, out_data](
const size_t size, const size_t stride, const size_t base) {
// calculate max in softmax dim. During softmax computation each
// value is subtracted by the maximum in value before calling exp
// to preserve numerical stability.
const CTYPE max_in = apply_unary_reduce_fn(
[](const CTYPE val_in, CTYPE val_accum) {
return std::max(val_in, val_accum);
},
in_data + base,
size,
stride);
apply_over_dim(
[in_data, out_data](
const size_t size, const size_t stride, const size_t base) {
// calculate max in softmax dim. During softmax computation each
// value is subtracted by the maximum in value before calling exp
// to preserve numerical stability.
const CTYPE max_in = apply_unary_reduce_fn(
[](const CTYPE val_in, CTYPE val_accum) {
return std::max(val_in, val_accum);
},
in_data + base,
size,
stride);

const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
[max_in](const CTYPE val_in) {
return std::exp(val_in - max_in);
},
[](const CTYPE mapped_in, CTYPE val_accum) {
return val_accum + mapped_in;
},
in_data + base,
size,
stride);
const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
[max_in](const CTYPE val_in) {
return std::exp(val_in - max_in);
},
[](const CTYPE mapped_in, CTYPE val_accum) {
return val_accum + mapped_in;
},
in_data + base,
size,
stride);

apply_unary_map_fn(
[max_in, temp_sum](const CTYPE val_in) {
return std::exp(val_in - max_in) / temp_sum;
},
in_data + base,
out_data + base,
size,
stride);
},
in,
dim);
});
apply_unary_map_fn(
[max_in, temp_sum](const CTYPE val_in) {
return std::exp(val_in - max_in) / temp_sum;
},
in_data + base,
out_data + base,
size,
stride);
},
in,
dim);
});

return out;
}
Expand Down
17 changes: 13 additions & 4 deletions kernels/test/op_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@ class OpSoftmaxOutTest : public OperatorTest {
});
// clang-format on

EXPECT_TENSOR_CLOSE(out, expected);
if (DTYPE == ScalarType::BFloat16) {
EXPECT_TENSOR_CLOSE_WITH_TOL(
out,
expected,
1e-2,
executorch::runtime::testing::internal::kDefaultAtol);
} else {
EXPECT_TENSOR_CLOSE(out, expected);
}
}
};

Expand Down Expand Up @@ -100,9 +108,10 @@ TEST_F(OpSoftmaxOutTest, HalfSupport) {
}

TEST_F(OpSoftmaxOutTest, AllDtypesSupported) {
test_dtype<float, ScalarType::Float>();
test_dtype<double, ScalarType::Double>();
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
// TODO: Also add tests for complex, quantized, and other types. Easiest
// way to do that would be to make TensorFactory support zeros() and ones()
// for those types.
}
Expand Down

0 comments on commit e21d862

Please sign in to comment.