Skip to content

Commit

Permalink
Implement multithreading primitives on Windows (#11647)
Browse files Browse the repository at this point in the history
  • Loading branch information
HertzDevil authored Dec 11, 2022
1 parent 6acf031 commit a3f9199
Show file tree
Hide file tree
Showing 18 changed files with 286 additions and 52 deletions.
9 changes: 5 additions & 4 deletions spec/std/thread/mutex_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ describe Thread::Mutex do
a = 0
mutex = Thread::Mutex.new

threads = 10.times.map do
threads = Array.new(10) do
Thread.new do
mutex.synchronize { a += 1 }
end
end.to_a
end

threads.each(&.join)
a.should eq(10)
Expand All @@ -29,15 +29,16 @@ describe Thread::Mutex do
mutex = Thread::Mutex.new
mutex.try_lock.should be_true
mutex.try_lock.should be_false
expect_raises(RuntimeError, "pthread_mutex_lock: ") { mutex.lock }
expect_raises(RuntimeError) { mutex.lock }
mutex.unlock
Thread.new { mutex.synchronize { } }.join
end

it "won't unlock from another thread" do
mutex = Thread::Mutex.new
mutex.lock

expect_raises(RuntimeError, "pthread_mutex_unlock: ") do
expect_raises(RuntimeError) do
Thread.new { mutex.unlock }.join
end

Expand Down
2 changes: 1 addition & 1 deletion src/crystal/system/thread.cr
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class Thread
end

require "./thread_linked_list"
require "./thread_condition_variable"

{% if flag?(:wasi) %}
require "./wasi/thread"
{% elsif flag?(:unix) %}
require "./unix/pthread"
require "./unix/pthread_condition_variable"
{% elsif flag?(:win32) %}
require "./win32/thread"
{% else %}
Expand Down
31 changes: 31 additions & 0 deletions src/crystal/system/thread_condition_variable.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
class Thread
class ConditionVariable
# Creates a new condition variable.
# def initialize

# Unblocks one thread that is waiting on `self`.
# def signal : Nil

# Unblocks all threads that are waiting on `self`.
# def broadcast : Nil

# Causes the calling thread to wait on `self` and unlock the given *mutex*
# atomically.
# def wait(mutex : Thread::Mutex) : Nil

# Causes the calling thread to wait on `self` and unlock the given *mutex*
# atomically within the given *time* span. Yields to the given block if a
# timeout occurs.
# def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
end
end

{% if flag?(:wasi) %}
require "./wasi/thread_condition_variable"
{% elsif flag?(:unix) %}
require "./unix/pthread_condition_variable"
{% elsif flag?(:win32) %}
require "./win32/thread_condition_variable"
{% else %}
{% raise "thread condition variable not supported" %}
{% end %}
2 changes: 1 addition & 1 deletion src/crystal/system/thread_mutex.cr
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ end
{% elsif flag?(:win32) %}
require "./win32/thread_mutex"
{% else %}
{% raise "thread not supported" %}
{% raise "thread mutex not supported" %}
{% end %}
2 changes: 1 addition & 1 deletion src/crystal/system/unix/pthread_condition_variable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Thread
raise RuntimeError.from_os_error("pthread_cond_wait", Errno.new(ret)) unless ret == 0
end

def wait(mutex : Thread::Mutex, time : Time::Span)
def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
ret =
{% if flag?(:darwin) %}
ts = uninitialized LibC::Timespec
Expand Down
16 changes: 0 additions & 16 deletions src/crystal/system/wasi/thread.cr
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,4 @@ class Thread
# TODO: Implement
Pointer(Void).null
end

# :nodoc:
# TODO: Implement
class ConditionVariable
def signal : Nil
end

def broadcast : Nil
end

def wait(mutex : Thread::Mutex) : Nil
end

def wait(mutex : Thread::Mutex, time : Time::Span, &)
end
end
end
16 changes: 16 additions & 0 deletions src/crystal/system/wasi/thread_condition_variable.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# TODO: Implement
class Thread
class ConditionVariable
def signal : Nil
end

def broadcast : Nil
end

def wait(mutex : Thread::Mutex) : Nil
end

def wait(mutex : Thread::Mutex, time : Time::Span, &)
end
end
end
2 changes: 1 addition & 1 deletion src/crystal/system/win32/process.cr
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct Crystal::System::Process
end

def wait
if LibC.WaitForSingleObject(@process_handle, LibC::INFINITE) != 0
if LibC.WaitForSingleObject(@process_handle, LibC::INFINITE) != LibC::WAIT_OBJECT_0
raise RuntimeError.from_winerror("WaitForSingleObject")
end

Expand Down
85 changes: 68 additions & 17 deletions src/crystal/system/win32/thread.cr
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
require "c/processthreadsapi"
require "c/synchapi"

# TODO: Implement for multithreading.
class Thread
# all thread objects, so the GC can see them (it doesn't scan thread locals)
@@threads = Thread::LinkedList(Thread).new
protected class_getter(threads) { Thread::LinkedList(Thread).new }

@th : LibC::HANDLE
@exception : Exception?
@detached = Atomic(UInt8).new(0)
@main_fiber : Fiber?
Expand All @@ -16,42 +17,87 @@ class Thread
property previous : Thread?

def self.unsafe_each
@@threads.unsafe_each { |thread| yield thread }
threads.unsafe_each { |thread| yield thread }
end

# Starts a new system thread.
def initialize(&@func : ->)
@th = uninitialized LibC::HANDLE

@th = GC.beginthreadex(
security: Pointer(Void).null,
stack_size: LibC::UInt.zero,
start_address: ->(data : Void*) { data.as(Thread).start; LibC::UInt.zero },
arglist: self.as(Void*),
initflag: LibC::UInt.zero,
thrdaddr: Pointer(LibC::UInt).null)
end

# Used once to initialize the thread object representing the main thread of
# the process (that already exists).
def initialize
# `GetCurrentThread` returns a _constant_ and is only meaningful as an
# argument to Win32 APIs; to uniquely identify it we must duplicate the handle
@th = uninitialized LibC::HANDLE
cur_proc = LibC.GetCurrentProcess
LibC.DuplicateHandle(cur_proc, LibC.GetCurrentThread, cur_proc, pointerof(@th), 0, true, LibC::DUPLICATE_SAME_ACCESS)

@func = ->{}
@main_fiber = Fiber.new(stack_address, self)
@@threads.push(self)

Thread.threads.push(self)
end

private def detach
if @detached.compare_and_set(0, 1).last
yield
end
end

@@current : Thread? = nil
# Suspends the current thread until this thread terminates.
def join : Nil
detach do
if LibC.WaitForSingleObject(@th, LibC::INFINITE) != LibC::WAIT_OBJECT_0
@exception ||= RuntimeError.from_winerror("WaitForSingleObject")
end
if LibC.CloseHandle(@th) == 0
@exception ||= RuntimeError.from_winerror("CloseHandle")
end
end

# Associates the Thread object to the running system thread.
protected def self.current=(@@current : Thread) : Thread
if exception = @exception
raise exception
end
end

@[ThreadLocal]
@@current : Thread?

# Returns the Thread object associated to the running system thread.
def self.current : Thread
@@current || raise "BUG: Thread.current returned NULL"
@@current ||= new
end

# Associates the Thread object to the running system thread.
protected def self.current=(@@current : Thread) : Thread
end

# Create the thread object for the current thread (aka the main thread of the
# process).
#
# TODO: consider moving to `kernel.cr` or `crystal/main.cr`
self.current = new
def self.yield : Nil
LibC.SwitchToThread
end

# Returns the Fiber representing the thread's main stack.
def main_fiber
def main_fiber : Fiber
@main_fiber.not_nil!
end

# :nodoc:
def scheduler
def scheduler : Crystal::Scheduler
@scheduler ||= Crystal::Scheduler.new(main_fiber)
end

protected def start
Thread.threads.push(self)
Thread.current = self
@main_fiber = fiber = Fiber.new(stack_address, self)

Expand All @@ -60,9 +106,9 @@ class Thread
rescue ex
@exception = ex
ensure
@@threads.delete(self)
Thread.threads.delete(self)
Fiber.inactive(fiber)
detach_self
detach { LibC.CloseHandle(@th) }
end
end

Expand All @@ -71,4 +117,9 @@ class Thread

Pointer(Void).new(low_limit)
end

# :nodoc:
def to_unsafe
@th
end
end
41 changes: 41 additions & 0 deletions src/crystal/system/win32/thread_condition_variable.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
require "c/synchapi"

# :nodoc:
class Thread
# :nodoc:
class ConditionVariable
def initialize
@cond = uninitialized LibC::CONDITION_VARIABLE
LibC.InitializeConditionVariable(self)
end

def signal : Nil
LibC.WakeConditionVariable(self)
end

def broadcast : Nil
LibC.WakeAllConditionVariable(self)
end

def wait(mutex : Thread::Mutex) : Nil
ret = LibC.SleepConditionVariableCS(self, mutex, LibC::INFINITE)
raise RuntimeError.from_winerror("SleepConditionVariableCS") if ret == 0
end

def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
ret = LibC.SleepConditionVariableCS(self, mutex, time.total_milliseconds)
return if ret != 0

error = WinError.value
if error == WinError::ERROR_TIMEOUT
yield
else
raise RuntimeError.from_os_error("SleepConditionVariableCS", error)
end
end

def to_unsafe
pointerof(@cond)
end
end
end
57 changes: 55 additions & 2 deletions src/crystal/system/win32/thread_mutex.cr
Original file line number Diff line number Diff line change
@@ -1,8 +1,61 @@
# TODO: Implement
require "c/synchapi"

# :nodoc:
class Thread
# :nodoc:
# for Win32 condition variable interop we must use either a critical section
# or a slim reader/writer lock, not a Win32 mutex
# also note critical sections are reentrant; to match the behaviour in
# `../unix/pthread_mutex.cr` we must do extra housekeeping ourselves
class Mutex
def initialize
@cs = uninitialized LibC::CRITICAL_SECTION
LibC.InitializeCriticalSectionAndSpinCount(self, 1000)
end

def lock : Nil
LibC.EnterCriticalSection(self)
if @cs.recursionCount > 1
LibC.LeaveCriticalSection(self)
raise RuntimeError.new "Attempt to lock a mutex recursively (deadlock)"
end
end

def try_lock : Bool
if LibC.TryEnterCriticalSection(self) != 0
if @cs.recursionCount > 1
LibC.LeaveCriticalSection(self)
false
else
true
end
else
false
end
end

def unlock : Nil
# `owningThread` is declared as `LibC::HANDLE` for historical reasons, so
# the following comparison is correct
unless @cs.owningThread == LibC::HANDLE.new(LibC.GetCurrentThreadId)
raise RuntimeError.new "Attempt to unlock a mutex locked by another thread"
end
LibC.LeaveCriticalSection(self)
end

def synchronize
yield
lock
yield self
ensure
unlock
end

def finalize
LibC.DeleteCriticalSection(self)
end

def to_unsafe
pointerof(@cs)
end
end
end
Loading

0 comments on commit a3f9199

Please sign in to comment.