Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
iMartyan committed Dec 23, 2024
1 parent ee9070f commit cb2f7f6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions include/oneapi/math/rng/device/detail/geometric_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,17 @@ class distribution_base<oneapi::math::rng::device::geometric<IntType, Method>> {
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

auto uni_res = engine.generate(FpType(0.0), FpType(1.0));
FpType uni_res = engine.generate(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
if constexpr (EngineType::vec_size == 1) {
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}
else {
sycl::vec<IntType, EngineType::vec_size> vec_out;
for (int i = 0; i < EngineType::vec_size; i++)
for (int i = 0; i < EngineType::vec_size; i++) {
vec_out[i] = static_cast<IntType>(sycl::floor(ln_wrapper(uni_res[i]) * inv_ln));
}
return vec_out;
}
}
Expand All @@ -84,8 +85,8 @@ class distribution_base<oneapi::math::rng::device::geometric<IntType, Method>> {
std::is_same_v<IntType, std::int64_t>,
double, float>::type;

auto uni_res = engine.generate_single(FpType(0.0), FpType(1.0));
float inv_ln = ln_wrapper(FpType(1.0) - p_);
FpType uni_res = engine.generate_single(FpType(0.0), FpType(1.0));
FpType inv_ln = ln_wrapper(FpType(1.0) - p_);
inv_ln = FpType(1.0) / inv_ln;
return static_cast<IntType>(sycl::floor(ln_wrapper(uni_res) * inv_ln));
}
Expand Down

0 comments on commit cb2f7f6

Please sign in to comment.