Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Barrier parity waiting and algorithm tweaks #69

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/cuda/std/barrier
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ private:
return __this->__try_wait(__phase);
}
};
struct __poll_tester_parity {
barrier const* __this;
bool __parity;

_LIBCUDACXX_INLINE_VISIBILITY
__poll_tester_parity(barrier const* __this_, bool __parity_)
: __this(__this_)
, __parity(__parity_)
{}

inline _LIBCUDACXX_INLINE_VISIBILITY
bool operator()() const
{
return __this->__try_wait_parity(__parity);
}
};

_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __phase) const {
Expand All @@ -127,6 +143,28 @@ private:
}
}

_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_parity(bool __parity) const {
#if __CUDA_ARCH__ >= 800
if (__isShared(&__barrier)) {
int __ready = 0;
asm volatile ("{\n\t"
".reg .pred p;\n\t"
"mbarrier.test_wait.parity.shared.b64 p, [%1], %2;\n\t"
"selp.b32 %0, 1, 0, p;\n\t"
"}"
: "=r"(__ready)
: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))), "r"(static_cast<std::uint32_t>(__parity))
: "memory");
return bool(__ready);
}
else
#endif
{
return __barrier.__try_wait_parity(__parity);
}
}

template<thread_scope>
friend class pipeline;

Expand Down Expand Up @@ -245,6 +283,12 @@ public:
{
return (1 << 20) - 1;
}

inline _LIBCUDACXX_INLINE_VISIBILITY
friend void barrier_wait_for_parity(barrier const* __self, bool __parity)
{
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__poll_tester_parity(__self, __parity));
}
};

_LIBCUDACXX_END_NAMESPACE_CUDA
Expand Down
126 changes: 102 additions & 24 deletions libcxx/include/barrier
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,20 @@ public:
template<int _Sco>
class __barrier_base<__empty_completion, _Sco> {

static constexpr uint64_t __arrived_shift = 32;
static constexpr uint64_t __guard_shift = 31;
static constexpr uint64_t __phase_shift = 63;

static constexpr uint64_t __expected_unit = 1ull;
static constexpr uint64_t __arrived_unit = 1ull << 32;
static constexpr uint64_t __expected_mask = __arrived_unit - 1;
static constexpr uint64_t __phase_bit = 1ull << 63;
static constexpr uint64_t __arrived_mask = (__phase_bit - 1) & ~__expected_mask;
static constexpr uint64_t __arrived_unit = 1ull << __arrived_shift;

static constexpr uint64_t __phase_bit = 1ull << __phase_shift;
static constexpr uint64_t __arrived_sign_bit = __phase_bit >> 1;
static constexpr uint64_t __guard_bit = 1ull << __guard_shift;

static constexpr uint64_t __expected_mask = __guard_bit - 1;

_LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<uint64_t, _Sco> __phase_arrived_expected;
mutable _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<uint64_t, _Sco> __phase_arrived_expected;

public:
using arrival_token = uint64_t;
Expand All @@ -301,12 +308,53 @@ private:
return __this->__try_wait(__phase);
}
};
struct __poll_tester_parity {
__barrier_base const* __this;
bool __parity;

_LIBCUDACXX_INLINE_VISIBILITY
__poll_tester_parity(__barrier_base const* __this_, bool __parity_)
: __this(__this_)
, __parity(__parity_)
{}

inline _LIBCUDACXX_INLINE_VISIBILITY
bool operator()() const
{
return __this->__try_wait_parity(__parity);
}
};

static inline _LIBCUDACXX_INLINE_VISIBILITY
constexpr uint64_t __init(ptrdiff_t __count) _NOEXCEPT
{
return (((1u << 31) - __count) << 32)
| ((1u << 31) - __count);
return ((__guard_bit - __count) << __arrived_shift)
| (__guard_bit - __count);
}

