Skip to content

Commit

Permalink
Another improvment to the sqlite bind code
Browse files Browse the repository at this point in the history
Any rust container like `Box<T>`, `Vec<T>` or `String<T>` internally
contains a `Unique<T>` pointer, which communicates to the compiler that
this container is the owner of that memory location and all access goes
through that pointer. See
rust-lang/unsafe-code-guidelines#194 for
details. Passing out a pointer to the underlying buffer to sqlite could
cause UB according to this definition, at least if someone else accesses
the buffer through the originial pointer. To prevent that we temporarily
leak the Buffer and manage the pointer by ourself.

Additionally this change introduces a way to construct the
`BoundStatement` as early as possible as part of the
`BoundStatement::bind` function, so that all cleanup code can be
concetracted in the corresponding `Drop` impl
  • Loading branch information
weiznich committed Dec 15, 2021
1 parent 10cf0e5 commit bb79170
Showing 1 changed file with 139 additions and 116 deletions.
255 changes: 139 additions & 116 deletions diesel/src/sqlite/connection/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::sqlite::{Sqlite, SqliteType};
use crate::util::OnceCell;
use std::ffi::{CStr, CString};
use std::io::{stderr, Write};
use std::mem::ManuallyDrop;
use std::os::raw as libc;
use std::ptr::{self, NonNull};

Expand Down Expand Up @@ -58,9 +57,10 @@ impl Statement {
unsafe fn bind(
&mut self,
tpe: SqliteType,
value: &SqliteBindValue<'_>,
value: SqliteBindValue<'_>,
bind_index: i32,
) -> QueryResult<()> {
) -> QueryResult<Option<NonNull<[u8]>>> {
let mut ret_ptr = None;
let result = match (tpe, value) {
(_, SqliteBindValue::Null) => {
ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
Expand All @@ -72,47 +72,87 @@ impl Statement {
bytes.len() as libc::c_int,
ffi::SQLITE_STATIC(),
),
(SqliteType::Binary, SqliteBindValue::Binary(bytes)) => ffi::sqlite3_bind_blob(
self.inner_statement.as_ptr(),
bind_index,
bytes.as_ptr() as *const libc::c_void,
bytes.len() as libc::c_int,
ffi::SQLITE_STATIC(),
),
(SqliteType::Binary, SqliteBindValue::Binary(mut bytes)) => {
let len = bytes.len();
// We need a seperate pointer here to pass it to sqlite
// as the returned pointer is a pointer to a dyn sized **slice**
// and not the pointer to the first element of the slice
let ptr;
ret_ptr = if len > 0 {
ptr = bytes.as_mut_ptr();
NonNull::new(Box::into_raw(bytes))
} else {
ptr = std::ptr::null_mut();
None
};
ffi::sqlite3_bind_blob(
self.inner_statement.as_ptr(),
bind_index,
ptr as *const libc::c_void,
len as libc::c_int,
ffi::SQLITE_STATIC(),
)
}
(SqliteType::Text, SqliteBindValue::BorrowedString(bytes)) => ffi::sqlite3_bind_text(
self.inner_statement.as_ptr(),
bind_index,
bytes.as_ptr() as *const libc::c_char,
bytes.len() as libc::c_int,
ffi::SQLITE_STATIC(),
),
(SqliteType::Text, SqliteBindValue::String(bytes)) => ffi::sqlite3_bind_text(
self.inner_statement.as_ptr(),
bind_index,
bytes.as_ptr() as *const libc::c_char,
bytes.len() as libc::c_int,
ffi::SQLITE_STATIC(),
),
(SqliteType::Text, SqliteBindValue::String(bytes)) => {
let mut bytes = Box::<[u8]>::from(bytes);
let len = bytes.len();
// We need a seperate pointer here to pass it to sqlite
// as the returned pointer is a pointer to a dyn sized **slice**
// and not the pointer to the first element of the slice
let ptr;
ret_ptr = if len > 0 {
ptr = bytes.as_mut_ptr();
NonNull::new(Box::into_raw(bytes))
} else {
ptr = std::ptr::null_mut();
None
};
ffi::sqlite3_bind_text(
self.inner_statement.as_ptr(),
bind_index,
ptr as *const libc::c_char,
len as libc::c_int,
ffi::SQLITE_STATIC(),
)
}
(SqliteType::Float, SqliteBindValue::F64(value))
| (SqliteType::Double, SqliteBindValue::F64(value)) => ffi::sqlite3_bind_double(
self.inner_statement.as_ptr(),
bind_index,
*value as libc::c_double,
value as libc::c_double,
),
(SqliteType::SmallInt, SqliteBindValue::I32(value))
| (SqliteType::Integer, SqliteBindValue::I32(value)) => {
ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, *value)
ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
}
(SqliteType::Long, SqliteBindValue::I64(value)) => {
ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, *value)
ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
}
(t, b) => {
return Err(Error::SerializationError(
format!("Type missmatch: Expected {:?}, got {}", t, b).into(),
))
}
};
ensure_sqlite_ok(result, self.raw_connection())
match ensure_sqlite_ok(result, self.raw_connection()) {
Ok(()) => Ok(ret_ptr),
Err(e) => {
if let Some(ptr) = ret_ptr {
// This is a `NonNul` ptr so it cannot be null
// It points to a slice internally as we did not apply
// any cast above.
std::mem::drop(Box::from_raw(ptr.as_ptr()))
}
Err(e)
}
}
}

