Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multithreading primitives on Windows #11647

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions spec/std/thread/mutex_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ describe Thread::Mutex do
a = 0
mutex = Thread::Mutex.new

threads = 10.times.map do
threads = Array.new(10) do
Thread.new do
mutex.synchronize { a += 1 }
end
end.to_a
end

threads.each(&.join)
a.should eq(10)
Expand All @@ -27,15 +27,16 @@ describe Thread::Mutex do
mutex = Thread::Mutex.new
mutex.try_lock.should be_true
mutex.try_lock.should be_false
expect_raises(RuntimeError, "pthread_mutex_lock: ") { mutex.lock }
expect_raises(RuntimeError) { mutex.lock }
mutex.unlock
Thread.new { mutex.synchronize { } }.join
end

it "won't unlock from another thread" do
mutex = Thread::Mutex.new
mutex.lock

expect_raises(RuntimeError, "pthread_mutex_unlock: ") do
expect_raises(RuntimeError) do
Thread.new { mutex.unlock }.join
end

Expand Down
6 changes: 3 additions & 3 deletions spec/win32_std_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ require "./std/syscall_spec.cr"
# require "./std/system/user_spec.cr" (failed codegen)
require "./std/system_error_spec.cr"
require "./std/system_spec.cr"
# require "./std/thread/condition_variable_spec.cr" (failed codegen)
# require "./std/thread/mutex_spec.cr" (failed codegen)
# require "./std/thread_spec.cr" (failed codegen)
require "./std/thread/condition_variable_spec.cr"
require "./std/thread/mutex_spec.cr"
require "./std/thread_spec.cr"
require "./std/time/custom_formats_spec.cr"
require "./std/time/format_spec.cr"
require "./std/time/location_spec.cr"
Expand Down
2 changes: 1 addition & 1 deletion src/crystal/system/thread.cr
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class Thread
end

require "./thread_linked_list"
require "./thread_condition_variable"

{% if flag?(:wasi) %}
require "./wasi/thread"
{% elsif flag?(:unix) %}
require "./unix/pthread"
require "./unix/pthread_condition_variable"
{% elsif flag?(:win32) %}
require "./win32/thread"
{% else %}
Expand Down
29 changes: 29 additions & 0 deletions src/crystal/system/thread_condition_variable.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class Thread
class ConditionVariable
# Creates a new condition variable.
# def initialize

# Unblocks one thread that is waiting on `self`.
# def signal : Nil

# Unblocks all threads that are waiting on `self`.
# def broadcast : Nil

# Causes the calling thread to wait on `self` and unlock the given *mutex*
# atomically.
# def wait(mutex : Thread::Mutex) : Nil

# Causes the calling thread to wait on `self` and unlock the given *mutex*
# atomically within the given *time* span. Yields to the given block if a
# timeout occurs.
# def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
end
end

{% if flag?(:unix) %}
require "./unix/pthread_condition_variable"
{% elsif flag?(:win32) %}
require "./win32/thread_condition_variable"
{% else %}
{% raise "thread condition variable not supported" %}
{% end %}
2 changes: 1 addition & 1 deletion src/crystal/system/thread_mutex.cr
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ end
{% elsif flag?(:win32) %}
require "./win32/thread_mutex"
{% else %}
{% raise "thread not supported" %}
{% raise "thread mutex not supported" %}
{% end %}
2 changes: 1 addition & 1 deletion src/crystal/system/unix/pthread_condition_variable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Thread
raise RuntimeError.from_os_error("pthread_cond_wait", Errno.new(ret)) unless ret == 0
end

def wait(mutex : Thread::Mutex, time : Time::Span)
def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
ret =
{% if flag?(:darwin) %}
ts = uninitialized LibC::Timespec
Expand Down
2 changes: 1 addition & 1 deletion src/crystal/system/win32/process.cr
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct Crystal::System::Process
end

def wait
if LibC.WaitForSingleObject(@process_handle, LibC::INFINITE) != 0
if LibC.WaitForSingleObject(@process_handle, LibC::INFINITE) != LibC::WAIT_OBJECT_0
raise RuntimeError.from_winerror("WaitForSingleObject")
end

Expand Down
89 changes: 72 additions & 17 deletions src/crystal/system/win32/thread.cr
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
require "c/processthreadsapi"
require "c/synchapi"

