From 45d0f0aaa2355603b1d771bd344ebf84808df8ac Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 4 Sep 2024 13:02:48 +0800 Subject: [PATCH] Make DNS resolution asynchronous on Windows --- spec/std/socket/addrinfo_spec.cr | 32 +++++++++++-- src/crystal/system/win32/addrinfo.cr | 51 +++++++++++++++++---- src/crystal/system/win32/iocp.cr | 45 ++++++++++++++++-- src/http/client.cr | 8 ++-- src/lib_c/x86_64-windows-msvc/c/winsock2.cr | 7 +++ src/lib_c/x86_64-windows-msvc/c/ws2def.cr | 14 ++++++ src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr | 20 ++++++++ src/socket/addrinfo.cr | 30 +++++++----- src/socket/tcp_socket.cr | 2 +- src/winerror.cr | 5 +- 10 files changed, 180 insertions(+), 34 deletions(-) diff --git a/spec/std/socket/addrinfo_spec.cr b/spec/std/socket/addrinfo_spec.cr index 615058472525..109eb383562b 100644 --- a/spec/std/socket/addrinfo_spec.cr +++ b/spec/std/socket/addrinfo_spec.cr @@ -22,6 +22,20 @@ describe Socket::Addrinfo, tags: "network" do end end end + + it "raises helpful message on getaddrinfo failure" do + expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do + Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM) + end + end + + {% if flag?(:win32) %} + it "raises timeout error" do + expect_raises(IO::TimeoutError) do + Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::STREAM, timeout: 0.milliseconds) + end + end + {% end %} end describe ".tcp" do @@ -37,11 +51,13 @@ describe Socket::Addrinfo, tags: "network" do end end - it "raises helpful message on getaddrinfo failure" do - expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do - Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM) + {% if flag?(:win32) %} + it "raises timeout error" do + expect_raises(IO::TimeoutError) do + Socket::Addrinfo.tcp("badhostname", 80, timeout: 0.milliseconds) + end end - end + {% end %} end describe ".udp" do @@ -56,6 +72,14 @@ describe Socket::Addrinfo, tags: "network" do typeof(addrinfo).should eq(Socket::Addrinfo) end end + + {% if flag?(:win32) %} + it "raises timeout error" do + expect_raises(IO::TimeoutError) do + Socket::Addrinfo.udp("badhostname", 80, timeout: 0.milliseconds) + end + end + {% end %} end describe "#ip_address" do diff --git a/src/crystal/system/win32/addrinfo.cr b/src/crystal/system/win32/addrinfo.cr index b033d61f16e7..b871098cb9d4 100644 --- a/src/crystal/system/win32/addrinfo.cr +++ b/src/crystal/system/win32/addrinfo.cr @@ -1,5 +1,5 @@ module Crystal::System::Addrinfo - alias Handle = LibC::Addrinfo* + alias Handle = LibC::ADDRINFOEXW* @addr : LibC::SockaddrIn6 @@ -30,7 +30,7 @@ module Crystal::System::Addrinfo end def self.getaddrinfo(domain, service, family, type, protocol, timeout) : Handle - hints = LibC::Addrinfo.new + hints = LibC::ADDRINFOEXW.new hints.ai_family = (family || ::Socket::Family::UNSPEC).to_i32 hints.ai_socktype = type hints.ai_protocol = protocol @@ -43,12 +43,47 @@ module Crystal::System::Addrinfo end end - ret = LibC.getaddrinfo(domain, service.to_s, pointerof(hints), out ptr) - unless ret.zero? - error = WinError.new(ret.to_u32!) - raise ::Socket::Addrinfo::Error.from_os_error(nil, error, domain: domain, type: type, protocol: protocol, service: service) + # we use `LibC::INVALID_HANDLE_VALUE` because there is no I/O handle + # associated with a GetAddrInfoExW call, but later we'll replace it with a + # cancellation handle + Crystal::IOCP::OverlappedOperation.run(LibC::INVALID_HANDLE_VALUE) do |operation| + # This assumes the `OVERLAPPED` struct's `internalHigh` field is unused, + # so that we could pass closure data to the completion routine + operation.internal_high = Crystal::EventLoop.current.iocp.address.to_u64! + + completion_routine = LibC::LPLOOKUPSERVICE_COMPLETION_ROUTINE.new do |dwError, dwBytes, lpOverlapped| + orig_operation = Crystal::IOCP::OverlappedOperation.unbox(lpOverlapped) + iocp = LibC::HANDLE.new(orig_operation.internal_high) + LibC.PostQueuedCompletionStatus(iocp, 0, 0, lpOverlapped) + end + + # NOTE: we handle the timeout ourselves so we don't pass a `LibC::Timeval` + # to Win32 here + result = LibC.GetAddrInfoExW( + Crystal::System.to_wstr(domain), Crystal::System.to_wstr(service.to_s), LibC::NS_DNS, nil, pointerof(hints), + out addrinfos, nil, operation, completion_routine, out cancel_handle) + + if result == 0 + return addrinfos + else + case error = WinError.new(result.to_u32!) + when .wsa_io_pending? + # used in `Crystal::IOCP::OverlappedOperation#try_cancel_getaddrinfo` + operation.handle = cancel_handle + else + raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service) + end + end + + operation.wait_for_getaddrinfo_result(timeout) do |error| + case error + when .wsa_e_cancelled? + raise IO::TimeoutError.new("GetAddrInfoExW timed out") + else + raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service) + end + end end - ptr end def self.next_addrinfo(addrinfo : Handle) : Handle @@ -56,6 +91,6 @@ module Crystal::System::Addrinfo end def self.free_addrinfo(addrinfo : Handle) - LibC.freeaddrinfo(addrinfo) + LibC.FreeAddrInfoExW(addrinfo) end end diff --git a/src/crystal/system/win32/iocp.cr b/src/crystal/system/win32/iocp.cr index ba87ed123f22..3418aeb4348e 100644 --- a/src/crystal/system/win32/iocp.cr +++ b/src/crystal/system/win32/iocp.cr @@ -87,6 +87,7 @@ module Crystal::IOCP @overlapped = LibC::OVERLAPPED.new @fiber = Fiber.current @state : State = :started + setter handle : LibC::HANDLE def initialize(@handle : LibC::HANDLE) end @@ -110,8 +111,16 @@ module Crystal::IOCP pointerof(@overlapped) end + def internal_high + @overlapped.internalHigh + end + + def internal_high=(value) + @overlapped.internalHigh = value + end + def wait_for_result(timeout, &) - wait_for_completion(timeout) + wait_for_completion(timeout) { try_cancel } result = LibC.GetOverlappedResult(@handle, self, out bytes, 0) if result.zero? @@ -125,7 +134,7 @@ module Crystal::IOCP end def wait_for_wsa_result(timeout, &) - wait_for_completion(timeout) + wait_for_completion(timeout) { try_cancel } flags = 0_u32 result = LibC.WSAGetOverlappedResult(LibC::SOCKET.new(@handle.address), self, out bytes, false, pointerof(flags)) @@ -139,6 +148,20 @@ module Crystal::IOCP bytes end + def wait_for_getaddrinfo_result(timeout, &) + wait_for_completion(timeout) { try_cancel_getaddrinfo } + + result = LibC.GetAddrInfoExOverlappedResult(self) + unless result.zero? + error = WinError.new(result.to_u32!) + yield error + + raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExOverlappedResult", error) + end + + @overlapped.union.pointer.as(LibC::ADDRINFOEXW**).value + end + protected def schedule(&) done! yield @fiber @@ -168,7 +191,21 @@ module Crystal::IOCP true end - def wait_for_completion(timeout) + def try_cancel_getaddrinfo : Bool + ret = LibC.GetAddrInfoExCancel(pointerof(@handle)) + unless ret.zero? + case error = WinError.new(ret.to_u32!) + when .wsa_invalid_handle? + # Operation has already completed, do nothing + return false + else + raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExCancel", error) + end + end + true + end + + def wait_for_completion(timeout, & : -> Bool) if timeout sleep timeout else @@ -176,7 +213,7 @@ module Crystal::IOCP end unless @state.done? - if try_cancel + if yield # Wait for cancellation to complete. We must not free the operation # until it's completed. Fiber.suspend diff --git a/src/http/client.cr b/src/http/client.cr index b641065ac930..7324bdf7d639 100644 --- a/src/http/client.cr +++ b/src/http/client.cr @@ -343,10 +343,10 @@ class HTTP::Client # ``` setter connect_timeout : Time::Span? - # **This method has no effect right now** - # # Sets the number of seconds to wait when resolving a name, before raising an `IO::TimeoutError`. # + # NOTE: *dns_timeout* is currently only supported on Windows. + # # ``` # require "http/client" # @@ -363,10 +363,10 @@ class HTTP::Client self.dns_timeout = dns_timeout.seconds end - # **This method has no effect right now** - # # Sets the number of seconds to wait when resolving a name with a `Time::Span`, before raising an `IO::TimeoutError`. # + # NOTE: *dns_timeout* is currently only supported on Windows. + # # ``` # require "http/client" # diff --git a/src/lib_c/x86_64-windows-msvc/c/winsock2.cr b/src/lib_c/x86_64-windows-msvc/c/winsock2.cr index 223c2366b072..68ce6f9ef421 100644 --- a/src/lib_c/x86_64-windows-msvc/c/winsock2.cr +++ b/src/lib_c/x86_64-windows-msvc/c/winsock2.cr @@ -20,6 +20,8 @@ lib LibC lpVendorInfo : Char* end + NS_DNS = 12_u32 + INVALID_SOCKET = ~SOCKET.new(0) SOCKET_ERROR = -1 @@ -111,6 +113,11 @@ lib LibC alias WSAOVERLAPPED_COMPLETION_ROUTINE = Proc(DWORD, DWORD, WSAOVERLAPPED*, DWORD, Void) + struct Timeval + tv_sec : Long + tv_usec : Long + end + struct Linger l_onoff : UShort l_linger : UShort diff --git a/src/lib_c/x86_64-windows-msvc/c/ws2def.cr b/src/lib_c/x86_64-windows-msvc/c/ws2def.cr index 9fc19857f4a3..41e0a1a408eb 100644 --- a/src/lib_c/x86_64-windows-msvc/c/ws2def.cr +++ b/src/lib_c/x86_64-windows-msvc/c/ws2def.cr @@ -208,4 +208,18 @@ lib LibC ai_addr : Sockaddr* ai_next : Addrinfo* end + + struct ADDRINFOEXW + ai_flags : Int + ai_family : Int + ai_socktype : Int + ai_protocol : Int + ai_addrlen : SizeT + ai_canonname : LPWSTR + ai_addr : Sockaddr* + ai_blob : Void* + ai_bloblen : SizeT + ai_provider : GUID* + ai_next : ADDRINFOEXW* + end end diff --git a/src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr b/src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr index 338063ccf6f6..3b3f61ba7fdb 100644 --- a/src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr +++ b/src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr @@ -17,4 +17,24 @@ lib LibC fun getaddrinfo(pNodeName : Char*, pServiceName : Char*, pHints : Addrinfo*, ppResult : Addrinfo**) : Int fun inet_ntop(family : Int, pAddr : Void*, pStringBuf : Char*, stringBufSize : SizeT) : Char* fun inet_pton(family : Int, pszAddrString : Char*, pAddrBuf : Void*) : Int + + fun FreeAddrInfoExW(pAddrInfoEx : ADDRINFOEXW*) + + alias LPLOOKUPSERVICE_COMPLETION_ROUTINE = DWORD, DWORD, WSAOVERLAPPED* -> + + fun GetAddrInfoExW( + pName : LPWSTR, + pServiceName : LPWSTR, + dwNameSpace : DWORD, + lpNspId : GUID*, + hints : ADDRINFOEXW*, + ppResult : ADDRINFOEXW**, + timeout : Timeval*, + lpOverlapped : OVERLAPPED*, + lpCompletionRoutine : LPLOOKUPSERVICE_COMPLETION_ROUTINE, + lpHandle : HANDLE*, + ) : Int + + fun GetAddrInfoExOverlappedResult(lpOverlapped : OVERLAPPED*) : Int + fun GetAddrInfoExCancel(lpHandle : HANDLE*) : Int end diff --git a/src/socket/addrinfo.cr b/src/socket/addrinfo.cr index cdf55c912601..ef76d0e285b6 100644 --- a/src/socket/addrinfo.cr +++ b/src/socket/addrinfo.cr @@ -23,6 +23,9 @@ class Socket # specified. # - *protocol* is the intended socket protocol (e.g. `Protocol::TCP`) and # should be specified. + # - *timeout* is optional and specifies the maximum time to wait before + # `IO::TimeoutError` is raised. Currently this is only supported on + # Windows. # # Example: # ``` @@ -107,8 +110,11 @@ class Socket "Hostname lookup for #{domain} failed" end - def self.os_error_message(os_error : Errno, *, type, service, protocol, **opts) - case os_error.value + def self.os_error_message(os_error : Errno | WinError, *, type, service, protocol, **opts) + # when `EAI_NONAME` etc. is an integer then only `os_error.value` can + # match; when `EAI_NONAME` is a `WinError` then `os_error` itself can + # match + case os_error.is_a?(Errno) ? os_error.value : os_error when LibC::EAI_NONAME "No address found" when LibC::EAI_SOCKTYPE @@ -116,12 +122,14 @@ class Socket when LibC::EAI_SERVICE "The requested service #{service} is not available for the requested socket type #{type}" else - {% unless flag?(:win32) %} - # There's no need for a special win32 branch because the os_error on Windows - # is of type WinError, which wouldn't match this overload anyways. - - String.new(LibC.gai_strerror(os_error.value)) + # Win32 also has this method, but `WinError` is already sufficient + {% if LibC.has_method?(:gai_strerror) %} + if os_error.is_a?(Errno) + return String.new(LibC.gai_strerror(os_error)) + end {% end %} + + super end end end @@ -148,13 +156,13 @@ class Socket # addrinfos = Socket::Addrinfo.tcp("example.org", 80) # ``` def self.tcp(domain : String, service, family = Family::UNSPEC, timeout = nil) : Array(Addrinfo) - resolve(domain, service, family, Type::STREAM, Protocol::TCP) + resolve(domain, service, family, Type::STREAM, Protocol::TCP, timeout) end # Resolves a domain for the TCP protocol with STREAM type, and yields each # possible `Addrinfo`. See `#resolve` for details. def self.tcp(domain : String, service, family = Family::UNSPEC, timeout = nil, &) - resolve(domain, service, family, Type::STREAM, Protocol::TCP) { |addrinfo| yield addrinfo } + resolve(domain, service, family, Type::STREAM, Protocol::TCP, timeout) { |addrinfo| yield addrinfo } end # Resolves *domain* for the UDP protocol and returns an `Array` of possible @@ -167,13 +175,13 @@ class Socket # addrinfos = Socket::Addrinfo.udp("example.org", 53) # ``` def self.udp(domain : String, service, family = Family::UNSPEC, timeout = nil) : Array(Addrinfo) - resolve(domain, service, family, Type::DGRAM, Protocol::UDP) + resolve(domain, service, family, Type::DGRAM, Protocol::UDP, timeout) end # Resolves a domain for the UDP protocol with DGRAM type, and yields each # possible `Addrinfo`. See `#resolve` for details. def self.udp(domain : String, service, family = Family::UNSPEC, timeout = nil, &) - resolve(domain, service, family, Type::DGRAM, Protocol::UDP) { |addrinfo| yield addrinfo } + resolve(domain, service, family, Type::DGRAM, Protocol::UDP, timeout) { |addrinfo| yield addrinfo } end # Returns an `IPAddress` matching this addrinfo. diff --git a/src/socket/tcp_socket.cr b/src/socket/tcp_socket.cr index 387417211a1a..4edcb3d08e5f 100644 --- a/src/socket/tcp_socket.cr +++ b/src/socket/tcp_socket.cr @@ -25,7 +25,7 @@ class TCPSocket < IPSocket # connection time to the remote server with `connect_timeout`. Both values # must be in seconds (integers or floats). # - # Note that `dns_timeout` is currently ignored. + # NOTE: *dns_timeout* is currently only supported on Windows. def initialize(host : String, port, dns_timeout = nil, connect_timeout = nil, blocking = false) Addrinfo.tcp(host, port, timeout: dns_timeout) do |addrinfo| super(addrinfo.family, addrinfo.type, addrinfo.protocol, blocking) diff --git a/src/winerror.cr b/src/winerror.cr index ab978769d553..fbb2fb553873 100644 --- a/src/winerror.cr +++ b/src/winerror.cr @@ -2305,6 +2305,7 @@ enum WinError : UInt32 ERROR_STATE_CONTAINER_NAME_SIZE_LIMIT_EXCEEDED = 15818_u32 ERROR_API_UNAVAILABLE = 15841_u32 - WSA_IO_PENDING = ERROR_IO_PENDING - WSA_IO_INCOMPLETE = ERROR_IO_INCOMPLETE + WSA_IO_PENDING = ERROR_IO_PENDING + WSA_IO_INCOMPLETE = ERROR_IO_INCOMPLETE + WSA_INVALID_HANDLE = ERROR_INVALID_HANDLE end