Skip to content

Commit

Permalink
Make DNS resolution asynchronous on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
HertzDevil committed Sep 6, 2024
1 parent db2ecd7 commit 45d0f0a
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 34 deletions.
32 changes: 28 additions & 4 deletions spec/std/socket/addrinfo_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ describe Socket::Addrinfo, tags: "network" do
end
end
end

it "raises helpful message on getaddrinfo failure" do
expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM)
end
end

{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::STREAM, timeout: 0.milliseconds)
end
end
{% end %}
end

describe ".tcp" do
Expand All @@ -37,11 +51,13 @@ describe Socket::Addrinfo, tags: "network" do
end
end

it "raises helpful message on getaddrinfo failure" do
expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM)
{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.tcp("badhostname", 80, timeout: 0.milliseconds)
end
end
end
{% end %}
end

describe ".udp" do
Expand All @@ -56,6 +72,14 @@ describe Socket::Addrinfo, tags: "network" do
typeof(addrinfo).should eq(Socket::Addrinfo)
end
end

{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.udp("badhostname", 80, timeout: 0.milliseconds)
end
end
{% end %}
end

describe "#ip_address" do
Expand Down
51 changes: 43 additions & 8 deletions src/crystal/system/win32/addrinfo.cr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Crystal::System::Addrinfo
alias Handle = LibC::Addrinfo*
alias Handle = LibC::ADDRINFOEXW*

@addr : LibC::SockaddrIn6

Expand Down Expand Up @@ -30,7 +30,7 @@ module Crystal::System::Addrinfo
end

def self.getaddrinfo(domain, service, family, type, protocol, timeout) : Handle
hints = LibC::Addrinfo.new
hints = LibC::ADDRINFOEXW.new
hints.ai_family = (family || ::Socket::Family::UNSPEC).to_i32
hints.ai_socktype = type
hints.ai_protocol = protocol
Expand All @@ -43,19 +43,54 @@ module Crystal::System::Addrinfo
end
end

ret = LibC.getaddrinfo(domain, service.to_s, pointerof(hints), out ptr)
unless ret.zero?
error = WinError.new(ret.to_u32!)
raise ::Socket::Addrinfo::Error.from_os_error(nil, error, domain: domain, type: type, protocol: protocol, service: service)
# we use `LibC::INVALID_HANDLE_VALUE` because there is no I/O handle
# associated with a GetAddrInfoExW call, but later we'll replace it with a
# cancellation handle
Crystal::IOCP::OverlappedOperation.run(LibC::INVALID_HANDLE_VALUE) do |operation|
# This assumes the `OVERLAPPED` struct's `internalHigh` field is unused,
# so that we could pass closure data to the completion routine
operation.internal_high = Crystal::EventLoop.current.iocp.address.to_u64!

completion_routine = LibC::LPLOOKUPSERVICE_COMPLETION_ROUTINE.new do |dwError, dwBytes, lpOverlapped|
orig_operation = Crystal::IOCP::OverlappedOperation.unbox(lpOverlapped)
iocp = LibC::HANDLE.new(orig_operation.internal_high)
LibC.PostQueuedCompletionStatus(iocp, 0, 0, lpOverlapped)
end

# NOTE: we handle the timeout ourselves so we don't pass a `LibC::Timeval`
# to Win32 here
result = LibC.GetAddrInfoExW(
Crystal::System.to_wstr(domain), Crystal::System.to_wstr(service.to_s), LibC::NS_DNS, nil, pointerof(hints),
out addrinfos, nil, operation, completion_routine, out cancel_handle)

if result == 0
return addrinfos
else
case error = WinError.new(result.to_u32!)
when .wsa_io_pending?
# used in `Crystal::IOCP::OverlappedOperation#try_cancel_getaddrinfo`
operation.handle = cancel_handle
else
raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service)
end
end

operation.wait_for_getaddrinfo_result(timeout) do |error|
case error
when .wsa_e_cancelled?
raise IO::TimeoutError.new("GetAddrInfoExW timed out")
else
raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service)
end
end
end
ptr
end

def self.next_addrinfo(addrinfo : Handle) : Handle
addrinfo.value.ai_next
end

def self.free_addrinfo(addrinfo : Handle)
LibC.freeaddrinfo(addrinfo)
LibC.FreeAddrInfoExW(addrinfo)
end
end
45 changes: 41 additions & 4 deletions src/crystal/system/win32/iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ module Crystal::IOCP
@overlapped = LibC::OVERLAPPED.new
@fiber = Fiber.current
@state : State = :started
setter handle : LibC::HANDLE

def initialize(@handle : LibC::HANDLE)
end
Expand All @@ -110,8 +111,16 @@ module Crystal::IOCP
pointerof(@overlapped)
end