inline _LIBCUDACXX_INLINE_VISIBILITY
void __update(uint64_t __old, uint64_t __with_inc, memory_order __order) const
{
while((__old & __arrived_sign_bit) == 0) {
auto const __shifted_expected = (__old & __expected_mask) << __arrived_shift;
auto const __new = __old + __with_inc + __shifted_expected;
if(__phase_arrived_expected.compare_exchange_weak(__old, __new, __order)) {
if((__old ^ __new) & __phase_bit) {
if(__with_inc)
__update(__new, 0, memory_order_relaxed);
__phase_arrived_expected.notify_all();
}
return;
}
}
if(__with_inc) {
__old = __phase_arrived_expected.fetch_add(__with_inc, __order);
auto const __new = __old + __with_inc;
if((__old ^ __new) & __phase_bit) {
__update(__new, 0, memory_order_relaxed);
__phase_arrived_expected.notify_all();
}
}
}

public:
Expand All @@ -322,28 +370,52 @@ public:
__barrier_base(__barrier_base const&) = delete;
__barrier_base& operator=(__barrier_base const&) = delete;

_LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __phase) const
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_phase(uint64_t __phase) const
{
auto const __current = __phase_arrived_expected.load(memory_order_acquire);
if((__current & __arrived_sign_bit) == 0)
__update(__current, 0, memory_order_acquire);
return (__current & __phase_bit) != __phase;
}
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait_parity(bool __parity) const
{
return __try_wait_phase(__parity ? __phase_bit : 0);
}
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __try_wait(arrival_token __old) const
{
return __try_wait_phase(__old & __phase_bit);
}

#ifndef _LIBCUDACXX_HAS_PLATFORM_WAIT
inline _LIBCUDACXX_INLINE_VISIBILITY
friend void __expect_extra_arrive(__barrier_base* __self, uint64_t __count = 1)
{
auto const __inc = __count << __arrived_shift;
auto const __old = __self->__phase_arrived_expected.load(memory_order_relaxed);
__self->__update(__old, ~__inc + 1, memory_order_relaxed);
}
inline _LIBCUDACXX_INLINE_VISIBILITY
friend uint64_t __extra_arrive(__barrier_base* __self, uint64_t __count = 1)
{
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
return ((__current & __phase_bit) != __phase);
auto const __inc = __count << __arrived_shift;
return __self->__phase_arrived_expected.fetch_add(__inc, memory_order_release);
}
#endif

_LIBCUDACXX_NODISCARD_ATTRIBUTE inline _LIBCUDACXX_INLINE_VISIBILITY
arrival_token arrive(ptrdiff_t __update = 1)
arrival_token arrive(ptrdiff_t __count = 1)
{
auto const __inc = __arrived_unit * __update;
auto const __old = __phase_arrived_expected.fetch_add(__inc, memory_order_acq_rel);
if((__old ^ (__old + __inc)) & __phase_bit) {
__phase_arrived_expected.fetch_add((__old & __expected_mask) << 32, memory_order_relaxed);
__phase_arrived_expected.notify_all();
}
return __old & __phase_bit;
auto const __old = __phase_arrived_expected.load(memory_order_relaxed);
__update(__old, __count << __arrived_shift, memory_order_release);
return __old;
}
inline _LIBCUDACXX_INLINE_VISIBILITY
void wait(arrival_token&& __phase) const
void wait(arrival_token&& __old) const
{
__libcpp_thread_poll_with_backoff(__poll_tester(this, _CUDA_VSTD::move(__phase)));
__libcpp_thread_poll_with_backoff(__poll_tester(this, _CUDA_VSTD::move(__old)));
}
inline _LIBCUDACXX_INLINE_VISIBILITY
void arrive_and_wait()
Expand All @@ -353,15 +425,21 @@ public:
inline _LIBCUDACXX_INLINE_VISIBILITY
void arrive_and_drop()
{
__phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);
(void)arrive();
auto const __old = __phase_arrived_expected.load(memory_order_relaxed);
__update(__old, __arrived_unit + __expected_unit, memory_order_release);
}

_LIBCUDACXX_INLINE_VISIBILITY
static constexpr ptrdiff_t max() noexcept
{
return numeric_limits<int32_t>::max();
return numeric_limits<int32_t>::max() >> 1;
}

inline _LIBCUDACXX_INLINE_VISIBILITY
friend void wait_for_parity(__barrier_base const* __self, bool __parity)
{
__libcpp_thread_poll_with_backoff(__poll_tester_parity(__self, __parity));
}
};

#endif //_LIBCUDACXX_HAS_NO_TREE_BARRIER
Expand Down