Skip to content

Commit

Permalink
Use class_getter? + push proc definitions to c/ntdll
Browse files Browse the repository at this point in the history
  • Loading branch information
ysbaddaden committed Dec 2, 2024
1 parent d165e1e commit d2cb083
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
55 changes: 27 additions & 28 deletions src/crystal/system/win32/iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,45 @@ struct Crystal::System::IOCP
@@wait_completion_packet_methods : Bool? = nil

{% if flag?(:interpreted) %}
# The interpreter doesn't like the interpreted code to dynamically load
# symbols from an external library. We thus merely check for their existence
# then simply call them, so the interpreter will load/call them properly.
def self.wait_completion_packet_methods? : Bool
unless (supported = @@wait_completion_packet_methods).nil?
return supported
end
handle = LibC.LoadLibraryExW(Crystal::System.to_wstr("ntdll.dll"), nil, 0)
return @@wait_completion_packet_methods = false if handle.null?

pointer = LibC.GetProcAddress(handle, "NtCreateWaitCompletionPacket")
return @@wait_completion_packet_methods = false if pointer.null?
# We can't load the symbols from interpreted code since it would create
# interpreted Proc. We thus merely check for the existence of the symbols,
# then let the interpreter load the symbols, which will create interpreter
# Proc (not interpreted) that can be called.
class_getter?(wait_completion_packet_methods : Bool) do
detect_wait_completion_packet_symbols
end

@@wait_completion_packet_methods = true
private def self.detect_wait_completion_packet_methods : Bool
if handle = LibC.LoadLibraryExW(Crystal::System.to_wstr("ntdll.dll"), nil, 0)
!LibC.GetProcAddress(handle, "NtCreateWaitCompletionPacket").null?
else
false
end
end
{% else %}
@@wait_completion_packet_methods : Bool? = nil
@@_NtCreateWaitCompletionPacket = uninitialized Proc(LibC::HANDLE*, LibNTDLL::ACCESS_MASK, LibC::OBJECT_ATTRIBUTES*, LibNTDLL::NTSTATUS)
@@_NtAssociateWaitCompletionPacket = uninitialized Proc(LibC::HANDLE, LibC::HANDLE, LibC::HANDLE, Void*, Void*, LibNTDLL::NTSTATUS, LibC::ULONG*, LibC::BOOLEAN*, LibNTDLL::NTSTATUS)
@@_NtCancelWaitCompletionPacket = uninitialized Proc(LibC::HANDLE, LibC::BOOLEAN, LibNTDLL::NTSTATUS)

def self.wait_completion_packet_methods? : Bool
unless (supported = @@wait_completion_packet_methods).nil?
return supported
end
@@_NtCreateWaitCompletionPacket = uninitialized LibNTDLL::NtCreateWaitCompletionPacketProc
@@_NtAssociateWaitCompletionPacket = uninitialized LibNTDLL::NtAssociateWaitCompletionPacketProc
@@_NtCancelWaitCompletionPacket = uninitialized LibNTDLL::NtCancelWaitCompletionPacketProc

class_getter?(wait_completion_packet_methods : Bool) do
load_wait_completion_packet_symbols
end

private def self.load_wait_completion_packet_methods : Bool
handle = LibC.LoadLibraryExW(Crystal::System.to_wstr("ntdll.dll"), nil, 0)
return @@wait_completion_packet_methods = false if handle.null?
return false if handle.null?

pointer = LibC.GetProcAddress(handle, "NtCreateWaitCompletionPacket")
return @@wait_completion_packet_methods = false if pointer.null?
@@_NtCreateWaitCompletionPacket = Proc(LibC::HANDLE*, LibNTDLL::ACCESS_MASK, LibC::OBJECT_ATTRIBUTES*, LibNTDLL::NTSTATUS).new(pointer, Pointer(Void).null)
return false if pointer.null?
@@_NtCreateWaitCompletionPacket = LibNTDLL::NtCreateWaitCompletionPacketProc.new(pointer, Pointer(Void).null)

pointer = LibC.GetProcAddress(handle, "NtAssociateWaitCompletionPacket")
@@_NtAssociateWaitCompletionPacket = Proc(LibC::HANDLE, LibC::HANDLE, LibC::HANDLE, Void*, Void*, LibNTDLL::NTSTATUS, LibC::ULONG*, LibC::BOOLEAN*, LibNTDLL::NTSTATUS).new(pointer, Pointer(Void).null)
@@_NtAssociateWaitCompletionPacket = LibNTDLL::NtAssociateWaitCompletionPacketProc.new(pointer, Pointer(Void).null)

pointer = LibC.GetProcAddress(handle, "NtCancelWaitCompletionPacket")
@@_NtCancelWaitCompletionPacket = Proc(LibC::HANDLE, LibC::BOOLEAN, LibNTDLL::NTSTATUS).new(pointer, Pointer(Void).null)
@@_NtCancelWaitCompletionPacket = LibNTDLL::NtCancelWaitCompletionPacketProc.new(pointer, Pointer(Void).null)

@@wait_completion_packet_methods = true
true
end
{% end %}

Expand Down
4 changes: 4 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/ntdll.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ lib LibNTDLL

GENERIC_ALL = 0x10000000_u32

alias NtCreateWaitCompletionPacketProc = Proc(LibC::HANDLE*, ACCESS_MASK, LibC::OBJECT_ATTRIBUTES*, NTSTATUS)
alias NtAssociateWaitCompletionPacketProc = Proc(LibC::HANDLE, LibC::HANDLE, LibC::HANDLE, Void*, Void*, NTSTATUS, LibC::ULONG*, LibC::BOOLEAN*, NTSTATUS)
alias NtCancelWaitCompletionPacketProc = Proc(LibC::HANDLE, LibC::BOOLEAN, NTSTATUS)

fun NtCreateWaitCompletionPacket(
waitCompletionPacketHandle : LibC::HANDLE*,
desiredAccess : ACCESS_MASK,
Expand Down

0 comments on commit d2cb083

Please sign in to comment.