def internal_high
@overlapped.internalHigh
end

def internal_high=(value)
@overlapped.internalHigh = value
end

def wait_for_result(timeout, &)
wait_for_completion(timeout)
wait_for_completion(timeout) { try_cancel }

result = LibC.GetOverlappedResult(@handle, self, out bytes, 0)
if result.zero?
Expand All @@ -125,7 +134,7 @@ module Crystal::IOCP
end

def wait_for_wsa_result(timeout, &)
wait_for_completion(timeout)
wait_for_completion(timeout) { try_cancel }

flags = 0_u32
result = LibC.WSAGetOverlappedResult(LibC::SOCKET.new(@handle.address), self, out bytes, false, pointerof(flags))
Expand All @@ -139,6 +148,20 @@ module Crystal::IOCP
bytes
end

def wait_for_getaddrinfo_result(timeout, &)
wait_for_completion(timeout) { try_cancel_getaddrinfo }

result = LibC.GetAddrInfoExOverlappedResult(self)
unless result.zero?
error = WinError.new(result.to_u32!)
yield error

raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExOverlappedResult", error)
end

@overlapped.union.pointer.as(LibC::ADDRINFOEXW**).value
end

protected def schedule(&)
done!
yield @fiber
Expand Down Expand Up @@ -168,15 +191,29 @@ module Crystal::IOCP
true
end

def wait_for_completion(timeout)
def try_cancel_getaddrinfo : Bool
ret = LibC.GetAddrInfoExCancel(pointerof(@handle))
unless ret.zero?
case error = WinError.new(ret.to_u32!)
when .wsa_invalid_handle?
# Operation has already completed, do nothing
return false
else
raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExCancel", error)
end
end
true
end

def wait_for_completion(timeout, & : -> Bool)
if timeout
sleep timeout
else
Fiber.suspend
end

unless @state.done?
if try_cancel
if yield
# Wait for cancellation to complete. We must not free the operation
# until it's completed.
Fiber.suspend
Expand Down
8 changes: 4 additions & 4 deletions src/http/client.cr
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,10 @@ class HTTP::Client
# ```
setter connect_timeout : Time::Span?

# **This method has no effect right now**
#
# Sets the number of seconds to wait when resolving a name, before raising an `IO::TimeoutError`.
#
# NOTE: *dns_timeout* is currently only supported on Windows.
#
# ```
# require "http/client"
#
Expand All @@ -363,10 +363,10 @@ class HTTP::Client
self.dns_timeout = dns_timeout.seconds
end

# **This method has no effect right now**
#
# Sets the number of seconds to wait when resolving a name with a `Time::Span`, before raising an `IO::TimeoutError`.
#
# NOTE: *dns_timeout* is currently only supported on Windows.
#
# ```
# require "http/client"
#
Expand Down
7 changes: 7 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/winsock2.cr
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ lib LibC
lpVendorInfo : Char*
end

NS_DNS = 12_u32

INVALID_SOCKET = ~SOCKET.new(0)
SOCKET_ERROR = -1

Expand Down Expand Up @@ -111,6 +113,11 @@ lib LibC

alias WSAOVERLAPPED_COMPLETION_ROUTINE = Proc(DWORD, DWORD, WSAOVERLAPPED*, DWORD, Void)

struct Timeval
tv_sec : Long
tv_usec : Long
end

struct Linger
l_onoff : UShort
l_linger : UShort
Expand Down
14 changes: 14 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/ws2def.cr
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,18 @@ lib LibC
ai_addr : Sockaddr*
ai_next : Addrinfo*
end

struct ADDRINFOEXW
ai_flags : Int
ai_family : Int
ai_socktype : Int
ai_protocol : Int
ai_addrlen : SizeT
ai_canonname : LPWSTR
ai_addr : Sockaddr*
ai_blob : Void*
ai_bloblen : SizeT
ai_provider : GUID*
ai_next : ADDRINFOEXW*
end
end
20 changes: 20 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,24 @@ lib LibC
fun getaddrinfo(pNodeName : Char*, pServiceName : Char*, pHints : Addrinfo*, ppResult : Addrinfo**) : Int
fun inet_ntop(family : Int, pAddr : Void*, pStringBuf : Char*, stringBufSize : SizeT) : Char*
fun inet_pton(family : Int, pszAddrString : Char*, pAddrBuf : Void*) : Int

fun FreeAddrInfoExW(pAddrInfoEx : ADDRINFOEXW*)

alias LPLOOKUPSERVICE_COMPLETION_ROUTINE = DWORD, DWORD, WSAOVERLAPPED* ->

