Skip to content

Commit

Permalink
Merge pull request #177 from JuliaWeb/so/issue_174_step3
Browse files Browse the repository at this point in the history
Resolve #174 - Step 3. seperate implementation/interface for read/write
  • Loading branch information
quinnj authored Oct 12, 2018
2 parents 3d1693d + 4e59a65 commit d92baca
Showing 1 changed file with 139 additions and 32 deletions.
171 changes: 139 additions & 32 deletions src/ssl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,21 @@ end
# Handshake

function handshake(ctx::SSLContext)

ctx.isopen && throw(ArgumentError("handshake() already done!"))

while true
n = ssl_handshake(ctx)
if n == 0
break
elseif n == MBEDTLS_ERR_SSL_WANT_READ
if eof(ctx.bio)
throw(EOFError())
end
else
ssl_abandon(ctx)
throw(MbedException(n))
end
if n != MBEDTLS_ERR_SSL_WANT_READ
mbed_err(n)
end
wait(ctx)
end
ctx.isopen = true

Expand Down Expand Up @@ -101,8 +107,40 @@ function handshake(ctx::SSLContext)
end


# Fatal Errors

"""
The documentation for `ssl_read`, `ssl_write` and `ssl_close_notify` all say:
> If this function returns something other than 0 or
> MBEDTLS_ERR_SSL_WANT_READ/WRITE, you must stop using the SSL context
> for reading or writing, and either free it or call
This function ensures that the `SSLContext` is won't be used again.
"""
function ssl_abandon(ctx::SSLContext)
ctx.isopen = false
close(ctx.bio)
end


# Base ::IO Connection State Methods

"""
True unless:
- TLS `close_notify` was received, or
- the peer closed the connection (and the TLS buffer is empty), or
- an un-handled exception occurred while reading.
"""
Base.isreadable(ctx::SSLContext) = true

"""
True unless:
- `close(::SSLContext)` is called, or
- the peer closed the connection.
"""
Base.iswritable(ctx::SSLContext) = ctx.isopen && isopen(ctx.bio)

"""
Same as `iswritable(ctx)`.
> "...a closed stream may still have data to read in its buffer,
Expand Down Expand Up @@ -169,15 +207,30 @@ all the data. The TLS library encrypts the data and passes it to the `f_send`
function which sends it to the underlying connection (`ctx.bio`).
See `f_send` and `set_bio!` below.
"""
function Base.unsafe_write(ctx::SSLContext, msg::Ptr{UInt8}, N::UInt)
nw = 0
while nw < N
ret = ssl_write(ctx, msg, N - nw)
ret < 0 && mbed_err(ret)
nw += ret
msg += ret
function ssl_unsafe_write(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt)

iswritable(ctx) ||
throw(ArgumentError("`unsafe_write` requires `iswritable(::SSLContext)`"))

nwritten = 0
while nwritten < nbytes
n = ssl_write(ctx, buf + nwritten, nbytes - nwritten)
if n == MBEDTLS_ERR_SSL_WANT_READ || n == MBEDTLS_ERR_SSL_WANT_WRITE
@assert false "Should not get to here because `f_send` " *
"never returns ...WANT_READ/WRITE."
yield()
continue
elseif n == MBEDTLS_ERR_NET_CONN_RESET
ssl_abandon(ctx)
Base.check_open(ctx.bio)
@assert false
elseif n < 0
ssl_abandon(ctx)
throw(MbedException(n))
end
nwritten += n
end
return Int(nw)
return Int(nwritten)
end


Expand All @@ -204,24 +257,48 @@ end

# Receiving Data

function Base.unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt; err=true)
function ssl_unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt)

isreadable(ctx) ||
throw(ArgumentError("`ssl_unsafe_read` requires `isreadable(::SSLContext)`"))

nread::UInt = 0
while nread < nbytes
n = ssl_read(ctx, buf + nread, nbytes - nread)
if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || n == 0
close(ctx)
err ? throw(EOFError()) : return nread
elseif n == MBEDTLS_ERR_SSL_WANT_READ
wait(ctx)
elseif n < 0
mbed_err(n)
else
try
while true

