Skip to content

Commit

Permalink
Port the mysql backend to use iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
Georg Semmler committed Jun 11, 2021
1 parent b80abaf commit 0d39aa3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
45 changes: 30 additions & 15 deletions diesel/src/mysql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct MysqlConnection {
raw_connection: RawConnection,
transaction_state: AnsiTransactionManager,
statement_cache: StatementCache<Mysql, Statement>,
current_statement: Option<Statement>,
}

unsafe impl Send for MysqlConnection {}
Expand Down Expand Up @@ -50,6 +51,7 @@ impl Connection for MysqlConnection {
raw_connection,
transaction_state: AnsiTransactionManager::default(),
statement_cache: StatementCache::new(),
current_statement: None,
};
conn.set_config_options()
.map_err(CouldntSetupConfiguration)?;
Expand All @@ -73,23 +75,35 @@ impl Connection for MysqlConnection {
T::Query: QueryFragment<Self::Backend> + QueryId,
Self::Backend: QueryMetadata<T::SqlType>,
{
let mut stmt = self.prepare_query(&source.as_query())?;
let mut metadata = Vec::new();
Mysql::row_metadata(&mut (), &mut metadata);
let results = unsafe { stmt.results(metadata)? };
Ok(results)
self.with_prepared_query(&source.as_query(), |stmt, current_statement| {
let mut metadata = Vec::new();
Mysql::row_metadata(&mut (), &mut metadata);
let stmt = match stmt {
MaybeCached::CannotCache(stmt) => {
*current_statement = Some(stmt);
current_statement
.as_mut()
.expect("We set it literally above")
}
MaybeCached::Cached(stmt) => stmt,
};

let results = unsafe { stmt.results(metadata)? };
Ok(results)
})
}

#[doc(hidden)]
fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
where
T: QueryFragment<Self::Backend> + QueryId,
{
let stmt = self.prepare_query(source)?;
unsafe {
stmt.execute()?;
}
Ok(stmt.affected_rows())
self.with_prepared_query(source, |stmt, _| {
unsafe {
stmt.execute()?;
}
Ok(stmt.affected_rows())
})
}

#[doc(hidden)]
Expand All @@ -99,10 +113,11 @@ impl Connection for MysqlConnection {
}

impl MysqlConnection {
fn prepare_query<T>(&mut self, source: &T) -> QueryResult<MaybeCached<Statement>>
where
T: QueryFragment<Mysql> + QueryId,
{
fn with_prepared_query<'a, T: QueryFragment<Mysql> + QueryId, R>(
&'a mut self,
source: &'_ T,
f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option<Statement>) -> QueryResult<R>,
) -> QueryResult<R> {
let cache = &mut self.statement_cache;
let conn = &mut self.raw_connection;

Expand All @@ -114,7 +129,7 @@ impl MysqlConnection {
.into_iter()
.zip(bind_collector.binds);
stmt.bind(binds)?;
Ok(stmt)
f(stmt, &mut self.current_statement)
}

fn set_config_options(&mut self) -> QueryResult<()> {
Expand Down
3 changes: 1 addition & 2 deletions diesel/src/mysql/connection/stmt/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::marker::PhantomData;
use std::rc::Rc;

use super::{Binds, Statement, StatementMetadata};
use super::metadata::MysqlFieldMetadata;
use crate::mysql::{Mysql, MysqlType};
use crate::result::QueryResult;
use crate::row::*;
Expand Down Expand Up @@ -152,7 +151,7 @@ pub struct MysqlField<'a> {
bind: Rc<Binds>,
metadata: Rc<StatementMetadata>,
idx: usize,
_marker: PhantomData<&'a (Binds, StatementMetadata)>
_marker: PhantomData<&'a (Binds, StatementMetadata)>,
}

impl<'a> Field<Mysql> for MysqlField<'a> {
Expand Down
8 changes: 4 additions & 4 deletions diesel/src/mysql/connection/stmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::os::raw as libc;
use std::ptr::NonNull;

use self::iterator::*;
use super::bind::{BindData, Binds};
use super::bind::Binds;
use crate::mysql::MysqlType;
use crate::result::{DatabaseErrorKind, Error, QueryResult};

Expand Down Expand Up @@ -80,10 +80,10 @@ impl Statement {
/// This function should be called instead of `execute` for queries which
/// have a return value. After calling this function, `execute` can never
/// be called on this statement.
pub unsafe fn results(
self,
pub unsafe fn results<'a>(
&'a mut self,
types: Vec<Option<MysqlType>>,
) -> QueryResult<StatementIterator> {
) -> QueryResult<StatementIterator<'a>> {
StatementIterator::new(self, types)
}

Expand Down

0 comments on commit 0d39aa3

Please sign in to comment.