diff --git a/src/ssl.jl b/src/ssl.jl index f8d845e..8f67e72 100644 --- a/src/ssl.jl +++ b/src/ssl.jl @@ -24,7 +24,6 @@ Base.show(io::IO, c::SSLConfig) = print(io, "MbedTLS.SSLConfig()") mutable struct SSLContext <: IO data::Ptr{Cvoid} datalock::ReentrantLock - nonblocking::Bool config::SSLConfig isopen::Bool bio @@ -33,7 +32,6 @@ mutable struct SSLContext <: IO ctx = new() ctx.data = Libc.malloc(1000) # 488 ctx.datalock = ReentrantLock() - ctx.nonblocking = false ccall((:mbedtls_ssl_init, libmbedtls), Cvoid, (Ptr{Cvoid},), ctx.data) @compat finalizer(ctx->begin ccall((:mbedtls_ssl_free, libmbedtls), Cvoid, (Ptr{Cvoid},), ctx.data) @@ -113,24 +111,23 @@ end function f_send(c_ctx, c_msg, sz) jl_ctx = unsafe_pointer_to_objref(c_ctx) - jl_msg = unsafe_wrap(Array, c_msg, sz) - return Cint(write(jl_ctx.bio, jl_msg)) + return Cint(unsafe_write(jl_ctx, c_msg, sz)) end function f_recv(c_ctx, c_msg, sz) jl_ctx = unsafe_pointer_to_objref(c_ctx) - jl_msg = unsafe_wrap(Array, c_msg, sz) - if jl_ctx.nonblocking && nb_available(jl_ctx.bio) == 0 - n = 0 - else - n = readbytes!(jl_ctx.bio, jl_msg, sz) + n = nb_available(jl_ctx) + if n == 0 + return Cint(MBEDTLS_ERR_SSL_WANT_READ) end + n = min(sz, n) + unsafe_read(jl_ctx, c_msg, n) return Cint(n) end function set_bio!(ssl_ctx::SSLContext, jl_ctx::T) where {T<:IO} ssl_ctx.bio = jl_ctx - set_bio!(ssl_ctx, pointer_from_objref(ssl_ctx), c_send[], c_recv[]) + set_bio!(ssl_ctx, pointer_from_objref(ssl_ctx.bio), c_send[], c_recv[]) nothing end @@ -164,13 +161,25 @@ function set_dbg_level(level) nothing end +Base.wait(ctx::SSLContext) = (eof(ctx.bio); nothing) + # eof blocks if the receive buffer is empty + function handshake(ctx::SSLContext) - @lockdata ctx begin - @err_check ccall((:mbedtls_ssl_handshake, libmbedtls), Cint, - (Ptr{Cvoid},), ctx.data) - ctx.isopen = true + while true + n = @lockdata ctx begin + ccall((:mbedtls_ssl_handshake, libmbedtls), Cint, + (Ptr{Cvoid},), ctx.data) + end + if n == 0 + break + end + if n != MBEDTLS_ERR_SSL_WANT_READ + mbed_err(n) + end + wait(ctx) end - nothing + ctx.isopen = true + return end function set_alpn!(conf::SSLConfig, protos) @@ -216,9 +225,11 @@ function Base.unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt; err=tr if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || n == 0 ctx.isopen = false err ? throw(EOFError()) : return nread - end - if n != MBEDTLS_ERR_SSL_WANT_READ - n < 0 && mbed_err(n) + elseif n == MBEDTLS_ERR_SSL_WANT_READ + wait(ctx) + elseif n < 0 + mbed_err(n) + else nread += n end end @@ -281,19 +292,21 @@ function get_ciphersuite(ctx::SSLContext) end function Base.nb_available(ctx::SSLContext) + @lockdata ctx begin - # First try to read from the socket and decrypt incoming data if - # possible. https://esp32.com/viewtopic.php?t=1101#p4884 - ctx.nonblocking = true - try - ccall((:mbedtls_ssl_read, libmbedtls), - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), ctx.data, C_NULL, 0) - finally - ctx.nonblocking = false - end - n = ccall((:mbedtls_ssl_get_bytes_avail, libmbedtls), - Csize_t, (Ptr{Cvoid},), ctx.data) - return Int(n) + + # First do a zero-byte read. + # This causes MbedTLS to call f_recv (which is always non-blocking) + # and decrypt any bytes that are already in the LibuvStream read buffer. + # https://esp32.com/viewtopic.php?t=1101#p4884 + ccall((:mbedtls_ssl_read, libmbedtls), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), + ctx.data, C_NULL, 0) + + # Now that the bufferd bytes have been processed, find out how many + # decrypted bytes are available. + return Int(ccall((:mbedtls_ssl_get_bytes_avail, libmbedtls), + Csize_t, (Ptr{Cvoid},), ctx.data)) end end