fun GetAddrInfoExW(
pName : LPWSTR,
pServiceName : LPWSTR,
dwNameSpace : DWORD,
lpNspId : GUID*,
hints : ADDRINFOEXW*,
ppResult : ADDRINFOEXW**,
timeout : Timeval*,
lpOverlapped : OVERLAPPED*,
lpCompletionRoutine : LPLOOKUPSERVICE_COMPLETION_ROUTINE,
lpHandle : HANDLE*,
) : Int

fun GetAddrInfoExOverlappedResult(lpOverlapped : OVERLAPPED*) : Int
fun GetAddrInfoExCancel(lpHandle : HANDLE*) : Int
end
30 changes: 19 additions & 11 deletions src/socket/addrinfo.cr
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Socket
# specified.
# - *protocol* is the intended socket protocol (e.g. `Protocol::TCP`) and
# should be specified.
# - *timeout* is optional and specifies the maximum time to wait before
# `IO::TimeoutError` is raised. Currently this is only supported on
# Windows.
#
# Example:
# ```
Expand Down Expand Up @@ -107,21 +110,26 @@ class Socket
"Hostname lookup for #{domain} failed"
end

def self.os_error_message(os_error : Errno, *, type, service, protocol, **opts)
case os_error.value
def self.os_error_message(os_error : Errno | WinError, *, type, service, protocol, **opts)
# when `EAI_NONAME` etc. is an integer then only `os_error.value` can
# match; when `EAI_NONAME` is a `WinError` then `os_error` itself can
# match
case os_error.is_a?(Errno) ? os_error.value : os_error
when LibC::EAI_NONAME
"No address found"
when LibC::EAI_SOCKTYPE
"The requested socket type #{type} protocol #{protocol} is not supported"
when LibC::EAI_SERVICE
"The requested service #{service} is not available for the requested socket type #{type}"
else
{% unless flag?(:win32) %}
# There's no need for a special win32 branch because the os_error on Windows
# is of type WinError, which wouldn't match this overload anyways.

String.new(LibC.gai_strerror(os_error.value))
# Win32 also has this method, but `WinError` is already sufficient
{% if LibC.has_method?(:gai_strerror) %}
if os_error.is_a?(Errno)
return String.new(LibC.gai_strerror(os_error))
end
{% end %}

super
end
end
end
Expand All @@ -148,13 +156,13 @@ class Socket
# addrinfos = Socket::Addrinfo.tcp("example.org", 80)
# ```
def self.tcp(domain : String, service, family = Family::UNSPEC, timeout = nil) : Array(Addrinfo)
resolve(domain, service, family, Type::STREAM, Protocol::TCP)
resolve(domain, service, family, Type::STREAM, Protocol::TCP, timeout)
end

# Resolves a domain for the TCP protocol with STREAM type, and yields each
# possible `Addrinfo`. See `#resolve` for details.
def self.tcp(domain : String, service, family = Family::UNSPEC, timeout = nil, &)
resolve(domain, service, family, Type::STREAM, Protocol::TCP) { |addrinfo| yield addrinfo }
resolve(domain, service, family, Type::STREAM, Protocol::TCP, timeout) { |addrinfo| yield addrinfo }
end

# Resolves *domain* for the UDP protocol and returns an `Array` of possible
Expand All @@ -167,13 +175,13 @@ class Socket
# addrinfos = Socket::Addrinfo.udp("example.org", 53)
# ```
def self.udp(domain : String, service, family = Family::UNSPEC, timeout = nil) : Array(Addrinfo)
resolve(domain, service, family, Type::DGRAM, Protocol::UDP)
resolve(domain, service, family, Type::DGRAM, Protocol::UDP, timeout)
end

# Resolves a domain for the UDP protocol with DGRAM type, and yields each
# possible `Addrinfo`. See `#resolve` for details.
def self.udp(domain : String, service, family = Family::UNSPEC, timeout = nil, &)
resolve(domain, service, family, Type::DGRAM, Protocol::UDP) { |addrinfo| yield addrinfo }
resolve(domain, service, family, Type::DGRAM, Protocol::UDP, timeout) { |addrinfo| yield addrinfo }
end

# Returns an `IPAddress` matching this addrinfo.
Expand Down
2 changes: 1 addition & 1 deletion src/socket/tcp_socket.cr
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TCPSocket < IPSocket
# connection time to the remote server with `connect_timeout`. Both values
# must be in seconds (integers or floats).
#
# Note that `dns_timeout` is currently ignored.
# NOTE: *dns_timeout* is currently only supported on Windows.
def initialize(host : String, port, dns_timeout = nil, connect_timeout = nil, blocking = false)
Addrinfo.tcp(host, port, timeout: dns_timeout) do |addrinfo|
super(addrinfo.family, addrinfo.type, addrinfo.protocol, blocking)
Expand Down
Loading

0 comments on commit 45d0f0a

Please sign in to comment.