# TODO: Implement for multithreading.
class Thread
# all thread objects, so the GC can see them (it doesn't scan thread locals)
@@threads = Thread::LinkedList(Thread).new
protected class_getter(threads) { Thread::LinkedList(Thread).new }

@th : LibC::HANDLE
@exception : Exception?
@detached = Atomic(UInt8).new(0)
@main_fiber : Fiber?
Expand All @@ -16,42 +17,91 @@ class Thread
property previous : Thread?

def self.unsafe_each
@@threads.unsafe_each { |thread| yield thread }
threads.unsafe_each { |thread| yield thread }
end

# Starts a new system thread.
def initialize(&@func : ->)
@th = uninitialized LibC::HANDLE

@th = GC.beginthreadex(
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved
security: Pointer(Void).null,
stack_size: LibC::UInt.zero,
start_address: ->(data : Void*) { data.as(Thread).start; LibC::UInt.zero },
arglist: self.as(Void*),
initflag: LibC::UInt.zero,
thrdaddr: Pointer(LibC::UInt).null)

if @th.null?
raise RuntimeError.from_errno("_beginthreadex")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, raising on null both here and in beginthreadex seems a bit redundant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

end
HertzDevil marked this conversation as resolved.
Show resolved Hide resolved
end

# Used once to initialize the thread object representing the main thread of
# the process (that already exists).
def initialize
# `GetCurrentThread` returns a _constant_ and is only meaningful as an
# argument to Win32 APIs; to uniquely identify it we must duplicate the handle
@th = uninitialized LibC::HANDLE
cur_proc = LibC.GetCurrentProcess
LibC.DuplicateHandle(cur_proc, LibC.GetCurrentThread, cur_proc, pointerof(@th), 0, true, LibC::DUPLICATE_SAME_ACCESS)

@func = ->{}
@main_fiber = Fiber.new(stack_address, self)
@@threads.push(self)

Thread.threads.push(self)
end

@@current : Thread? = nil
private def detach
if @detached.compare_and_set(0, 1).last
yield
end
end

# Associates the Thread object to the running system thread.
protected def self.current=(@@current : Thread) : Thread
# Suspends the current thread until this thread terminates.
def join : Nil
detach do
if LibC.WaitForSingleObject(@th, LibC::INFINITE) != LibC::WAIT_OBJECT_0
@exception ||= RuntimeError.from_winerror("WaitForSingleObject")
end
if LibC.CloseHandle(@th) == 0
@exception ||= RuntimeError.from_winerror("CloseHandle")
end
end

if exception = @exception
raise exception
end
end

@[ThreadLocal]
@@current : Thread?

# Returns the Thread object associated to the running system thread.
def self.current : Thread
@@current || raise "BUG: Thread.current returned NULL"
@@current ||= new
end

# Create the thread object for the current thread (aka the main thread of the
# process).
#
# TODO: consider moving to `kernel.cr` or `crystal/main.cr`
self.current = new
# Associates the Thread object to the running system thread.
protected def self.current=(@@current : Thread) : Thread
end

def self.yield : Nil
LibC.SwitchToThread
end

# Returns the Fiber representing the thread's main stack.
def main_fiber
def main_fiber : Fiber
@main_fiber.not_nil!
end

# :nodoc:
def scheduler
def scheduler : Crystal::Scheduler
@scheduler ||= Crystal::Scheduler.new(main_fiber)
end

protected def start
Thread.threads.push(self)
Thread.current = self
@main_fiber = fiber = Fiber.new(stack_address, self)

Expand All @@ -60,9 +110,9 @@ class Thread
rescue ex
@exception = ex
ensure
@@threads.delete(self)
Thread.threads.delete(self)
Fiber.inactive(fiber)
detach_self
detach { LibC.CloseHandle(@th) }
end
end

Expand All @@ -71,4 +121,9 @@ class Thread

Pointer(Void).new(low_limit)
end

# :nodoc:
def to_unsafe
@th
end
end
41 changes: 41 additions & 0 deletions src/crystal/system/win32/thread_condition_variable.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
require "c/synchapi"

# :nodoc:
class Thread
# :nodoc:
class ConditionVariable
def initialize
@cond = uninitialized LibC::CONDITION_VARIABLE
LibC.InitializeConditionVariable(self)
end

def signal : Nil
LibC.WakeConditionVariable(self)
end

def broadcast : Nil
LibC.WakeAllConditionVariable(self)
end

