diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 2cdcf3dda08f..449485ed5238 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -11,7 +11,6 @@ use crate::sqlite::{Sqlite, SqliteType}; use crate::util::OnceCell; use std::ffi::{CStr, CString}; use std::io::{stderr, Write}; -use std::mem::ManuallyDrop; use std::os::raw as libc; use std::ptr::{self, NonNull}; @@ -58,9 +57,10 @@ impl Statement { unsafe fn bind( &mut self, tpe: SqliteType, - value: &SqliteBindValue<'_>, + value: SqliteBindValue<'_>, bind_index: i32, - ) -> QueryResult<()> { + ) -> QueryResult>> { + let mut ret_ptr = None; let result = match (tpe, value) { (_, SqliteBindValue::Null) => { ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index) @@ -72,13 +72,27 @@ impl Statement { bytes.len() as libc::c_int, ffi::SQLITE_STATIC(), ), - (SqliteType::Binary, SqliteBindValue::Binary(bytes)) => ffi::sqlite3_bind_blob( - self.inner_statement.as_ptr(), - bind_index, - bytes.as_ptr() as *const libc::c_void, - bytes.len() as libc::c_int, - ffi::SQLITE_STATIC(), - ), + (SqliteType::Binary, SqliteBindValue::Binary(mut bytes)) => { + let len = bytes.len(); + // We need a seperate pointer here to pass it to sqlite + // as the returned pointer is a pointer to a dyn sized **slice** + // and not the pointer to the first element of the slice + let ptr; + ret_ptr = if len > 0 { + ptr = bytes.as_mut_ptr(); + NonNull::new(Box::into_raw(bytes)) + } else { + ptr = std::ptr::null_mut(); + None + }; + ffi::sqlite3_bind_blob( + self.inner_statement.as_ptr(), + bind_index, + ptr as *const libc::c_void, + len as libc::c_int, + ffi::SQLITE_STATIC(), + ) + } (SqliteType::Text, SqliteBindValue::BorrowedString(bytes)) => ffi::sqlite3_bind_text( self.inner_statement.as_ptr(), bind_index, @@ -86,25 +100,40 @@ impl Statement { bytes.len() as libc::c_int, ffi::SQLITE_STATIC(), ), - (SqliteType::Text, SqliteBindValue::String(bytes)) => ffi::sqlite3_bind_text( - self.inner_statement.as_ptr(), - bind_index, - bytes.as_ptr() as *const libc::c_char, - bytes.len() as libc::c_int, - ffi::SQLITE_STATIC(), - ), + (SqliteType::Text, SqliteBindValue::String(bytes)) => { + let mut bytes = Box::<[u8]>::from(bytes); + let len = bytes.len(); + // We need a seperate pointer here to pass it to sqlite + // as the returned pointer is a pointer to a dyn sized **slice** + // and not the pointer to the first element of the slice + let ptr; + ret_ptr = if len > 0 { + ptr = bytes.as_mut_ptr(); + NonNull::new(Box::into_raw(bytes)) + } else { + ptr = std::ptr::null_mut(); + None + }; + ffi::sqlite3_bind_text( + self.inner_statement.as_ptr(), + bind_index, + ptr as *const libc::c_char, + len as libc::c_int, + ffi::SQLITE_STATIC(), + ) + } (SqliteType::Float, SqliteBindValue::F64(value)) | (SqliteType::Double, SqliteBindValue::F64(value)) => ffi::sqlite3_bind_double( self.inner_statement.as_ptr(), bind_index, - *value as libc::c_double, + value as libc::c_double, ), (SqliteType::SmallInt, SqliteBindValue::I32(value)) | (SqliteType::Integer, SqliteBindValue::I32(value)) => { - ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, *value) + ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value) } (SqliteType::Long, SqliteBindValue::I64(value)) => { - ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, *value) + ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value) } (t, b) => { return Err(Error::SerializationError( @@ -112,7 +141,18 @@ impl Statement { )) } }; - ensure_sqlite_ok(result, self.raw_connection()) + match ensure_sqlite_ok(result, self.raw_connection()) { + Ok(()) => Ok(ret_ptr), + Err(e) => { + if let Some(ptr) = ret_ptr { + // This is a `NonNul` ptr so it cannot be null + // It points to a slice internally as we did not apply + // any cast above. + std::mem::drop(Box::from_raw(ptr.as_ptr())) + } + Err(e) + } + } } fn reset(&mut self) { @@ -180,23 +220,32 @@ impl Drop for Statement { } } +// A warning for future editiors: +// Changing this code to something "simplier" may +// introduce undefined behaviour. Make sure you read +// the following discussions for details about +// the current version: +// +// * https://github.com/weiznich/diesel/pull/7 +// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/ +// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194 struct BoundStatement<'stmt, 'query> { statement: MaybeCached<'stmt, Statement>, // we need to store the query here to ensure noone does // drop it till the end ot the statement // We use a boxed queryfragment here just to erase the - // generic type, we use ManuallyDrop to communicate + // generic type, we use NonNull to communicate // that this is a shared buffer - query: ManuallyDrop + 'query>>, + query: Option + 'query>>, // we need to store any owned bind values speratly, as they are not - // contained in the query itself. We use ManuallyDrop to + // contained in the query itself. We use NonNull to // communicate that this is a shared buffer - binds_to_free: ManuallyDrop>)>>, + binds_to_free: Vec<(i32, Option>)>, } impl<'stmt, 'query> BoundStatement<'stmt, 'query> { fn bind( - mut statement: MaybeCached<'stmt, Statement>, + statement: MaybeCached<'stmt, Statement>, query: T, ) -> QueryResult> where @@ -204,102 +253,78 @@ impl<'stmt, 'query> BoundStatement<'stmt, 'query> { { // Don't use a trait object here to prevent using a virtual function call // For sqlite this can introduce a measurable overhead - let mut query = ManuallyDrop::new(Box::new(query)); + let query = Box::new(query); let mut bind_collector = SqliteBindCollector::new(); query.collect_binds(&mut bind_collector, &mut ())?; let SqliteBindCollector { binds } = bind_collector; - let binds_to_free = match Self::bind_buffers(binds, &mut statement) { - Ok(value) => value, - Err(e) => { - unsafe { - // We return from this function afterwards and - // any buffer is already unbound by `bind_buffers` - // so it's safe to drop query now - ManuallyDrop::drop(&mut query); - } - return Err(e); - } - }; - - Ok(Self { + let mut ret = BoundStatement { statement, - binds_to_free, - query: ManuallyDrop::new( - // Cast to a trait object here, to erase the generic parameter T - ManuallyDrop::into_inner(query) as Box + 'query>, + query: None, + binds_to_free: Vec::with_capacity( + binds + .iter() + .filter(|&(b, _)| { + matches!( + b, + SqliteBindValue::BorrowedBinary(_) + | SqliteBindValue::BorrowedString(_) + | SqliteBindValue::String(_) + | SqliteBindValue::Binary(_) + ) + }) + .count(), ), - }) + }; + + ret.bind_buffers(binds)?; + + let query = query as Box + 'query>; + ret.query = NonNull::new(Box::into_raw(query)); + + Ok(ret) } // This is a seperate function so that // not the whole construtor is generic over the query type T. // This hopefully prevents binary bloat. - fn bind_buffers( - binds: Vec<(SqliteBindValue<'_>, SqliteType)>, - statement: &mut MaybeCached<'stmt, Statement>, - ) -> QueryResult>)>>> { - let mut binds_to_free = ManuallyDrop::new(Vec::with_capacity( - binds - .iter() - .filter(|&(b, _)| { - matches!( - b, - SqliteBindValue::BorrowedBinary(_) - | SqliteBindValue::BorrowedString(_) - | SqliteBindValue::String(_) - | SqliteBindValue::Binary(_) - ) - }) - .count(), - )); + fn bind_buffers(&mut self, binds: Vec<(SqliteBindValue<'_>, SqliteType)>) -> QueryResult<()> { for (bind_idx, (bind, tpe)) in (1..).zip(binds) { + if matches!( + bind, + SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_) + ) { + // Store the id's of borrowed binds to unbind them on drop + self.binds_to_free.push((bind_idx, None)); + } + // It's safe to call bind here as: // * The type and value matches // * We ensure that corresponding buffers lives long enough below // * The statement is not used yet by `step` or anything else - let res = unsafe { statement.bind(tpe, &bind, bind_idx) }; - - if let Err(e) = res { - Self::unbind_buffers(statement, &binds_to_free); - unsafe { - // It's safe to drop binds_to_free here as - // we've already unbound the buffers - ManuallyDrop::drop(&mut binds_to_free); - } - return Err(e); - } - - // We want to unbind the buffers later to ensure - // that sqlite does not access uninitilized memory - match bind { - SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_) => { - binds_to_free.push((bind_idx, None)); - } - SqliteBindValue::Binary(b) => { - binds_to_free.push((bind_idx, Some(SqliteBindValue::Binary(b)))); - } - SqliteBindValue::String(b) => { - binds_to_free.push((bind_idx, Some(SqliteBindValue::String(b)))); - } - SqliteBindValue::I32(_) - | SqliteBindValue::I64(_) - | SqliteBindValue::F64(_) - | SqliteBindValue::Null => {} + let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?; + if let Some(ptr) = res { + // Store the id + pointer for a owned bind + // as we must unbind and free them on drop + self.binds_to_free.push((bind_idx, Some(ptr))); } } - Ok(binds_to_free) + Ok(()) } +} + +impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> { + fn drop(&mut self) { + // First reset the statement, otherwise the bind calls + // below will fails + self.statement.reset(); - fn unbind_buffers( - stmt: &mut MaybeCached<'stmt, Statement>, - binds_to_free: &[(i32, Option>)], - ) { - for (idx, _buffer) in binds_to_free { + for (idx, buffer) in std::mem::take(&mut self.binds_to_free) { unsafe { // It's always safe to bind null values, as there is no buffer that needs to outlife something - stmt.bind(SqliteType::Text, &SqliteBindValue::Null, *idx) + self.statement + .bind(SqliteType::Text, SqliteBindValue::Null, idx) .expect( "Binding a null value should never fail. \ If you ever see this error message please open \ @@ -307,25 +332,23 @@ impl<'stmt, 'query> BoundStatement<'stmt, 'query> { code how to trigger this message.", ); } - } - } -} -impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> { - fn drop(&mut self) { - // First reset the statement, otherwise the bind calls - // below will fails - self.statement.reset(); + if let Some(buffer) = buffer { + unsafe { + // Constructing the `Box` here is safe as we + // got the pointer from a box + it is guarenteed to be not null. + std::mem::drop(Box::from_raw(buffer.as_ptr())); + } + } + } - // Reset the binds that may point to memory that will be/needs to be freed - Self::unbind_buffers(&mut self.statement, &self.binds_to_free); - unsafe { - // We unbound the corresponding buffers above, so it's fine to drop the - // owned binds now - ManuallyDrop::drop(&mut self.binds_to_free); - // We've dropped everything that could reference the query - // so it's safe to drop the query here - ManuallyDrop::drop(&mut self.query); + if let Some(query) = self.query { + unsafe { + // Constructing the `Box` here is safe as we + // got the pointer from a box + it is guarenteed to be not null. + std::mem::drop(Box::from_raw(query.as_ptr())); + } + self.query = None; } } }