From c24421132c00c14db0356b49e898b65e9a3e5961 Mon Sep 17 00:00:00 2001 From: James Clarke Date: Sun, 6 Nov 2022 20:58:58 +0000 Subject: [PATCH] Fix `OpenSSL::SSL::Context::Client#alpn_protocol=` --- spec/std/openssl/ssl/socket_spec.cr | 20 +++++++++++++ src/openssl/lib_ssl.cr | 1 + src/openssl/ssl/context.cr | 44 +++++++++++++++++------------ 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/spec/std/openssl/ssl/socket_spec.cr b/spec/std/openssl/ssl/socket_spec.cr index f4265cd0a1b8..ccf1046b6949 100644 --- a/spec/std/openssl/ssl/socket_spec.cr +++ b/spec/std/openssl/ssl/socket_spec.cr @@ -93,6 +93,26 @@ describe OpenSSL::SSL::Socket do ) end + it "returns selected alpn protocol" do + tcp_server = TCPServer.new("127.0.0.1", 0) + server_context, client_context = ssl_context_pair + + server_context.alpn_protocol = "h2" + client_context.alpn_protocol = "h2" + + OpenSSL::SSL::Server.open(tcp_server, server_context) do |server| + spawn do + Client.open(TCPSocket.new(tcp_server.local_address.address, tcp_server.local_address.port), client_context, hostname: "example.com") do |socket| + socket.alpn_protocol.should eq("h2") + end + end + + client = server.accept + client.alpn_protocol.should eq("h2") + client.close + end + end + it "accepts clients that only write then close the connection" do tcp_server = TCPServer.new("127.0.0.1", 0) server_context, client_context = ssl_context_pair diff --git a/src/openssl/lib_ssl.cr b/src/openssl/lib_ssl.cr index 2e781ea09919..e6faefa52d8e 100644 --- a/src/openssl/lib_ssl.cr +++ b/src/openssl/lib_ssl.cr @@ -261,6 +261,7 @@ lib LibSSL fun ssl_get0_alpn_selected = SSL_get0_alpn_selected(handle : SSL, data : Char**, len : LibC::UInt*) : Void fun ssl_ctx_set_alpn_select_cb = SSL_CTX_set_alpn_select_cb(ctx : SSLContext, cb : ALPNCallback, arg : Void*) : Void + fun ssl_ctx_set_alpn_protos = SSL_CTX_set_alpn_protos(ctx : SSLContext, protos : Char*, protos_len : Int) : Int {% end %} {% if compare_versions(OPENSSL_VERSION, "1.0.2") >= 0 %} diff --git a/src/openssl/ssl/context.cr b/src/openssl/ssl/context.cr index 0bf32ae7f959..37c54ca77eb2 100644 --- a/src/openssl/ssl/context.cr +++ b/src/openssl/ssl/context.cr @@ -98,6 +98,14 @@ abstract class OpenSSL::SSL::Context end }, hostname.as(Void*)) end + + private def alpn_protocol=(protocol : Bytes) + {% if LibSSL.has_method?(:ssl_ctx_set_alpn_protos) %} + LibSSL.ssl_ctx_set_alpn_protos(@handle, protocol, protocol.size) + {% else %} + raise NotImplementedError.new("LibSSL.ssl_ctx_set_alpn_protos") + {% end %} + end end class Server < Context @@ -173,6 +181,24 @@ abstract class OpenSSL::SSL::Context raise OpenSSL::Error.new("SSL_CTX_set_num_tickets") if ret != 1 {% end %} end + + private def alpn_protocol=(protocol : Bytes) + {% if LibSSL.has_method?(:ssl_ctx_set_alpn_select_cb) %} + alpn_cb = ->(ssl : LibSSL::SSL, o : LibC::Char**, olen : LibC::Char*, i : LibC::Char*, ilen : LibC::Int, data : Void*) { + proto = Box(Bytes).unbox(data) + ret = LibSSL.ssl_select_next_proto(o, olen, proto, 2, i, ilen) + if ret != LibSSL::OPENSSL_NPN_NEGOTIATED + LibSSL::SSL_TLSEXT_ERR_NOACK + else + LibSSL::SSL_TLSEXT_ERR_OK + end + } + @alpn_protocol = alpn_protocol = Box.box(protocol) + LibSSL.ssl_ctx_set_alpn_select_cb(@handle, alpn_cb, alpn_protocol) + {% else %} + raise NotImplementedError.new("LibSSL.ssl_ctx_set_alpn_select_cb") + {% end %} + end end protected def initialize(method : LibSSL::SSLMethod) @@ -429,24 +455,6 @@ abstract class OpenSSL::SSL::Context self.alpn_protocol = proto end - private def alpn_protocol=(protocol : Bytes) - {% if LibSSL.has_method?(:ssl_ctx_set_alpn_select_cb) %} - alpn_cb = ->(ssl : LibSSL::SSL, o : LibC::Char**, olen : LibC::Char*, i : LibC::Char*, ilen : LibC::Int, data : Void*) { - proto = Box(Bytes).unbox(data) - ret = LibSSL.ssl_select_next_proto(o, olen, proto, 2, i, ilen) - if ret != LibSSL::OPENSSL_NPN_NEGOTIATED - LibSSL::SSL_TLSEXT_ERR_NOACK - else - LibSSL::SSL_TLSEXT_ERR_OK - end - } - @alpn_protocol = alpn_protocol = Box.box(protocol) - LibSSL.ssl_ctx_set_alpn_select_cb(@handle, alpn_cb, alpn_protocol) - {% else %} - raise NotImplementedError.new("LibSSL.ssl_ctx_set_alpn_select_cb") - {% end %} - end - # Sets this context verify param to the default one of the given name. # # Depending on the OpenSSL version, the available defaults are