def wait(mutex : Thread::Mutex) : Nil
ret = LibC.SleepConditionVariableCS(self, mutex, LibC::INFINITE)
raise RuntimeError.from_winerror("SleepConditionVariableCS") if ret == 0
end

def wait(mutex : Thread::Mutex, time : Time::Span, & : ->)
ret = LibC.SleepConditionVariableCS(self, mutex, time.total_milliseconds)
return if ret != 0

error = WinError.value
if error == WinError::ERROR_TIMEOUT
yield
else
raise RuntimeError.from_os_error("SleepConditionVariableCS", error)
end
end

def to_unsafe
pointerof(@cond)
end
end
end
57 changes: 55 additions & 2 deletions src/crystal/system/win32/thread_mutex.cr
Original file line number Diff line number Diff line change
@@ -1,8 +1,61 @@
# TODO: Implement
require "c/synchapi"

# :nodoc:
class Thread
# :nodoc:
# for Win32 condition variable interop we must use either a critical section
# or a slim reader/writer lock, not a Win32 mutex
# also note critical sections are reentrant; to match the behaviour in
# `../unix/pthread_mutex.cr` we must do extra housekeeping ourselves
class Mutex
def initialize
@cs = uninitialized LibC::CRITICAL_SECTION
LibC.InitializeCriticalSectionAndSpinCount(self, 1000)
end

def lock : Nil
LibC.EnterCriticalSection(self)
if @cs.recursionCount > 1
LibC.LeaveCriticalSection(self)
raise RuntimeError.new "Attempt to lock a mutex recursively (deadlock)"
end
end

def try_lock : Bool
if LibC.TryEnterCriticalSection(self) != 0
if @cs.recursionCount > 1
LibC.LeaveCriticalSection(self)
false
else
true
end
else
false
end
end

def unlock : Nil
# `owningThread` is declared as `LibC::HANDLE` for historical reasons, so
# the following comparison is correct
unless @cs.owningThread == LibC::HANDLE.new(LibC.GetCurrentThreadId)
raise RuntimeError.new "Attempt to unlock a mutex locked by another thread"
end
LibC.LeaveCriticalSection(self)
end

def synchronize
yield
lock
yield self
ensure
unlock
end

def finalize
LibC.DeleteCriticalSection(self)
end

def to_unsafe
pointerof(@cs)
end
end
end
16 changes: 13 additions & 3 deletions src/gc/boehm.cr
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ lib LibGC

fun size = GC_size(addr : Void*) : LibC::SizeT

{% unless flag?(:win32) || flag?(:wasm32) %}
# Boehm GC requires to use GC_pthread_create and GC_pthread_join instead of pthread_create and pthread_join
# Boehm GC requires to use its own thread manipulation routines instead of pthread's or Win32's
{% if flag?(:win32) %}
fun beginthreadex = GC_beginthreadex(security : Void*, stack_size : LibC::UInt, start_address : Void* -> LibC::UInt,
arglist : Void*, initflag : LibC::UInt, thrdaddr : LibC::UInt*) : Void*
{% elsif !flag?(:wasm32) %}
fun pthread_create = GC_pthread_create(thread : LibC::PthreadT*, attr : LibC::PthreadAttrT*, start : Void* -> Void*, arg : Void*) : LibC::Int
fun pthread_join = GC_pthread_join(thread : LibC::PthreadT, value : Void**) : LibC::Int
fun pthread_detach = GC_pthread_detach(thread : LibC::PthreadT) : LibC::Int
Expand Down Expand Up @@ -235,7 +238,14 @@ module GC
reclaimed_bytes_before_gc: stats.reclaimed_bytes_before_gc)
end

{% unless flag?(:win32) %}
{% if flag?(:win32) %}
# :nodoc:
def self.beginthreadex(security : Void*, stack_size : LibC::UInt, start_address : Void* -> LibC::UInt, arglist : Void*, initflag : LibC::UInt, thrdaddr : LibC::UInt*) : LibC::HANDLE
ret = LibGC.beginthreadex(security, stack_size, start_address, arglist, initflag, thrdaddr)
raise RuntimeError.from_errno("GC_beginthreadex") if ret.null?
ret.as(LibC::HANDLE)
end
{% else %}
# :nodoc:
def self.pthread_create(thread : LibC::PthreadT*, attr : LibC::PthreadAttrT*, start : Void* -> Void*, arg : Void*)
LibGC.pthread_create(thread, attr, start, arg)
Expand Down
Loading