Skip to content

Commit

Permalink
constexpr mutex constructor (microsoft#3824)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
cpplearner and StephanTLavavej authored Jul 20, 2023
1 parent b81f61c commit 8d18fec
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 144 deletions.
54 changes: 15 additions & 39 deletions stl/inc/mutex
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,20 @@ _STD_BEGIN
_EXPORT_STD class condition_variable;
_EXPORT_STD class condition_variable_any;

struct _Mtx_internal_imp_mirror {
#ifdef _CRT_WINDOWS
#ifdef _WIN64
static constexpr size_t _Critical_section_size = 16;
#else // _WIN64
static constexpr size_t _Critical_section_size = 8;
#endif // _WIN64
#else // _CRT_WINDOWS
#ifdef _WIN64
static constexpr size_t _Critical_section_size = 64;
#else // _WIN64
static constexpr size_t _Critical_section_size = 36;
#endif // _WIN64
#endif // _CRT_WINDOWS

static constexpr size_t _Critical_section_align = alignof(void*);

int _Type;
_Aligned_storage_t<_Critical_section_size, _Critical_section_align> _Cs_storage;
long _Thread_id;
int _Count;
};

static_assert(sizeof(_Mtx_internal_imp_mirror) == _Mtx_internal_imp_size, "inconsistent size for mutex");
static_assert(alignof(_Mtx_internal_imp_mirror) == _Mtx_internal_imp_alignment, "inconsistent alignment for mutex");

class _Mutex_base { // base class for all mutex types
public:
#ifdef _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR
_Mutex_base(int _Flags = 0) noexcept {
_Mtx_init_in_situ(_Mymtx(), _Flags | _Mtx_try);
}
#else // ^^^ _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR / !_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR vvv
constexpr _Mutex_base(int _Flags = 0) noexcept {
_Mtx_storage._Critical_section = {};
_Mtx_storage._Thread_id = -1;
_Mtx_storage._Type = _Flags | _Mtx_try;
_Mtx_storage._Count = 0;
}
#endif // !_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR

~_Mutex_base() noexcept {
_Mtx_destroy_in_situ(_Mymtx());
Expand Down Expand Up @@ -97,9 +80,9 @@ public:

protected:
_NODISCARD_TRY_CHANGE_STATE bool _Verify_ownership_levels() noexcept {
if (_Mtx_storage_mirror._Count == INT_MAX) {
if (_Mtx_storage._Count == INT_MAX) {
// only occurs for recursive mutexes (N4950 [thread.mutex.recursive]/3)
--_Mtx_storage_mirror._Count;
--_Mtx_storage._Count;
return false;
}

Expand All @@ -110,23 +93,16 @@ private:
friend condition_variable;
friend condition_variable_any;

union {
_Aligned_storage_t<_Mtx_internal_imp_size, _Mtx_internal_imp_alignment> _Mtx_storage;
_Mtx_internal_imp_mirror _Mtx_storage_mirror;
};
_Mtx_internal_imp_t _Mtx_storage{};

_Mtx_t _Mymtx() noexcept { // get pointer to _Mtx_internal_imp_t inside _Mtx_storage
return reinterpret_cast<_Mtx_t>(&_Mtx_storage);
_Mtx_t _Mymtx() noexcept {
return &_Mtx_storage;
}
};

static_assert(sizeof(_Mutex_base) == _Mtx_internal_imp_size, "inconsistent size for mutex");
static_assert(alignof(_Mutex_base) == _Mtx_internal_imp_alignment, "inconsistent alignment for mutex");

_EXPORT_STD class mutex : public _Mutex_base { // class for mutual exclusion
public:
/* constexpr */ mutex() noexcept // TRANSITION, ABI
: _Mutex_base() {}
mutex() noexcept = default;

mutex(const mutex&) = delete;
mutex& operator=(const mutex&) = delete;
Expand Down
52 changes: 38 additions & 14 deletions stl/inc/xthreads.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR
#include <climits>
#include <type_traits>
#include <xtimec.h>

#pragma pack(push, _CRT_PACKING)
Expand All @@ -25,40 +26,64 @@ struct _Thrd_t { // thread identifier for Win32
_Thrd_id_t _Id;
};

// Size and alignment for _Mtx_internal_imp_t and _Cnd_internal_imp_t
using _Smtx_t = void*;

struct _Stl_critical_section {
void* _Unused = nullptr; // TRANSITION, ABI: was the vptr
_Smtx_t _M_srw_lock = nullptr;
};

struct _Mtx_internal_imp_t {
#ifdef _CRT_WINDOWS
#ifdef _WIN64
static constexpr size_t _Critical_section_size = 16;
#else // _WIN64
static constexpr size_t _Critical_section_size = 8;
#endif // _WIN64
#else // _CRT_WINDOWS
#ifdef _WIN64
static constexpr size_t _Critical_section_size = 64;
#else // _WIN64
static constexpr size_t _Critical_section_size = 36;
#endif // _WIN64
#endif // _CRT_WINDOWS

static constexpr size_t _Critical_section_align = alignof(void*);

int _Type{};
union {
_Stl_critical_section _Critical_section{};
_STD _Aligned_storage_t<_Critical_section_size, _Critical_section_align> _Cs_storage;
};
long _Thread_id{};
int _Count{};
};

// Size and alignment for _Cnd_internal_imp_t
#ifdef _CRT_WINDOWS
#ifdef _WIN64
_INLINE_VAR constexpr size_t _Mtx_internal_imp_size = 32;
_INLINE_VAR constexpr size_t _Mtx_internal_imp_alignment = 8;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_size = 16;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_alignment = 8;
#else // _WIN64
_INLINE_VAR constexpr size_t _Mtx_internal_imp_size = 20;
_INLINE_VAR constexpr size_t _Mtx_internal_imp_alignment = 4;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_size = 8;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_alignment = 4;
#endif // _WIN64
#else // _CRT_WINDOWS
#ifdef _WIN64
_INLINE_VAR constexpr size_t _Mtx_internal_imp_size = 80;
_INLINE_VAR constexpr size_t _Mtx_internal_imp_alignment = 8;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_size = 72;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_alignment = 8;
#else // _WIN64
_INLINE_VAR constexpr size_t _Mtx_internal_imp_size = 48;
_INLINE_VAR constexpr size_t _Mtx_internal_imp_alignment = 4;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_size = 40;
_INLINE_VAR constexpr size_t _Cnd_internal_imp_alignment = 4;
#endif // _WIN64
#endif // _CRT_WINDOWS

#ifdef _M_CEE // avoid warning LNK4248: unresolved typeref token for '_Mtx_internal_imp_t'; image may not run
using _Mtx_t = void*;
using _Mtx_t = _Mtx_internal_imp_t*;

#ifdef _M_CEE // avoid warning LNK4248: unresolved typeref token for '_Cnd_internal_imp_t'; image may not run
using _Cnd_t = void*;
#else // ^^^ defined(_M_CEE) / !defined(_M_CEE) vvv
struct _Mtx_internal_imp_t;
struct _Cnd_internal_imp_t;
using _Mtx_t = _Mtx_internal_imp_t*;
using _Cnd_t = _Cnd_internal_imp_t*;
#endif // ^^^ !defined(_M_CEE) ^^^

Expand Down Expand Up @@ -96,7 +121,6 @@ _CRTIMP2_PURE void __cdecl _Mtx_reset_owner(_Mtx_t);

// shared mutex
// these declarations must be in sync with those in sharedmutex.cpp
using _Smtx_t = void*;
void __cdecl _Smtx_lock_exclusive(_Smtx_t*);
void __cdecl _Smtx_lock_shared(_Smtx_t*);
int __cdecl _Smtx_try_lock_exclusive(_Smtx_t*);
Expand Down
4 changes: 2 additions & 2 deletions stl/src/cond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void _Cnd_destroy(const _Cnd_t cond) { // clean up
}

int _Cnd_wait(const _Cnd_t cond, const _Mtx_t mtx) { // wait until signaled
const auto cs = static_cast<Concurrency::details::stl_critical_section_win7*>(_Mtx_getconcrtcs(mtx));
const auto cs = &mtx->_Critical_section;
_Mtx_clear_owner(mtx);
cond->_get_cv()->wait(cs);
_Mtx_reset_owner(mtx);
Expand All @@ -61,7 +61,7 @@ int _Cnd_wait(const _Cnd_t cond, const _Mtx_t mtx) { // wait until signaled
// wait until signaled or timeout
int _Cnd_timedwait(const _Cnd_t cond, const _Mtx_t mtx, const _timespec64* const target) {
int res = _Thrd_success;
const auto cs = static_cast<Concurrency::details::stl_critical_section_win7*>(_Mtx_getconcrtcs(mtx));
const auto cs = &mtx->_Critical_section;
if (target == nullptr) { // no target time specified, wait on mutex
_Mtx_clear_owner(mtx);
cond->_get_cv()->wait(cs);
Expand Down
89 changes: 39 additions & 50 deletions stl/src/mutex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,20 @@ extern "C" [[noreturn]] _CRTIMP2_PURE void _Thrd_abort(const char* msg) { // abo
enum class __stl_sync_api_modes_enum { normal, win7, vista, concrt };
extern "C" _CRTIMP2 void __cdecl __set_stl_sync_api_mode(__stl_sync_api_modes_enum) {}

struct _Mtx_internal_imp_t { // ConcRT mutex
int type;
typename std::_Aligned_storage<Concurrency::details::stl_critical_section_max_size,
Concurrency::details::stl_critical_section_max_alignment>::type cs;
long thread_id;
int count;
[[nodiscard]] Concurrency::details::stl_critical_section_win7* _get_cs() { // get pointer to implementation
return reinterpret_cast<Concurrency::details::stl_critical_section_win7*>(&cs);
}
};

static_assert(sizeof(_Mtx_internal_imp_t) == _Mtx_internal_imp_size, "incorrect _Mtx_internal_imp_size");
static_assert(alignof(_Mtx_internal_imp_t) == _Mtx_internal_imp_alignment, "incorrect _Mtx_internal_imp_alignment");

static_assert(
std::_Mtx_internal_imp_mirror::_Critical_section_size == Concurrency::details::stl_critical_section_max_size);
static_assert(
std::_Mtx_internal_imp_mirror::_Critical_section_align == Concurrency::details::stl_critical_section_max_alignment);
[[nodiscard]] static PSRWLOCK get_srw_lock(_Mtx_t mtx) {
return reinterpret_cast<PSRWLOCK>(&mtx->_Critical_section._M_srw_lock);
}

// TRANSITION, only used when constexpr mutex constructor is not enabled
void _Mtx_init_in_situ(_Mtx_t mtx, int type) { // initialize mutex in situ
Concurrency::details::create_stl_critical_section(mtx->_get_cs());
mtx->thread_id = -1;
mtx->type = type;
mtx->count = 0;
Concurrency::details::create_stl_critical_section(&mtx->_Critical_section);
mtx->_Thread_id = -1;
mtx->_Type = type;
mtx->_Count = 0;
}

void _Mtx_destroy_in_situ(_Mtx_t mtx) { // destroy mutex in situ
_THREAD_ASSERT(mtx->count == 0, "mutex destroyed while busy");
_THREAD_ASSERT(mtx->_Count == 0, "mutex destroyed while busy");
(void) mtx;
}

Expand All @@ -91,27 +77,27 @@ void _Mtx_destroy(_Mtx_t mtx) { // destroy mutex
}

static int mtx_do_lock(_Mtx_t mtx, const _timespec64* target) { // lock mutex
if ((mtx->type & ~_Mtx_recursive) == _Mtx_plain) { // set the lock
if (mtx->thread_id != static_cast<long>(GetCurrentThreadId())) { // not current thread, do lock
mtx->_get_cs()->lock();
mtx->thread_id = static_cast<long>(GetCurrentThreadId());
if ((mtx->_Type & ~_Mtx_recursive) == _Mtx_plain) { // set the lock
if (mtx->_Thread_id != static_cast<long>(GetCurrentThreadId())) { // not current thread, do lock
AcquireSRWLockExclusive(get_srw_lock(mtx));
mtx->_Thread_id = static_cast<long>(GetCurrentThreadId());
}
++mtx->count;
++mtx->_Count;

return _Thrd_success;
} else { // handle timed or recursive mutex
int res = WAIT_TIMEOUT;
if (target == nullptr) { // no target --> plain wait (i.e. infinite timeout)
if (mtx->thread_id != static_cast<long>(GetCurrentThreadId())) {
mtx->_get_cs()->lock();
if (mtx->_Thread_id != static_cast<long>(GetCurrentThreadId())) {
AcquireSRWLockExclusive(get_srw_lock(mtx));
}

res = WAIT_OBJECT_0;

} else if (target->tv_sec < 0 || target->tv_sec == 0 && target->tv_nsec <= 0) {
// target time <= 0 --> plain trylock or timed wait for time that has passed; try to lock with 0 timeout
if (mtx->thread_id != static_cast<long>(GetCurrentThreadId())) { // not this thread, lock it
if (mtx->_get_cs()->try_lock()) {
if (mtx->_Thread_id != static_cast<long>(GetCurrentThreadId())) { // not this thread, lock it
if (TryAcquireSRWLockExclusive(get_srw_lock(mtx)) != 0) {
res = WAIT_OBJECT_0;
} else {
res = WAIT_TIMEOUT;
Expand All @@ -125,8 +111,8 @@ static int mtx_do_lock(_Mtx_t mtx, const _timespec64* target) { // lock mutex
_Timespec64_get_sys(&now);
while (now.tv_sec < target->tv_sec || now.tv_sec == target->tv_sec && now.tv_nsec < target->tv_nsec) {
// time has not expired
if (mtx->thread_id == static_cast<long>(GetCurrentThreadId())
|| mtx->_get_cs()->try_lock()) { // stop waiting
if (mtx->_Thread_id == static_cast<long>(GetCurrentThreadId())
|| TryAcquireSRWLockExclusive(get_srw_lock(mtx)) != 0) { // stop waiting
res = WAIT_OBJECT_0;
break;
} else {
Expand All @@ -138,13 +124,13 @@ static int mtx_do_lock(_Mtx_t mtx, const _timespec64* target) { // lock mutex
}

if (res == WAIT_OBJECT_0 || res == WAIT_ABANDONED) {
if (1 < ++mtx->count) { // check count
if ((mtx->type & _Mtx_recursive) != _Mtx_recursive) { // not recursive, fixup count
--mtx->count;
if (1 < ++mtx->_Count) { // check count
if ((mtx->_Type & _Mtx_recursive) != _Mtx_recursive) { // not recursive, fixup count
--mtx->_Count;
res = WAIT_TIMEOUT;
}
} else {
mtx->thread_id = static_cast<long>(GetCurrentThreadId());
mtx->_Thread_id = static_cast<long>(GetCurrentThreadId());
}
}

Expand All @@ -168,11 +154,14 @@ static int mtx_do_lock(_Mtx_t mtx, const _timespec64* target) { // lock mutex

int _Mtx_unlock(_Mtx_t mtx) { // unlock mutex
_THREAD_ASSERT(
1 <= mtx->count && mtx->thread_id == static_cast<long>(GetCurrentThreadId()), "unlock of unowned mutex");
1 <= mtx->_Count && mtx->_Thread_id == static_cast<long>(GetCurrentThreadId()), "unlock of unowned mutex");

if (--mtx->_Count == 0) { // leave critical section
mtx->_Thread_id = -1;

if (--mtx->count == 0) { // leave critical section
mtx->thread_id = -1;
mtx->_get_cs()->unlock();
auto srw_lock = get_srw_lock(mtx);
_Analysis_assume_lock_held_(*srw_lock);
ReleaseSRWLockExclusive(srw_lock);
}
return _Thrd_success; // TRANSITION, ABI: always returns _Thrd_success
}
Expand All @@ -183,7 +172,7 @@ int _Mtx_lock(_Mtx_t mtx) { // lock mutex

int _Mtx_trylock(_Mtx_t mtx) { // attempt to lock try_mutex
_timespec64 xt;
_THREAD_ASSERT((mtx->type & (_Mtx_try | _Mtx_timed)) != 0, "trylock not supported by mutex");
_THREAD_ASSERT((mtx->_Type & (_Mtx_try | _Mtx_timed)) != 0, "trylock not supported by mutex");
xt.tv_sec = 0;
xt.tv_nsec = 0;
return mtx_do_lock(mtx, &xt);
Expand All @@ -192,27 +181,27 @@ int _Mtx_trylock(_Mtx_t mtx) { // attempt to lock try_mutex
int _Mtx_timedlock(_Mtx_t mtx, const _timespec64* xt) { // attempt to lock timed mutex
int res;

_THREAD_ASSERT((mtx->type & _Mtx_timed) != 0, "timedlock not supported by mutex");
_THREAD_ASSERT((mtx->_Type & _Mtx_timed) != 0, "timedlock not supported by mutex");
res = mtx_do_lock(mtx, xt);
return res == _Thrd_busy ? _Thrd_timedout : res;
}

int _Mtx_current_owns(_Mtx_t mtx) { // test if current thread owns mutex
return mtx->count != 0 && mtx->thread_id == static_cast<long>(GetCurrentThreadId());
return mtx->_Count != 0 && mtx->_Thread_id == static_cast<long>(GetCurrentThreadId());
}

void* _Mtx_getconcrtcs(_Mtx_t mtx) { // get internal cs impl
return mtx->_get_cs();
return &mtx->_Critical_section;
}

void _Mtx_clear_owner(_Mtx_t mtx) { // set owner to nobody
mtx->thread_id = -1;
--mtx->count;
mtx->_Thread_id = -1;
--mtx->_Count;
}

void _Mtx_reset_owner(_Mtx_t mtx) { // set owner to current thread
mtx->thread_id = static_cast<long>(GetCurrentThreadId());
++mtx->count;
mtx->_Thread_id = static_cast<long>(GetCurrentThreadId());
++mtx->_Count;
}

/*
Expand Down
Loading

0 comments on commit 8d18fec

Please sign in to comment.