From a86b48405dc0d99e16d0a677259ca0093c763e36 Mon Sep 17 00:00:00 2001 From: Yorick Peterse Date: Thu, 2 Jan 2025 23:46:49 +0100 Subject: [PATCH] Implement RFC 8305 for TcpClient This changes TcpClient.new and TcpClient.with_timeout to support connecting to multiple IP addresses based on RFC 8305, also known as "Happy Eyeballs version 2". In addition, TcpClient.new now uses a default timeout of 60 seconds instead of waiting forever, and std.time.Duration is extended with a few extra methods such as Duration.positive? and Duration./. This fixes https://github.com/inko-lang/inko/issues/795. Changelog: added --- std/src/std/net/happy.inko | 237 ++++++++++++++++++++++++++++++++++++ std/src/std/net/socket.inko | 85 +++++++------ std/src/std/time.inko | 56 +++++++-- std/test/std/test_time.inko | 18 +++ 4 files changed, 350 insertions(+), 46 deletions(-) create mode 100644 std/src/std/net/happy.inko diff --git a/std/src/std/net/happy.inko b/std/src/std/net/happy.inko new file mode 100644 index 00000000..793201fc --- /dev/null +++ b/std/src/std/net/happy.inko @@ -0,0 +1,237 @@ +# Support for connecting TCP sockets using [RFC +# 8305](https://datatracker.ietf.org/doc/html/rfc8305). +import std.cmp (min) +import std.drop (Drop) +import std.io (Error) +import std.iter (Stream) +import std.net.ip (IpAddress) +import std.net.socket (TcpClient) +import std.sync (Channel) +import std.time (Duration, Instant) + +# The amount of milliseconds to wait for a socket to connect. +let TIMEOUT = 250 + +# Returns an iterator that yields IP addresses in alternating order, starting +# with an IPv6 address. +fn interleave(ips: ref Array[IpAddress]) -> Stream[IpAddress] { + let mut v6_idx = 0 + let mut v4_idx = 0 + let mut v6 = true + + Stream.new(fn move { + if v6 := v6.false? { + loop { + match ips.opt(v6_idx := v6_idx + 1) { + case Some(V6(ip)) -> return Option.Some(IpAddress.V6(ip)) + case Some(_) -> {} + case _ -> break + } + } + } + + loop { + match ips.opt(v4_idx := v4_idx + 1) { + case Some(V4(ip)) -> return Option.Some(IpAddress.V4(ip)) + case Some(_) -> {} + case _ -> return Option.None + } + } + }) +} + +fn connect( + ips: ref Array[IpAddress], + port: Int, + timeout_after: Instant, +) -> Result[TcpClient, Error] { + let size = ips.size + + # It's possible the list of IPs is passed directly from "user input" such as + # a DNS record. If this list is empty we don't want to panic and abort, but + # instead give callers a chance to handle the error. As such, we return + # `InvalidArgument` instead. + # + # When there's only one IP address we can skip the Happy Eyeballs algorithm + # and just connect to it directly. + match size { + case 0 -> throw Error.InvalidArgument + case 1 -> return TcpClient.connect(ips.get(0), port, timeout_after) + case _ -> {} + } + + let ips = interleave(ips) + let cons = Connections.new(port, timeout_after) + let mut pending = 0 + + while timeout_after.remaining.positive? { + let id = match ips.next { + case Some(ip) -> { + pending += 1 + cons.connect(ip) + } + case _ -> break + } + + let wait = Duration.from_millis(TIMEOUT) + let deadline = min(timeout_after, wait.to_instant) + + loop { + match cons.receive(deadline) { + case Some(Ok(v)) -> return Result.Ok(v) + case Some(Error(v)) if v == id -> { + # If the socket we're waiting for produces an error then there's no + # point in waiting any longer, so we just move on. + pending -= 1 + break + } + case Some(_) -> { + # If a socket we tried to use previously produces an error we just + # ignore it and continue waiting for the current socket. + pending -= 1 + } + case _ -> { + # We waited long enough and so we need to move on to the next socket. + break + } + } + } + } + + # None of the sockets could connect within the initial timeout, but they + # might connect before our supplied deadline (if this hasn't already expired + # at this point). + while pending > 0 { + match cons.receive(timeout_after) { + case Some(Ok(v)) -> return Result.Ok(v) + case Some(_) -> pending -= 1 + case _ -> break + } + } + + Result.Error( + if timeout_after.remaining.positive? { + Error.ConnectionRefused + } else { + Error.TimedOut + }, + ) +} + +# A type for connecting a `TcpClient` asynchronously. +type async Connection { + # The ID of the current connection. + # + # This is used to determine when an error is produced what socket that error + # belongs to. + let @id: Int + + # The IP address to connect to. + let @ip: IpAddress + + # The port to connect to. + let @port: Int + + # The deadline after which we should give up. + let @deadline: Instant + + # The channel to send the results back to. + let @output: Channel[Result[TcpClient, Int]] + + # A flag indicating if we should continue trying to connect or if we should + # stop. + let @run: Bool + + fn static new( + id: Int, + ip: IpAddress, + port: Int, + deadline: Instant, + output: uni Channel[Result[TcpClient, Int]], + ) -> Connection { + Connection( + id: id, + ip: ip, + port: port, + deadline: deadline, + output: output, + run: true, + ) + } + + fn async mut cancel { + @run = false + } + + fn async connect { + if @run.false? { return } + + # To support cancellation we use an internal timeout. This way we don't just + # sit around for e.g. 60 seconds even though another socket connected + # successfully. + let interval = Duration.from_millis(TIMEOUT) + let deadline = min(interval.to_instant, @deadline) + let res = recover { + match TcpClient.connect(@ip, @port, deadline) { + case Ok(v) -> Result.Ok(v) + case Error(TimedOut) if @deadline.remaining.to_nanos > 0 -> { + # We finished one cycle but there's still time left, so let's try + # again until the user-provided deadline is also exceeded. + return connect + } + case Error(_) -> { + # We wan out of time or encountered a non-timeout error (e.g. the + # connection is refused). In this case we need to report back to the + # parent process such that it doesn't hang waiting for a result + # forever. + Result.Error(@id) + } + } + } + + @output.send(res) + } +} + +type inline Connections { + # The post to connect the IPs to. + let @port: Int + + # The deadline after which all attempts should time out. + let @timeout_after: Instant + + # The channel to use for communicating results back to the parent process. + let @channel: Channel[Result[TcpClient, Int]] + + # The processes used to establish connections + let @connections: Array[Connection] + + fn static new(port: Int, timeout_after: Instant) -> Connections { + Connections( + port: port, + timeout_after: timeout_after, + channel: Channel.new, + connections: [], + ) + } + + fn mut connect(ip: IpAddress) -> Int { + let id = @connections.size + let chan = recover @channel.clone + let proc = Connection.new(id, ip, @port, @timeout_after, chan) + + proc.connect + @connections.push(proc) + id + } + + fn mut receive(timeout_after: Instant) -> Option[Result[TcpClient, Int]] { + @channel.receive_until(timeout_after) + } +} + +impl Drop for Connections { + fn mut drop { + @connections.iter_mut.each(fn (c) { c.cancel }) + } +} diff --git a/std/src/std/net/socket.inko b/std/src/std/net/socket.inko index f9d4ee64..0a183897 100644 --- a/std/src/std/net/socket.inko +++ b/std/src/std/net/socket.inko @@ -65,11 +65,12 @@ import std.fmt (Format, Formatter) import std.fs.path (Path) import std.io (Error, Read, Write, WriteInternal) import std.libc +import std.net.happy import std.net.ip (IpAddress) import std.string (ToString) import std.sys.net import std.sys.unix.net (self as sys) if unix -import std.time (Duration, ToInstant) +import std.time (Duration, Instant, ToInstant) # The maximum value valid for a listen() call. # @@ -811,24 +812,29 @@ type pub TcpClient { Result.Ok(TcpClient(socket)) } - # Creates a new `TcpClient` that's connected to an IP address and port number. - # - # If multiple IP addresses are given, this method attempts to connect to them - # in order, returning upon the first successful connection. If no connection - # can be established, the error of the last attempt is returned. - # - # This method doesn't enforce a deadline on establishing the connection. If - # you need to limit the amount of time spent waiting to establish the - # connection, use `TcpClient.with_timeout` instead. + fn static connect( + ip: IpAddress, + port: Int, + timeout_after: Instant, + ) -> Result[TcpClient, Error] { + let socket = try Socket.stream(ip.v6?) + + socket.timeout_after = timeout_after + try socket.connect(ip, port) + socket.reset_deadline + from(socket) + } + + # Creates a new `TcpClient` that's connected to an IP address and port number, + # using a default timeout. # - # # Panics + # This method uses a default timeout of 5 seconds. If you wish to use a + # custom timeout/deadline, use `TcpClient.with_timeout` instead. # - # This method panics if `ips` is empty. + # For more details, refer to the documentation of `TcpClient.with_timeout`. # # # Examples # - # Connecting a `TcpClient`: - # # ```inko # import std.net.socket (TcpClient) # import std.net.ip (IpAddress) @@ -839,29 +845,45 @@ type pub TcpClient { ips: ref Array[IpAddress], port: Int, ) -> Result[TcpClient, Error] { - try_ips(ips, fn (ip) { - let socket = try Socket.stream(ip.v6?) - - try socket.connect(ip, port) - from(socket) - }) + with_timeout(ips, port, Duration.from_secs(5)) } # Creates a new `TcpClient` but limits the amount of time spent waiting for # the connection to be established. # + # The `timeout_after` argument specifies the deadline after which the + # `connect()` system call times out. This deadline is _not_ inherited by the + # returned `TcpClient`. + # + # # Connecting to multiple IP addresses + # # If multiple IP addresses are given, this method attempts to connect to them - # in order, returning upon the first successful connection. If no connection - # can be established, the error of the last attempt is returned. + # in accordance with [RFC 8305](https://datatracker.ietf.org/doc/html/rfc8305) + # (also known as "Happy Eyeballs version 2"), with the following differences: # - # The `timeout_after` argument specifies the deadline after which the - # `connect()` times out. The deadline is cleared once connected. + # - DNS requests are performed separately and thus not subject to the Happy + # Eyeballs algorithm. + # - We always interleave IPv6 and IPv4 addresses, starting with an IPv6 + # address (so `IPv6, IPv4, IPv6, IPv4, ...`). + # - There's no way to configure this behavior, nor is it planned to add the + # ability to do so. + # + # # Errors + # + # If the connection can't be established, a `std.io.Error` error is returned. # - # See `TcpClient.new` for more information. + # If `ips` contains multiple IP addresses and a connection can't be + # established to any of the addresses, one of the following errors is + # returned: # - # # Panics + # - `Error.ConnectionRefused` if no connection could be established before the + # deadline expired + # - `Error.TimedOut` if the deadline expired # - # This method panics if `ips` is empty. + # If `ips` is empty, an `Error.InvalidArgument` error is returned instead of + # producing a panic. This is to allow handling of cases where one passes the + # output of e.g. `std.net.dns.Resolver.resolve` directly to this method + # without checking if the DNS record actually contains any IP addresses. # # # Examples # @@ -883,14 +905,7 @@ type pub TcpClient { port: Int, timeout_after: ref T, ) -> Result[TcpClient, Error] { - try_ips(ips, fn (ip) { - let socket = try Socket.stream(ip.v6?) - - socket.timeout_after = timeout_after - try socket.connect(ip, port) - socket.reset_deadline - from(socket) - }) + happy.connect(ips, port, timeout_after.to_instant) } # Returns the local address of this socket. diff --git a/std/src/std/time.inko b/std/src/std/time.inko index 363f151a..aa33e236 100644 --- a/std/src/std/time.inko +++ b/std/src/std/time.inko @@ -7,7 +7,7 @@ import std.fmt (Format, Formatter) import std.int (ToInt) import std.locale (Locale) import std.locale.en (Locale as English) -import std.ops (Add, Multiply, Subtract) +import std.ops (Add, Divide, Multiply, Subtract) import std.string (Bytes) import std.sys.unix.time (self as sys) if unix @@ -238,6 +238,34 @@ type pub copy Duration { fn pub inline to_nanos -> Int { @nanos } + + # Returns `true` if `self` is greater than zero. + # + # # Examples + # + # ```inko + # import std.time (Duration) + # + # Duration.from_secs(1).positive? # => true + # Duration.from_secs(0).positive? # => false + # ``` + fn pub inline positive? -> Bool { + @nanos > 0 + } + + # Returns `true` if `self` is zero. + # + # # Examples + # + # ```inko + # import std.time (Duration) + # + # Duration.from_secs(1).zero? # => false + # Duration.from_secs(0).zero? # => true + # ``` + fn pub inline zero? -> Bool { + @nanos == 0 + } } impl ToInstant for Duration { @@ -270,31 +298,37 @@ impl Clone[Duration] for Duration { } impl Add[Duration, Duration] for Duration { - fn pub inline +(other: ref Duration) -> Duration { + fn pub inline +(other: Duration) -> Duration { Duration(@nanos + other.nanos) } } impl Subtract[Duration, Duration] for Duration { - fn pub inline -(other: ref Duration) -> Duration { + fn pub inline -(other: Duration) -> Duration { Duration(@nanos - other.nanos) } } impl Multiply[Int, Duration] for Duration { - fn pub inline *(other: ref Int) -> Duration { + fn pub inline *(other: Int) -> Duration { Duration(@nanos * other) } } +impl Divide[Int, Duration] for Duration { + fn pub inline /(other: Int) -> Duration { + Duration(@nanos / other) + } +} + impl Compare[Duration] for Duration { - fn pub inline cmp(other: ref Duration) -> Ordering { + fn pub inline cmp(other: Duration) -> Ordering { @nanos.cmp(other.nanos) } } -impl Equal[ref Duration] for Duration { - fn pub inline ==(other: ref Duration) -> Bool { +impl Equal[Duration] for Duration { + fn pub inline ==(other: Duration) -> Bool { @nanos == other.nanos } } @@ -1325,7 +1359,7 @@ impl Add[Duration, DateTime] for DateTime { # # This method may panic if the result can't be expressed as a `DateTime` (e.g. # the year is too great). - fn pub +(other: ref Duration) -> DateTime { + fn pub +(other: Duration) -> DateTime { let timestamp = to_float + other.to_secs DateTime.from_timestamp(timestamp, utc_offset: @utc_offset).get @@ -1340,7 +1374,7 @@ impl Subtract[Duration, DateTime] for DateTime { # # This method may panic if the result can't be expressed as a `DateTime` (e.g. # the year is too great). - fn pub -(other: ref Duration) -> DateTime { + fn pub -(other: Duration) -> DateTime { let timestamp = to_float - other.to_secs DateTime.from_timestamp(timestamp, utc_offset: @utc_offset).get @@ -1459,7 +1493,7 @@ impl ToFloat for Instant { } impl Add[Duration, Instant] for Instant { - fn pub inline +(other: ref Duration) -> Instant { + fn pub inline +(other: Duration) -> Instant { let nanos = @nanos + other.nanos if nanos < 0 { negative_time_error(nanos) } @@ -1469,7 +1503,7 @@ impl Add[Duration, Instant] for Instant { } impl Subtract[Duration, Instant] for Instant { - fn pub inline -(other: ref Duration) -> Instant { + fn pub inline -(other: Duration) -> Instant { let nanos = @nanos - other.nanos if nanos < 0 { negative_time_error(nanos) } diff --git a/std/test/std/test_time.inko b/std/test/std/test_time.inko index 54a249a6..f10ccbb9 100644 --- a/std/test/std/test_time.inko +++ b/std/test/std/test_time.inko @@ -76,6 +76,16 @@ fn pub tests(t: mut Tests) { t.equal(Duration.from_secs(1).clone, Duration.from_secs(1)) }) + t.test('Duration.positive?', fn (t) { + t.true(Duration.from_secs(1).positive?) + t.false(Duration.from_secs(0).positive?) + }) + + t.test('Duration.zero?', fn (t) { + t.false(Duration.from_secs(1).zero?) + t.true(Duration.from_secs(0).zero?) + }) + t.test('Duration.+', fn (t) { t.equal( Duration.from_secs(1) + Duration.from_secs(1), @@ -116,6 +126,14 @@ fn pub tests(t: mut Tests) { Duration.from_secs(1) * -9_223_372_036_854_775_808 }) + t.test('Duration./', fn (t) { + t.equal(Duration.from_secs(2) / 2, Duration.from_secs(1)) + }) + + t.panic('Duration./ with an argument that overflows', fn { + Duration.from_secs(2) / 0 + }) + t.test('Duration.cmp', fn (t) { let a = Duration.from_secs(1) let b = Duration.from_secs(2)