diff --git a/src/ssl.jl b/src/ssl.jl index 0b31965..4f8331b 100644 --- a/src/ssl.jl +++ b/src/ssl.jl @@ -24,6 +24,7 @@ Base.show(io::IO, c::SSLConfig) = print(io, "MbedTLS.SSLConfig()") mutable struct SSLContext <: IO data::Ptr{Cvoid} datalock::ReentrantLock + nonblocking::Bool config::SSLConfig isopen::Bool bio @@ -32,6 +33,7 @@ mutable struct SSLContext <: IO ctx = new() ctx.data = Libc.malloc(1000) # 488 ctx.datalock = ReentrantLock() + ctx.nonblocking = false ccall((:mbedtls_ssl_init, MBED_TLS), Cvoid, (Ptr{Cvoid},), ctx.data) @compat finalizer(ctx->begin ccall((:mbedtls_ssl_free, MBED_TLS), Cvoid, (Ptr{Cvoid},), ctx.data) @@ -112,19 +114,23 @@ end function f_send(c_ctx, c_msg, sz) jl_ctx = unsafe_pointer_to_objref(c_ctx) jl_msg = unsafe_wrap(Array, c_msg, sz) - return Cint(write(jl_ctx, jl_msg)) + return Cint(write(jl_ctx.bio, jl_msg)) end function f_recv(c_ctx, c_msg, sz) jl_ctx = unsafe_pointer_to_objref(c_ctx) jl_msg = unsafe_wrap(Array, c_msg, sz) - n = readbytes!(jl_ctx, jl_msg, sz) + if jl_ctx.nonblocking && nb_available(jl_ctx.bio) == 0 + n = 0 + else + n = readbytes!(jl_ctx.bio, jl_msg, sz) + end return Cint(n) end function set_bio!(ssl_ctx::SSLContext, jl_ctx::T) where {T<:IO} ssl_ctx.bio = jl_ctx - set_bio!(ssl_ctx, pointer_from_objref(jl_ctx), c_send[], c_recv[]) + set_bio!(ssl_ctx, pointer_from_objref(ssl_ctx), c_send[], c_recv[]) nothing end @@ -183,37 +189,37 @@ end import Base: unsafe_read, unsafe_write function Base.unsafe_write(ctx::SSLContext, msg::Ptr{UInt8}, N::UInt) - @lockdata ctx begin - nw = 0 - while nw < N - ret = ccall((:mbedtls_ssl_write, MBED_TLS), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), - ctx.data, msg, N - nw) - ret < 0 && mbed_err(ret) - nw += ret - msg += ret + nw = 0 + while nw < N + ret = @lockdata ctx begin + ccall((:mbedtls_ssl_write, MBED_TLS), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), + ctx.data, msg, N - nw) end - return Int(nw) + ret < 0 && mbed_err(ret) + nw += ret + msg += ret end + return Int(nw) end Base.write(ctx::SSLContext, msg::UInt8) = write(ctx, Ref(msg)) function Base.unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt; err=true) - @lockdata ctx begin - nread::UInt = 0 - while nread < nbytes - n = ccall((:mbedtls_ssl_read, MBED_TLS), Cint, - (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), - ctx.data, buf + nread, nbytes - nread) - if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || n == 0 - ctx.isopen = false - err ? throw(EOFError()) : return nread - end - if n != MBEDTLS_ERR_SSL_WANT_READ - n < 0 && mbed_err(n) - nread += n - end + nread::UInt = 0 + while nread < nbytes + n = @lockdata ctx begin + ccall((:mbedtls_ssl_read, MBED_TLS), Cint, + (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), + ctx.data, buf + nread, nbytes - nread) + end + if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || n == 0 + ctx.isopen = false + err ? throw(EOFError()) : return nread + end + if n != MBEDTLS_ERR_SSL_WANT_READ + n < 0 && mbed_err(n) + nread += n end end end @@ -278,8 +284,13 @@ function Base.nb_available(ctx::SSLContext) @lockdata ctx begin # First try to read from the socket and decrypt incoming data if # possible. https://esp32.com/viewtopic.php?t=1101#p4884 - ccall((:mbedtls_ssl_read, MBED_TLS), - Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), ctx.data, C_NULL, 0) + ctx.nonblocking = true + try + ccall((:mbedtls_ssl_read, MBED_TLS), + Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), ctx.data, C_NULL, 0) + finally + ctx.nonblocking = false + end n = ccall((:mbedtls_ssl_get_bytes_avail, MBED_TLS), Csize_t, (Ptr{Cvoid},), ctx.data) return Int(n)