fn reset(&mut self) {
Expand Down Expand Up @@ -180,152 +220,135 @@ impl Drop for Statement {
}
}

// A warning for future editiors:
// Changing this code to something "simplier" may
// introduce undefined behaviour. Make sure you read
// the following discussions for details about
// the current version:
//
// * https://github.com/weiznich/diesel/pull/7
// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/
// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194
struct BoundStatement<'stmt, 'query> {
statement: MaybeCached<'stmt, Statement>,
// we need to store the query here to ensure noone does
// drop it till the end ot the statement
// We use a boxed queryfragment here just to erase the
// generic type, we use ManuallyDrop to communicate
// generic type, we use NonNull to communicate
// that this is a shared buffer
query: ManuallyDrop<Box<dyn QueryFragment<Sqlite> + 'query>>,
query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
// we need to store any owned bind values speratly, as they are not
// contained in the query itself. We use ManuallyDrop to
// contained in the query itself. We use NonNull to
// communicate that this is a shared buffer
binds_to_free: ManuallyDrop<Vec<(i32, Option<SqliteBindValue<'static>>)>>,
binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
}

impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
fn bind<T>(
mut statement: MaybeCached<'stmt, Statement>,
statement: MaybeCached<'stmt, Statement>,
query: T,
) -> QueryResult<BoundStatement<'stmt, 'query>>
where
T: QueryFragment<Sqlite> + QueryId + 'query,
{
// Don't use a trait object here to prevent using a virtual function call
// For sqlite this can introduce a measurable overhead
let mut query = ManuallyDrop::new(Box::new(query));
let query = Box::new(query);

let mut bind_collector = SqliteBindCollector::new();
query.collect_binds(&mut bind_collector, &mut ())?;
let SqliteBindCollector { binds } = bind_collector;

let binds_to_free = match Self::bind_buffers(binds, &mut statement) {
Ok(value) => value,
Err(e) => {
unsafe {
// We return from this function afterwards and
// any buffer is already unbound by `bind_buffers`
// so it's safe to drop query now
ManuallyDrop::drop(&mut query);
}
return Err(e);
}
};

Ok(Self {
let mut ret = BoundStatement {
statement,
binds_to_free,
query: ManuallyDrop::new(
// Cast to a trait object here, to erase the generic parameter T
ManuallyDrop::into_inner(query) as Box<dyn QueryFragment<Sqlite> + 'query>,
query: None,
binds_to_free: Vec::with_capacity(
binds
.iter()
.filter(|&(b, _)| {
matches!(
b,
SqliteBindValue::BorrowedBinary(_)
| SqliteBindValue::BorrowedString(_)
| SqliteBindValue::String(_)
| SqliteBindValue::Binary(_)
)
})
.count(),
),
})
};

ret.bind_buffers(binds)?;

let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
ret.query = NonNull::new(Box::into_raw(query));

Ok(ret)
}

// This is a seperate function so that
// not the whole construtor is generic over the query type T.
// This hopefully prevents binary bloat.
fn bind_buffers(
binds: Vec<(SqliteBindValue<'_>, SqliteType)>,
statement: &mut MaybeCached<'stmt, Statement>,
) -> QueryResult<ManuallyDrop<Vec<(i32, Option<SqliteBindValue<'static>>)>>> {
let mut binds_to_free = ManuallyDrop::new(Vec::with_capacity(
binds
.iter()
.filter(|&(b, _)| {
matches!(
b,
SqliteBindValue::BorrowedBinary(_)
| SqliteBindValue::BorrowedString(_)
| SqliteBindValue::String(_)
| SqliteBindValue::Binary(_)
)
})
.count(),
));
fn bind_buffers(&mut self, binds: Vec<(SqliteBindValue<'_>, SqliteType)>) -> QueryResult<()> {
for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
if matches!(
bind,
SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_)
) {
// Store the id's of borrowed binds to unbind them on drop
self.binds_to_free.push((bind_idx, None));
}

// It's safe to call bind here as:
// * The type and value matches
// * We ensure that corresponding buffers lives long enough below
// * The statement is not used yet by `step` or anything else
let res = unsafe { statement.bind(tpe, &bind, bind_idx) };

if let Err(e) = res {
Self::unbind_buffers(statement, &binds_to_free);
unsafe {
// It's safe to drop binds_to_free here as
// we've already unbound the buffers
ManuallyDrop::drop(&mut binds_to_free);
}
return Err(e);
}

// We want to unbind the buffers later to ensure
// that sqlite does not access uninitilized memory
match bind {
SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_) => {
binds_to_free.push((bind_idx, None));
}
SqliteBindValue::Binary(b) => {
binds_to_free.push((bind_idx, Some(SqliteBindValue::Binary(b))));
}
SqliteBindValue::String(b) => {
binds_to_free.push((bind_idx, Some(SqliteBindValue::String(b))));
}
SqliteBindValue::I32(_)
| SqliteBindValue::I64(_)
| SqliteBindValue::F64(_)
| SqliteBindValue::Null => {}
let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
if let Some(ptr) = res {
// Store the id + pointer for a owned bind
// as we must unbind and free them on drop
self.binds_to_free.push((bind_idx, Some(ptr)));
}
}
Ok(binds_to_free)
Ok(())
}
}

impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> {
fn drop(&mut self) {
// First reset the statement, otherwise the bind calls
// below will fails
self.statement.reset();

fn unbind_buffers(
stmt: &mut MaybeCached<'stmt, Statement>,
binds_to_free: &[(i32, Option<SqliteBindValue<'static>>)],
) {
for (idx, _buffer) in binds_to_free {
for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
unsafe {
// It's always safe to bind null values, as there is no buffer that needs to outlife something
stmt.bind(SqliteType::Text, &SqliteBindValue::Null, *idx)
self.statement
.bind(SqliteType::Text, SqliteBindValue::Null, idx)
.expect(
"Binding a null value should never fail. \
If you ever see this error message please open \
an issue at diesels issue tracker containing \
code how to trigger this message.",
);
}
}
}
}

impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> {
fn drop(&mut self) {
// First reset the statement, otherwise the bind calls
// below will fails
self.statement.reset();
if let Some(buffer) = buffer {
unsafe {
// Constructing the `Box` here is safe as we
// got the pointer from a box + it is guarenteed to be not null.
std::mem::drop(Box::from_raw(buffer.as_ptr()));
}
}
}

// Reset the binds that may point to memory that will be/needs to be freed
Self::unbind_buffers(&mut self.statement, &self.binds_to_free);
unsafe {
// We unbound the corresponding buffers above, so it's fine to drop the
// owned binds now
ManuallyDrop::drop(&mut self.binds_to_free);
// We've dropped everything that could reference the query
// so it's safe to drop the query here
ManuallyDrop::drop(&mut self.query);
if let Some(query) = self.query {
unsafe {
// Constructing the `Box` here is safe as we
// got the pointer from a box + it is guarenteed to be not null.
std::mem::drop(Box::from_raw(query.as_ptr()));
}
self.query = None;
}
}
}
Expand Down

0 comments on commit bb79170

Please sign in to comment.