n = ssl_read(ctx, buf + nread, nbytes - nread)

if n == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ||
n == MBEDTLS_ERR_NET_CONN_RESET
ssl_abandon(ctx)
@assert ssl_get_bytes_avail(ctx) == 0 #FIXME remove this
@assert ssl_check_pending(ctx) == false #FIXME remove this
return nread
elseif n == MBEDTLS_ERR_SSL_WANT_READ
@assert ssl_get_bytes_avail(ctx) == 0 #FIXME remove this
return nread
elseif n < 0
ssl_abandon(ctx)
throw(MbedException(n))
end

nread += n
@assert nread <= nbytes

if nread == nbytes
return nread
end
end
catch e
ssl_abandon(ctx)
rethrow(e)
end

@assert false "unreachable"
end



# Receiving Encrypted Data

"""
Expand All @@ -246,27 +323,52 @@ end

# Base ::IO Write Methods

Base.unsafe_write(ctx::SSLContext, msg::Ptr{UInt8}, N::UInt) =
ssl_unsafe_write(ctx, msg, N)

Base.write(ctx::SSLContext, msg::UInt8) = write(ctx, Ref(msg))


# Base ::IO Read Methods

"""
Copy `nbytes` of decrypted data from `ctx` into `buf`.
Wait for sufficient decrypted data to be available.
Throw `EOFError` if the peer sends TLS `close_notify` or closes the
connection before `nbytes` have been copied.
"""
function Base.unsafe_read(ctx::SSLContext, buf::Ptr{UInt8}, nbytes::UInt)
nread = 0
while nread < nbytes
if eof(ctx)
throw(EOFError())
end
nread += ssl_unsafe_read(ctx, buf + nread, nbytes - nread)
end
nothing
end

"""
Copy at most `nbytes` of decrypted data from `ctx` into `buf`.
If `all=true`: wait for sufficient decrypted data to be available.
Less than `nbytes` may be copied if the peer sends TLS `close_notify` or closes
the connection.
Returns number of bytes copied into `buf` (`<= nbytes`).
"""
Base.readbytes!(ctx::SSLContext, buf::Vector{UInt8}, nbytes=length(buf)) = readbytes!(ctx, buf, UInt(nbytes))
function Base.readbytes!(ctx::SSLContext, buf::Vector{UInt8}, nbytes::UInt)
nr = unsafe_read(ctx, pointer(buf), nbytes; err=false)
if nr !== nothing
resize!(buf, nr::UInt)
else
nr = nbytes
Base.readbytes!(ctx::SSLContext, buf::Vector{UInt8}, nbytes=length(buf); kw...) =
readbytes!(ctx, buf, UInt(nbytes); kw...)

function Base.readbytes!(ctx::SSLContext, buf::Vector{UInt8}, nbytes::UInt;
all::Bool=true)
nbytes <= length(buf) || throw(ArgumentError("`buf` too small!"))
nread = 0
while nread < nbytes
nread += ssl_unsafe_read(ctx, pointer(buf) + nread, nbytes - nread)
if !all || eof(ctx)
break
end
end
return Int(nr::UInt)
return nread
end

"""
Expand All @@ -276,7 +378,12 @@ but don't wait for more data to arrive.
The amount of decrypted data that can be read at once is limited by
`MBEDTLS_SSL_MAX_CONTENT_LEN`.
"""
Base.readavailable(ctx::SSLContext) = read(ctx, bytesavailable(ctx))
function Base.readavailable(ctx::SSLContext)
n = UInt(MBEDTLS_SSL_MAX_CONTENT_LEN)
buf = Vector{UInt8}(#=undef,=# n)
n = ssl_unsafe_read(ctx, pointer(buf), n)
return resize!(buf, n)
end


# Configuration
Expand Down

0 comments on commit d92baca

Please sign in to comment.