From 0d39aa39f3b00992fd0758936c9e42c2228cf1c1 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 11 Jun 2021 12:31:50 +0200 Subject: [PATCH] Port the mysql backend to use iterators --- diesel/src/mysql/connection/mod.rs | 45 +++++++++++++------- diesel/src/mysql/connection/stmt/iterator.rs | 3 +- diesel/src/mysql/connection/stmt/mod.rs | 8 ++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index e5d4b486d4df..efae0ca6a6bf 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -20,6 +20,7 @@ pub struct MysqlConnection { raw_connection: RawConnection, transaction_state: AnsiTransactionManager, statement_cache: StatementCache, + current_statement: Option, } unsafe impl Send for MysqlConnection {} @@ -50,6 +51,7 @@ impl Connection for MysqlConnection { raw_connection, transaction_state: AnsiTransactionManager::default(), statement_cache: StatementCache::new(), + current_statement: None, }; conn.set_config_options() .map_err(CouldntSetupConfiguration)?; @@ -73,11 +75,22 @@ impl Connection for MysqlConnection { T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, { - let mut stmt = self.prepare_query(&source.as_query())?; - let mut metadata = Vec::new(); - Mysql::row_metadata(&mut (), &mut metadata); - let results = unsafe { stmt.results(metadata)? }; - Ok(results) + self.with_prepared_query(&source.as_query(), |stmt, current_statement| { + let mut metadata = Vec::new(); + Mysql::row_metadata(&mut (), &mut metadata); + let stmt = match stmt { + MaybeCached::CannotCache(stmt) => { + *current_statement = Some(stmt); + current_statement + .as_mut() + .expect("We set it literally above") + } + MaybeCached::Cached(stmt) => stmt, + }; + + let results = unsafe { stmt.results(metadata)? }; + Ok(results) + }) } #[doc(hidden)] @@ -85,11 +98,12 @@ impl Connection for MysqlConnection { where T: QueryFragment + QueryId, { - let stmt = self.prepare_query(source)?; - unsafe { - stmt.execute()?; - } - Ok(stmt.affected_rows()) + self.with_prepared_query(source, |stmt, _| { + unsafe { + stmt.execute()?; + } + Ok(stmt.affected_rows()) + }) } #[doc(hidden)] @@ -99,10 +113,11 @@ impl Connection for MysqlConnection { } impl MysqlConnection { - fn prepare_query(&mut self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - { + fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + &'a mut self, + source: &'_ T, + f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option) -> QueryResult, + ) -> QueryResult { let cache = &mut self.statement_cache; let conn = &mut self.raw_connection; @@ -114,7 +129,7 @@ impl MysqlConnection { .into_iter() .zip(bind_collector.binds); stmt.bind(binds)?; - Ok(stmt) + f(stmt, &mut self.current_statement) } fn set_config_options(&mut self) -> QueryResult<()> { diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index e6a78e48298b..854e5c86bee7 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use std::rc::Rc; use super::{Binds, Statement, StatementMetadata}; -use super::metadata::MysqlFieldMetadata; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; @@ -152,7 +151,7 @@ pub struct MysqlField<'a> { bind: Rc, metadata: Rc, idx: usize, - _marker: PhantomData<&'a (Binds, StatementMetadata)> + _marker: PhantomData<&'a (Binds, StatementMetadata)>, } impl<'a> Field for MysqlField<'a> { diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 31b556bf2bcb..6c828a786d7b 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -9,7 +9,7 @@ use std::os::raw as libc; use std::ptr::NonNull; use self::iterator::*; -use super::bind::{BindData, Binds}; +use super::bind::Binds; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, Error, QueryResult}; @@ -80,10 +80,10 @@ impl Statement { /// This function should be called instead of `execute` for queries which /// have a return value. After calling this function, `execute` can never /// be called on this statement. - pub unsafe fn results( - self, + pub unsafe fn results<'a>( + &'a mut self, types: Vec>, - ) -> QueryResult { + ) -> QueryResult> { StatementIterator::new(self, types) }