diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index e650363cc53b..df61c69e6a8a 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -16,7 +16,7 @@ edition = "2018" byteorder = "1.0" chrono = { version = "0.4.19", optional = true, default-features = false, features = ["clock", "std"] } libc = { version = "0.2.0", optional = true } -libsqlite3-sys = { version = ">=0.8.0, <0.23.0", optional = true, features = ["min_sqlite_version_3_7_16"] } +libsqlite3-sys = { version = ">=0.8.0, <0.23.0", optional = true, features = ["bundled_bindings"] } mysqlclient-sys = { version = "0.2.0", optional = true } pq-sys = { version = "0.4.0", optional = true } quickcheck = { version = "0.9.0", optional = true } @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["with-deprecated", "32-column-tables"] +default = ["32-column-tables", "with-deprecated"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 1cc4dbf742d4..86ee9cc760c7 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -6,14 +6,12 @@ mod transaction_manager; use std::fmt::Debug; use crate::backend::Backend; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::*; #[doc(hidden)] -pub use self::statement_cache::{MaybeCached, StatementCache, StatementCacheKey}; +pub use self::statement_cache::{MaybeCached, PrepareForCache, StatementCache, StatementCacheKey}; pub use self::transaction_manager::{AnsiTransactionManager, TransactionManager}; /// Perform simple operations on a backend. @@ -27,8 +25,25 @@ pub trait SimpleConnection { fn batch_execute(&mut self, query: &str) -> QueryResult<()>; } +/// This trait describes which cursor type is used by a given connection +/// implementation. This trait is only useful in combination with [`Connection`]. +/// +/// Implementation wise this is a workaround for GAT's +pub trait ConnectionGatWorkaround<'a, DB: Backend> { + /// The cursor type returned by [`Connection::load`] + /// + /// Users should handle this as opaque type that implements [`Iterator`] + type Cursor: Iterator>; + /// The row type used as [`Iterator::Item`] for the iterator implementation + /// of [`ConnectionGatWorkaround::Cursor`] + type Row: crate::row::Row<'a, DB>; +} + /// A connection to a database -pub trait Connection: SimpleConnection + Sized + Send { +pub trait Connection: SimpleConnection + Sized + Send +where + Self: for<'a> ConnectionGatWorkaround<'a, ::Backend>, +{ /// The backend this type connects to type Backend: Backend; @@ -177,12 +192,13 @@ pub trait Connection: SimpleConnection + Sized + Send { fn execute(&mut self, query: &str) -> QueryResult; #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load( + &mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata; #[doc(hidden)] diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache.rs index 06323589efe0..04970c4c69d1 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache.rs @@ -107,6 +107,12 @@ pub struct StatementCache { pub cache: HashMap, Statement>, } +#[derive(Debug, Clone, Copy)] +pub enum PrepareForCache { + Yes, + No, +} + #[allow(clippy::len_without_is_empty, clippy::new_without_default)] impl StatementCache where @@ -133,7 +139,7 @@ where ) -> QueryResult> where T: QueryFragment + QueryId, - F: FnOnce(&str) -> QueryResult, + F: FnOnce(&str, PrepareForCache) -> QueryResult, { use std::collections::hash_map::Entry::{Occupied, Vacant}; @@ -141,7 +147,7 @@ where if !source.is_safe_to_cache_prepared()? { let sql = cache_key.sql(source)?; - return prepare_fn(&sql).map(MaybeCached::CannotCache); + return prepare_fn(&sql, PrepareForCache::No).map(MaybeCached::CannotCache); } let cached_result = match self.cache.entry(cache_key) { @@ -149,7 +155,7 @@ where Vacant(entry) => { let statement = { let sql = entry.key().sql(source)?; - prepare_fn(&sql) + prepare_fn(&sql, PrepareForCache::Yes) }; entry.insert(statement?) diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index a3b894c5bec7..14994cf84c39 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -11,11 +11,23 @@ use crate::mysql::types::MYSQL_TIME; use crate::mysql::{MysqlType, MysqlValue}; use crate::result::QueryResult; -pub struct Binds { +pub struct PreparedStatementBinds(Binds); + +pub struct OutputBinds(Binds); + +impl Clone for OutputBinds { + fn clone(&self) -> Self { + Self(Binds { + data: self.0.data.clone(), + }) + } +} + +struct Binds { data: Vec, } -impl Binds { +impl PreparedStatementBinds { pub fn from_input_data(input: Iter) -> QueryResult where Iter: IntoIterator>)>, @@ -25,34 +37,31 @@ impl Binds { .map(BindData::for_input) .collect::>(); - Ok(Binds { data }) + Ok(Self(Binds { data })) } - pub fn from_output_types(types: Vec>, metadata: &StatementMetadata) -> Self { + pub fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + self.0.with_mysql_binds(f) + } +} + +impl OutputBinds { + pub fn from_output_types(types: &[Option], metadata: &StatementMetadata) -> Self { let data = metadata .fields() .iter() - .zip(types.into_iter().chain(std::iter::repeat(None))) + .zip(types.iter().copied().chain(std::iter::repeat(None))) .map(|(field, tpe)| BindData::for_output(tpe, field)) .collect(); - Binds { data } - } - - pub fn with_mysql_binds(&mut self, f: F) -> T - where - F: FnOnce(*mut ffi::MYSQL_BIND) -> T, - { - let mut binds = self - .data - .iter_mut() - .map(|x| unsafe { x.mysql_bind() }) - .collect::>(); - f(binds.as_mut_ptr()) + Self(Binds { data }) } pub fn populate_dynamic_buffers(&mut self, stmt: &Statement) -> QueryResult<()> { - for (i, data) in self.data.iter_mut().enumerate() { + for (i, data) in self.0.data.iter_mut().enumerate() { data.did_numeric_overflow_occur()?; // This is safe because we are re-binding the invalidated buffers // at the end of this function @@ -69,20 +78,37 @@ impl Binds { } pub fn update_buffer_lengths(&mut self) { - for data in &mut self.data { + for data in &mut self.0.data { data.update_buffer_length(); } } - pub fn len(&self) -> usize { - self.data.len() + pub fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + self.0.with_mysql_binds(f) + } +} + +impl Binds { + fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + let mut binds = self + .data + .iter_mut() + .map(|x| unsafe { x.mysql_bind() }) + .collect::>(); + f(binds.as_mut_ptr()) } } -impl Index for Binds { +impl Index for OutputBinds { type Output = BindData; fn index(&self, index: usize) -> &Self::Output { - &self.data[index] + &self.0.data[index] } } @@ -122,7 +148,7 @@ impl From for Flags { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BindData { tpe: ffi::enum_field_types, bytes: Vec, @@ -713,9 +739,8 @@ mod tests { ) .unwrap(); - let mut stmt = conn - .prepare_query(&crate::sql_query( - "SELECT + let mut stmt = conn.prepared_query(&crate::sql_query( + "SELECT tiny_int, small_int, medium_int, int_col, big_int, unsigned_int, zero_fill_int, numeric_col, decimal_col, float_col, double_col, bit_col, @@ -725,30 +750,21 @@ mod tests { ST_AsText(polygon_col), ST_AsText(multipoint_col), ST_AsText(multilinestring_col), ST_AsText(multipolygon_col), ST_AsText(geometry_collection), json_col FROM all_mysql_types", - )) - .unwrap(); + )).unwrap(); let metadata = stmt.metadata().unwrap(); let mut output_binds = - Binds::from_output_types(vec![None; metadata.fields().len()], &metadata); + OutputBinds::from_output_types(&vec![None; metadata.fields().len()], &metadata); stmt.execute_statement(&mut output_binds).unwrap(); stmt.populate_row_buffers(&mut output_binds).unwrap(); let results: Vec<(BindData, &_)> = output_binds + .0 .data .into_iter() .zip(metadata.fields()) .collect::>(); - macro_rules! matches { - ($expression:expr, $( $pattern:pat )|+ $( if $guard: expr )?) => { - match $expression { - $( $pattern )|+ $( if $guard )? => true, - _ => false - } - } - } - let tiny_int_col = &results[0].0; assert_eq!(tiny_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TINY); assert!(tiny_int_col.flags.contains(Flags::NUM_FLAG)); @@ -1057,9 +1073,9 @@ mod tests { assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); assert_eq!( - to_value::(polygon_col).unwrap(), - "MULTIPOLYGON(((28 26,28 0,84 0,84 42,28 26),(52 18,66 23,73 9,48 6,52 18)),((59 18,67 18,67 13,59 13,59 18)))" - ); + to_value::(polygon_col).unwrap(), + "MULTIPOLYGON(((28 26,28 0,84 0,84 42,28 26),(52 18,66 23,73 9,48 6,52 18)),((59 18,67 18,67 13,59 13,59 18)))" + ); let geometry_collection = &results[32].0; assert_eq!( @@ -1105,12 +1121,12 @@ mod tests { let bind = BindData::for_test_output(bind_tpe.into()); - let mut binds = Binds { data: vec![bind] }; + let mut binds = OutputBinds(Binds { data: vec![bind] }); stmt.execute_statement(&mut binds).unwrap(); stmt.populate_row_buffers(&mut binds).unwrap(); - binds.data.remove(0) + binds.0.data.remove(0) } fn input_bind( @@ -1144,9 +1160,9 @@ mod tests { is_truncated: None, }; - let binds = Binds { + let binds = PreparedStatementBinds(Binds { data: vec![id_bind, field_bind], - }; + }); stmt.input_bind(binds).unwrap(); stmt.did_an_error_occur().unwrap(); unsafe { diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 04093d901d09..8ba6eb950a32 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -4,15 +4,14 @@ mod stmt; mod url; use self::raw::RawConnection; +use self::stmt::iterator::StatementIterator; use self::stmt::Statement; use self::url::ConnectionOptions; use super::backend::Mysql; use crate::connection::*; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::*; #[allow(missing_debug_implementations, missing_copy_implementations)] @@ -33,6 +32,11 @@ impl SimpleConnection for MysqlConnection { } } +impl<'a> ConnectionGatWorkaround<'a, Mysql> for MysqlConnection { + type Cursor = self::stmt::iterator::StatementIterator<'a>; + type Row = self::stmt::iterator::MysqlRow; +} + impl Connection for MysqlConnection { type Backend = Mysql; type TransactionManager = AnsiTransactionManager; @@ -61,21 +65,21 @@ impl Connection for MysqlConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load( + &mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { - use crate::result::Error::DeserializationError; + let stmt = self.prepared_query(&source.as_query())?; - let mut stmt = self.prepare_query(&source.as_query())?; let mut metadata = Vec::new(); Mysql::row_metadata(&mut (), &mut metadata); - let results = unsafe { stmt.results(metadata)? }; - results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) + + StatementIterator::from_stmt(stmt, &metadata) } #[doc(hidden)] @@ -83,7 +87,7 @@ impl Connection for MysqlConnection { where T: QueryFragment + QueryId, { - let stmt = self.prepare_query(source)?; + let stmt = self.prepared_query(source)?; unsafe { stmt.execute()?; } @@ -97,14 +101,14 @@ impl Connection for MysqlConnection { } impl MysqlConnection { - fn prepare_query(&mut self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - { + fn prepared_query<'a, T: QueryFragment + QueryId>( + &'a mut self, + source: &'_ T, + ) -> QueryResult> { let cache = &mut self.statement_cache; let conn = &mut self.raw_connection; - let mut stmt = cache.cached_statement(source, &[], |sql| conn.prepare(sql))?; + let mut stmt = cache.cached_statement(source, &[], |sql, _| conn.prepare(sql))?; let mut bind_collector = RawBytesBindCollector::new(); source.collect_binds(&mut bind_collector, &mut ())?; let binds = bind_collector diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 584b66f949ae..d832f17e3dcf 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,78 +1,175 @@ -use super::{metadata::MysqlFieldMetadata, BindData, Binds, Statement, StatementMetadata}; +use std::cell::{Ref, RefCell}; +use std::rc::Rc; + +use super::{OutputBinds, Statement, StatementMetadata}; +use crate::connection::MaybeCached; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; +#[allow(missing_debug_implementations)] pub struct StatementIterator<'a> { - stmt: &'a mut Statement, - output_binds: Binds, - metadata: StatementMetadata, + stmt: MaybeCached<'a, Statement>, + last_row: Rc>, + metadata: Rc, + len: usize, } -#[allow(clippy::should_implement_trait)] // don't neet `Iterator` here impl<'a> StatementIterator<'a> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement, types: Vec>) -> QueryResult { + pub fn from_stmt( + mut stmt: MaybeCached<'a, Statement>, + types: &[Option], + ) -> QueryResult { let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_output_types(types, &metadata); + let mut output_binds = OutputBinds::from_output_types(types, &metadata); stmt.execute_statement(&mut output_binds)?; + let size = unsafe { stmt.result_size() }?; Ok(StatementIterator { + metadata: Rc::new(metadata), + last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))), + len: size, stmt, - output_binds, - metadata, }) } +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = QueryResult; + + fn next(&mut self) -> Option { + // check if we own the only instance of the bind buffer + // if that's the case we can reuse the underlying allocations + // if that's not the case, we need to copy the output bind buffers + // to somewhere else + let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) { + if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) { + self.stmt.populate_row_buffers(binds) + } else { + // any other state than `PrivateMysqlRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } else { + // The shared bind buffer is in use by someone else, + // this means we copy out the values and replace the used reference + // by the copied values. After this we can advance the statment + // another step + let mut last_row = { + let mut last_row = match self.last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + return Some(Err(crate::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(); + std::mem::replace(last_row, duplicated) + }; + let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row { + self.stmt.populate_row_buffers(binds) + } else { + // any other state than `PrivateMysqlRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + }; + self.last_row = Rc::new(RefCell::new(last_row)); + res + }; - pub fn map(mut self, mut f: F) -> QueryResult> + match res { + Ok(Some(())) => { + self.len = self.len.saturating_sub(1); + Some(Ok(MysqlRow { + metadata: self.metadata.clone(), + row: self.last_row.clone(), + })) + } + Ok(None) => None, + Err(e) => { + self.len = self.len.saturating_sub(1); + Some(Err(e)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } + + fn count(self) -> usize where - F: FnMut(MysqlRow) -> QueryResult, + Self: Sized, { - let mut results = Vec::new(); - while let Some(row) = self.next() { - results.push(f(row?)?); - } - Ok(results) + self.len() } +} - fn next(&mut self) -> Option> { - match self.stmt.populate_row_buffers(&mut self.output_binds) { - Ok(Some(())) => Some(Ok(MysqlRow { - col_idx: 0, - binds: &mut self.output_binds, - metadata: &self.metadata, - })), - Ok(None) => None, - Err(e) => Some(Err(e)), - } +impl<'a> ExactSizeIterator for StatementIterator<'a> { + fn len(&self) -> usize { + self.len } } #[derive(Clone)] -pub struct MysqlRow<'a> { - col_idx: usize, - binds: &'a Binds, - metadata: &'a StatementMetadata, +#[allow(missing_debug_implementations)] +pub struct MysqlRow { + row: Rc>, + metadata: Rc, } -impl<'a> Row<'a, Mysql> for MysqlRow<'a> { +enum PrivateMysqlRow { + Direct(OutputBinds), + Copied(OutputBinds), +} + +impl PrivateMysqlRow { + fn duplicate(&self) -> Self { + match self { + Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()), + } + } +} + +impl<'a> RowGatWorkaround<'a, Mysql> for MysqlRow { type Field = MysqlField<'a>; +} + +impl<'a> Row<'a, Mysql> for MysqlRow { type InnerPartialRow = Self; fn field_count(&self) -> usize { - self.binds.len() + self.metadata.fields().len() } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex, { let idx = self.idx(idx)?; Some(MysqlField { - bind: &self.binds[idx], - metadata: &self.metadata.fields()[idx], + binds: self.row.borrow(), + metadata: self.metadata.clone(), + idx, }) } @@ -81,7 +178,7 @@ impl<'a> Row<'a, Mysql> for MysqlRow<'a> { } } -impl<'a> RowIndex for MysqlRow<'a> { +impl RowIndex for MysqlRow { fn idx(&self, idx: usize) -> Option { if idx < self.field_count() { Some(idx) @@ -91,7 +188,7 @@ impl<'a> RowIndex for MysqlRow<'a> { } } -impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { +impl<'a> RowIndex<&'a str> for MysqlRow { fn idx(&self, idx: &'a str) -> Option { self.metadata .fields() @@ -102,21 +199,171 @@ impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { } } +#[allow(missing_debug_implementations)] pub struct MysqlField<'a> { - bind: &'a BindData, - metadata: &'a MysqlFieldMetadata<'a>, + binds: Ref<'a, PrivateMysqlRow>, + metadata: Rc, + idx: usize, } impl<'a> Field<'a, Mysql> for MysqlField<'a> { - fn field_name(&self) -> Option<&'a str> { - self.metadata.field_name() + fn field_name(&self) -> Option<&str> { + self.metadata.fields()[self.idx].field_name() } fn is_null(&self) -> bool { - self.bind.is_null() + match &*self.binds { + PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(), + } } - fn value(&self) -> Option> { - self.bind.value() + fn value(&self) -> Option> { + match &*self.binds { + PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(), + } + } +} + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(conn) + .unwrap(); + crate::sql_query("DELETE FROM users;") + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + { + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + assert_eq!(collected_rows.len(), 2); + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_fields); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = ( + Row::get(&second_row, 0).unwrap(), + Row::get(&second_row, 1).unwrap(), + ); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_fields); + + assert!(row_iter.next().is_none()); + + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let second_fields = ( + Row::get(&second_row, 0).unwrap(), + Row::get(&second_row, 1).unwrap(), + ); + + let first_values = (first_fields.0.value(), first_fields.1.value()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); } diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index a5f17a70792a..de774ee3db91 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -1,24 +1,24 @@ extern crate mysqlclient_sys as ffi; -mod iterator; +pub(super) mod iterator; mod metadata; +use std::convert::TryFrom; use std::ffi::CStr; use std::os::raw as libc; use std::ptr::NonNull; -use self::iterator::*; -use super::bind::{BindData, Binds}; +use super::bind::{OutputBinds, PreparedStatementBinds}; use crate::mysql::MysqlType; -use crate::result::{DatabaseErrorKind, QueryResult}; +use crate::result::{DatabaseErrorKind, Error, QueryResult}; pub use self::metadata::{MysqlFieldMetadata, StatementMetadata}; -#[allow(dead_code)] +#[allow(dead_code, missing_debug_implementations)] // https://github.com/rust-lang/rust/issues/81658 pub struct Statement { stmt: NonNull, - input_binds: Option, + input_binds: Option, } impl Statement { @@ -44,11 +44,14 @@ impl Statement { where Iter: IntoIterator>)>, { - let input_binds = Binds::from_input_data(binds)?; + let input_binds = PreparedStatementBinds::from_input_data(binds)?; self.input_bind(input_binds) } - pub(super) fn input_bind(&mut self, mut input_binds: Binds) -> QueryResult<()> { + pub(super) fn input_bind( + &mut self, + mut input_binds: PreparedStatementBinds, + ) -> QueryResult<()> { input_binds.with_mysql_binds(|bind_ptr| { // This relies on the invariant that the current value of `self.input_binds` // will not change without this function being called @@ -76,14 +79,11 @@ impl Statement { affected_rows as usize } - /// This function should be called instead of `execute` for queries which - /// have a return value. After calling this function, `execute` can never - /// be called on this statement. - pub unsafe fn results( - &mut self, - types: Vec>, - ) -> QueryResult { - StatementIterator::new(self, types) + /// This function should be called after `execute` only + /// otherwise it's not guranteed to return a valid result + pub(in crate::mysql::connection) unsafe fn result_size(&mut self) -> QueryResult { + let size = ffi::mysql_stmt_num_rows(self.stmt.as_ptr()); + usize::try_from(size).map_err(|e| Error::DeserializationError(Box::new(e))) } fn last_error_message(&self) -> String { @@ -153,7 +153,7 @@ impl Statement { } } - pub(super) fn execute_statement(&mut self, binds: &mut Binds) -> QueryResult<()> { + pub(super) fn execute_statement(&mut self, binds: &mut OutputBinds) -> QueryResult<()> { unsafe { binds.with_mysql_binds(|bind_ptr| self.bind_result(bind_ptr))?; self.execute()?; @@ -161,7 +161,7 @@ impl Statement { Ok(()) } - pub(super) fn populate_row_buffers(&self, binds: &mut Binds) -> QueryResult> { + pub(super) fn populate_row_buffers(&self, binds: &mut OutputBinds) -> QueryResult> { let next_row_result = unsafe { ffi::mysql_stmt_fetch(self.stmt.as_ptr()) }; match next_row_result as libc::c_uint { ffi::MYSQL_NO_DATA => Ok(None), diff --git a/diesel/src/mysql/types/mod.rs b/diesel/src/mysql/types/mod.rs index 76a646f40b30..fd7e4158adb0 100644 --- a/diesel/src/mysql/types/mod.rs +++ b/diesel/src/mysql/types/mod.rs @@ -55,7 +55,7 @@ impl FromSql for i8 { /// Represents the MySQL unsigned type. #[derive(Debug, Clone, Copy, Default, SqlType, QueryId)] -pub struct Unsigned(ST); +pub struct Unsigned(ST); impl Add for Unsigned where diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index 043f50f982be..c7c1329d829f 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -1,36 +1,39 @@ +use std::rc::Rc; + use super::result::PgResult; use super::row::PgRow; /// The type returned by various [`Connection`] methods. /// Acts as an iterator over `T`. -pub struct Cursor<'a> { +#[allow(missing_debug_implementations)] +pub struct Cursor { current_row: usize, - db_result: &'a PgResult, + db_result: Rc, } -impl<'a> Cursor<'a> { - pub(super) fn new(db_result: &'a PgResult) -> Self { +impl Cursor { + pub(super) fn new(db_result: PgResult) -> Self { Cursor { current_row: 0, - db_result, + db_result: Rc::new(db_result), } } } -impl<'a> ExactSizeIterator for Cursor<'a> { +impl ExactSizeIterator for Cursor { fn len(&self) -> usize { self.db_result.num_rows() - self.current_row } } -impl<'a> Iterator for Cursor<'a> { - type Item = PgRow<'a>; +impl Iterator for Cursor { + type Item = crate::QueryResult; fn next(&mut self) -> Option { if self.current_row < self.db_result.num_rows() { - let row = self.db_result.get_row(self.current_row); + let row = self.db_result.clone().get_row(self.current_row); self.current_row += 1; - Some(row) + Some(Ok(row)) } else { None } @@ -45,4 +48,141 @@ impl<'a> Iterator for Cursor<'a> { let len = self.len(); (len, Some(len)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } +} + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } + } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::pg::Pg; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().is_none()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + let row_iter1 = conn.load(&query).unwrap(); + let row_iter2 = conn.load(&query).unwrap(); + + for ((row1, row2), (expected_id, expected_name)) in row_iter1.zip(row_iter2).zip(expected) { + let (id1, name1) = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + Pg, + >>::build_from_row(&row1.unwrap()) + .unwrap(); + let (id2, name2) = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + Pg, + >>::build_from_row(&row2.unwrap()) + .unwrap(); + assert_eq!(id1, expected_id); + assert_eq!(id2, expected_id); + assert_eq!(name1, expected_name); + assert_eq!(name2, expected_name); + } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index dfbf680b5e1a..36b1e35994b7 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -1,4 +1,4 @@ -mod cursor; +pub(crate) mod cursor; pub mod raw; #[doc(hidden)] pub mod result; @@ -13,15 +13,12 @@ use self::raw::RawConnection; use self::result::PgResult; use self::stmt::Statement; use crate::connection::*; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::pg::metadata_lookup::{GetPgMetadataCache, PgMetadataCache}; use crate::pg::{Pg, TransactionBuilder}; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::ConnectionError::CouldntSetupConfiguration; -use crate::result::Error::DeserializationError; use crate::result::*; /// The connection string expected by `PgConnection::establish` @@ -46,6 +43,11 @@ impl SimpleConnection for PgConnection { } } +impl<'a> ConnectionGatWorkaround<'a, Pg> for PgConnection { + type Cursor = Cursor; + type Row = self::row::PgRow; +} + impl Connection for PgConnection { type Backend = Pg; type TransactionManager = AnsiTransactionManager; @@ -70,21 +72,17 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load(&mut self, source: T) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { self.with_prepared_query(&source.as_query(), |stmt, params, conn| { let result = stmt.execute(conn, ¶ms)?; - let cursor = Cursor::new(&result); + let cursor = Cursor::new(result); - cursor - .map(|row| U::build_from_row(&row).map_err(DeserializationError)) - .collect::>>() + Ok(cursor) }) } @@ -140,13 +138,13 @@ impl PgConnection { TransactionBuilder::new(self) } - fn with_prepared_query + QueryId, R>( - &mut self, - source: &T, + fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + &'a mut self, + source: &'_ T, f: impl FnOnce( MaybeCached, Vec>>, - &mut RawConnection, + &'a mut RawConnection, ) -> QueryResult, ) -> QueryResult { let mut bind_collector = RawBytesBindCollector::::new(); @@ -157,7 +155,7 @@ impl PgConnection { let cache_len = self.statement_cache.len(); let cache = &mut self.statement_cache; let raw_conn = &mut self.raw_connection; - let query = cache.cached_statement(source, &metadata, |sql| { + let query = cache.cached_statement(source, &metadata, |sql, _| { let query_name = if source.is_safe_to_cache_prepared()? { Some(format!("__diesel_stmt_{}", cache_len)) } else { diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 34ccb2550cfc..7fdc401a33a7 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -4,20 +4,26 @@ use self::pq_sys::*; use std::ffi::CStr; use std::num::NonZeroU32; use std::os::raw as libc; +use std::rc::Rc; use std::{slice, str}; use super::raw::RawResult; use super::row::PgRow; use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryResult}; +use crate::util::OnceCell; // Message after a database connection has been unexpectedly closed. const CLOSED_CONNECTION_MSG: &str = "server closed the connection unexpectedly\n\t\ This probably means the server terminated abnormally\n\tbefore or while processing the request.\n"; -pub struct PgResult { +pub(crate) struct PgResult { internal_result: RawResult, column_count: usize, row_count: usize, + // We store field names as pointer + // as we cannot put a correct lifetime here + // The value is valid as long as we haven't freed `RawResult` + column_name_map: OnceCell>>, } impl PgResult { @@ -32,6 +38,7 @@ impl PgResult { internal_result, column_count, row_count, + column_name_map: OnceCell::new(), }) } ExecStatusType::PGRES_EMPTY_QUERY => { @@ -89,7 +96,7 @@ impl PgResult { self.row_count } - pub fn get_row(&self, idx: usize) -> PgRow { + pub fn get_row(self: Rc, idx: usize) -> PgRow { PgRow::new(self, idx) } @@ -127,17 +134,33 @@ impl PgResult { } pub fn column_name(&self, col_idx: usize) -> Option<&str> { - unsafe { - let ptr = PQfname(self.internal_result.as_ptr(), col_idx as libc::c_int); - if ptr.is_null() { - None - } else { - Some(CStr::from_ptr(ptr).to_str().expect( - "Expect postgres field names to be UTF-8, because we \ + self.column_name_map + .get_or_init(|| { + (0..self.column_count) + .map(|idx| unsafe { + // https://www.postgresql.org/docs/13/libpq-exec.html#LIBPQ-PQFNAME + // states that the returned ptr is valid till the underlying result is freed + // That means we can couple the lifetime to self + let ptr = PQfname(self.internal_result.as_ptr(), idx as libc::c_int); + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr).to_str().expect( + "Expect postgres field names to be UTF-8, because we \ requested UTF-8 encoding on connection setup", - )) - } - } + ) as *const str) + } + }) + .collect() + }) + .get(col_idx) + .and_then(|n| { + n.map(|n: *const str| unsafe { + // The pointer is valid for the same lifetime as &self + // so we can dereference it without any check + &*n + }) + }) } pub fn column_count(&self) -> usize { diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index a3d9c9d76c32..f689fa21e3fd 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -1,34 +1,41 @@ use super::result::PgResult; +use crate::pg::value::TypeOidLookup; use crate::pg::{Pg, PgValue}; use crate::row::*; +use std::rc::Rc; #[derive(Clone)] -pub struct PgRow<'a> { - db_result: &'a PgResult, +#[allow(missing_debug_implementations)] +pub struct PgRow { + db_result: Rc, row_idx: usize, } -impl<'a> PgRow<'a> { - pub fn new(db_result: &'a PgResult, row_idx: usize) -> Self { +impl PgRow { + pub(crate) fn new(db_result: Rc, row_idx: usize) -> Self { PgRow { db_result, row_idx } } } -impl<'a> Row<'a, Pg> for PgRow<'a> { +impl<'a> RowGatWorkaround<'a, Pg> for PgRow { type Field = PgField<'a>; +} + +impl<'a> Row<'a, Pg> for PgRow { type InnerPartialRow = Self; fn field_count(&self) -> usize { self.db_result.column_count() } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex, { let idx = self.idx(idx)?; Some(PgField { - db_result: self.db_result, + db_result: &self.db_result, row_idx: self.row_idx, col_idx: idx, }) @@ -39,7 +46,7 @@ impl<'a> Row<'a, Pg> for PgRow<'a> { } } -impl<'a> RowIndex for PgRow<'a> { +impl RowIndex for PgRow { fn idx(&self, idx: usize) -> Option { if idx < self.field_count() { Some(idx) @@ -49,12 +56,13 @@ impl<'a> RowIndex for PgRow<'a> { } } -impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { +impl<'a> RowIndex<&'a str> for PgRow { fn idx(&self, field_name: &'a str) -> Option { (0..self.field_count()).find(|idx| self.db_result.column_name(*idx) == Some(field_name)) } } +#[allow(missing_debug_implementations)] pub struct PgField<'a> { db_result: &'a PgResult, row_idx: usize, @@ -62,14 +70,19 @@ pub struct PgField<'a> { } impl<'a> Field<'a, Pg> for PgField<'a> { - fn field_name(&self) -> Option<&'a str> { + fn field_name(&self) -> Option<&str> { self.db_result.column_name(self.col_idx) } - fn value(&self) -> Option> { + fn value(&self) -> Option> { let raw = self.db_result.get(self.row_idx, self.col_idx)?; - let type_oid = self.db_result.column_type(self.col_idx); - Some(PgValue::new(raw, type_oid)) + Some(PgValue::new(raw, self)) + } +} + +impl<'a> TypeOidLookup for PgField<'a> { + fn lookup(&self) -> std::num::NonZeroU32 { + self.db_result.column_type(self.col_idx) } } diff --git a/diesel/src/pg/connection/stmt/mod.rs b/diesel/src/pg/connection/stmt/mod.rs index e42382118e73..a7790ed1facd 100644 --- a/diesel/src/pg/connection/stmt/mod.rs +++ b/diesel/src/pg/connection/stmt/mod.rs @@ -10,17 +10,16 @@ use crate::result::QueryResult; pub use super::raw::RawConnection; -pub struct Statement { +pub(crate) struct Statement { name: CString, param_formats: Vec, } impl Statement { - #[allow(clippy::ptr_arg)] pub fn execute( &self, raw_connection: &mut RawConnection, - param_data: &Vec>>, + param_data: &[Option>], ) -> QueryResult { let params_pointer = param_data .iter() @@ -48,7 +47,6 @@ impl Statement { PgResult::new(internal_res?) } - #[allow(clippy::ptr_arg)] pub fn prepare( raw_connection: &mut RawConnection, sql: &str, diff --git a/diesel/src/pg/expression/array.rs b/diesel/src/pg/expression/array.rs index 8636d9b7dbe1..c135d9f01825 100644 --- a/diesel/src/pg/expression/array.rs +++ b/diesel/src/pg/expression/array.rs @@ -58,6 +58,7 @@ where impl Expression for ArrayLiteral where + ST: 'static, T: Expression, { type SqlType = sql_types::Array; diff --git a/diesel/src/pg/expression/array_comparison.rs b/diesel/src/pg/expression/array_comparison.rs index 4d3c636a2a1c..b37ad4627283 100644 --- a/diesel/src/pg/expression/array_comparison.rs +++ b/diesel/src/pg/expression/array_comparison.rs @@ -128,7 +128,7 @@ where impl_selectable_expression!(All); -pub trait AsArrayExpression { +pub trait AsArrayExpression { type Expression: Expression>; // This method is part of the public API @@ -139,6 +139,7 @@ pub trait AsArrayExpression { impl AsArrayExpression for T where + ST: 'static, T: AsExpression>, { type Expression = >>::Expression; @@ -151,6 +152,7 @@ where impl AsArrayExpression for SelectStatement where + ST: 'static, Self: SelectQuery, { type Expression = Subselect>; @@ -162,6 +164,7 @@ where impl<'a, ST, QS, DB, GB> AsArrayExpression for BoxedSelectStatement<'a, ST, QS, DB, GB> where + ST: 'static, Self: SelectQuery, { type Expression = Subselect>; diff --git a/diesel/src/pg/expression/expression_methods.rs b/diesel/src/pg/expression/expression_methods.rs index a40712aa1601..4a56a08d4185 100644 --- a/diesel/src/pg/expression/expression_methods.rs +++ b/diesel/src/pg/expression/expression_methods.rs @@ -611,7 +611,10 @@ pub trait RangeHelper: SqlType { type Inner; } -impl RangeHelper for Range { +impl RangeHelper for Range +where + Self: 'static, +{ type Inner = ST; } diff --git a/diesel/src/pg/types/array.rs b/diesel/src/pg/types/array.rs index 54098d3c7ee4..b6315eb60a1e 100644 --- a/diesel/src/pg/types/array.rs +++ b/diesel/src/pg/types/array.rs @@ -48,7 +48,7 @@ where } else { let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - T::from_sql(PgValue::new(elem_bytes, value.get_oid())) + T::from_sql(PgValue::new(elem_bytes, &value)) } }) .collect() @@ -60,7 +60,7 @@ use crate::expression::AsExpression; macro_rules! array_as_expression { ($ty:ty, $sql_type:ty) => { - impl<'a, 'b, ST, T> AsExpression<$sql_type> for $ty { + impl<'a, 'b, ST: 'static, T> AsExpression<$sql_type> for $ty { type Expression = Bound<$sql_type, Self>; fn as_expression(self) -> Self::Expression { @@ -126,6 +126,7 @@ where impl ToSql>, Pg> for [T] where [T]: ToSql, Pg>, + ST: 'static, { fn to_sql(&self, out: &mut Output) -> serialize::Result { ToSql::, Pg>::to_sql(self, out) @@ -134,6 +135,7 @@ where impl ToSql, Pg> for Vec where + ST: 'static, [T]: ToSql, Pg>, T: fmt::Debug, { @@ -144,6 +146,7 @@ where impl ToSql>, Pg> for Vec where + ST: 'static, Vec: ToSql, Pg>, { fn to_sql(&self, out: &mut Output) -> serialize::Result { diff --git a/diesel/src/pg/types/mod.rs b/diesel/src/pg/types/mod.rs index 81e843de4a15..d232b50d0015 100644 --- a/diesel/src/pg/types/mod.rs +++ b/diesel/src/pg/types/mod.rs @@ -99,7 +99,7 @@ pub mod sql_types { /// [Vec]: std::vec::Vec /// [slice]: https://doc.rust-lang.org/nightly/std/primitive.slice.html #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] - pub struct Array(ST); + pub struct Array(ST); /// The `Range` SQL type. /// @@ -117,7 +117,7 @@ pub mod sql_types { /// [`FromSql`]: crate::deserialize::FromSql /// [bound]: std::collections::Bound #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] - pub struct Range(ST); + pub struct Range(ST); #[doc(hidden)] pub type Int4range = Range; @@ -171,7 +171,7 @@ pub mod sql_types { /// [`WriteTuple`]: super::super::super::serialize::WriteTuple #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] #[postgres(oid = "2249", array_oid = "2287")] - pub struct Record(ST); + pub struct Record(ST); /// Alias for `SmallInt` pub type SmallSerial = crate::sql_types::SmallInt; diff --git a/diesel/src/pg/types/ranges.rs b/diesel/src/pg/types/ranges.rs index 02187c1a4cf3..98b443ed8d48 100644 --- a/diesel/src/pg/types/ranges.rs +++ b/diesel/src/pg/types/ranges.rs @@ -23,7 +23,7 @@ bitflags! { } } -impl AsExpression> for (Bound, Bound) { +impl AsExpression> for (Bound, Bound) { type Expression = SqlBound, Self>; fn as_expression(self) -> Self::Expression { @@ -31,7 +31,7 @@ impl AsExpression> for (Bound, Bound) { } } -impl<'a, ST, T> AsExpression> for &'a (Bound, Bound) { +impl<'a, ST: 'static, T> AsExpression> for &'a (Bound, Bound) { type Expression = SqlBound, Self>; fn as_expression(self) -> Self::Expression { @@ -39,7 +39,7 @@ impl<'a, ST, T> AsExpression> for &'a (Bound, Bound) { } } -impl AsExpression>> for (Bound, Bound) { +impl AsExpression>> for (Bound, Bound) { type Expression = SqlBound>, Self>; fn as_expression(self) -> Self::Expression { @@ -47,7 +47,7 @@ impl AsExpression>> for (Bound, Bound) { } } -impl<'a, ST, T> AsExpression>> for &'a (Bound, Bound) { +impl<'a, ST: 'static, T> AsExpression>> for &'a (Bound, Bound) { type Expression = SqlBound>, Self>; fn as_expression(self) -> Self::Expression { @@ -69,7 +69,7 @@ where let elem_size = bytes.read_i32::()?; let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - let value = T::from_sql(PgValue::new(elem_bytes, value.get_oid()))?; + let value = T::from_sql(PgValue::new(elem_bytes, &value))?; lower_bound = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) @@ -80,7 +80,7 @@ where if !flags.contains(RangeFlags::UB_INF) { let _size = bytes.read_i32::()?; - let value = T::from_sql(PgValue::new(bytes, value.get_oid()))?; + let value = T::from_sql(PgValue::new(bytes, &value))?; upper_bound = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) @@ -158,6 +158,7 @@ where impl ToSql>, Pg> for (Bound, Bound) where + ST: 'static, (Bound, Bound): ToSql, Pg>, { fn to_sql(&self, out: &mut Output) -> serialize::Result { diff --git a/diesel/src/pg/types/record.rs b/diesel/src/pg/types/record.rs index 34c923e96913..0b5e9186e7c5 100644 --- a/diesel/src/pg/types/record.rs +++ b/diesel/src/pg/types/record.rs @@ -54,7 +54,7 @@ macro_rules! tuple_impls { bytes = new_bytes; $T::from_sql(PgValue::new( elem_bytes, - oid, + &oid, ))? } },)+); @@ -147,6 +147,7 @@ where impl Expression for PgTuple where T: Expression, + T::SqlType: 'static, { type SqlType = Record; } diff --git a/diesel/src/pg/value.rs b/diesel/src/pg/value.rs index 24ae07c65392..63b42110602c 100644 --- a/diesel/src/pg/value.rs +++ b/diesel/src/pg/value.rs @@ -8,7 +8,32 @@ use std::ops::Range; #[allow(missing_debug_implementations)] pub struct PgValue<'a> { raw_value: &'a [u8], - type_oid: NonZeroU32, + type_oid_lookup: &'a dyn TypeOidLookup, +} + +pub(crate) trait TypeOidLookup { + fn lookup(&self) -> NonZeroU32; +} + +impl TypeOidLookup for F +where + F: Fn() -> NonZeroU32, +{ + fn lookup(&self) -> NonZeroU32 { + (self)() + } +} + +impl<'a> TypeOidLookup for PgValue<'a> { + fn lookup(&self) -> NonZeroU32 { + self.type_oid_lookup.lookup() + } +} + +impl TypeOidLookup for NonZeroU32 { + fn lookup(&self) -> NonZeroU32 { + *self + } } impl<'a> BinaryRawValue<'a> for Pg { @@ -20,16 +45,20 @@ impl<'a> BinaryRawValue<'a> for Pg { impl<'a> PgValue<'a> { #[cfg(test)] pub(crate) fn for_test(raw_value: &'a [u8]) -> Self { + static FAKE_OID: NonZeroU32 = unsafe { + // 42 != 0, so this is actually safe + NonZeroU32::new_unchecked(42) + }; Self { raw_value, - type_oid: NonZeroU32::new(42).unwrap(), + type_oid_lookup: &FAKE_OID, } } - pub(crate) fn new(raw_value: &'a [u8], type_oid: NonZeroU32) -> Self { + pub(crate) fn new(raw_value: &'a [u8], type_oid_lookup: &'a dyn TypeOidLookup) -> Self { Self { raw_value, - type_oid, + type_oid_lookup, } } @@ -40,7 +69,7 @@ impl<'a> PgValue<'a> { /// Get the type oid of this value pub fn get_oid(&self) -> NonZeroU32 { - self.type_oid + self.type_oid_lookup.lookup() } pub(crate) fn subslice(&self, range: Range) -> Self { diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index 6c4a3e9dbd30..35a1ee464ab3 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -1,6 +1,6 @@ use super::RunQueryDsl; use crate::backend::Backend; -use crate::connection::Connection; +use crate::connection::{Connection, ConnectionGatWorkaround}; use crate::deserialize::FromSqlRow; use crate::expression::{select_by::SelectBy, Expression, QueryMetadata, Selectable}; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; @@ -13,9 +13,19 @@ use crate::result::QueryResult; /// to call `load` from generic code. /// /// [`RunQueryDsl`]: crate::RunQueryDsl -pub trait LoadQuery: RunQueryDsl { +pub trait LoadQuery: RunQueryDsl +where + for<'a> Self: LoadQueryGatWorkaround<'a, Conn, U>, +{ /// Load this query - fn internal_load(self, conn: &mut Conn) -> QueryResult>; + fn internal_load( + self, + conn: &mut Conn, + ) -> QueryResult<>::Ret>; +} + +pub trait LoadQueryGatWorkaround<'a, Conn, U> { + type Ret: Iterator>; } use crate::expression::TypedExpressionType; @@ -53,17 +63,49 @@ where type SqlType = ST; } -impl LoadQuery for T +#[allow(missing_debug_implementations)] +pub struct LoadIter<'a, U, C, ST, DB> { + cursor: C, + _marker: std::marker::PhantomData<&'a (ST, U, DB)>, +} + +impl<'a, Conn, T, U, DB> LoadQueryGatWorkaround<'a, Conn, U> for T +where + Conn: Connection, + T: AsQuery + RunQueryDsl, + T::Query: QueryFragment + QueryId, + T::SqlType: CompatibleType, + DB: Backend + QueryMetadata + 'static, + U: FromSqlRow<>::SqlType, DB> + 'static, + >::SqlType: 'static, +{ + type Ret = LoadIter< + 'a, + U, + >::Cursor, + >::SqlType, + DB, + >; +} + +impl LoadQuery for T where - Conn: Connection, + Conn: Connection, T: AsQuery + RunQueryDsl, - T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - Conn::Backend: QueryMetadata, - U: FromSqlRow<>::SqlType, Conn::Backend>, + T::Query: QueryFragment + QueryId, + T::SqlType: CompatibleType, + DB: Backend + QueryMetadata + 'static, + U: FromSqlRow<>::SqlType, DB> + 'static, + >::SqlType: 'static, { - fn internal_load(self, conn: &mut Conn) -> QueryResult> { - conn.load(self) + fn internal_load( + self, + conn: &mut Conn, + ) -> QueryResult<>::Ret> { + Ok(LoadIter { + cursor: conn.load(self)?, + _marker: Default::default(), + }) } } @@ -91,3 +133,68 @@ where conn.execute_returning_count(&query) } } + +impl<'a, C, U, ST, DB, R> LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + fn map_row(row: Option>) -> Option> { + match row? { + Ok(row) => { + Some(U::build_from_row(&row).map_err(crate::result::Error::DeserializationError)) + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a, C, U, ST, DB, R> Iterator for LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + type Item = QueryResult; + + fn next(&mut self) -> Option { + Self::map_row(self.cursor.next()) + } + + fn size_hint(&self) -> (usize, Option) { + self.cursor.size_hint() + } + + fn count(self) -> usize + where + Self: Sized, + { + self.cursor.count() + } + + fn last(self) -> Option + where + Self: Sized, + { + Self::map_row(self.cursor.last()) + } + + fn nth(&mut self, n: usize) -> Option { + Self::map_row(self.cursor.nth(n)) + } +} + +impl<'a, C, U, ST, DB, R> ExactSizeIterator for LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: ExactSizeIterator + Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + fn len(&self) -> usize { + self.cursor.len() + } +} diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index dc0ab22a4081..fd8696d6378e 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -18,7 +18,7 @@ use crate::expression::Expression; use crate::helper_types::*; use crate::query_builder::locking_clause as lock; use crate::query_source::{joins, Table}; -use crate::result::{first_or_not_found, QueryResult}; +use crate::result::QueryResult; mod belonging_to_dsl; #[doc(hidden)] @@ -1307,12 +1307,10 @@ pub trait RunQueryDsl: Sized { methods::ExecuteDsl::execute(self, conn) } - /// Executes the given query, returning a `Vec` with the returned rows. + /// Executes the given query, returning a [`Vec`] with the returned rows. /// - /// When using the query builder, - /// the return type can be - /// a tuple of the values, - /// or a struct which implements [`Queryable`]. + /// When using the query builder, the return type can be + /// a tuple of the values, or a struct which implements [`Queryable`]. /// /// When this method is called on [`sql_query`], /// the return type can only be a struct which implements [`QueryableByName`] @@ -1404,7 +1402,116 @@ pub trait RunQueryDsl: Sized { where Self: LoadQuery, { - self.internal_load(conn) + self.internal_load(conn)?.collect() + } + + /// Executes the given query, returning a [`Iterator`] with the returned rows. + /// + /// **You should normally prefer to use [`RunQueryDsl::load`] instead**. This method + /// is provided for situations where the result needs to be collected into a different + /// container than a [`Vec`] + /// + /// When using the query builder, the return type can be + /// a tuple of the values, or a struct which implements [`Queryable`]. + /// + /// When this method is called on [`sql_query`], + /// the return type can only be a struct which implements [`QueryableByName`] + /// + /// For insert, update, and delete operations where only a count of affected is needed, + /// [`execute`] should be used instead. + /// + /// [`Queryable`]: crate::deserialize::Queryable + /// [`QueryableByName`]: crate::deserialize::QueryableByName + /// [`execute`]: crate::query_dsl::RunQueryDsl::execute() + /// [`sql_query`]: crate::sql_query() + /// + /// # Examples + /// + /// ## Returning a single field + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users.select(name) + /// .load_iter::(connection)? + /// .collect::>>()?; + /// assert_eq!(vec!["Sean", "Tess"], data); + /// # Ok(()) + /// # } + /// ``` + /// + /// ## Returning a tuple + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users + /// .load_iter::<(i32, String)>(connection)? + /// .collect::>>()?; + /// let expected_data = vec![ + /// (1, String::from("Sean")), + /// (2, String::from("Tess")), + /// ]; + /// assert_eq!(expected_data, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// ## Returning a struct + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// #[derive(Queryable, PartialEq, Debug)] + /// struct User { + /// id: i32, + /// name: String, + /// } + /// + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users + /// .load_iter::(connection)? + /// .collect::>>()?; + /// let expected_data = vec![ + /// User { id: 1, name: String::from("Sean") }, + /// User { id: 2, name: String::from("Tess") }, + /// ]; + /// assert_eq!(expected_data, data); + /// # Ok(()) + /// # } + /// ``` + fn load_iter<'a, U>( + self, + conn: &'a mut Conn, + ) -> QueryResult> + 'a>> + where + U: 'a, + Self: LoadQuery + 'a, + { + self.internal_load(conn).map(|i| Box::new(i) as Box<_>) } /// Runs the command, and returns the affected row. @@ -1456,7 +1563,10 @@ pub trait RunQueryDsl: Sized { where Self: LoadQuery, { - first_or_not_found(self.load(conn)) + match self.internal_load(conn)?.next() { + Some(v) => v, + None => Err(crate::result::Error::NotFound), + } } /// Runs the command, returning an `Vec` with the affected rows. diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 932e75e6112a..5d24c3b7d2a9 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -16,12 +16,11 @@ use std::convert::Into; use std::fmt; use std::marker::PhantomData; -use crate::connection::{SimpleConnection, TransactionManager}; -use crate::deserialize::FromSqlRow; +use crate::backend::Backend; +use crate::connection::{ConnectionGatWorkaround, SimpleConnection, TransactionManager}; use crate::expression::QueryMetadata; use crate::prelude::*; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::query_dsl::load_dsl::CompatibleType; /// An r2d2 connection manager for use with Diesel. /// @@ -131,6 +130,16 @@ where } } +impl<'a, DB, M> ConnectionGatWorkaround<'a, DB> for PooledConnection +where + M: ManageConnection, + M::Connection: Connection, + DB: Backend, +{ + type Cursor = >::Cursor; + type Row = >::Row; +} + impl Connection for PooledConnection where M: ManageConnection, @@ -150,12 +159,13 @@ where (&mut **self).execute(query) } - fn load(&mut self, source: T) -> QueryResult> + fn load( + &mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { (&mut **self).load(source) diff --git a/diesel/src/result.rs b/diesel/src/result.rs index 68d7c4049b6d..de8b7d0dedc7 100644 --- a/diesel/src/result.rs +++ b/diesel/src/result.rs @@ -360,10 +360,6 @@ fn error_impls_send() { let x: &Send = &err; } -pub(crate) fn first_or_not_found(records: QueryResult>) -> QueryResult { - records?.into_iter().next().ok_or(Error::NotFound) -} - /// An unexpected `NULL` was encountered during deserialization #[derive(Debug, Clone, Copy)] pub struct UnexpectedNullError; diff --git a/diesel/src/row.rs b/diesel/src/row.rs index 9d9c29c555a6..6f7635d5b0d0 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -21,12 +21,9 @@ pub trait RowIndex { fn idx(&self, idx: I) -> Option; } -/// Represents a single database row. -/// -/// This trait is used as an argument to [`FromSqlRow`]. -/// -/// [`FromSqlRow`]: crate::deserialize::FromSqlRow -pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Sized { +/// A helper trait to indicate the life time bound for a field returned +/// by [`Row::get`] +pub trait RowGatWorkaround<'a, DB: Backend> { /// Field type returned by a `Row` implementation /// /// * Crates using existing backend should not concern themself with the @@ -35,7 +32,16 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// * Crates implementing custom backends should provide their own type /// meeting the required trait bounds type Field: Field<'a, DB>; +} +/// Represents a single database row. +/// +/// This trait is used as an argument to [`FromSqlRow`]. +/// +/// [`FromSqlRow`]: crate::deserialize::FromSqlRow +pub trait Row<'a, DB: Backend>: + RowIndex + for<'b> RowIndex<&'b str> + for<'b> RowGatWorkaround<'b, DB> + Sized +{ /// Return type of `PartialRow` /// /// For all implementations, beside of the `Row` implementation on `PartialRow` itself @@ -49,10 +55,21 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// Get the field with the provided index from the row. /// /// Returns `None` if there is no matching field for the given index - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex; + /// Get a deserialized value with the provided index from the row. + fn get_value(&self, idx: I) -> crate::deserialize::Result + where + Self: RowIndex, + T: FromSql, + { + let field = self.get(idx).ok_or(crate::result::UnexpectedEndOfRow)?; + >::from_nullable_sql(field.value()) + } + /// Returns a wrapping row that allows only to access fields, where the index is part of /// the provided range. #[doc(hidden)] @@ -67,11 +84,11 @@ pub trait Field<'a, DB: Backend> { /// The name of the current field /// /// Returns `None` if it's an unnamed field - fn field_name(&self) -> Option<&'a str>; + fn field_name(&self) -> Option<&str>; /// Get the value representing the current field in the raw representation /// as it is transmitted by the database - fn value(&self) -> Option>; + fn value(&self) -> Option>; /// Checks whether this field is null or not. fn is_null(&self) -> bool { @@ -109,20 +126,28 @@ impl<'a, R> PartialRow<'a, R> { } } +impl<'a, 'b, DB, R> RowGatWorkaround<'a, DB> for PartialRow<'b, R> +where + DB: Backend, + R: RowGatWorkaround<'a, DB>, +{ + type Field = R::Field; +} + impl<'a, 'b, DB, R> Row<'a, DB> for PartialRow<'b, R> where DB: Backend, R: Row<'a, DB>, { - type Field = R::Field; type InnerPartialRow = R; fn field_count(&self) -> usize { self.range.len() } - fn get(&self, idx: I) -> Option + fn get<'c, I>(&'c self, idx: I) -> Option<>::Field> where + 'a: 'c, Self: RowIndex, { let idx = self.idx(idx)?; diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index b929afa57e86..1868d5eb0c49 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -509,7 +509,7 @@ pub use diesel_derives::SqlType; /// This trait is automatically implemented by [`#[derive(SqlType)]`](derive@SqlType) /// which sets `IsNull` to [`is_nullable::NotNull`] /// -pub trait SqlType { +pub trait SqlType: 'static { /// Is this type nullable? /// /// This type should always be one of the structs in the ['is_nullable`] diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 1c6af1637e82..bb7451cffa28 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -45,7 +45,7 @@ impl Backend for Sqlite { } impl<'a> HasRawValue<'a> for Sqlite { - type RawValue = SqliteValue<'a>; + type RawValue = SqliteValue<'a, 'a>; } impl TypeMetadata for Sqlite { diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index 79034c73a885..500fc52286c9 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -1,14 +1,21 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; +use super::row::PrivateSqliteRow; use super::serialized_value::SerializedValue; -use super::{Sqlite, SqliteAggregateFunction, SqliteValue}; +use super::{Sqlite, SqliteAggregateFunction}; use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; -use crate::row::{Field, PartialRow, Row, RowIndex}; +use crate::row::{Field, PartialRow, Row, RowGatWorkaround, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; +use crate::sqlite::connection::sqlite_value::OwnedSqliteValue; +use crate::sqlite::SqliteValue; +use std::cell::{Ref, RefCell}; use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; +use std::rc::Rc; pub fn register( conn: &RawConnection, @@ -85,7 +92,7 @@ where } pub(crate) fn build_sql_function_args( - args: &[*mut ffi::sqlite3_value], + args: &mut [*mut ffi::sqlite3_value], ) -> Result where Args: FromSqlRow, @@ -117,34 +124,90 @@ where }) } -#[derive(Clone)] struct FunctionRow<'a> { - args: &'a [*mut ffi::sqlite3_value], + // we use `ManuallyDrop` to prevent dropping the content of the internal vector + // as this buffer is owned by sqlite not by diesel + args: Rc>>>, + field_count: usize, + marker: PhantomData<&'a ffi::sqlite3_value>, +} + +impl<'a> Drop for FunctionRow<'a> { + fn drop(&mut self) { + if let Some(args) = Rc::get_mut(&mut self.args) { + if let PrivateSqliteRow::Duplicated { column_names, .. } = + DerefMut::deref_mut(RefCell::get_mut(args)) + { + if let Some(inner) = Rc::get_mut(column_names) { + // According the https://doc.rust-lang.org/std/mem/struct.ManuallyDrop.html#method.drop + // it's fine to just drop the values here + unsafe { std::ptr::drop_in_place(inner as *mut _) } + } + } + } + } } impl<'a> FunctionRow<'a> { - fn new(args: &'a [*mut ffi::sqlite3_value]) -> Self { - Self { args } + fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self { + let lenghts = args.len(); + let args = unsafe { + Vec::from_raw_parts( + // This cast is safe because: + // * Casting from a pointer to an array to a pointer to the first array + // element is safe + // * Casting from a raw pointer to `NonNull` is safe, + // because `NonNull` is #[repr(transparent)] + // * Casting from `NonNull` to `OwnedSqliteValue` is safe, + // as the struct is `#[repr(transparent)] + // * Casting from `NonNull` to `Option>` as the documentation + // states: "This is so that enums may use this forbidden value as a discriminant – + // Option> has the same size as *mut T" + // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)] + // guarantees the same layout as the inner type + // * It's unsafe to drop the vector (and the vector elements) + // because of this we wrap the vector (or better the Row) + // Into `ManualDrop` to prevent the dropping + args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value + as *mut Option, + lenghts, + lenghts, + ) + }; + + Self { + field_count: lenghts, + args: Rc::new(RefCell::new(ManuallyDrop::new( + PrivateSqliteRow::Duplicated { + values: args, + column_names: Rc::from(vec![None; lenghts]), + }, + ))), + marker: PhantomData, + } } } -impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { +impl<'a, 'b> RowGatWorkaround<'a, Sqlite> for FunctionRow<'b> { type Field = FunctionArgument<'a>; +} + +impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { - self.args.len() + self.field_count } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: crate::row::RowIndex, { let idx = self.idx(idx)?; - - self.args.get(idx).map(|arg| FunctionArgument { - arg: *arg, - p: PhantomData, + Some(FunctionArgument { + args: self.args.borrow(), + col_idx: idx as i32, }) } @@ -155,7 +218,7 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { impl<'a> RowIndex for FunctionRow<'a> { fn idx(&self, idx: usize) -> Option { - if idx < self.args.len() { + if idx < self.field_count() { Some(idx) } else { None @@ -170,12 +233,12 @@ impl<'a, 'b> RowIndex<&'a str> for FunctionRow<'b> { } struct FunctionArgument<'a> { - arg: *mut ffi::sqlite3_value, - p: PhantomData<&'a ()>, + args: Ref<'a, ManuallyDrop>>, + col_idx: i32, } impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { - fn field_name(&self) -> Option<&'a str> { + fn field_name(&self) -> Option<&str> { None } @@ -183,7 +246,10 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { self.value().is_none() } - fn value(&self) -> Option> { - unsafe { SqliteValue::new(self.arg) } + fn value(&self) -> Option> { + SqliteValue::new( + Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)), + self.col_idx, + ) } } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 6b7d9ac9e30a..195de331eeec 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -3,6 +3,7 @@ extern crate libsqlite3_sys as ffi; mod functions; #[doc(hidden)] pub mod raw; +mod row; mod serialized_value; mod sqlite_value; mod statement_iterator; @@ -34,6 +35,9 @@ use crate::sqlite::Sqlite; /// - Special identifiers (`:memory:`) #[allow(missing_debug_implementations)] pub struct SqliteConnection { + // statement_cache needs to be before raw_connection + // otherwise we will get errors about open statements before closing the + // connection itself statement_cache: StatementCache, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, @@ -50,6 +54,11 @@ impl SimpleConnection for SqliteConnection { } } +impl<'a> ConnectionGatWorkaround<'a, Sqlite> for SqliteConnection { + type Cursor = StatementIterator<'a>; + type Row = self::row::SqliteRow<'a>; +} + impl Connection for SqliteConnection { type Backend = Sqlite; type TransactionManager = AnsiTransactionManager; @@ -81,18 +90,19 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load( + &mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: crate::query_dsl::load_dsl::CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { - let mut statement = self.prepare_query(&source.as_query())?; - let statement_use = StatementUse::new(&mut statement, true); - let iter = StatementIterator::<_, U>::new(statement_use); - iter.collect() + let stmt = self.prepared_query(&source.as_query())?; + + let statement_use = StatementUse::new(stmt); + Ok(StatementIterator::new(statement_use)) } #[doc(hidden)] @@ -100,11 +110,11 @@ impl Connection for SqliteConnection { where T: QueryFragment + QueryId, { - { - let mut statement = self.prepare_query(source)?; - let mut statement_use = StatementUse::new(&mut statement, false); - statement_use.run()?; - } + let stmt = self.prepared_query(source)?; + + let statement_use = StatementUse::new(stmt); + statement_use.run()?; + Ok(self.raw_connection.rows_affected_by_last_query()) } @@ -193,11 +203,15 @@ impl SqliteConnection { } } - fn prepare_query + QueryId>( - &mut self, - source: &T, - ) -> QueryResult> { - let mut statement = self.cached_prepared_statement(source)?; + fn prepared_query<'a, T: QueryFragment + QueryId>( + &'a mut self, + source: &'_ T, + ) -> QueryResult> { + let raw_connection = &self.raw_connection; + let cache = &mut self.statement_cache; + let mut statement = cache.cached_statement(source, &[], |sql, is_cached| { + Statement::prepare(raw_connection, sql, is_cached) + })?; let mut bind_collector = RawBytesBindCollector::::new(); source.collect_binds(&mut bind_collector, &mut ())?; @@ -210,15 +224,6 @@ impl SqliteConnection { Ok(statement) } - fn cached_prepared_statement + QueryId>( - &mut self, - source: &T, - ) -> QueryResult> { - let raw_connection = &self.raw_connection; - let cache = &mut self.statement_cache; - cache.cached_statement(source, &[], |sql| Statement::prepare(raw_connection, sql)) - } - #[doc(hidden)] pub fn register_sql_function( &mut self, diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs index 23f6186f33ba..b65ca1511881 100644 --- a/diesel/src/sqlite/connection/raw.rs +++ b/diesel/src/sqlite/connection/raw.rs @@ -91,7 +91,7 @@ impl RawConnection { f: F, ) -> QueryResult<()> where - F: FnMut(&Self, &[*mut ffi::sqlite3_value]) -> QueryResult + F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult + std::panic::UnwindSafe + Send + 'static, @@ -269,7 +269,7 @@ extern "C" fn run_custom_function( num_args: libc::c_int, value_ptr: *mut *mut ffi::sqlite3_value, ) where - F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult + F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult + std::panic::UnwindSafe + Send + 'static, @@ -278,7 +278,6 @@ extern "C" fn run_custom_function( static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen."; static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen."; - let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) }; let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } { // We use `ManuallyDrop` here because we do not want to run the // Drop impl of `RawConnection` as this would close the connection @@ -306,13 +305,16 @@ extern "C" fn run_custom_function( // this is sound as `F` itself and the stored string is `UnwindSafe` let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback); - let result = - std::panic::catch_unwind(move || Ok((callback.0)(&*conn, args)?)).unwrap_or_else(|p| { - Err(SqliteCallbackError::Panic( - p, - data_ptr.function_name.clone(), - )) - }); + let result = std::panic::catch_unwind(move || { + let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) }; + Ok((callback.0)(&*conn, args)?) + }) + .unwrap_or_else(|p| { + Err(SqliteCallbackError::Panic( + p, + data_ptr.function_name.clone(), + )) + }); match result { Ok(value) => value.result_of(ctx), Err(e) => { @@ -342,15 +344,16 @@ extern "C" fn run_aggregator_step_function, Sqlite: HasSqlType, { - let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) }; - let result = - std::panic::catch_unwind(move || run_aggregator_step::(ctx, args)) - .unwrap_or_else(|e| { - Err(SqliteCallbackError::Panic( - e, - format!("{}::step() paniced", std::any::type_name::()), - )) - }); + let result = std::panic::catch_unwind(move || { + let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) }; + run_aggregator_step::(ctx, args) + }) + .unwrap_or_else(|e| { + Err(SqliteCallbackError::Panic( + e, + format!("{}::step() paniced", std::any::type_name::()), + )) + }); match result { Ok(()) => {} @@ -360,7 +363,7 @@ extern "C" fn run_aggregator_step_function( ctx: *mut ffi::sqlite3_context, - args: &[*mut ffi::sqlite3_value], + args: &mut [*mut ffi::sqlite3_value], ) -> Result<(), SqliteCallbackError> where A: SqliteAggregateFunction, diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs new file mode 100644 index 000000000000..52f2928e5d47 --- /dev/null +++ b/diesel/src/sqlite/connection/row.rs @@ -0,0 +1,296 @@ +use std::cell::{Ref, RefCell}; +use std::convert::TryFrom; +use std::rc::Rc; + +use super::sqlite_value::{OwnedSqliteValue, SqliteValue}; +use super::stmt::StatementUse; +use crate::row::{Field, PartialRow, Row, RowGatWorkaround, RowIndex}; +use crate::sqlite::Sqlite; +use crate::util::OnceCell; + +#[allow(missing_debug_implementations)] +pub struct SqliteRow<'a> { + pub(super) inner: Rc>>, + pub(super) field_count: usize, +} + +pub(super) enum PrivateSqliteRow<'a> { + Direct(StatementUse<'a>), + Duplicated { + values: Vec>, + column_names: Rc<[Option]>, + }, + TemporaryEmpty, +} + +impl<'a> PrivateSqliteRow<'a> { + pub(super) fn duplicate(&mut self, column_names: &mut Option]>>) -> Self { + match self { + PrivateSqliteRow::Direct(stmt) => { + let column_names = if let Some(column_names) = column_names { + column_names.clone() + } else { + let c: Rc<[Option]> = Rc::from( + (0..stmt.column_count()) + .map(|idx| stmt.field_name(idx).map(|s| s.to_owned())) + .collect::>(), + ); + *column_names = Some(c.clone()); + c + }; + PrivateSqliteRow::Duplicated { + values: (0..stmt.column_count()) + .map(|idx| stmt.copy_value(idx)) + .collect(), + column_names, + } + } + PrivateSqliteRow::Duplicated { + values, + column_names, + } => PrivateSqliteRow::Duplicated { + values: values + .iter() + .map(|v| v.as_ref().map(|v| v.duplicate())) + .collect(), + column_names: column_names.clone(), + }, + PrivateSqliteRow::TemporaryEmpty => PrivateSqliteRow::TemporaryEmpty, + } + } +} + +impl<'a, 'b> RowGatWorkaround<'a, Sqlite> for SqliteRow<'b> { + type Field = SqliteField<'a>; +} + +impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + self.field_count + } + + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + where + 'a: 'b, + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(SqliteField { + row: self.inner.borrow(), + col_idx: i32::try_from(idx).ok()?, + field_name: OnceCell::new(), + }) + } + + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) + } +} + +impl<'a> RowIndex for SqliteRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count { + Some(idx) + } else { + None + } + } +} + +impl<'a, 'd> RowIndex<&'d str> for SqliteRow<'a> { + fn idx(&self, field_name: &'d str) -> Option { + match &mut *self.inner.borrow_mut() { + PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name), + PrivateSqliteRow::Duplicated { column_names, .. } => column_names + .iter() + .position(|n| n.as_ref().map(|s| s as &str) == Some(field_name)), + PrivateSqliteRow::TemporaryEmpty => { + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } + } +} + +#[allow(missing_debug_implementations)] +pub struct SqliteField<'a> { + pub(super) row: Ref<'a, PrivateSqliteRow<'a>>, + pub(super) col_idx: i32, + field_name: OnceCell>, +} + +impl<'a> Field<'a, Sqlite> for SqliteField<'a> { + fn field_name(&self) -> Option<&str> { + self.field_name + .get_or_init(|| match &*self.row { + PrivateSqliteRow::Direct(stmt) => { + let column_name = unsafe { + // This is safe due to the fact that we + // move the column name to an allocation + // on rust side as soon as possible + // + // We cannot index into a non existing column here, because + // we checked that before even constructing the corresponding + // field + let column_name = stmt.column_name(self.col_idx); + (&*column_name).to_owned() + }; + Some(column_name) + } + PrivateSqliteRow::Duplicated { column_names, .. } => column_names + .get(self.col_idx as usize) + .and_then(|n| n.clone()), + PrivateSqliteRow::TemporaryEmpty => { + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + }) + .as_ref() + .map(|s| s as &str) + } + + fn is_null(&self) -> bool { + self.value().is_none() + } + + fn value(&self) -> Option> { + SqliteValue::new(Ref::clone(&self.row), self.col_idx) + } +} + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } + } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);") + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_fields); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_fields); + + assert!(row_iter.next().is_none()); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + + let first_values = (first_fields.0.value(), first_fields.1.value()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); +} diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 8136d9087acd..67dc4d4eff75 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -1,55 +1,85 @@ extern crate libsqlite3_sys as ffi; -use std::marker::PhantomData; +use std::cell::Ref; use std::ptr::NonNull; use std::{slice, str}; -use crate::row::*; -use crate::sqlite::{Sqlite, SqliteType}; +use crate::sqlite::SqliteType; -use super::stmt::StatementUse; +use super::row::PrivateSqliteRow; /// Raw sqlite value as received from the database /// /// Use existing `FromSql` implementations to convert this into -/// rust values: +/// rust values #[allow(missing_debug_implementations, missing_copy_implementations)] -pub struct SqliteValue<'a> { +pub struct SqliteValue<'a, 'b> { + // This field exists to ensure that nobody + // can modify the underlying row while we are + // holding a reference to some row value here + _row: Ref<'a, PrivateSqliteRow<'b>>, + // we extract the raw value pointer as part of the constructor + // to safe the match statements for each method + // Acconding to benchmarks this leads to a ~20-30% speedup + // + // This is sound as long as nobody calls `stmt.step()` + // while holding this value. We ensure this by including + // a reference to the row above. value: NonNull, - p: PhantomData<&'a ()>, } -pub struct SqliteRow<'a: 'b, 'b: 'c, 'c> { - stmt: &'c StatementUse<'a, 'b>, +#[repr(transparent)] +pub struct OwnedSqliteValue { + pub(super) value: NonNull, } -impl<'a> SqliteValue<'a> { - pub(crate) unsafe fn new(inner: *mut ffi::sqlite3_value) -> Option { - NonNull::new(inner) - .map(|value| SqliteValue { - value, - p: PhantomData, - }) - .and_then(|value| { - // We check here that the actual value represented by the inner - // `sqlite3_value` is not `NULL` (is sql meaning, not ptr meaning) - if value.is_null() { - None - } else { - Some(value) - } - }) +impl Drop for OwnedSqliteValue { + fn drop(&mut self) { + unsafe { ffi::sqlite3_value_free(self.value.as_ptr()) } } +} - pub(crate) fn read_text(&self) -> &str { - unsafe { +impl<'a, 'b> SqliteValue<'a, 'b> { + pub(super) fn new(row: Ref<'a, PrivateSqliteRow<'b>>, col_idx: i32) -> Option { + let value = match &*row { + PrivateSqliteRow::Direct(stmt) => stmt.column_value(col_idx)?, + PrivateSqliteRow::Duplicated { values, .. } => { + values.get(col_idx as usize).and_then(|v| v.as_ref())?.value + } + PrivateSqliteRow::TemporaryEmpty => { + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + }; + + let ret = Self { _row: row, value }; + if ret.value_type().is_none() { + None + } else { + Some(ret) + } + } + + pub(crate) fn parse_string<'c, R>(&'c self, f: impl FnOnce(&'c str) -> R) -> R { + let s = unsafe { let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); // The string is guaranteed to be utf8 according to // https://www.sqlite.org/c3ref/value_blob.html str::from_utf8_unchecked(bytes) - } + }; + f(s) + } + + pub(crate) fn read_text(&self) -> &str { + self.parse_string(|s| s) } pub(crate) fn read_blob(&self) -> &[u8] { @@ -61,15 +91,15 @@ impl<'a> SqliteValue<'a> { } pub(crate) fn read_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) as i32 } + unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) } } pub(crate) fn read_long(&self) -> i64 { - unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) as i64 } + unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) } } pub(crate) fn read_double(&self) -> f64 { - unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) as f64 } + unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) } } /// Get the type of the value as returned by sqlite @@ -81,78 +111,35 @@ impl<'a> SqliteValue<'a> { ffi::SQLITE_FLOAT => Some(SqliteType::Double), ffi::SQLITE_BLOB => Some(SqliteType::Binary), ffi::SQLITE_NULL => None, - _ => unreachable!("Sqlite docs saying this is not reachable"), + _ => unreachable!( + "Sqlite's documentation state that this case ({}) is not reachable. \ + If you ever see this error message please open an issue at \ + https://github.com/diesel-rs/diesel." + ), } } - - pub(crate) fn is_null(&self) -> bool { - self.value_type().is_none() - } } -impl<'a: 'b, 'b: 'c, 'c> SqliteRow<'a, 'b, 'c> { - pub(crate) fn new(inner_statement: &'c StatementUse<'a, 'b>) -> Self { - SqliteRow { - stmt: inner_statement, +impl OwnedSqliteValue { + pub(super) fn copy_from_ptr(ptr: NonNull) -> Option { + let tpe = unsafe { ffi::sqlite3_value_type(ptr.as_ptr()) }; + if ffi::SQLITE_NULL == tpe { + return None; } - } -} - -impl<'a: 'b, 'b: 'c, 'c> Row<'c, Sqlite> for SqliteRow<'a, 'b, 'c> { - type Field = SqliteField<'a, 'b, 'c>; - type InnerPartialRow = Self; - - fn field_count(&self) -> usize { - self.stmt.column_count() as usize - } - - fn get(&self, idx: I) -> Option - where - Self: RowIndex, - { - let idx = self.idx(idx)?; - Some(SqliteField { - stmt: &self.stmt, - col_idx: idx as i32, + let value = unsafe { ffi::sqlite3_value_dup(ptr.as_ptr()) }; + Some(Self { + value: NonNull::new(value)?, }) } - fn partial_row(&self, range: std::ops::Range) -> PartialRow { - PartialRow::new(self, range) - } -} - -impl<'a: 'b, 'b: 'c, 'c> RowIndex for SqliteRow<'a, 'b, 'c> { - fn idx(&self, idx: usize) -> Option { - if idx < self.stmt.column_count() as usize { - Some(idx) - } else { - None - } - } -} - -impl<'a: 'b, 'b: 'c, 'c, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b, 'c> { - fn idx(&self, field_name: &'d str) -> Option { - self.stmt.index_for_column_name(field_name) - } -} - -pub struct SqliteField<'a: 'b, 'b: 'c, 'c> { - stmt: &'c StatementUse<'a, 'b>, - col_idx: i32, -} - -impl<'a: 'b, 'b: 'c, 'c> Field<'c, Sqlite> for SqliteField<'a, 'b, 'c> { - fn field_name(&self) -> Option<&'c str> { - self.stmt.field_name(self.col_idx) - } - - fn is_null(&self) -> bool { - self.value().is_none() - } - - fn value(&self) -> Option> { - self.stmt.value(self.col_idx) + pub(super) fn duplicate(&self) -> OwnedSqliteValue { + // self.value is a `NonNull` ptr so this cannot be null + let value = unsafe { ffi::sqlite3_value_dup(self.value.as_ptr()) }; + let value = NonNull::new(value).expect( + "Sqlite documentation states this returns only null if value is null \ + or OOM. If you ever see this panic message please open an issue at \ + https://github.com/diesel-rs/diesel.", + ); + OwnedSqliteValue { value } } } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 1330d120b408..ae51cbe52a64 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -1,36 +1,138 @@ -use std::marker::PhantomData; +use std::cell::RefCell; +use std::rc::Rc; +use super::row::{PrivateSqliteRow, SqliteRow}; use super::stmt::StatementUse; -use crate::deserialize::FromSqlRow; -use crate::result::Error::DeserializationError; use crate::result::QueryResult; -use crate::sqlite::Sqlite; -pub struct StatementIterator<'a: 'b, 'b, ST, T> { - stmt: StatementUse<'a, 'b>, - _marker: PhantomData<(ST, T)>, +#[allow(missing_debug_implementations)] +pub struct StatementIterator<'a> { + inner: PrivateStatementIterator<'a>, + column_names: Option]>>, + field_count: usize, } -impl<'a: 'b, 'b, ST, T> StatementIterator<'a, 'b, ST, T> { - pub fn new(stmt: StatementUse<'a, 'b>) -> Self { - StatementIterator { - stmt, - _marker: PhantomData, +enum PrivateStatementIterator<'a> { + NotStarted(StatementUse<'a>), + Started(Rc>>), + TemporaryEmpty, +} + +impl<'a> StatementIterator<'a> { + pub fn new(stmt: StatementUse<'a>) -> Self { + Self { + inner: PrivateStatementIterator::NotStarted(stmt), + column_names: None, + field_count: 0, } } } -impl<'a: 'b, 'b, ST, T> Iterator for StatementIterator<'a, 'b, ST, T> -where - T: FromSqlRow, -{ - type Item = QueryResult; +impl<'a> Iterator for StatementIterator<'a> { + type Item = QueryResult>; fn next(&mut self) -> Option { - let row = match self.stmt.step() { - Ok(row) => row, - Err(e) => return Some(Err(e)), - }; - row.map(|row| T::build_from_row(&row).map_err(DeserializationError)) + use PrivateStatementIterator::{NotStarted, Started, TemporaryEmpty}; + + match std::mem::replace(&mut self.inner, TemporaryEmpty) { + NotStarted(stmt) => match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(stmt)) => { + let field_count = stmt.column_count() as usize; + self.field_count = field_count; + let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + self.inner = Started(inner.clone()); + Some(Ok(SqliteRow { inner, field_count })) + } + }, + Started(mut last_row) => { + // There was already at least one iteration step + // We check here if the caller already released the row value or not + // by checking if our Rc owns the data or not + if let Some(last_row_ref) = Rc::get_mut(&mut last_row) { + // We own the statement, there is no other reference here. + // This means we don't need to copy out values from the sqlite provided + // datastructures for now + // We don't need to use the runtime borrowing system of the RefCell here + // as we have a mutable reference, so all of this below is checked at compile time + if let PrivateSqliteRow::Direct(stmt) = + std::mem::replace(last_row_ref.get_mut(), PrivateSqliteRow::TemporaryEmpty) + { + match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(stmt)) => { + let field_count = self.field_count; + (*last_row_ref.get_mut()) = PrivateSqliteRow::Direct(stmt); + self.inner = Started(last_row.clone()); + Some(Ok(SqliteRow { + inner: last_row, + field_count, + })) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } else { + // We don't own the statement. There is another existing reference, likly because + // a user stored the row in some long time container before calling next another time + // In this case we copy out the current values into a temporary store and advance + // the statement iterator internally afterwards + let last_row = { + let mut last_row = match last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + self.inner = Started(last_row.clone()); + return Some(Err(crate::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(&mut self.column_names); + std::mem::replace(last_row, duplicated) + }; + if let PrivateSqliteRow::Direct(stmt) = last_row { + match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(stmt)) => { + let field_count = self.field_count; + let last_row = + Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + self.inner = Started(last_row.clone()); + Some(Ok(SqliteRow { + inner: last_row, + field_count, + })) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } + } + TemporaryEmpty => None, + } } } diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index a8625e04892b..b3d284d39894 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -1,17 +1,17 @@ extern crate libsqlite3_sys as ffi; -use std::ffi::{CStr, CString}; -use std::io::{stderr, Write}; -use std::os::raw as libc; -use std::ptr::{self, NonNull}; - use super::raw::RawConnection; use super::serialized_value::SerializedValue; -use super::sqlite_value::SqliteRow; -use super::SqliteValue; +use super::sqlite_value::OwnedSqliteValue; +use crate::connection::{MaybeCached, PrepareForCache}; use crate::result::Error::DatabaseError; use crate::result::*; use crate::sqlite::SqliteType; +use crate::util::OnceCell; +use std::ffi::{CStr, CString}; +use std::io::{stderr, Write}; +use std::os::raw as libc; +use std::ptr::{self, NonNull}; pub struct Statement { inner_statement: NonNull, @@ -19,14 +19,23 @@ pub struct Statement { } impl Statement { - pub fn prepare(raw_connection: &RawConnection, sql: &str) -> QueryResult { + pub fn prepare( + raw_connection: &RawConnection, + sql: &str, + is_cached: PrepareForCache, + ) -> QueryResult { let mut stmt = ptr::null_mut(); let mut unused_portion = ptr::null(); let prepare_result = unsafe { - ffi::sqlite3_prepare_v2( + ffi::sqlite3_prepare_v3( raw_connection.internal_connection.as_ptr(), CString::new(sql)?.as_ptr(), sql.len() as libc::c_int, + if matches!(is_cached, PrepareForCache::Yes) { + ffi::SQLITE_PREPARE_PERSISTENT as u32 + } else { + 0 + }, &mut stmt, &mut unused_portion, ) @@ -117,36 +126,25 @@ impl Drop for Statement { } } -pub struct StatementUse<'a: 'b, 'b> { - statement: &'a mut Statement, - column_names: Vec<&'b str>, - should_init_column_names: bool, +#[allow(missing_debug_implementations)] +pub struct StatementUse<'a> { + statement: MaybeCached<'a, Statement>, + column_names: OnceCell>, } -impl<'a, 'b> StatementUse<'a, 'b> { - pub(in crate::sqlite::connection) fn new( - statement: &'a mut Statement, - should_init_column_names: bool, - ) -> Self { +impl<'a> StatementUse<'a> { + pub(in crate::sqlite::connection) fn new(statement: MaybeCached<'a, Statement>) -> Self { StatementUse { statement, - // Init with empty vector because column names - // can change till the first call to `step()` - column_names: Vec::new(), - should_init_column_names, + column_names: OnceCell::new(), } } - pub(in crate::sqlite::connection) fn run(&mut self) -> QueryResult<()> { + pub(in crate::sqlite::connection) fn run(self) -> QueryResult<()> { self.step().map(|_| ()) } - pub(in crate::sqlite::connection) fn step<'c>( - &'c mut self, - ) -> QueryResult>> - where - 'b: 'c, - { + pub(in crate::sqlite::connection) fn step(self) -> QueryResult> { let res = unsafe { match ffi::sqlite3_step(self.statement.inner_statement.as_ptr()) { ffi::SQLITE_DONE => Ok(None), @@ -154,13 +152,7 @@ impl<'a, 'b> StatementUse<'a, 'b> { _ => Err(last_error(self.statement.raw_connection())), } }?; - if self.should_init_column_names { - self.column_names = (0..self.column_count()) - .map(|idx| unsafe { self.column_name(idx) }) - .collect(); - self.should_init_column_names = false; - } - Ok(res.map(move |()| SqliteRow::new(self))) + Ok(res.map(move |()| self)) } // The returned string pointer is valid until either the prepared statement is @@ -170,10 +162,7 @@ impl<'a, 'b> StatementUse<'a, 'b> { // on the same column. // // https://sqlite.org/c3ref/column_name.html - // - // As result of this requirements: Never use that function outside of `ColumnInformation` - // and never use `ColumnInformation` outside of `StatementUse` - unsafe fn column_name(&mut self, idx: i32) -> &'b str { + pub(super) unsafe fn column_name(&self, idx: i32) -> *const str { let name = { let column_name = ffi::sqlite3_column_name(self.statement.inner_statement.as_ptr(), idx); @@ -189,46 +178,45 @@ impl<'a, 'b> StatementUse<'a, 'b> { If you see this error message something has gone \ horribliy wrong. Please open an issue at the \ diesel repository.", - ) + ) as *const str } - pub(in crate::sqlite::connection) fn column_count(&self) -> i32 { + pub(super) fn column_count(&self) -> i32 { unsafe { ffi::sqlite3_column_count(self.statement.inner_statement.as_ptr()) } } - pub(in crate::sqlite::connection) fn index_for_column_name( - &self, - field_name: &str, - ) -> Option { - self.column_names - .iter() - .enumerate() - .find(|(_, name)| name == &&field_name) - .map(|(idx, _)| idx) - } - - pub(in crate::sqlite::connection) fn field_name<'c>(&'c self, idx: i32) -> Option<&'c str> - where - 'b: 'c, - { - self.column_names.get(idx as usize).copied() - } - - pub(in crate::sqlite::connection) fn value<'c>( - &'c self, - idx: i32, - ) -> Option> - where - 'b: 'c, - { - unsafe { - let ptr = ffi::sqlite3_column_value(self.statement.inner_statement.as_ptr(), idx); - SqliteValue::new(ptr) + pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option { + (0..self.column_count()) + .find(|idx| self.field_name(*idx) == Some(field_name)) + .map(|v| v as usize) + } + + pub(super) fn field_name(&mut self, idx: i32) -> Option<&str> { + if let Some(column_names) = self.column_names.get() { + return column_names + .get(idx as usize) + .and_then(|c| unsafe { c.as_ref() }); } + let values = (0..self.column_count()) + .map(|idx| unsafe { self.column_name(idx) }) + .collect::>(); + let ret = values.get(idx as usize).copied(); + let _ = self.column_names.set(values); + ret.and_then(|p| unsafe { p.as_ref() }) + } + + pub(super) fn copy_value(&self, idx: i32) -> Option { + OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?) + } + + pub(super) fn column_value(&self, idx: i32) -> Option> { + let ptr = + unsafe { ffi::sqlite3_column_value(self.statement.inner_statement.as_ptr(), idx) }; + NonNull::new(ptr) } } -impl<'a, 'b> Drop for StatementUse<'a, 'b> { +impl<'a> Drop for StatementUse<'a> { fn drop(&mut self) { self.statement.reset(); } diff --git a/diesel/src/sqlite/types/date_and_time/chrono.rs b/diesel/src/sqlite/types/date_and_time/chrono.rs index 41136050e067..fa8867a031cc 100644 --- a/diesel/src/sqlite/types/date_and_time/chrono.rs +++ b/diesel/src/sqlite/types/date_and_time/chrono.rs @@ -13,9 +13,9 @@ const SQLITE_DATE_FORMAT: &str = "%F"; impl FromSql for NaiveDate { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - Self::parse_from_str(text, SQLITE_DATE_FORMAT).map_err(Into::into) + value + .parse_string(|s| Self::parse_from_str(s, SQLITE_DATE_FORMAT)) + .map_err(Into::into) } } @@ -28,21 +28,21 @@ impl ToSql for NaiveDate { impl FromSql for NaiveTime { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - let valid_time_formats = &[ - // Most likely - "%T%.f", // All other valid formats in order of documentation - "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", - ]; - - for format in valid_time_formats { - if let Ok(time) = Self::parse_from_str(text, format) { - return Ok(time); + value.parse_string(|text| { + let valid_time_formats = &[ + // Most likely + "%T%.f", // All other valid formats in order of documentation + "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", + ]; + + for format in valid_time_formats { + if let Ok(time) = Self::parse_from_str(text, format) { + return Ok(time); + } } - } - Err(format!("Invalid time {}", text).into()) + Err(format!("Invalid time {}", text).into()) + }) } } @@ -55,44 +55,43 @@ impl ToSql for NaiveTime { impl FromSql for NaiveDateTime { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - - let sqlite_datetime_formats = &[ - // Most likely format - "%F %T%.f", - // Other formats in order of appearance in docs - "%F %R", - "%F %RZ", - "%F %R%:z", - "%F %T%.fZ", - "%F %T%.f%:z", - "%FT%R", - "%FT%RZ", - "%FT%R%:z", - "%FT%T%.f", - "%FT%T%.fZ", - "%FT%T%.f%:z", - ]; - - for format in sqlite_datetime_formats { - if let Ok(dt) = Self::parse_from_str(text, format) { - return Ok(dt); + value.parse_string(|text| { + let sqlite_datetime_formats = &[ + // Most likely format + "%F %T%.f", + // Other formats in order of appearance in docs + "%F %R", + "%F %RZ", + "%F %R%:z", + "%F %T%.fZ", + "%F %T%.f%:z", + "%FT%R", + "%FT%RZ", + "%FT%R%:z", + "%FT%T%.f", + "%FT%T%.fZ", + "%FT%T%.f%:z", + ]; + + for format in sqlite_datetime_formats { + if let Ok(dt) = Self::parse_from_str(text, format) { + return Ok(dt); + } } - } - if let Ok(julian_days) = text.parse::() { - let epoch_in_julian_days = 2_440_587.5; - let seconds_in_day = 86400.0; - let timestamp = (julian_days - epoch_in_julian_days) * seconds_in_day; - let seconds = timestamp as i64; - let nanos = (timestamp.fract() * 1E9) as u32; - if let Some(timestamp) = Self::from_timestamp_opt(seconds, nanos) { - return Ok(timestamp); + if let Ok(julian_days) = text.parse::() { + let epoch_in_julian_days = 2_440_587.5; + let seconds_in_day = 86400.0; + let timestamp = (julian_days - epoch_in_julian_days) * seconds_in_day; + let seconds = timestamp as i64; + let nanos = (timestamp.fract() * 1E9) as u32; + if let Some(timestamp) = Self::from_timestamp_opt(seconds, nanos) { + return Ok(timestamp); + } } - } - Err(format!("Invalid datetime {}", text).into()) + Err(format!("Invalid datetime {}", text).into()) + }) } } diff --git a/diesel/src/sqlite/types/date_and_time/mod.rs b/diesel/src/sqlite/types/date_and_time/mod.rs index 18bdd45713f0..563a68ef83d6 100644 --- a/diesel/src/sqlite/types/date_and_time/mod.rs +++ b/diesel/src/sqlite/types/date_and_time/mod.rs @@ -9,13 +9,8 @@ use crate::sqlite::Sqlite; #[cfg(feature = "chrono")] mod chrono; -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -32,13 +27,8 @@ impl ToSql for String { } } -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -55,13 +45,8 @@ impl ToSql for String { } } -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index 326fc40bc5fe..3d96ae2c6206 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -15,7 +15,7 @@ use crate::sql_types; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { let text = value.read_text(); Ok(text as *const _) } @@ -27,44 +27,44 @@ impl FromSql for *const str { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const [u8] { - fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(bytes: SqliteValue<'_, '_>) -> deserialize::Result { let bytes = bytes.read_blob(); Ok(bytes as *const _) } } impl FromSql for i16 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer() as i16) } } impl FromSql for i32 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer()) } } impl FromSql for bool { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer() != 0) } } impl FromSql for i64 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_long()) } } impl FromSql for f32 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_double() as f32) } } impl FromSql for f64 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_double()) } } diff --git a/diesel/src/sqlite/types/numeric.rs b/diesel/src/sqlite/types/numeric.rs index b3dc7aa99d23..aaf2683544dc 100644 --- a/diesel/src/sqlite/types/numeric.rs +++ b/diesel/src/sqlite/types/numeric.rs @@ -8,7 +8,7 @@ use crate::sqlite::connection::SqliteValue; use crate::sqlite::Sqlite; impl FromSql for BigDecimal { - fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(bytes: SqliteValue<'_, '_>) -> deserialize::Result { let x = >::from_sql(bytes)?; BigDecimal::from_f64(x).ok_or_else(|| format!("{} is not valid decimal number ", x).into()) } diff --git a/diesel/src/util.rs b/diesel/src/util.rs index ed0439d49ed5..5bf0d9024540 100644 --- a/diesel/src/util.rs +++ b/diesel/src/util.rs @@ -9,3 +9,9 @@ pub trait TupleAppend { pub trait TupleSize { const SIZE: usize; } + +#[cfg(any(feature = "postgres", feature = "sqlite"))] +mod once_cell; + +#[cfg(any(feature = "postgres", feature = "sqlite"))] +pub(crate) use self::once_cell::OnceCell; diff --git a/diesel/src/util/once_cell.rs b/diesel/src/util/once_cell.rs new file mode 100644 index 000000000000..cb55e81e193e --- /dev/null +++ b/diesel/src/util/once_cell.rs @@ -0,0 +1,110 @@ +// This is a copy of the unstable `OnceCell` implementation in rusts std-library +// https://github.com/rust-lang/rust/blob/1160cf864f2a0014e3442367e1b96496bfbeadf4/library/core/src/lazy.rs#L8-L276 +// +// See https://github.com/rust-lang/rust/issues/74465 for the corresponding tracking issue + +use std::cell::UnsafeCell; + +/// A cell which can be written to only once. +/// +/// Unlike `RefCell`, a `OnceCell` only provides shared `&T` references to its value. +/// Unlike `Cell`, a `OnceCell` doesn't require copying or replacing the value to access it. +/// +/// # Examples +/// +/// ```ignore +/// +/// use crate::util::OnceCell; +/// +/// let cell = OnceCell::new(); +/// assert!(cell.get().is_none()); +/// +/// let value: &String = cell.get_or_init(|| { +/// "Hello, World!".to_string() +/// }); +/// assert_eq!(value, "Hello, World!"); +/// assert!(cell.get().is_some()); +/// ``` +pub(crate) struct OnceCell { + // Invariant: written to at most once. + inner: UnsafeCell>, +} + +impl Default for OnceCell { + fn default() -> Self { + Self::new() + } +} + +impl OnceCell { + /// Creates a new empty cell. + pub(crate) const fn new() -> OnceCell { + OnceCell { + inner: UnsafeCell::new(None), + } + } + + /// Gets the contents of the cell, initializing it with `f` if + /// the cell was empty. If the cell was empty and `f` failed, an + /// error is returned. + /// + /// # Panics + /// + /// If `f` panics, the panic is propagated to the caller, and the cell + /// remains uninitialized. + /// + /// It is an error to reentrantly initialize the cell from `f`. Doing + /// so results in a panic. + /// + /// # Examples + /// + /// ```ignore + /// + /// use crate::util::OnceCell; + /// + /// let cell = OnceCell::new(); + /// assert_eq!(cell.get_or_try_init(|| Err(())), Err(())); + /// assert!(cell.get().is_none()); + /// let value = cell.get_or_try_init(|| -> Result { + /// Ok(92) + /// }); + /// assert_eq!(value, Ok(&92)); + /// assert_eq!(cell.get(), Some(&92)) + /// ``` + pub(crate) fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> T, + { + if let Some(val) = self.get() { + return val; + } + let val = f(); + // Note that *some* forms of reentrant initialization might lead to + // UB (see `reentrant_init` test). I believe that just removing this + // `assert`, while keeping `set/get` would be sound, but it seems + // better to panic, rather than to silently use an old value. + assert!(self.set(val).is_ok(), "reentrant init"); + self.get().unwrap() + } + + pub(crate) fn get(&self) -> Option<&T> { + // SAFETY: Safe due to `inner`'s invariant + unsafe { &*self.inner.get() }.as_ref() + } + + pub(crate) fn set(&self, value: T) -> Result<(), T> { + // SAFETY: Safe because we cannot have overlapping mutable borrows + let slot = unsafe { &*self.inner.get() }; + if slot.is_some() { + return Err(value); + } + + // SAFETY: This is the only place where we set the slot, no races + // due to reentrancy/concurrency are possible, and we've + // checked that slot is currently `None`, so this write + // maintains the `inner`'s invariant. + let slot = unsafe { &mut *self.inner.get() }; + *slot = Some(value); + Ok(()) + } +} diff --git a/diesel_bench/Cargo.toml b/diesel_bench/Cargo.toml index 9c0a252c8da5..5497918fcbba 100644 --- a/diesel_bench/Cargo.toml +++ b/diesel_bench/Cargo.toml @@ -7,7 +7,6 @@ build = "build.rs" autobenches = false [workspace] - # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 698f604eb285..841a269ba9db 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -203,11 +203,11 @@ where columns::column_name, >: QueryFragment, Conn::Backend: QueryMetadata<( - sql_types::Text, - sql_types::Text, - sql_types::Nullable, - sql_types::Text, - )>, + sql_types::Text, + sql_types::Text, + sql_types::Nullable, + sql_types::Text, + )> + 'static, { use self::information_schema::columns::dsl::*; @@ -252,7 +252,7 @@ where >, key_column_usage::ordinal_position, >: QueryFragment, - Conn::Backend: QueryMetadata, + Conn::Backend: QueryMetadata + 'static, { use self::information_schema::key_column_usage::dsl::*; use self::information_schema::table_constraints::constraint_type; @@ -281,7 +281,7 @@ pub fn load_table_names<'a, Conn>( ) -> Result, Box> where Conn: Connection, - Conn::Backend: UsesInformationSchema, + Conn::Backend: UsesInformationSchema + 'static, String: FromSql, Filter< Filter< diff --git a/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr b/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr index 9bca051f805b..927edf363d1a 100644 --- a/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr +++ b/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr @@ -66,18 +66,18 @@ error[E0277]: the trait bound `f64: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` -error[E0277]: the trait bound `f64: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `f64: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_correct_type.rs:9:33 | 9 | select(array((1f64, 3f64))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `f64` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `f64` | - = note: required because of the requirements on the impl of `QueryFragment<_>` for `(f64, f64)` + = note: required because of the requirements on the impl of `QueryFragment` for `(f64, f64)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<(f64, f64), diesel::sql_types::Integer>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<(f64, f64), diesel::sql_types::Integer>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` error[E0277]: the trait bound `f64: diesel::Expression` is not satisfied diff --git a/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr b/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr index c60d03994aeb..6a9335579ad8 100644 --- a/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr +++ b/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr @@ -66,18 +66,18 @@ error[E0277]: the trait bound `f64: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` -error[E0277]: the trait bound `f64: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `f64: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_same_type.rs:11:30 | 11 | select(array((1, 3f64))).get_result::>(&mut connection).unwrap(); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `f64` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `f64` | - = note: required because of the requirements on the impl of `QueryFragment<_>` for `(diesel::expression::bound::Bound, f64)` + = note: required because of the requirements on the impl of `QueryFragment` for `(diesel::expression::bound::Bound, f64)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<(diesel::expression::bound::Bound, f64), diesel::sql_types::Integer>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<(diesel::expression::bound::Bound, f64), diesel::sql_types::Integer>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` error[E0277]: the trait bound `f64: diesel::Expression` is not satisfied @@ -192,11 +192,11 @@ error[E0277]: the trait bound `{integer}: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` -error[E0277]: the trait bound `{integer}: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `{integer}: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_same_type.rs:12:30 | 12 | select(array((1, 3f64))).get_result::>(&mut connection).unwrap(); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `{integer}` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `{integer}` | = help: the following implementations were found: <&'a T as QueryFragment> @@ -204,12 +204,12 @@ error[E0277]: the trait bound `{integer}: QueryFragment<_>` is not satisfied <(A, B) as QueryFragment<__DB>> <(A, B, C) as QueryFragment<__DB>> and 223 others - = note: required because of the requirements on the impl of `QueryFragment<_>` for `({integer}, diesel::expression::bound::Bound)` + = note: required because of the requirements on the impl of `QueryFragment` for `({integer}, diesel::expression::bound::Bound)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<({integer}, diesel::expression::bound::Bound), diesel::sql_types::Double>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<({integer}, diesel::expression::bound::Bound), diesel::sql_types::Double>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` error[E0277]: the trait bound `{integer}: diesel::Expression` is not satisfied diff --git a/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr b/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr index 2a425d396d30..309e78588330 100644 --- a/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr +++ b/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr @@ -6,21 +6,6 @@ error[E0271]: type mismatch resolving `>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` -error[E0277]: the trait bound `Sqlite: HasSqlType>` is not satisfied - --> $DIR/array_only_usable_with_pg.rs:8:25 - | -8 | select(array((1,))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `HasSqlType>` is not implemented for `Sqlite` - | - = help: the following implementations were found: - > - > - > - > - and 8 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` - error[E0271]: type mismatch resolving `::Backend == Pg` --> $DIR/array_only_usable_with_pg.rs:11:25 | @@ -28,18 +13,3 @@ error[E0271]: type mismatch resolving `>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` - -error[E0277]: the trait bound `Mysql: HasSqlType>` is not satisfied - --> $DIR/array_only_usable_with_pg.rs:11:25 - | -11 | select(array((1,))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `HasSqlType>` is not implemented for `Mysql` - | - = help: the following implementations were found: - > - > - > - > - and 15 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Mysql` - = note: required because of the requirements on the impl of `LoadQuery>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` diff --git a/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr b/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr index ca17b5993493..7deae98b5423 100644 --- a/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr +++ b/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr @@ -22,9 +22,8 @@ error[E0277]: the trait bound `*const str: FromSql> <*const [u8] as FromSql> - <*const str as FromSql> <*const str as FromSql> - and 3 others + <*const str as FromSql> = note: required because of the requirements on the impl of `FromSql` for `std::string::String` = note: required because of the requirements on the impl of `Queryable` for `std::string::String` = note: required because of the requirements on the impl of `FromSqlRow` for `std::string::String` diff --git a/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr b/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr index 33aa2fc9a9ed..27e3009a1ea3 100644 --- a/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr +++ b/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr @@ -7,9 +7,8 @@ error[E0277]: the trait bound `*const str: FromSql` is not satisfied = help: the following implementations were found: <*const [u8] as FromSql> <*const [u8] as FromSql> - <*const str as FromSql> <*const str as FromSql> - and 3 others + <*const str as FromSql> = note: required because of the requirements on the impl of `FromSql` for `std::string::String` = note: required because of the requirements on the impl of `Queryable` for `std::string::String` = note: required because of the requirements on the impl of `FromSqlRow` for `std::string::String` diff --git a/diesel_compile_tests/tests/fail/selectable.stderr b/diesel_compile_tests/tests/fail/selectable.stderr index 1b175a43590f..555672aa6411 100644 --- a/diesel_compile_tests/tests/fail/selectable.stderr +++ b/diesel_compile_tests/tests/fail/selectable.stderr @@ -603,27 +603,3 @@ error[E0271]: type mismatch resolving `` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` - -error[E0277]: the trait bound `Sqlite: HasSqlType>` is not satisfied - --> $DIR/selectable.rs:210:10 - | -210 | .load(&mut conn) - | ^^^^ the trait `HasSqlType>` is not implemented for `Sqlite` - | - = help: the following implementations were found: - > - > - > - > - and 8 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` - -error[E0277]: the trait bound `SelectBy: SingleValue` is not satisfied - --> $DIR/selectable.rs:210:10 - | -210 | .load(&mut conn) - | ^^^^ the trait `SingleValue` is not implemented for `SelectBy` - | - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.rs b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.rs new file mode 100644 index 000000000000..e30f92909815 --- /dev/null +++ b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.rs @@ -0,0 +1,32 @@ +extern crate diesel; + +use diesel::prelude::*; +use diesel::sql_query; + +fn main() { + let conn = &mut SqliteConnection::establish("foo").unwrap(); + // For sqlite the returned iterator is coupled to + // a statement, which is coupled to the connection itself + // so we cannot have more than one iterator + // for the same connection + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + + let conn = &mut MysqlConnection::establish("foo").unwrap(); + // The same argument applies to mysql + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + + let conn = &mut PgConnection::establish("foo").unwrap(); + // It works for PgConnection as the result is not related to the + // connection in any way + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + +} diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr new file mode 100644 index 000000000000..c3fa7217fb21 --- /dev/null +++ b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr @@ -0,0 +1,21 @@ +error[E0499]: cannot borrow `*conn` as mutable more than once at a time + --> $DIR/sqlite_and_mysql_do_not_allow_multiple_iterators.rs:13:21 + | +12 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + | ---- first mutable borrow occurs here +13 | let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + | ^^^^ second mutable borrow occurs here +14 | +15 | let _ = row_iter1.zip(row_iter2); + | --------- first borrow later used here + +error[E0499]: cannot borrow `*conn` as mutable more than once at a time + --> $DIR/sqlite_and_mysql_do_not_allow_multiple_iterators.rs:20:21 + | +19 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + | ---- first mutable borrow occurs here +20 | let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + | ^^^^ second mutable borrow occurs here +21 | +22 | let _ = row_iter1.zip(row_iter2); + | --------- first borrow later used here diff --git a/diesel_tests/tests/deserialization.rs b/diesel_tests/tests/deserialization.rs index 380c701605bd..ecc630de3948 100644 --- a/diesel_tests/tests/deserialization.rs +++ b/diesel_tests/tests/deserialization.rs @@ -1,5 +1,5 @@ use crate::schema::*; -use diesel::*; +use diesel::prelude::*; use std::borrow::Cow; #[derive(Queryable, PartialEq, Debug, Selectable)] diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index d36d675b6117..524bd1fdb726 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -1241,8 +1241,9 @@ fn third_party_crates_can_add_new_types() { assert_eq!(70_000, query_single_value::("70000")); } -fn query_single_value>(sql_str: &str) -> U +fn query_single_value(sql_str: &str) -> U where + U: FromSqlRow + 'static, TestBackend: HasSqlType, T: QueryId + SingleValue + SqlType, { diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index 94757819bfdf..4296e65db6eb 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -26,7 +26,8 @@ where + FromSqlRow::Backend> + PartialEq + Clone - + ::std::fmt::Debug, + + ::std::fmt::Debug + + 'static, >::Expression: SelectableExpression<(), SqlType = ST> + NonAggregate + QueryFragment<::Backend>