diff --git a/openssl-sys/src/handwritten/ssl.rs b/openssl-sys/src/handwritten/ssl.rs index 944a476618..cdcdea5881 100644 --- a/openssl-sys/src/handwritten/ssl.rs +++ b/openssl-sys/src/handwritten/ssl.rs @@ -640,7 +640,13 @@ extern "C" { pub fn SSL_stateless(s: *mut SSL) -> c_int; pub fn SSL_connect(ssl: *mut SSL) -> c_int; pub fn SSL_read(ssl: *mut SSL, buf: *mut c_void, num: c_int) -> c_int; + #[cfg(any(ossl111, libressl350))] + pub fn SSL_read_ex(ssl: *mut SSL, buf: *mut c_void, num: usize, readbytes: *mut usize) + -> c_int; pub fn SSL_peek(ssl: *mut SSL, buf: *mut c_void, num: c_int) -> c_int; + #[cfg(any(ossl111, libressl350))] + pub fn SSL_peek_ex(ssl: *mut SSL, buf: *mut c_void, num: usize, readbytes: *mut usize) + -> c_int; #[cfg(any(ossl111, libressl340))] pub fn SSL_read_early_data( s: *mut SSL, @@ -661,6 +667,13 @@ extern "C" { extern "C" { pub fn SSL_write(ssl: *mut SSL, buf: *const c_void, num: c_int) -> c_int; + #[cfg(any(ossl111, libressl350))] + pub fn SSL_write_ex( + ssl: *mut SSL, + buf: *const c_void, + num: size_t, + written: *mut size_t, + ) -> c_int; #[cfg(any(ossl111, libressl340))] pub fn SSL_write_early_data( s: *mut SSL, diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index fb38bb3e4a..fe1e38649f 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -90,14 +90,13 @@ use libc::{c_char, c_int, c_long, c_uchar, c_uint, c_void}; use once_cell::sync::{Lazy, OnceCell}; use openssl_macros::corresponds; use std::any::TypeId; -use std::cmp; use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; +use std::mem::{self, ManuallyDrop, MaybeUninit}; use std::ops::{Deref, DerefMut}; use std::panic::resume_unwind; use std::path::Path; @@ -2367,21 +2366,6 @@ impl SslRef { unsafe { ffi::SSL_get_rbio(self.as_ptr()) } } - fn read(&mut self, buf: &mut [u8]) -> c_int { - let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; - unsafe { ffi::SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) } - } - - fn peek(&mut self, buf: &mut [u8]) -> c_int { - let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; - unsafe { ffi::SSL_peek(self.as_ptr(), buf.as_ptr() as *mut c_void, len) } - } - - fn write(&mut self, buf: &[u8]) -> c_int { - let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; - unsafe { ffi::SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) } - } - fn get_error(&self, ret: c_int) -> ErrorCode { unsafe { ErrorCode::from_raw(ffi::SSL_get_error(self.as_ptr(), ret)) } } @@ -3750,26 +3734,86 @@ impl SslStream { } } + /// Like `read`, but takes a possibly-uninitialized slice. + /// + /// # Safety + /// + /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`, + /// then the first `n` bytes of `buf` are guaranteed to be initialized. + #[corresponds(SSL_read_ex)] + pub fn read_uninit(&mut self, buf: &mut [MaybeUninit]) -> io::Result { + loop { + match self.ssl_read_uninit(buf) { + Ok(n) => return Ok(n), + Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), + Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => { + return Ok(0); + } + Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {} + Err(e) => { + return Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); + } + } + } + } + /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`. /// /// It is particularly useful with a non-blocking socket, where the error value will identify if /// OpenSSL is waiting on read or write readiness. - #[corresponds(SSL_read)] + #[corresponds(SSL_read_ex)] pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result { - // The interpretation of the return code here is a little odd with a - // zero-length write. OpenSSL will likely correctly report back to us - // that it read zero bytes, but zero is also the sentinel for "error". - // To avoid that confusion short-circuit that logic and return quickly - // if `buf` has a length of zero. - if buf.is_empty() { - return Ok(0); + // SAFETY: `ssl_read_uninit` does not de-initialize the buffer. + unsafe { + self.ssl_read_uninit(slice::from_raw_parts_mut( + buf.as_mut_ptr().cast::>(), + buf.len(), + )) } + } - let ret = self.ssl.read(buf); - if ret > 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) + /// Like `read_ssl`, but takes a possibly-uninitialized slice. + /// + /// # Safety + /// + /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`, + /// then the first `n` bytes of `buf` are guaranteed to be initialized. + #[corresponds(SSL_read_ex)] + pub fn ssl_read_uninit(&mut self, buf: &mut [MaybeUninit]) -> Result { + cfg_if! { + if #[cfg(any(ossl111, libressl350))] { + let mut readbytes = 0; + let ret = unsafe { + ffi::SSL_read_ex( + self.ssl().as_ptr(), + buf.as_mut_ptr().cast(), + buf.len(), + &mut readbytes, + ) + }; + + if ret > 0 { + Ok(readbytes) + } else { + Err(self.make_error(ret)) + } + } else { + if buf.is_empty() { + return Ok(0); + } + + let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int; + let ret = unsafe { + ffi::SSL_read(self.ssl().as_ptr(), buf.as_mut_ptr().cast(), len) + }; + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } } } @@ -3777,34 +3821,78 @@ impl SslStream { /// /// It is particularly useful with a non-blocking socket, where the error value will identify if /// OpenSSL is waiting on read or write readiness. - #[corresponds(SSL_write)] + #[corresponds(SSL_write_ex)] pub fn ssl_write(&mut self, buf: &[u8]) -> Result { - // See above for why we short-circuit on zero-length buffers - if buf.is_empty() { - return Ok(0); - } + cfg_if! { + if #[cfg(any(ossl111, libressl350))] { + let mut written = 0; + let ret = unsafe { + ffi::SSL_write_ex( + self.ssl().as_ptr(), + buf.as_ptr().cast(), + buf.len(), + &mut written, + ) + }; + + if ret > 0 { + Ok(written) + } else { + Err(self.make_error(ret)) + } + } else { + if buf.is_empty() { + return Ok(0); + } - let ret = self.ssl.write(buf); - if ret > 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) + let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int; + let ret = unsafe { + ffi::SSL_write(self.ssl().as_ptr(), buf.as_ptr().cast(), len) + }; + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } } } /// Reads data from the stream, without removing it from the queue. - #[corresponds(SSL_peek)] + #[corresponds(SSL_peek_ex)] pub fn ssl_peek(&mut self, buf: &mut [u8]) -> Result { - // See above for why we short-circuit on zero-length buffers - if buf.is_empty() { - return Ok(0); - } + cfg_if! { + if #[cfg(any(ossl111, libressl350))] { + let mut readbytes = 0; + let ret = unsafe { + ffi::SSL_peek_ex( + self.ssl().as_ptr(), + buf.as_mut_ptr().cast(), + buf.len(), + &mut readbytes, + ) + }; + + if ret > 0 { + Ok(readbytes) + } else { + Err(self.make_error(ret)) + } + } else { + if buf.is_empty() { + return Ok(0); + } - let ret = self.ssl.peek(buf); - if ret > 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) + let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int; + let ret = unsafe { + ffi::SSL_peek(self.ssl().as_ptr(), buf.as_mut_ptr().cast(), len) + }; + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } } } @@ -3910,20 +3998,12 @@ impl SslStream { impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - loop { - match self.ssl_read(buf) { - Ok(n) => return Ok(n), - Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), - Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => { - return Ok(0); - } - Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {} - Err(e) => { - return Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); - } - } + // SAFETY: `read_uninit` does not de-initialize the buffer + unsafe { + self.read_uninit(slice::from_raw_parts_mut( + buf.as_mut_ptr().cast::>(), + buf.len(), + )) } } }