Skip to content

Commit

Permalink
Replace LibC.ntohs and htons with native code (#13027)
Browse files Browse the repository at this point in the history
  • Loading branch information
HertzDevil authored Jan 30, 2023
1 parent f33b5fc commit bc01989
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
26 changes: 24 additions & 2 deletions spec/std/socket/address_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@ describe Socket::Address do
end

describe Socket::IPAddress do
c_port = {% if IO::ByteFormat::NetworkEndian != IO::ByteFormat::SystemEndian %}
36895 # 0x901F
{% else %}
8080 # 0x1F90
{% end %}

it "transforms an IPv4 address into a C struct and back" do
addr1 = Socket::IPAddress.new("127.0.0.1", 8080)
addr2 = Socket::IPAddress.from(addr1.to_unsafe, addr1.size)

addr1_c = addr1.to_unsafe
addr1_c.as(LibC::SockaddrIn*).value.sin_port.should eq(c_port)

addr2 = Socket::IPAddress.from(addr1_c, addr1.size)
addr2.family.should eq(addr1.family)
addr2.port.should eq(addr1.port)
typeof(addr2.address).should eq(String)
Expand All @@ -45,8 +54,11 @@ describe Socket::IPAddress do

it "transforms an IPv6 address into a C struct and back" do
addr1 = Socket::IPAddress.new("2001:db8:8714:3a90::12", 8080)
addr2 = Socket::IPAddress.from(addr1.to_unsafe, addr1.size)

addr1_c = addr1.to_unsafe
addr1_c.as(LibC::SockaddrIn6*).value.sin6_port.should eq(c_port)

addr2 = Socket::IPAddress.from(addr1_c, addr1.size)
addr2.family.should eq(addr1.family)
addr2.port.should eq(addr1.port)
typeof(addr2.address).should eq(String)
Expand All @@ -59,6 +71,16 @@ describe Socket::IPAddress do
end
end

it "errors on out of range port numbers" do
expect_raises(Socket::Error, /Invalid port number/) do
Socket::IPAddress.new("localhost", -1)
end

expect_raises(Socket::Error, /Invalid port number/) do
Socket::IPAddress.new("localhost", 65536)
end
end

it "to_s" do
Socket::IPAddress.new("127.0.0.1", 80).to_s.should eq("127.0.0.1:80")
Socket::IPAddress.new("2001:db8:8714:3a90::12", 443).to_s.should eq("[2001:db8:8714:3a90::12]:443")
Expand Down
36 changes: 14 additions & 22 deletions src/socket/address.cr
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class Socket
@addr : LibC::In6Addr | LibC::InAddr

def initialize(@address : String, @port : Int32)
raise Error.new("Invalid port number: #{port}") unless 0 <= port <= UInt16::MAX

if addr = IPAddress.address_v6?(address)
@addr = addr
@family = Family::INET6
Expand Down Expand Up @@ -145,23 +147,13 @@ class Socket
protected def initialize(sockaddr : LibC::SockaddrIn6*, @size)
@family = Family::INET6
@addr = sockaddr.value.sin6_addr
@port =
{% if flag?(:dragonfly) %}
sockaddr.value.sin6_port.byte_swap.to_i
{% else %}
LibC.ntohs(sockaddr.value.sin6_port).to_i
{% end %}
@port = endian_swap(sockaddr.value.sin6_port).to_i
end

protected def initialize(sockaddr : LibC::SockaddrIn*, @size)
@family = Family::INET
@addr = sockaddr.value.sin_addr
@port =
{% if flag?(:dragonfly) %}
sockaddr.value.sin_port.byte_swap.to_i
{% else %}
LibC.ntohs(sockaddr.value.sin_port).to_i
{% end %}
@port = endian_swap(sockaddr.value.sin_port).to_i
end

# Returns `true` if *address* is a valid IPv4 or IPv6 address.
Expand Down Expand Up @@ -308,27 +300,27 @@ class Socket
private def to_sockaddr_in6(addr)
sockaddr = Pointer(LibC::SockaddrIn6).malloc
sockaddr.value.sin6_family = family
{% if flag?(:dragonfly) %}
sockaddr.value.sin6_port = port.byte_swap
{% else %}
sockaddr.value.sin6_port = LibC.htons(port)
{% end %}
sockaddr.value.sin6_port = endian_swap(port.to_u16!)
sockaddr.value.sin6_addr = addr
sockaddr.as(LibC::Sockaddr*)
end

private def to_sockaddr_in(addr)
sockaddr = Pointer(LibC::SockaddrIn).malloc
sockaddr.value.sin_family = family
{% if flag?(:dragonfly) %}
sockaddr.value.sin_port = port.byte_swap
{% else %}
sockaddr.value.sin_port = LibC.htons(port)
{% end %}
sockaddr.value.sin_port = endian_swap(port.to_u16!)
sockaddr.value.sin_addr = addr
sockaddr.as(LibC::Sockaddr*)
end

private def endian_swap(x : UInt16) : UInt16
{% if IO::ByteFormat::NetworkEndian != IO::ByteFormat::SystemEndian %}
x.byte_swap
{% else %}
x
{% end %}
end

# Returns `true` if *port* is a valid port number.
#
# Valid port numbers are in the range `0..65_535`.
Expand Down

0 comments on commit bc01989

Please sign in to comment.