From c69d22a51f8afe8838664ef0dbc865af456cc620 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 1 Jun 2021 11:42:55 +0200 Subject: [PATCH 01/32] Return a iterator from `Connection::load` This PR provides a prototypical implementation of `Connection::load` so that a iterator of `Row`'s is returned. This is currently only implemented for the postgres backend to checkout the performance impact of this change. --- .github/workflows/benches.yml | 2 +- .github/workflows/ci.yml | 3 +- Cargo.toml | 40 ++++---- diesel/Cargo.toml | 2 +- diesel/src/connection/mod.rs | 19 ++-- diesel/src/lib.rs | 2 +- diesel/src/mysql/connection/mod.rs | 36 ++++--- diesel/src/mysql/connection/stmt/mod.rs | 2 +- diesel/src/pg/connection/cursor.rs | 25 +++-- diesel/src/pg/connection/mod.rs | 31 +++--- diesel/src/pg/connection/result.rs | 46 ++++++--- diesel/src/pg/connection/row.rs | 17 ++-- diesel/src/pg/connection/stmt/mod.rs | 12 +-- diesel/src/query_dsl/load_dsl.rs | 129 ++++++++++++++++++++++-- diesel/src/query_dsl/mod.rs | 4 +- diesel/src/result.rs | 9 +- diesel/src/row.rs | 8 +- diesel/src/sqlite/connection/mod.rs | 25 +++-- diesel/src/util.rs | 4 + diesel/src/util/once_cell.rs | 112 ++++++++++++++++++++ diesel_tests/tests/types.rs | 5 +- diesel_tests/tests/types_roundtrip.rs | 5 +- 22 files changed, 416 insertions(+), 122 deletions(-) create mode 100644 diesel/src/util/once_cell.rs diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml index 80ee0d2a061c..58a5af88df4c 100644 --- a/.github/workflows/benches.yml +++ b/.github/workflows/benches.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - backend: ["postgres", "sqlite", "mysql"] + backend: ["postgres"] #, "sqlite", "mysql"] steps: - name: Checkout sources uses: actions/checkout@v2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 97eaee7eba93..44c0c5ab4a37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,8 @@ jobs: fail-fast: false matrix: rust: ["stable", "beta", "nightly"] - backend: ["postgres", "sqlite", "mysql"] + # backend: ["postgres", "sqlite", "mysql"] + backend: ["postgres"] os: [ubuntu-20.04, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: diff --git a/Cargo.toml b/Cargo.toml index 3dce5ddff7b5..56a73eccdf0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,26 +1,26 @@ [workspace] members = [ "diesel", - "diesel_cli", +# "diesel_cli", "diesel_derives", "diesel_tests", - "diesel_migrations", - "diesel_migrations/migrations_internals", - "diesel_migrations/migrations_macros", - "diesel_dynamic_schema", - "examples/mysql/all_about_inserts", - "examples/mysql/getting_started_step_1", - "examples/mysql/getting_started_step_2", - "examples/mysql/getting_started_step_3", - "examples/postgres/advanced-blog-cli", - "examples/postgres/all_about_inserts", - "examples/postgres/all_about_updates", - "examples/postgres/getting_started_step_1", - "examples/postgres/getting_started_step_2", - "examples/postgres/getting_started_step_3", - "examples/postgres/custom_types", - "examples/sqlite/all_about_inserts", - "examples/sqlite/getting_started_step_1", - "examples/sqlite/getting_started_step_2", - "examples/sqlite/getting_started_step_3", + # "diesel_migrations", + # "diesel_migrations/migrations_internals", + # "diesel_migrations/migrations_macros", + #"diesel_dynamic_schema", + # "examples/mysql/all_about_inserts", + # "examples/mysql/getting_started_step_1", + # "examples/mysql/getting_started_step_2", + # "examples/mysql/getting_started_step_3", + # "examples/postgres/advanced-blog-cli", + # "examples/postgres/all_about_inserts", + # "examples/postgres/all_about_updates", + # "examples/postgres/getting_started_step_1", + # "examples/postgres/getting_started_step_2", + # "examples/postgres/getting_started_step_3", + # "examples/postgres/custom_types", + # "examples/sqlite/all_about_inserts", + # "examples/sqlite/getting_started_step_1", + # "examples/sqlite/getting_started_step_2", + # "examples/sqlite/getting_started_step_3", ] diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index e650363cc53b..85d3fd26ec79 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["with-deprecated", "32-column-tables"] +default = ["postgres"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 1cc4dbf742d4..bcfe9b4db04b 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -6,10 +6,8 @@ mod transaction_manager; use std::fmt::Debug; use crate::backend::Backend; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::*; #[doc(hidden)] @@ -27,8 +25,16 @@ pub trait SimpleConnection { fn batch_execute(&mut self, query: &str) -> QueryResult<()>; } +pub trait IterableConnection<'a, DB: Backend> { + type Cursor: Iterator>; + type Row: crate::row::Row<'a, DB>; +} + /// A connection to a database -pub trait Connection: SimpleConnection + Sized + Send { +pub trait Connection: SimpleConnection + Sized + Send +where + Self: for<'a> IterableConnection<'a, ::Backend>, +{ /// The backend this type connects to type Backend: Backend; @@ -177,12 +183,13 @@ pub trait Connection: SimpleConnection + Sized + Send { fn execute(&mut self, query: &str) -> QueryResult; #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load<'a, T>( + &'a mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata; #[doc(hidden)] diff --git a/diesel/src/lib.rs b/diesel/src/lib.rs index 96211a2569a5..59a15bcf280d 100644 --- a/diesel/src/lib.rs +++ b/diesel/src/lib.rs @@ -95,7 +95,7 @@ // For the `specialization` feature. #![cfg_attr(feature = "unstable", allow(incomplete_features))] // Built-in Lints -#![deny(warnings)] +//#![deny(warnings)] #![warn( missing_debug_implementations, missing_copy_implementations, diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 04093d901d09..2222f5d2bace 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -14,6 +14,7 @@ use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::query_dsl::load_dsl::CompatibleType; use crate::result::*; +use crate::row::Row; #[allow(missing_debug_implementations, missing_copy_implementations)] /// A connection to a MySQL database. Connection URLs should be in the form @@ -33,7 +34,12 @@ impl SimpleConnection for MysqlConnection { } } -impl Connection for MysqlConnection { +impl<'a> IterableConnection<'a> for MysqlConnection { + type Cursor = self::stmt::iterator::StatementIterator<'a>; + type Row = self::stmt::iterator::MysqlRow<'a>; +} + +/*impl Connection for MysqlConnection { type Backend = Mysql; type TransactionManager = AnsiTransactionManager; @@ -61,21 +67,27 @@ impl Connection for MysqlConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load<'a, T, ST>( + &'a mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, + Self: IterableConnection<'a>, + >::Cursor: + Iterator>::Row>>, + for<'b> >::Row: Row<'b, Self::Backend>, { - use crate::result::Error::DeserializationError; - - let mut stmt = self.prepare_query(&source.as_query())?; - let mut metadata = Vec::new(); - Mysql::row_metadata(&mut (), &mut metadata); - let results = unsafe { stmt.results(metadata)? }; - results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) + todo!() + // use crate::result::Error::DeserializationError; + + // let mut stmt = self.prepare_query(&source.as_query())?; + // let mut metadata = Vec::new(); + // Mysql::row_metadata(&mut (), &mut metadata); + // let results = unsafe { stmt.results(metadata)? }; + // results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) } #[doc(hidden)] @@ -123,7 +135,7 @@ impl MysqlConnection { self.execute("SET character_set_results = 'utf8mb4'")?; Ok(()) } -} +}*/ #[cfg(test)] mod tests { diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index a5f17a70792a..291dbf441c3c 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -1,6 +1,6 @@ extern crate mysqlclient_sys as ffi; -mod iterator; +pub mod iterator; mod metadata; use std::ffi::CStr; diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index 043f50f982be..593c819379f1 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -1,18 +1,22 @@ +use std::rc::Rc; + use super::result::PgResult; use super::row::PgRow; -/// The type returned by various [`Connection`] methods. +/// The type returned by various [`Conn +/// ection`] methods. /// Acts as an iterator over `T`. +#[allow(missing_debug_implementations)] pub struct Cursor<'a> { current_row: usize, - db_result: &'a PgResult, + db_result: Rc>, } impl<'a> Cursor<'a> { - pub(super) fn new(db_result: &'a PgResult) -> Self { + pub(super) fn new(db_result: PgResult<'a>) -> Self { Cursor { current_row: 0, - db_result, + db_result: Rc::new(db_result), } } } @@ -24,13 +28,13 @@ impl<'a> ExactSizeIterator for Cursor<'a> { } impl<'a> Iterator for Cursor<'a> { - type Item = PgRow<'a>; + type Item = crate::QueryResult>; fn next(&mut self) -> Option { if self.current_row < self.db_result.num_rows() { - let row = self.db_result.get_row(self.current_row); + let row = self.db_result.clone().get_row(self.current_row); self.current_row += 1; - Some(row) + Some(Ok(row)) } else { None } @@ -45,4 +49,11 @@ impl<'a> Iterator for Cursor<'a> { let len = self.len(); (len, Some(len)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index dfbf680b5e1a..65680dc4b29d 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -1,4 +1,4 @@ -mod cursor; +pub(crate) mod cursor; pub mod raw; #[doc(hidden)] pub mod result; @@ -13,15 +13,12 @@ use self::raw::RawConnection; use self::result::PgResult; use self::stmt::Statement; use crate::connection::*; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::pg::metadata_lookup::{GetPgMetadataCache, PgMetadataCache}; use crate::pg::{Pg, TransactionBuilder}; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::ConnectionError::CouldntSetupConfiguration; -use crate::result::Error::DeserializationError; use crate::result::*; /// The connection string expected by `PgConnection::establish` @@ -46,6 +43,11 @@ impl SimpleConnection for PgConnection { } } +impl<'a> IterableConnection<'a, Pg> for PgConnection { + type Cursor = Cursor<'a>; + type Row = self::row::PgRow<'a>; +} + impl Connection for PgConnection { type Backend = Pg; type TransactionManager = AnsiTransactionManager; @@ -70,21 +72,20 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load<'a, T>( + &'a mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { self.with_prepared_query(&source.as_query(), |stmt, params, conn| { let result = stmt.execute(conn, ¶ms)?; - let cursor = Cursor::new(&result); + let cursor = Cursor::new(result); - cursor - .map(|row| U::build_from_row(&row).map_err(DeserializationError)) - .collect::>>() + Ok(cursor) }) } @@ -140,13 +141,13 @@ impl PgConnection { TransactionBuilder::new(self) } - fn with_prepared_query + QueryId, R>( - &mut self, - source: &T, + fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + &'a mut self, + source: &'_ T, f: impl FnOnce( MaybeCached, Vec>>, - &mut RawConnection, + &'a mut RawConnection, ) -> QueryResult, ) -> QueryResult { let mut bind_collector = RawBytesBindCollector::::new(); diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 34ccb2550cfc..379b9ad3b81f 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -2,25 +2,29 @@ extern crate pq_sys; use self::pq_sys::*; use std::ffi::CStr; +use std::marker::PhantomData; use std::num::NonZeroU32; use std::os::raw as libc; +use std::rc::Rc; use std::{slice, str}; use super::raw::RawResult; use super::row::PgRow; use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryResult}; +use crate::util::OnceCell; + // Message after a database connection has been unexpectedly closed. const CLOSED_CONNECTION_MSG: &str = "server closed the connection unexpectedly\n\t\ This probably means the server terminated abnormally\n\tbefore or while processing the request.\n"; -pub struct PgResult { +pub(crate) struct PgResult<'a> { internal_result: RawResult, column_count: usize, row_count: usize, } -impl PgResult { +impl<'a> PgResult<'a> { #[allow(clippy::new_ret_no_self)] pub fn new(internal_result: RawResult) -> QueryResult { let result_status = unsafe { PQresultStatus(internal_result.as_ptr()) }; @@ -32,6 +36,8 @@ impl PgResult { internal_result, column_count, row_count, + column_name_map: OnceCell::new(), + _marker: PhantomData, }) } ExecStatusType::PGRES_EMPTY_QUERY => { @@ -89,11 +95,11 @@ impl PgResult { self.row_count } - pub fn get_row(&self, idx: usize) -> PgRow { + pub fn get_row(self: Rc, idx: usize) -> PgRow<'a> { PgRow::new(self, idx) } - pub fn get(&self, row_idx: usize, col_idx: usize) -> Option<&[u8]> { + pub fn get(&self, row_idx: usize, col_idx: usize) -> Option<&'a [u8]> { if self.is_null(row_idx, col_idx) { None } else { @@ -127,17 +133,29 @@ impl PgResult { } pub fn column_name(&self, col_idx: usize) -> Option<&str> { - unsafe { - let ptr = PQfname(self.internal_result.as_ptr(), col_idx as libc::c_int); - if ptr.is_null() { - None - } else { - Some(CStr::from_ptr(ptr).to_str().expect( - "Expect postgres field names to be UTF-8, because we \ + self.column_name_map + .get_or_init(|| { + (0..self.column_count) + .map(|idx| unsafe { + let ptr = PQfname(self.internal_result.as_ptr(), idx as libc::c_int); + if ptr.is_null() { + None + } else { + Some( + CStr::from_ptr(ptr) + .to_str() + .expect( + "Expect postgres field names to be UTF-8, because we \ requested UTF-8 encoding on connection setup", - )) - } - } + ) + .to_owned(), + ) + } + }) + .collect() + }) + .get(col_idx) + .and_then(|n| n.as_ref().map(|n| n as &str)) } pub fn column_count(&self) -> usize { diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index a3d9c9d76c32..03ba6a03cd8a 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -1,15 +1,17 @@ use super::result::PgResult; use crate::pg::{Pg, PgValue}; use crate::row::*; +use std::rc::Rc; #[derive(Clone)] +#[allow(missing_debug_implementations)] pub struct PgRow<'a> { - db_result: &'a PgResult, + db_result: Rc>, row_idx: usize, } impl<'a> PgRow<'a> { - pub fn new(db_result: &'a PgResult, row_idx: usize) -> Self { + pub(crate) fn new(db_result: Rc>, row_idx: usize) -> Self { PgRow { db_result, row_idx } } } @@ -28,7 +30,7 @@ impl<'a> Row<'a, Pg> for PgRow<'a> { { let idx = self.idx(idx)?; Some(PgField { - db_result: self.db_result, + db_result: self.db_result.clone(), row_idx: self.row_idx, col_idx: idx, }) @@ -55,18 +57,19 @@ impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { } } +#[allow(missing_debug_implementations)] pub struct PgField<'a> { - db_result: &'a PgResult, + db_result: Rc>, row_idx: usize, col_idx: usize, } -impl<'a> Field<'a, Pg> for PgField<'a> { - fn field_name(&self) -> Option<&'a str> { +impl<'a> Field for PgField<'a> { + fn field_name(&self) -> Option<&str> { self.db_result.column_name(self.col_idx) } - fn value(&self) -> Option> { + fn value<'b>(&'b self) -> Option> { let raw = self.db_result.get(self.row_idx, self.col_idx)?; let type_oid = self.db_result.column_type(self.col_idx); diff --git a/diesel/src/pg/connection/stmt/mod.rs b/diesel/src/pg/connection/stmt/mod.rs index e42382118e73..ce6f59426626 100644 --- a/diesel/src/pg/connection/stmt/mod.rs +++ b/diesel/src/pg/connection/stmt/mod.rs @@ -10,18 +10,18 @@ use crate::result::QueryResult; pub use super::raw::RawConnection; -pub struct Statement { +pub(crate) struct Statement { name: CString, param_formats: Vec, } impl Statement { #[allow(clippy::ptr_arg)] - pub fn execute( - &self, - raw_connection: &mut RawConnection, - param_data: &Vec>>, - ) -> QueryResult { + pub fn execute<'a>( + &'_ self, + raw_connection: &'a mut RawConnection, + param_data: &'_ Vec>>, + ) -> QueryResult> { let params_pointer = param_data .iter() .map(|data| { diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index 6c4a3e9dbd30..dc85d6cc3822 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -1,6 +1,6 @@ use super::RunQueryDsl; use crate::backend::Backend; -use crate::connection::Connection; +use crate::connection::{Connection, IterableConnection}; use crate::deserialize::FromSqlRow; use crate::expression::{select_by::SelectBy, Expression, QueryMetadata, Selectable}; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; @@ -13,9 +13,19 @@ use crate::result::QueryResult; /// to call `load` from generic code. /// /// [`RunQueryDsl`]: crate::RunQueryDsl -pub trait LoadQuery: RunQueryDsl { +pub trait LoadQuery: RunQueryDsl +where + for<'a> Self: LoadQueryRet<'a, Conn, U>, +{ /// Load this query - fn internal_load(self, conn: &mut Conn) -> QueryResult>; + fn internal_load<'a>( + self, + conn: &'a mut Conn, + ) -> QueryResult<>::Ret>; +} + +pub trait LoadQueryRet<'a, Conn, U> { + type Ret: Iterator>; } use crate::expression::TypedExpressionType; @@ -53,17 +63,49 @@ where type SqlType = ST; } -impl LoadQuery for T +#[allow(missing_debug_implementations)] +pub struct LoadIter<'a, U, C, ST, DB> { + cursor: C, + _marker: std::marker::PhantomData<&'a (ST, U, DB)>, +} + +impl<'a, Conn, T, U, DB> LoadQueryRet<'a, Conn, U> for T +where + Conn: Connection, + T: AsQuery + RunQueryDsl, + T::Query: QueryFragment + QueryId, + T::SqlType: CompatibleType, + DB: Backend + QueryMetadata + 'static, + U: FromSqlRow<>::SqlType, DB> + 'static, + >::SqlType: 'static, +{ + type Ret = LoadIter< + 'a, + U, + >::Cursor, + >::SqlType, + DB, + >; +} + +impl LoadQuery for T where - Conn: Connection, + Conn: Connection, T: AsQuery + RunQueryDsl, - T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - Conn::Backend: QueryMetadata, - U: FromSqlRow<>::SqlType, Conn::Backend>, + T::Query: QueryFragment + QueryId, + T::SqlType: CompatibleType, + DB: Backend + QueryMetadata + 'static, + U: FromSqlRow<>::SqlType, DB> + 'static, + >::SqlType: 'static, { - fn internal_load(self, conn: &mut Conn) -> QueryResult> { - conn.load(self) + fn internal_load<'a>( + self, + conn: &'a mut Conn, + ) -> QueryResult<>::Ret> { + Ok(LoadIter { + cursor: conn.load(self)?, + _marker: Default::default(), + }) } } @@ -91,3 +133,68 @@ where conn.execute_returning_count(&query) } } + +impl<'a, C, U, ST, DB, R> LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + fn map_row(row: Option>) -> Option> { + match row? { + Ok(row) => { + Some(U::build_from_row(&row).map_err(crate::result::Error::DeserializationError)) + } + Err(e) => Some(Err(e)), + } + } +} + +impl<'a, C, U, ST, DB, R> Iterator for LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + type Item = QueryResult; + + fn next(&mut self) -> Option { + Self::map_row(self.cursor.next()) + } + + fn size_hint(&self) -> (usize, Option) { + self.cursor.size_hint() + } + + fn count(self) -> usize + where + Self: Sized, + { + self.cursor.count() + } + + fn last(self) -> Option + where + Self: Sized, + { + Self::map_row(self.cursor.last()) + } + + fn nth(&mut self, n: usize) -> Option { + Self::map_row(self.cursor.nth(n)) + } +} + +impl<'a, C, U, ST, DB, R> ExactSizeIterator for LoadIter<'a, U, C, ST, DB> +where + DB: Backend, + C: ExactSizeIterator + Iterator>, + R: crate::row::Row<'a, DB>, + U: FromSqlRow, +{ + fn len(&self) -> usize { + self.cursor.len() + } +} diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index dc0ab22a4081..b1e91cf4f0b8 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -1404,7 +1404,7 @@ pub trait RunQueryDsl: Sized { where Self: LoadQuery, { - self.internal_load(conn) + self.internal_load(conn)?.collect() } /// Runs the command, and returns the affected row. @@ -1456,7 +1456,7 @@ pub trait RunQueryDsl: Sized { where Self: LoadQuery, { - first_or_not_found(self.load(conn)) + first_or_not_found(self.internal_load(conn)) } /// Runs the command, returning an `Vec` with the affected rows. diff --git a/diesel/src/result.rs b/diesel/src/result.rs index 68d7c4049b6d..cbb8cec6cafa 100644 --- a/diesel/src/result.rs +++ b/diesel/src/result.rs @@ -360,8 +360,13 @@ fn error_impls_send() { let x: &Send = &err; } -pub(crate) fn first_or_not_found(records: QueryResult>) -> QueryResult { - records?.into_iter().next().ok_or(Error::NotFound) +pub(crate) fn first_or_not_found( + records: QueryResult>>, +) -> QueryResult { + match records?.next() { + Some(r) => r, + None => Err(Error::NotFound), + } } /// An unexpected `NULL` was encountered during deserialization diff --git a/diesel/src/row.rs b/diesel/src/row.rs index 9d9c29c555a6..63fcc887016a 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -34,7 +34,7 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// /// * Crates implementing custom backends should provide their own type /// meeting the required trait bounds - type Field: Field<'a, DB>; + type Field: Field; /// Return type of `PartialRow` /// @@ -63,15 +63,15 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// /// This trait allows retrieving information on the name of the colum and on the value of the /// field. -pub trait Field<'a, DB: Backend> { +pub trait Field { /// The name of the current field /// /// Returns `None` if it's an unnamed field - fn field_name(&self) -> Option<&'a str>; + fn field_name(&self) -> Option<&str>; /// Get the value representing the current field in the raw representation /// as it is transmitted by the database - fn value(&self) -> Option>; + fn value<'a>(&'a self) -> Option>; /// Checks whether this field is null or not. fn is_null(&self) -> bool { diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 6b7d9ac9e30a..1b5697fe48c9 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -22,6 +22,7 @@ use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; +use crate::row::Row; use crate::serialize::ToSql; use crate::sql_types::HasSqlType; use crate::sqlite::Sqlite; @@ -50,6 +51,11 @@ impl SimpleConnection for SqliteConnection { } } +impl<'a> IterableConnection<'a> for SqliteConnection { + type Cursor = StatementIterator<'a, 'a, (), ()>; + type Row = self::sqlite_value::SqliteRow<'a, 'a, 'a>; +} + impl Connection for SqliteConnection { type Backend = Sqlite; type TransactionManager = AnsiTransactionManager; @@ -81,18 +87,23 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn load(&mut self, source: T) -> QueryResult> + fn load<'a, T, ST>( + &mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: crate::query_dsl::load_dsl::CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, + >::Cursor: + Iterator>::Row>>, + for<'b> >::Row: Row<'b, Self::Backend>, { - let mut statement = self.prepare_query(&source.as_query())?; - let statement_use = StatementUse::new(&mut statement, true); - let iter = StatementIterator::<_, U>::new(statement_use); - iter.collect() + todo!() + // let mut statement = self.prepare_query(&source.as_query())?; + // let statement_use = StatementUse::new(&mut statement, true); + // let iter = StatementIterator::<_, U>::new(statement_use); + // iter.collect() } #[doc(hidden)] diff --git a/diesel/src/util.rs b/diesel/src/util.rs index ed0439d49ed5..2a0a1de3d14a 100644 --- a/diesel/src/util.rs +++ b/diesel/src/util.rs @@ -9,3 +9,7 @@ pub trait TupleAppend { pub trait TupleSize { const SIZE: usize; } + +mod once_cell; + +pub(crate) use self::once_cell::OnceCell; diff --git a/diesel/src/util/once_cell.rs b/diesel/src/util/once_cell.rs new file mode 100644 index 000000000000..790ba6592efa --- /dev/null +++ b/diesel/src/util/once_cell.rs @@ -0,0 +1,112 @@ +// This is a copy of the unstable `OnceCell` implementation in rusts std-library +// https://github.com/rust-lang/rust/blob/1160cf864f2a0014e3442367e1b96496bfbeadf4/library/core/src/lazy.rs#L8-L276 +// +// See https://github.com/rust-lang/rust/issues/74465 for the corresponding tracking issue + +use std::cell::UnsafeCell; + +/// A cell which can be written to only once. +/// +/// Unlike `RefCell`, a `OnceCell` only provides shared `&T` references to its value. +/// Unlike `Cell`, a `OnceCell` doesn't require copying or replacing the value to access it. +/// +/// # Examples +/// +/// ``` +/// #![feature(once_cell)] +/// +/// use std::lazy::OnceCell; +/// +/// let cell = OnceCell::new(); +/// assert!(cell.get().is_none()); +/// +/// let value: &String = cell.get_or_init(|| { +/// "Hello, World!".to_string() +/// }); +/// assert_eq!(value, "Hello, World!"); +/// assert!(cell.get().is_some()); +/// ``` +pub struct OnceCell { + // Invariant: written to at most once. + inner: UnsafeCell>, +} + +impl Default for OnceCell { + fn default() -> Self { + Self::new() + } +} + +impl OnceCell { + /// Creates a new empty cell. + pub const fn new() -> OnceCell { + OnceCell { + inner: UnsafeCell::new(None), + } + } + + /// Gets the contents of the cell, initializing it with `f` if + /// the cell was empty. If the cell was empty and `f` failed, an + /// error is returned. + /// + /// # Panics + /// + /// If `f` panics, the panic is propagated to the caller, and the cell + /// remains uninitialized. + /// + /// It is an error to reentrantly initialize the cell from `f`. Doing + /// so results in a panic. + /// + /// # Examples + /// + /// ``` + /// #![feature(once_cell)] + /// + /// use std::lazy::OnceCell; + /// + /// let cell = OnceCell::new(); + /// assert_eq!(cell.get_or_try_init(|| Err(())), Err(())); + /// assert!(cell.get().is_none()); + /// let value = cell.get_or_try_init(|| -> Result { + /// Ok(92) + /// }); + /// assert_eq!(value, Ok(&92)); + /// assert_eq!(cell.get(), Some(&92)) + /// ``` + pub fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> T, + { + if let Some(val) = self.get() { + return val; + } + let val = f(); + // Note that *some* forms of reentrant initialization might lead to + // UB (see `reentrant_init` test). I believe that just removing this + // `assert`, while keeping `set/get` would be sound, but it seems + // better to panic, rather than to silently use an old value. + assert!(self.set(val).is_ok(), "reentrant init"); + self.get().unwrap() + } + + fn get(&self) -> Option<&T> { + // SAFETY: Safe due to `inner`'s invariant + unsafe { &*self.inner.get() }.as_ref() + } + + fn set(&self, value: T) -> Result<(), T> { + // SAFETY: Safe because we cannot have overlapping mutable borrows + let slot = unsafe { &*self.inner.get() }; + if slot.is_some() { + return Err(value); + } + + // SAFETY: This is the only place where we set the slot, no races + // due to reentrancy/concurrency are possible, and we've + // checked that slot is currently `None`, so this write + // maintains the `inner`'s invariant. + let slot = unsafe { &mut *self.inner.get() }; + *slot = Some(value); + Ok(()) + } +} diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index d36d675b6117..0d07cab6cd47 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -1241,10 +1241,11 @@ fn third_party_crates_can_add_new_types() { assert_eq!(70_000, query_single_value::("70000")); } -fn query_single_value>(sql_str: &str) -> U +fn query_single_value(sql_str: &str) -> U where + U: FromSqlRow + 'static, TestBackend: HasSqlType, - T: QueryId + SingleValue + SqlType, + T: QueryId + SingleValue + SqlType + 'static, { use diesel::dsl::sql; let connection = &mut connection(); diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index 94757819bfdf..b4c5995c72e1 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -20,13 +20,14 @@ use std::collections::Bound; pub fn test_type_round_trips(value: T) -> bool where - ST: QueryId + SqlType + TypedExpressionType + SingleValue, + ST: QueryId + SqlType + TypedExpressionType + SingleValue + 'static, ::Backend: HasSqlType, T: AsExpression + FromSqlRow::Backend> + PartialEq + Clone - + ::std::fmt::Debug, + + ::std::fmt::Debug + + 'static, >::Expression: SelectableExpression<(), SqlType = ST> + NonAggregate + QueryFragment<::Backend> From ef2d5e714f16f2fd24ce93d82cff9fba6ce51b3a Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Wed, 2 Jun 2021 12:08:09 +0200 Subject: [PATCH 02/32] WIP --- diesel/Cargo.toml | 2 +- diesel/src/expression/array_comparison.rs | 14 +++ diesel/src/mysql/connection/bind.rs | 13 +- diesel/src/mysql/connection/mod.rs | 30 ++--- diesel/src/mysql/connection/stmt/iterator.rs | 120 +++++++++++++------ diesel/src/mysql/connection/stmt/mod.rs | 12 +- 6 files changed, 128 insertions(+), 63 deletions(-) diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 85d3fd26ec79..fe8d3ec31da7 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["postgres"] +default = ["postgres", "mysql"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index 7c8a3c8387c5..afc2384a9c4b 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -119,6 +119,17 @@ where } } +// impl AsInExpression for [T; N] +// where T: AsExpression, +// ST: SqlType + TypedExpressionType +// { +// type InExpression = StaticMany; + +// fn as_in_expression(self) -> Self::InExpression { +// todo!() +// } +// } + pub trait MaybeEmpty { fn is_empty(&self) -> bool; } @@ -149,6 +160,9 @@ where } } +// #[derive(Debug, Clone, ValidGrouping)] +// pub struct StaticMany([T; N]); + #[derive(Debug, Clone, ValidGrouping)] pub struct Many(Vec); diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index a3b894c5bec7..7b2fa7e43f77 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -28,11 +28,16 @@ impl Binds { Ok(Binds { data }) } - pub fn from_output_types(types: Vec>, metadata: &StatementMetadata) -> Self { + pub fn from_output_types(types: &[Option], metadata: &StatementMetadata) -> Self { let data = metadata .fields() .iter() - .zip(types.into_iter().chain(std::iter::repeat(None))) + .zip( + types + .into_iter() + .map(|o| o.as_ref()) + .chain(std::iter::repeat(None)), + ) .map(|(field, tpe)| BindData::for_output(tpe, field)) .collect(); @@ -148,7 +153,7 @@ impl BindData { } } - fn for_output(tpe: Option, metadata: &MysqlFieldMetadata) -> Self { + fn for_output(tpe: Option<&MysqlType>, metadata: &MysqlFieldMetadata) -> Self { let (tpe, flags) = if let Some(tpe) = tpe { match (tpe, metadata.field_type()) { // Those are types where we handle the conversion in diesel itself @@ -275,7 +280,7 @@ impl BindData { (metadata.field_type(), metadata.flags()) } - (tpe, _) => tpe.into(), + (tpe, _) => (*tpe).into(), } } else { (metadata.field_type(), metadata.flags()) diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 2222f5d2bace..e5d4b486d4df 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -8,13 +8,10 @@ use self::stmt::Statement; use self::url::ConnectionOptions; use super::backend::Mysql; use crate::connection::*; -use crate::deserialize::FromSqlRow; use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; -use crate::query_dsl::load_dsl::CompatibleType; use crate::result::*; -use crate::row::Row; #[allow(missing_debug_implementations, missing_copy_implementations)] /// A connection to a MySQL database. Connection URLs should be in the form @@ -34,12 +31,12 @@ impl SimpleConnection for MysqlConnection { } } -impl<'a> IterableConnection<'a> for MysqlConnection { +impl<'a> IterableConnection<'a, Mysql> for MysqlConnection { type Cursor = self::stmt::iterator::StatementIterator<'a>; type Row = self::stmt::iterator::MysqlRow<'a>; } -/*impl Connection for MysqlConnection { +impl Connection for MysqlConnection { type Backend = Mysql; type TransactionManager = AnsiTransactionManager; @@ -67,27 +64,20 @@ impl<'a> IterableConnection<'a> for MysqlConnection { } #[doc(hidden)] - fn load<'a, T, ST>( + fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, - Self: IterableConnection<'a>, - >::Cursor: - Iterator>::Row>>, - for<'b> >::Row: Row<'b, Self::Backend>, { - todo!() - // use crate::result::Error::DeserializationError; - - // let mut stmt = self.prepare_query(&source.as_query())?; - // let mut metadata = Vec::new(); - // Mysql::row_metadata(&mut (), &mut metadata); - // let results = unsafe { stmt.results(metadata)? }; - // results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) + let mut stmt = self.prepare_query(&source.as_query())?; + let mut metadata = Vec::new(); + Mysql::row_metadata(&mut (), &mut metadata); + let results = unsafe { stmt.results(metadata)? }; + Ok(results) } #[doc(hidden)] @@ -135,7 +125,7 @@ impl MysqlConnection { self.execute("SET character_set_results = 'utf8mb4'")?; Ok(()) } -}*/ +} #[cfg(test)] mod tests { diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 584b66f949ae..e6a78e48298b 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,60 +1,104 @@ -use super::{metadata::MysqlFieldMetadata, BindData, Binds, Statement, StatementMetadata}; +use std::marker::PhantomData; +use std::rc::Rc; + +use super::{Binds, Statement, StatementMetadata}; +use super::metadata::MysqlFieldMetadata; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; pub struct StatementIterator<'a> { stmt: &'a mut Statement, - output_binds: Binds, - metadata: StatementMetadata, + output_binds: Rc, + metadata: Rc, + types: Vec>, + size: usize, + fetched_rows: usize, } -#[allow(clippy::should_implement_trait)] // don't neet `Iterator` here impl<'a> StatementIterator<'a> { #[allow(clippy::new_ret_no_self)] pub fn new(stmt: &'a mut Statement, types: Vec>) -> QueryResult { let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_output_types(types, &metadata); + let mut output_binds = Binds::from_output_types(&types, &metadata); stmt.execute_statement(&mut output_binds)?; + let size = unsafe { stmt.result_size() }?; Ok(StatementIterator { + metadata: Rc::new(metadata), + output_binds: Rc::new(output_binds), + fetched_rows: 0, + size, stmt, - output_binds, - metadata, + types, }) } +} + +impl<'a> Iterator for StatementIterator<'a> { + type Item = QueryResult>; + + fn next(&mut self) -> Option { + // check if we own the only instance of the bind buffer + // if that's the case we can reuse the underlying allocations + // if that's not the case, allocate a new buffer + let res = if let Some(binds) = Rc::get_mut(&mut self.output_binds) { + self.stmt + .populate_row_buffers(binds) + .map(|o| o.map(|()| self.output_binds.clone())) + } else { + // The shared bind buffer is in use by someone else, + // we allocate a new buffer here + let mut output_binds = Binds::from_output_types(&self.types, &self.metadata); + self.stmt + .populate_row_buffers(&mut output_binds) + .map(|o| o.map(|()| Rc::new(output_binds))) + }; + + match res { + Ok(Some(binds)) => { + self.fetched_rows += 1; + Some(Ok(MysqlRow { + col_idx: 0, + binds, + metadata: self.metadata.clone(), + _marker: Default::default(), + })) + } + Ok(None) => None, + Err(e) => { + self.fetched_rows += 1; + Some(Err(e)) + } + } + } - pub fn map(mut self, mut f: F) -> QueryResult> + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } + + fn count(self) -> usize where - F: FnMut(MysqlRow) -> QueryResult, + Self: Sized, { - let mut results = Vec::new(); - while let Some(row) = self.next() { - results.push(f(row?)?); - } - Ok(results) + self.len() } +} - fn next(&mut self) -> Option> { - match self.stmt.populate_row_buffers(&mut self.output_binds) { - Ok(Some(())) => Some(Ok(MysqlRow { - col_idx: 0, - binds: &mut self.output_binds, - metadata: &self.metadata, - })), - Ok(None) => None, - Err(e) => Some(Err(e)), - } +impl<'a> ExactSizeIterator for StatementIterator<'a> { + fn len(&self) -> usize { + self.size - self.fetched_rows } } #[derive(Clone)] pub struct MysqlRow<'a> { col_idx: usize, - binds: &'a Binds, - metadata: &'a StatementMetadata, + binds: Rc, + metadata: Rc, + _marker: PhantomData<&'a mut (Binds, StatementMetadata)>, } impl<'a> Row<'a, Mysql> for MysqlRow<'a> { @@ -71,8 +115,10 @@ impl<'a> Row<'a, Mysql> for MysqlRow<'a> { { let idx = self.idx(idx)?; Some(MysqlField { - bind: &self.binds[idx], - metadata: &self.metadata.fields()[idx], + bind: self.binds.clone(), + metadata: self.metadata.clone(), + idx, + _marker: Default::default(), }) } @@ -103,20 +149,22 @@ impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { } pub struct MysqlField<'a> { - bind: &'a BindData, - metadata: &'a MysqlFieldMetadata<'a>, + bind: Rc, + metadata: Rc, + idx: usize, + _marker: PhantomData<&'a (Binds, StatementMetadata)> } -impl<'a> Field<'a, Mysql> for MysqlField<'a> { - fn field_name(&self) -> Option<&'a str> { - self.metadata.field_name() +impl<'a> Field for MysqlField<'a> { + fn field_name(&self) -> Option<&str> { + self.metadata.fields()[self.idx].field_name() } fn is_null(&self) -> bool { - self.bind.is_null() + (*self.bind)[self.idx].is_null() } - fn value(&self) -> Option> { - self.bind.value() + fn value<'b>(&'b self) -> Option> { + self.bind[self.idx].value() } } diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 291dbf441c3c..31b556bf2bcb 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -3,6 +3,7 @@ extern crate mysqlclient_sys as ffi; pub mod iterator; mod metadata; +use std::convert::TryFrom; use std::ffi::CStr; use std::os::raw as libc; use std::ptr::NonNull; @@ -10,7 +11,7 @@ use std::ptr::NonNull; use self::iterator::*; use super::bind::{BindData, Binds}; use crate::mysql::MysqlType; -use crate::result::{DatabaseErrorKind, QueryResult}; +use crate::result::{DatabaseErrorKind, Error, QueryResult}; pub use self::metadata::{MysqlFieldMetadata, StatementMetadata}; @@ -80,12 +81,19 @@ impl Statement { /// have a return value. After calling this function, `execute` can never /// be called on this statement. pub unsafe fn results( - &mut self, + self, types: Vec>, ) -> QueryResult { StatementIterator::new(self, types) } + /// This function should be called after `execute` only + /// otherwise it's not guranteed to return a valid result + pub(in crate::mysql::connection) unsafe fn result_size(&mut self) -> QueryResult { + let size = ffi::mysql_stmt_num_rows(self.stmt.as_ptr()); + usize::try_from(size).map_err(|e| Error::DeserializationError(Box::new(e))) + } + fn last_error_message(&self) -> String { unsafe { CStr::from_ptr(ffi::mysql_stmt_error(self.stmt.as_ptr())) } .to_string_lossy() From 1cd7bb359f449892149b8bd73320451865b82b1f Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 11 Jun 2021 12:31:50 +0200 Subject: [PATCH 03/32] Port the mysql backend to use iterators --- diesel/src/mysql/connection/mod.rs | 45 +++++++++++++------- diesel/src/mysql/connection/stmt/iterator.rs | 3 +- diesel/src/mysql/connection/stmt/mod.rs | 8 ++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index e5d4b486d4df..efae0ca6a6bf 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -20,6 +20,7 @@ pub struct MysqlConnection { raw_connection: RawConnection, transaction_state: AnsiTransactionManager, statement_cache: StatementCache, + current_statement: Option, } unsafe impl Send for MysqlConnection {} @@ -50,6 +51,7 @@ impl Connection for MysqlConnection { raw_connection, transaction_state: AnsiTransactionManager::default(), statement_cache: StatementCache::new(), + current_statement: None, }; conn.set_config_options() .map_err(CouldntSetupConfiguration)?; @@ -73,11 +75,22 @@ impl Connection for MysqlConnection { T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, { - let mut stmt = self.prepare_query(&source.as_query())?; - let mut metadata = Vec::new(); - Mysql::row_metadata(&mut (), &mut metadata); - let results = unsafe { stmt.results(metadata)? }; - Ok(results) + self.with_prepared_query(&source.as_query(), |stmt, current_statement| { + let mut metadata = Vec::new(); + Mysql::row_metadata(&mut (), &mut metadata); + let stmt = match stmt { + MaybeCached::CannotCache(stmt) => { + *current_statement = Some(stmt); + current_statement + .as_mut() + .expect("We set it literally above") + } + MaybeCached::Cached(stmt) => stmt, + }; + + let results = unsafe { stmt.results(metadata)? }; + Ok(results) + }) } #[doc(hidden)] @@ -85,11 +98,12 @@ impl Connection for MysqlConnection { where T: QueryFragment + QueryId, { - let stmt = self.prepare_query(source)?; - unsafe { - stmt.execute()?; - } - Ok(stmt.affected_rows()) + self.with_prepared_query(source, |stmt, _| { + unsafe { + stmt.execute()?; + } + Ok(stmt.affected_rows()) + }) } #[doc(hidden)] @@ -99,10 +113,11 @@ impl Connection for MysqlConnection { } impl MysqlConnection { - fn prepare_query(&mut self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - { + fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + &'a mut self, + source: &'_ T, + f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option) -> QueryResult, + ) -> QueryResult { let cache = &mut self.statement_cache; let conn = &mut self.raw_connection; @@ -114,7 +129,7 @@ impl MysqlConnection { .into_iter() .zip(bind_collector.binds); stmt.bind(binds)?; - Ok(stmt) + f(stmt, &mut self.current_statement) } fn set_config_options(&mut self) -> QueryResult<()> { diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index e6a78e48298b..854e5c86bee7 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use std::rc::Rc; use super::{Binds, Statement, StatementMetadata}; -use super::metadata::MysqlFieldMetadata; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; @@ -152,7 +151,7 @@ pub struct MysqlField<'a> { bind: Rc, metadata: Rc, idx: usize, - _marker: PhantomData<&'a (Binds, StatementMetadata)> + _marker: PhantomData<&'a (Binds, StatementMetadata)>, } impl<'a> Field for MysqlField<'a> { diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 31b556bf2bcb..6c828a786d7b 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -9,7 +9,7 @@ use std::os::raw as libc; use std::ptr::NonNull; use self::iterator::*; -use super::bind::{BindData, Binds}; +use super::bind::Binds; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, Error, QueryResult}; @@ -80,10 +80,10 @@ impl Statement { /// This function should be called instead of `execute` for queries which /// have a return value. After calling this function, `execute` can never /// be called on this statement. - pub unsafe fn results( - self, + pub unsafe fn results<'a>( + &'a mut self, types: Vec>, - ) -> QueryResult { + ) -> QueryResult> { StatementIterator::new(self, types) } From 91124aa41af4912e4830ea4ef97ad8553c7f59f5 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 11 Jun 2021 12:36:49 +0200 Subject: [PATCH 04/32] Enable CI for mysql again --- .github/workflows/benches.yml | 2 +- .github/workflows/ci.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml index 58a5af88df4c..f1e08cfe14f8 100644 --- a/.github/workflows/benches.yml +++ b/.github/workflows/benches.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - backend: ["postgres"] #, "sqlite", "mysql"] + backend: ["postgres", "mysql"] #, "sqlite", "mysql"] steps: - name: Checkout sources uses: actions/checkout@v2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 44c0c5ab4a37..9ec6d80db4ed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: matrix: rust: ["stable", "beta", "nightly"] # backend: ["postgres", "sqlite", "mysql"] - backend: ["postgres"] + backend: ["postgres", "mysql"] os: [ubuntu-20.04, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: From 6071d418dcfb96699f20852ace9f4d06bbb6f2a7 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 17 Jun 2021 15:53:22 +0200 Subject: [PATCH 05/32] Adjust sqlite connection implementation to return an iterator --- Cargo.toml | 40 ++--- diesel/Cargo.toml | 2 +- diesel/src/connection/mod.rs | 9 ++ diesel/src/mysql/connection/stmt/iterator.rs | 3 + diesel/src/mysql/connection/stmt/mod.rs | 4 +- diesel/src/sqlite/backend.rs | 2 +- diesel/src/sqlite/connection/functions.rs | 6 +- diesel/src/sqlite/connection/mod.rs | 75 +++++---- diesel/src/sqlite/connection/sqlite_value.rs | 153 +++++++----------- .../sqlite/connection/statement_iterator.rs | 108 ++++++++++--- diesel/src/sqlite/connection/stmt.rs | 68 ++++---- diesel/src/sqlite/types/date_and_time/mod.rs | 6 +- diesel/src/sqlite/types/mod.rs | 16 +- diesel/src/sqlite/types/numeric.rs | 2 +- diesel/src/util/once_cell.rs | 4 +- 15 files changed, 265 insertions(+), 233 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 56a73eccdf0d..3dce5ddff7b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,26 +1,26 @@ [workspace] members = [ "diesel", -# "diesel_cli", + "diesel_cli", "diesel_derives", "diesel_tests", - # "diesel_migrations", - # "diesel_migrations/migrations_internals", - # "diesel_migrations/migrations_macros", - #"diesel_dynamic_schema", - # "examples/mysql/all_about_inserts", - # "examples/mysql/getting_started_step_1", - # "examples/mysql/getting_started_step_2", - # "examples/mysql/getting_started_step_3", - # "examples/postgres/advanced-blog-cli", - # "examples/postgres/all_about_inserts", - # "examples/postgres/all_about_updates", - # "examples/postgres/getting_started_step_1", - # "examples/postgres/getting_started_step_2", - # "examples/postgres/getting_started_step_3", - # "examples/postgres/custom_types", - # "examples/sqlite/all_about_inserts", - # "examples/sqlite/getting_started_step_1", - # "examples/sqlite/getting_started_step_2", - # "examples/sqlite/getting_started_step_3", + "diesel_migrations", + "diesel_migrations/migrations_internals", + "diesel_migrations/migrations_macros", + "diesel_dynamic_schema", + "examples/mysql/all_about_inserts", + "examples/mysql/getting_started_step_1", + "examples/mysql/getting_started_step_2", + "examples/mysql/getting_started_step_3", + "examples/postgres/advanced-blog-cli", + "examples/postgres/all_about_inserts", + "examples/postgres/all_about_updates", + "examples/postgres/getting_started_step_1", + "examples/postgres/getting_started_step_2", + "examples/postgres/getting_started_step_3", + "examples/postgres/custom_types", + "examples/sqlite/all_about_inserts", + "examples/sqlite/getting_started_step_1", + "examples/sqlite/getting_started_step_2", + "examples/sqlite/getting_started_step_3", ] diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index fe8d3ec31da7..14412097fe13 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["postgres", "mysql"] +default = ["32-column-tables", "without-deprecated"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index bcfe9b4db04b..48ffba21c09f 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -25,8 +25,17 @@ pub trait SimpleConnection { fn batch_execute(&mut self, query: &str) -> QueryResult<()>; } +/// This trait describes which cursor type is used by a given connection +/// implementation. This trait is only useful in combination with [`Connection`]. +/// +/// Implementation wise this is a workaround for GAT types pub trait IterableConnection<'a, DB: Backend> { + /// The cursor type returned by [`Connection::load`] + /// + /// Users should handle this as opaque type that implements [`Iterator`] type Cursor: Iterator>; + /// The row type used as [`Iterator::Item`] for the iterator implementation + /// of [`IterableConnection::Cursor`] type Row: crate::row::Row<'a, DB>; } diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 854e5c86bee7..84058734dd11 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -6,6 +6,7 @@ use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; +#[allow(missing_debug_implementations)] pub struct StatementIterator<'a> { stmt: &'a mut Statement, output_binds: Rc, @@ -93,6 +94,7 @@ impl<'a> ExactSizeIterator for StatementIterator<'a> { } #[derive(Clone)] +#[allow(missing_debug_implementations)] pub struct MysqlRow<'a> { col_idx: usize, binds: Rc, @@ -147,6 +149,7 @@ impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { } } +#[allow(missing_debug_implementations)] pub struct MysqlField<'a> { bind: Rc, metadata: Rc, diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 6c828a786d7b..4b9998349efb 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -1,6 +1,6 @@ extern crate mysqlclient_sys as ffi; -pub mod iterator; +pub(super) mod iterator; mod metadata; use std::convert::TryFrom; @@ -15,7 +15,7 @@ use crate::result::{DatabaseErrorKind, Error, QueryResult}; pub use self::metadata::{MysqlFieldMetadata, StatementMetadata}; -#[allow(dead_code)] +#[allow(dead_code, missing_debug_implementations)] // https://github.com/rust-lang/rust/issues/81658 pub struct Statement { stmt: NonNull, diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 1c6af1637e82..41a89447e051 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -45,7 +45,7 @@ impl Backend for Sqlite { } impl<'a> HasRawValue<'a> for Sqlite { - type RawValue = SqliteValue<'a>; + type RawValue = &'a SqliteValue; } impl TypeMetadata for Sqlite { diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index 79034c73a885..b22cdf41570b 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -174,8 +174,8 @@ struct FunctionArgument<'a> { p: PhantomData<&'a ()>, } -impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { - fn field_name(&self) -> Option<&'a str> { +impl<'a> Field for FunctionArgument<'a> { + fn field_name(&self) -> Option<&str> { None } @@ -183,7 +183,7 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { self.value().is_none() } - fn value(&self) -> Option> { + fn value<'b>(&'b self) -> Option> { unsafe { SqliteValue::new(self.arg) } } } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 1b5697fe48c9..2d4932d22116 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -3,6 +3,7 @@ extern crate libsqlite3_sys as ffi; mod functions; #[doc(hidden)] pub mod raw; +mod row; mod serialized_value; mod sqlite_value; mod statement_iterator; @@ -22,7 +23,6 @@ use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; -use crate::row::Row; use crate::serialize::ToSql; use crate::sql_types::HasSqlType; use crate::sqlite::Sqlite; @@ -35,7 +35,11 @@ use crate::sqlite::Sqlite; /// - Special identifiers (`:memory:`) #[allow(missing_debug_implementations)] pub struct SqliteConnection { + // Both statement_cache and current_statement needs to be before raw_connection + // otherwise we will get errors about open statements before closing the + // connection itself statement_cache: StatementCache, + current_statement: Option, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, } @@ -51,9 +55,9 @@ impl SimpleConnection for SqliteConnection { } } -impl<'a> IterableConnection<'a> for SqliteConnection { - type Cursor = StatementIterator<'a, 'a, (), ()>; - type Row = self::sqlite_value::SqliteRow<'a, 'a, 'a>; +impl<'a> IterableConnection<'a, Sqlite> for SqliteConnection { + type Cursor = StatementIterator<'a, 'a>; + type Row = self::row::SqliteRow<'a, 'a>; } impl Connection for SqliteConnection { @@ -74,6 +78,7 @@ impl Connection for SqliteConnection { statement_cache: StatementCache::new(), raw_connection, transaction_state: AnsiTransactionManager::default(), + current_statement: None, }; conn.register_diesel_sql_functions() .map_err(CouldntSetupConfiguration)?; @@ -87,23 +92,28 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn load<'a, T, ST>( - &mut self, + fn load<'a, T>( + &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, - >::Cursor: - Iterator>::Row>>, - for<'b> >::Row: Row<'b, Self::Backend>, { - todo!() - // let mut statement = self.prepare_query(&source.as_query())?; - // let statement_use = StatementUse::new(&mut statement, true); - // let iter = StatementIterator::<_, U>::new(statement_use); - // iter.collect() + self.with_prepared_query(&source.as_query(), |stmt, current_statement| { + let statement = match stmt { + MaybeCached::CannotCache(stmt) => { + *current_statement = Some(stmt); + current_statement + .as_mut() + .expect("We set it literally above") + } + MaybeCached::Cached(stmt) => stmt, + }; + let statement_use = StatementUse::new(statement); + Ok(StatementIterator::new(statement_use)) + }) } #[doc(hidden)] @@ -111,11 +121,11 @@ impl Connection for SqliteConnection { where T: QueryFragment + QueryId, { - { - let mut statement = self.prepare_query(source)?; - let mut statement_use = StatementUse::new(&mut statement, false); - statement_use.run()?; - } + self.with_prepared_query(source, |mut stmt, _| { + let statement_use = StatementUse::new(&mut stmt); + statement_use.run() + })?; + Ok(self.raw_connection.rows_affected_by_last_query()) } @@ -204,11 +214,15 @@ impl SqliteConnection { } } - fn prepare_query + QueryId>( - &mut self, - source: &T, - ) -> QueryResult> { - let mut statement = self.cached_prepared_statement(source)?; + fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + &'a mut self, + source: &'_ T, + f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option) -> QueryResult, + ) -> QueryResult { + let raw_connection = &self.raw_connection; + let cache = &mut self.statement_cache; + let mut statement = + cache.cached_statement(source, &[], |sql| Statement::prepare(raw_connection, sql))?; let mut bind_collector = RawBytesBindCollector::::new(); source.collect_binds(&mut bind_collector, &mut ())?; @@ -218,16 +232,7 @@ impl SqliteConnection { statement.bind(tpe, value)?; } - Ok(statement) - } - - fn cached_prepared_statement + QueryId>( - &mut self, - source: &T, - ) -> QueryResult> { - let raw_connection = &self.raw_connection; - let cache = &mut self.statement_cache; - cache.cached_statement(source, &[], |sql| Statement::prepare(raw_connection, sql)) + f(statement, &mut self.current_statement) } #[doc(hidden)] diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 8136d9087acd..2260b15213d4 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -1,50 +1,56 @@ extern crate libsqlite3_sys as ffi; -use std::marker::PhantomData; use std::ptr::NonNull; use std::{slice, str}; -use crate::row::*; -use crate::sqlite::{Sqlite, SqliteType}; +use crate::sqlite::SqliteType; -use super::stmt::StatementUse; +extern "C" { + pub fn sqlite3_value_free(value: *mut ffi::sqlite3_value); + pub fn sqlite3_value_dup(value: *const ffi::sqlite3_value) -> *mut ffi::sqlite3_value; +} /// Raw sqlite value as received from the database /// /// Use existing `FromSql` implementations to convert this into -/// rust values: +/// rust values #[allow(missing_debug_implementations, missing_copy_implementations)] -pub struct SqliteValue<'a> { - value: NonNull, - p: PhantomData<&'a ()>, +#[repr(C)] +pub struct SqliteValue { + value: ffi::sqlite3_value, } -pub struct SqliteRow<'a: 'b, 'b: 'c, 'c> { - stmt: &'c StatementUse<'a, 'b>, +pub struct OwnedSqliteValue { + pub(super) value: NonNull, } -impl<'a> SqliteValue<'a> { - pub(crate) unsafe fn new(inner: *mut ffi::sqlite3_value) -> Option { - NonNull::new(inner) - .map(|value| SqliteValue { - value, - p: PhantomData, - }) - .and_then(|value| { - // We check here that the actual value represented by the inner - // `sqlite3_value` is not `NULL` (is sql meaning, not ptr meaning) - if value.is_null() { - None - } else { - Some(value) - } - }) +impl Drop for OwnedSqliteValue { + fn drop(&mut self) { + unsafe { sqlite3_value_free(self.value.as_ptr()) } + } +} + +impl SqliteValue { + pub(crate) unsafe fn new<'a>(inner: *mut ffi::sqlite3_value) -> Option<&'a Self> { + let ptr = NonNull::new(inner as *mut SqliteValue)?; + // This cast is allowed because value is the only field + // of this struct and this cast is allowed in C + we have a `#[repr(C)]` + // on this type to fore the layout to be the same + // (I(weiznich) would like to use `#[repr(transparent)]` here instead, but + // that does not work as of rust 1.48 + let value = &*ptr.as_ptr(); + // We check if the SQL value is NULL here (in the SQL meaning, not in the ptr meaning) + if value.is_null() { + None + } else { + Some(value) + } } pub(crate) fn read_text(&self) -> &str { unsafe { - let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); - let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + let ptr = ffi::sqlite3_value_text(&self.value as *const _ as *mut ffi::sqlite3_value); + let len = ffi::sqlite3_value_bytes(&self.value as *const _ as *mut ffi::sqlite3_value); let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); // The string is guaranteed to be utf8 according to // https://www.sqlite.org/c3ref/value_blob.html @@ -54,27 +60,32 @@ impl<'a> SqliteValue<'a> { pub(crate) fn read_blob(&self) -> &[u8] { unsafe { - let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); - let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + let ptr = ffi::sqlite3_value_blob(&self.value as *const _ as *mut ffi::sqlite3_value); + let len = ffi::sqlite3_value_bytes(&self.value as *const _ as *mut ffi::sqlite3_value); slice::from_raw_parts(ptr as *const u8, len as usize) } } pub(crate) fn read_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) as i32 } + unsafe { ffi::sqlite3_value_int(&self.value as *const _ as *mut ffi::sqlite3_value) as i32 } } pub(crate) fn read_long(&self) -> i64 { - unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) as i64 } + unsafe { + ffi::sqlite3_value_int64(&self.value as *const _ as *mut ffi::sqlite3_value) as i64 + } } pub(crate) fn read_double(&self) -> f64 { - unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) as f64 } + unsafe { + ffi::sqlite3_value_double(&self.value as *const _ as *mut ffi::sqlite3_value) as f64 + } } /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { - let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; + let tpe = + unsafe { ffi::sqlite3_value_type(&self.value as *const _ as *mut ffi::sqlite3_value) }; match tpe { ffi::SQLITE_TEXT => Some(SqliteType::Text), ffi::SQLITE_INTEGER => Some(SqliteType::Long), @@ -88,71 +99,21 @@ impl<'a> SqliteValue<'a> { pub(crate) fn is_null(&self) -> bool { self.value_type().is_none() } -} - -impl<'a: 'b, 'b: 'c, 'c> SqliteRow<'a, 'b, 'c> { - pub(crate) fn new(inner_statement: &'c StatementUse<'a, 'b>) -> Self { - SqliteRow { - stmt: inner_statement, - } - } -} - -impl<'a: 'b, 'b: 'c, 'c> Row<'c, Sqlite> for SqliteRow<'a, 'b, 'c> { - type Field = SqliteField<'a, 'b, 'c>; - type InnerPartialRow = Self; - - fn field_count(&self) -> usize { - self.stmt.column_count() as usize - } - - fn get(&self, idx: I) -> Option - where - Self: RowIndex, - { - let idx = self.idx(idx)?; - Some(SqliteField { - stmt: &self.stmt, - col_idx: idx as i32, - }) - } - - fn partial_row(&self, range: std::ops::Range) -> PartialRow { - PartialRow::new(self, range) - } -} - -impl<'a: 'b, 'b: 'c, 'c> RowIndex for SqliteRow<'a, 'b, 'c> { - fn idx(&self, idx: usize) -> Option { - if idx < self.stmt.column_count() as usize { - Some(idx) - } else { - None - } - } -} -impl<'a: 'b, 'b: 'c, 'c, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b, 'c> { - fn idx(&self, field_name: &'d str) -> Option { - self.stmt.index_for_column_name(field_name) + pub(crate) fn duplicate(&self) -> OwnedSqliteValue { + let value = + unsafe { sqlite3_value_dup(&self.value as *const _ as *const ffi::sqlite3_value) }; + let value = NonNull::new(value) + .expect("Sqlite documentation states this returns only null if value is null or OOM"); + OwnedSqliteValue { value } } } -pub struct SqliteField<'a: 'b, 'b: 'c, 'c> { - stmt: &'c StatementUse<'a, 'b>, - col_idx: i32, -} - -impl<'a: 'b, 'b: 'c, 'c> Field<'c, Sqlite> for SqliteField<'a, 'b, 'c> { - fn field_name(&self) -> Option<&'c str> { - self.stmt.field_name(self.col_idx) - } - - fn is_null(&self) -> bool { - self.value().is_none() - } - - fn value(&self) -> Option> { - self.stmt.value(self.col_idx) +impl OwnedSqliteValue { + pub(crate) fn duplicate(&self) -> OwnedSqliteValue { + let value = unsafe { sqlite3_value_dup(self.value.as_ptr()) }; + let value = NonNull::new(value) + .expect("Sqlite documentation states this returns only null if value is null or OOM"); + OwnedSqliteValue { value } } } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 1330d120b408..a7de568fa7e2 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -1,36 +1,102 @@ -use std::marker::PhantomData; +use std::cell::RefCell; +use std::rc::Rc; +use super::row::{PrivateSqliteRow, SqliteRow}; use super::stmt::StatementUse; -use crate::deserialize::FromSqlRow; -use crate::result::Error::DeserializationError; use crate::result::QueryResult; -use crate::sqlite::Sqlite; -pub struct StatementIterator<'a: 'b, 'b, ST, T> { - stmt: StatementUse<'a, 'b>, - _marker: PhantomData<(ST, T)>, +#[allow(missing_debug_implementations)] +pub struct StatementIterator<'a: 'b, 'b> { + inner: PrivateStatementIterator<'a, 'b>, + column_names: Option>>>, } -impl<'a: 'b, 'b, ST, T> StatementIterator<'a, 'b, ST, T> { +enum PrivateStatementIterator<'a: 'b, 'b> { + NotStarted(StatementUse<'a, 'b>), + Started(Rc>>), + TemporaryEmpty, +} + +impl<'a: 'b, 'b> StatementIterator<'a, 'b> { pub fn new(stmt: StatementUse<'a, 'b>) -> Self { - StatementIterator { - stmt, - _marker: PhantomData, + Self { + inner: PrivateStatementIterator::NotStarted(stmt), + column_names: None, } } } -impl<'a: 'b, 'b, ST, T> Iterator for StatementIterator<'a, 'b, ST, T> -where - T: FromSqlRow, -{ - type Item = QueryResult; +impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { + type Item = QueryResult>; fn next(&mut self) -> Option { - let row = match self.stmt.step() { - Ok(row) => row, - Err(e) => return Some(Err(e)), - }; - row.map(|row| T::build_from_row(&row).map_err(DeserializationError)) + use PrivateStatementIterator::*; + + match std::mem::replace(&mut self.inner, TemporaryEmpty) { + NotStarted(stmt) => match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(row)) => { + let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(row))); + self.inner = Started(inner.clone()); + Some(Ok(SqliteRow { inner })) + } + }, + Started(mut last_row) => { + // There was already at least one iteration step + // We check here if the caller already released the row value or not + // by checking if our Rc owns the data or not + if let Some(last_row_ref) = Rc::get_mut(&mut last_row) { + // We own the statement, there is no other reference here. + // This means we don't need to copy out values from the sqlite provided + // datastructures for now + // We don't need to use the runtime borrowing system of the RefCell here + // as we have a mutable reference, so all of this below is checked at compile time + if let PrivateSqliteRow::Direct(stmt) = + std::mem::replace(last_row_ref.get_mut(), PrivateSqliteRow::TemporaryEmpty) + { + match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(stmt)) => { + (*last_row_ref.get_mut()) = PrivateSqliteRow::Direct(stmt); + self.inner = Started(last_row.clone()); + Some(Ok(SqliteRow { inner: last_row })) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!() + } + } else { + // We don't own the statement. There is another existing reference, likly because + // a user stored the row in some long time container before calling next another time + // In this case we copy out the current values into a temporary store and advance + // the statement iterator internally afterwards + if let PrivateSqliteRow::Direct(stmt) = + last_row.replace_with(|inner| inner.duplicate(&mut self.column_names)) + { + match stmt.step() { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(stmt)) => { + let last_row = + Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + self.inner = Started(last_row.clone()); + Some(Ok(SqliteRow { inner: last_row })) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!() + } + } + } + TemporaryEmpty => None, + } } } diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index a8625e04892b..73b9af6e4794 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -1,17 +1,16 @@ extern crate libsqlite3_sys as ffi; -use std::ffi::{CStr, CString}; -use std::io::{stderr, Write}; -use std::os::raw as libc; -use std::ptr::{self, NonNull}; - use super::raw::RawConnection; use super::serialized_value::SerializedValue; -use super::sqlite_value::SqliteRow; use super::SqliteValue; use crate::result::Error::DatabaseError; use crate::result::*; use crate::sqlite::SqliteType; +use crate::util::OnceCell; +use std::ffi::{CStr, CString}; +use std::io::{stderr, Write}; +use std::os::raw as libc; +use std::ptr::{self, NonNull}; pub struct Statement { inner_statement: NonNull, @@ -117,36 +116,25 @@ impl Drop for Statement { } } +#[allow(missing_debug_implementations)] pub struct StatementUse<'a: 'b, 'b> { statement: &'a mut Statement, - column_names: Vec<&'b str>, - should_init_column_names: bool, + column_names: OnceCell>, } impl<'a, 'b> StatementUse<'a, 'b> { - pub(in crate::sqlite::connection) fn new( - statement: &'a mut Statement, - should_init_column_names: bool, - ) -> Self { + pub(in crate::sqlite::connection) fn new(statement: &'a mut Statement) -> Self { StatementUse { statement, - // Init with empty vector because column names - // can change till the first call to `step()` - column_names: Vec::new(), - should_init_column_names, + column_names: OnceCell::new(), } } - pub(in crate::sqlite::connection) fn run(&mut self) -> QueryResult<()> { + pub(in crate::sqlite::connection) fn run(self) -> QueryResult<()> { self.step().map(|_| ()) } - pub(in crate::sqlite::connection) fn step<'c>( - &'c mut self, - ) -> QueryResult>> - where - 'b: 'c, - { + pub(in crate::sqlite::connection) fn step<'c>(self) -> QueryResult> { let res = unsafe { match ffi::sqlite3_step(self.statement.inner_statement.as_ptr()) { ffi::SQLITE_DONE => Ok(None), @@ -154,13 +142,7 @@ impl<'a, 'b> StatementUse<'a, 'b> { _ => Err(last_error(self.statement.raw_connection())), } }?; - if self.should_init_column_names { - self.column_names = (0..self.column_count()) - .map(|idx| unsafe { self.column_name(idx) }) - .collect(); - self.should_init_column_names = false; - } - Ok(res.map(move |()| SqliteRow::new(self))) + Ok(res.map(move |()| self)) } // The returned string pointer is valid until either the prepared statement is @@ -197,27 +179,33 @@ impl<'a, 'b> StatementUse<'a, 'b> { } pub(in crate::sqlite::connection) fn index_for_column_name( - &self, + &mut self, field_name: &str, ) -> Option { - self.column_names - .iter() - .enumerate() - .find(|(_, name)| name == &&field_name) - .map(|(idx, _)| idx) + (0..self.column_count()) + .find(|idx| self.field_name(*idx) == Some(field_name)) + .map(|v| v as usize) } - pub(in crate::sqlite::connection) fn field_name<'c>(&'c self, idx: i32) -> Option<&'c str> + pub(in crate::sqlite::connection) fn field_name<'c>(&'c mut self, idx: i32) -> Option<&'c str> where 'b: 'c, { - self.column_names.get(idx as usize).copied() + if let Some(column_names) = self.column_names.get() { + return column_names.get(idx as usize).copied(); + } + let values = (0..self.column_count()) + .map(|idx| unsafe { self.column_name(idx) }) + .collect::>(); + let ret = values.get(idx as usize).copied(); + let _ = self.column_names.set(values); + ret } pub(in crate::sqlite::connection) fn value<'c>( - &'c self, + &self, idx: i32, - ) -> Option> + ) -> Option<&'a super::SqliteValue> where 'b: 'c, { diff --git a/diesel/src/sqlite/types/date_and_time/mod.rs b/diesel/src/sqlite/types/date_and_time/mod.rs index 18bdd45713f0..6583061a3a6b 100644 --- a/diesel/src/sqlite/types/date_and_time/mod.rs +++ b/diesel/src/sqlite/types/date_and_time/mod.rs @@ -15,7 +15,7 @@ mod chrono; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -38,7 +38,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -61,7 +61,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { FromSql::::from_sql(value) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index 326fc40bc5fe..612331c5e1a8 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -15,7 +15,7 @@ use crate::sql_types; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { let text = value.read_text(); Ok(text as *const _) } @@ -27,44 +27,44 @@ impl FromSql for *const str { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const [u8] { - fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(bytes: &'_ SqliteValue) -> deserialize::Result { let bytes = bytes.read_blob(); Ok(bytes as *const _) } } impl FromSql for i16 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_integer() as i16) } } impl FromSql for i32 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_integer()) } } impl FromSql for bool { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_integer() != 0) } } impl FromSql for i64 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_long()) } } impl FromSql for f32 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_double() as f32) } } impl FromSql for f64 { - fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { Ok(value.read_double()) } } diff --git a/diesel/src/sqlite/types/numeric.rs b/diesel/src/sqlite/types/numeric.rs index b3dc7aa99d23..f4b4a1a955e2 100644 --- a/diesel/src/sqlite/types/numeric.rs +++ b/diesel/src/sqlite/types/numeric.rs @@ -8,7 +8,7 @@ use crate::sqlite::connection::SqliteValue; use crate::sqlite::Sqlite; impl FromSql for BigDecimal { - fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { + fn from_sql(bytes: &'_ SqliteValue) -> deserialize::Result { let x = >::from_sql(bytes)?; BigDecimal::from_f64(x).ok_or_else(|| format!("{} is not valid decimal number ", x).into()) } diff --git a/diesel/src/util/once_cell.rs b/diesel/src/util/once_cell.rs index 790ba6592efa..2e1771a9470e 100644 --- a/diesel/src/util/once_cell.rs +++ b/diesel/src/util/once_cell.rs @@ -89,12 +89,12 @@ impl OnceCell { self.get().unwrap() } - fn get(&self) -> Option<&T> { + pub(crate) fn get(&self) -> Option<&T> { // SAFETY: Safe due to `inner`'s invariant unsafe { &*self.inner.get() }.as_ref() } - fn set(&self, value: T) -> Result<(), T> { + pub(crate) fn set(&self, value: T) -> Result<(), T> { // SAFETY: Safe because we cannot have overlapping mutable borrows let slot = unsafe { &*self.inner.get() }; if slot.is_some() { From 3ef47c5822805330e04baccef7cb82ac29b8ac95 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 17 Jun 2021 15:56:34 +0200 Subject: [PATCH 06/32] Enable the CI and fix some generic code --- .github/workflows/benches.yml | 2 +- .github/workflows/ci.yml | 3 +-- diesel_cli/src/infer_schema_internals/information_schema.rs | 6 +++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml index f1e08cfe14f8..80ee0d2a061c 100644 --- a/.github/workflows/benches.yml +++ b/.github/workflows/benches.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - backend: ["postgres", "mysql"] #, "sqlite", "mysql"] + backend: ["postgres", "sqlite", "mysql"] steps: - name: Checkout sources uses: actions/checkout@v2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9ec6d80db4ed..97eaee7eba93 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,8 +22,7 @@ jobs: fail-fast: false matrix: rust: ["stable", "beta", "nightly"] - # backend: ["postgres", "sqlite", "mysql"] - backend: ["postgres", "mysql"] + backend: ["postgres", "sqlite", "mysql"] os: [ubuntu-20.04, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 698f604eb285..767104e9d187 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -207,7 +207,7 @@ where sql_types::Text, sql_types::Nullable, sql_types::Text, - )>, + )> + 'static, { use self::information_schema::columns::dsl::*; @@ -252,7 +252,7 @@ where >, key_column_usage::ordinal_position, >: QueryFragment, - Conn::Backend: QueryMetadata, + Conn::Backend: QueryMetadata + 'static, { use self::information_schema::key_column_usage::dsl::*; use self::information_schema::table_constraints::constraint_type; @@ -281,7 +281,7 @@ pub fn load_table_names<'a, Conn>( ) -> Result, Box> where Conn: Connection, - Conn::Backend: UsesInformationSchema, + Conn::Backend: UsesInformationSchema + 'static, String: FromSql, Filter< Filter< From 2dfd5f44339b477dc8542830e78c5b8fb8aa88df Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 17 Jun 2021 16:01:21 +0200 Subject: [PATCH 07/32] Add missing file --- diesel/src/sqlite/connection/row.rs | 147 ++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 diesel/src/sqlite/connection/row.rs diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs new file mode 100644 index 000000000000..20299421d1f2 --- /dev/null +++ b/diesel/src/sqlite/connection/row.rs @@ -0,0 +1,147 @@ +use std::cell::RefCell; +use std::convert::TryFrom; +use std::rc::Rc; + +use super::sqlite_value::{OwnedSqliteValue, SqliteValue}; +use super::stmt::StatementUse; +use crate::row::{Field, PartialRow, Row, RowIndex}; +use crate::sqlite::Sqlite; + +#[allow(missing_debug_implementations)] +pub struct SqliteRow<'a, 'b> { + pub(super) inner: Rc>>, +} + +pub(super) enum PrivateSqliteRow<'a, 'b> { + Direct(StatementUse<'a, 'b>), + Duplicated { + values: Vec>, + column_names: Rc>>, + }, + TemporaryEmpty, +} + +impl<'a, 'b> PrivateSqliteRow<'a, 'b> { + pub(super) fn duplicate(&mut self, column_names: &mut Option>>>) -> Self { + match self { + PrivateSqliteRow::Direct(stmt) => { + let column_names = if let Some(column_names) = column_names { + column_names.clone() + } else { + let c = Rc::new( + (0..stmt.column_count()) + .map(|idx| stmt.field_name(idx).map(|s| s.to_owned())) + .collect::>(), + ); + *column_names = Some(c.clone()); + c + }; + PrivateSqliteRow::Duplicated { + values: (0..stmt.column_count()) + .map(|idx| stmt.value(idx).map(|v| v.duplicate())) + .collect(), + column_names, + } + } + PrivateSqliteRow::Duplicated { + values, + column_names, + } => PrivateSqliteRow::Duplicated { + values: values + .iter() + .map(|v| v.as_ref().map(|v| v.duplicate())) + .collect(), + column_names: column_names.clone(), + }, + PrivateSqliteRow::TemporaryEmpty => PrivateSqliteRow::TemporaryEmpty, + } + } +} + +impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { + type Field = SqliteField<'a, 'b>; + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + match &*self.inner.borrow() { + PrivateSqliteRow::Direct(stmt) => stmt.column_count() as usize, + PrivateSqliteRow::Duplicated { values, .. } => values.len(), + PrivateSqliteRow::TemporaryEmpty => unreachable!(), + } + } + + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(SqliteField { + row: SqliteRow { + inner: self.inner.clone(), + }, + col_idx: i32::try_from(idx).ok()?, + }) + } + + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) + } +} + +impl<'a: 'b, 'b> RowIndex for SqliteRow<'a, 'b> { + fn idx(&self, idx: usize) -> Option { + match &*self.inner.borrow() { + PrivateSqliteRow::Duplicated { .. } | PrivateSqliteRow::Direct(_) + if idx < self.field_count() => + { + Some(idx) + } + PrivateSqliteRow::Direct(_) | PrivateSqliteRow::Duplicated { .. } => None, + PrivateSqliteRow::TemporaryEmpty => unreachable!(), + } + } +} + +impl<'a: 'b, 'b, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b> { + fn idx(&self, field_name: &'d str) -> Option { + match &mut *self.inner.borrow_mut() { + PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name), + PrivateSqliteRow::Duplicated { column_names, .. } => column_names + .iter() + .position(|n| n.as_ref().map(|s| s as &str) == Some(field_name)), + PrivateSqliteRow::TemporaryEmpty => { + unreachable!() + } + } + } +} + +#[allow(missing_debug_implementations)] +pub struct SqliteField<'a, 'b> { + row: SqliteRow<'a, 'b>, + col_idx: i32, +} + +impl<'a: 'b, 'b> Field for SqliteField<'a, 'b> { + fn field_name(&self) -> Option<&str> { + todo!() + // self.stmt.field_name(self.col_idx) + } + + fn is_null(&self) -> bool { + self.value().is_none() + } + + fn value<'d>(&'d self) -> Option> { + match &*self.row.inner.borrow() { + PrivateSqliteRow::Direct(stmt) => stmt.value(self.col_idx), + PrivateSqliteRow::Duplicated { values, .. } => { + values.get(self.col_idx as usize).and_then(|v| { + v.as_ref() + .and_then(|v| unsafe { SqliteValue::new(v.value.as_ptr()) }) + }) + } + PrivateSqliteRow::TemporaryEmpty => unreachable!(), + } + } +} From dd0d13a04d145a4e8d6d7350eb28e1b367cce33b Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 18 Jun 2021 17:00:43 +0200 Subject: [PATCH 08/32] Inline field count for sqlite rows, as this seems to be the performance bottle neck --- diesel/src/sqlite/connection/row.rs | 22 ++++++++----------- .../sqlite/connection/statement_iterator.rs | 19 +++++++++++----- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index 20299421d1f2..aebf165b9c23 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -6,10 +6,12 @@ use super::sqlite_value::{OwnedSqliteValue, SqliteValue}; use super::stmt::StatementUse; use crate::row::{Field, PartialRow, Row, RowIndex}; use crate::sqlite::Sqlite; +use crate::util::OnceCell; #[allow(missing_debug_implementations)] pub struct SqliteRow<'a, 'b> { pub(super) inner: Rc>>, + pub(super) field_count: usize, } pub(super) enum PrivateSqliteRow<'a, 'b> { @@ -63,11 +65,7 @@ impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { type InnerPartialRow = Self; fn field_count(&self) -> usize { - match &*self.inner.borrow() { - PrivateSqliteRow::Direct(stmt) => stmt.column_count() as usize, - PrivateSqliteRow::Duplicated { values, .. } => values.len(), - PrivateSqliteRow::TemporaryEmpty => unreachable!(), - } + self.field_count } fn get(&self, idx: I) -> Option @@ -78,6 +76,7 @@ impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { Some(SqliteField { row: SqliteRow { inner: self.inner.clone(), + field_count: self.field_count, }, col_idx: i32::try_from(idx).ok()?, }) @@ -89,15 +88,12 @@ impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { } impl<'a: 'b, 'b> RowIndex for SqliteRow<'a, 'b> { + #[inline] fn idx(&self, idx: usize) -> Option { - match &*self.inner.borrow() { - PrivateSqliteRow::Duplicated { .. } | PrivateSqliteRow::Direct(_) - if idx < self.field_count() => - { - Some(idx) - } - PrivateSqliteRow::Direct(_) | PrivateSqliteRow::Duplicated { .. } => None, - PrivateSqliteRow::TemporaryEmpty => unreachable!(), + if idx < self.field_count { + Some(idx) + } else { + None } } } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index a7de568fa7e2..25462a8191f6 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -36,10 +36,11 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { NotStarted(stmt) => match stmt.step() { Err(e) => Some(Err(e)), Ok(None) => None, - Ok(Some(row)) => { - let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(row))); + Ok(Some(stmt)) => { + let field_count = stmt.column_count() as usize; + let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); self.inner = Started(inner.clone()); - Some(Ok(SqliteRow { inner })) + Some(Ok(SqliteRow { inner, field_count })) } }, Started(mut last_row) => { @@ -59,9 +60,13 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { Err(e) => Some(Err(e)), Ok(None) => None, Ok(Some(stmt)) => { + let field_count = stmt.column_count() as usize; (*last_row_ref.get_mut()) = PrivateSqliteRow::Direct(stmt); self.inner = Started(last_row.clone()); - Some(Ok(SqliteRow { inner: last_row })) + Some(Ok(SqliteRow { + inner: last_row, + field_count, + })) } } } else { @@ -82,10 +87,14 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { Err(e) => Some(Err(e)), Ok(None) => None, Ok(Some(stmt)) => { + let field_count = stmt.column_count() as usize; let last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); self.inner = Started(last_row.clone()); - Some(Ok(SqliteRow { inner: last_row })) + Some(Ok(SqliteRow { + inner: last_row, + field_count, + })) } } } else { From 6b4000ade35013cdba464cdd65cebddb69238e11 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 18 Jun 2021 17:01:26 +0200 Subject: [PATCH 09/32] Fix todo --- diesel/src/sqlite/connection/row.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index aebf165b9c23..6311e513fe7f 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -79,6 +79,7 @@ impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { field_count: self.field_count, }, col_idx: i32::try_from(idx).ok()?, + field_name: OnceCell::new(), }) } @@ -116,12 +117,23 @@ impl<'a: 'b, 'b, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b> { pub struct SqliteField<'a, 'b> { row: SqliteRow<'a, 'b>, col_idx: i32, + field_name: OnceCell>, } impl<'a: 'b, 'b> Field for SqliteField<'a, 'b> { fn field_name(&self) -> Option<&str> { - todo!() - // self.stmt.field_name(self.col_idx) + self.field_name + .get_or_init(|| match &mut *self.row.inner.borrow_mut() { + PrivateSqliteRow::Direct(stmt) => { + stmt.field_name(self.col_idx).map(|s| s.to_owned()) + } + PrivateSqliteRow::Duplicated { column_names, .. } => column_names + .get(self.col_idx as usize) + .and_then(|n| n.clone()), + PrivateSqliteRow::TemporaryEmpty => unreachable!(), + }) + .as_ref() + .map(|s| s as &str) } fn is_null(&self) -> bool { From 569a9d7ec75fb57c6e96f6390db0a8d1eb289c5d Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 18 Jun 2021 17:06:56 +0200 Subject: [PATCH 10/32] Apply the iterator change for the r2d2 connection --- diesel/src/r2d2.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 932e75e6112a..a8bc7e5861ed 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -16,12 +16,11 @@ use std::convert::Into; use std::fmt; use std::marker::PhantomData; -use crate::connection::{SimpleConnection, TransactionManager}; -use crate::deserialize::FromSqlRow; +use crate::backend::Backend; +use crate::connection::{IterableConnection, SimpleConnection, TransactionManager}; use crate::expression::QueryMetadata; use crate::prelude::*; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::query_dsl::load_dsl::CompatibleType; /// An r2d2 connection manager for use with Diesel. /// @@ -131,6 +130,16 @@ where } } +impl<'a, DB, M> IterableConnection<'a, DB> for PooledConnection +where + M: ManageConnection, + M::Connection: Connection, + DB: Backend, +{ + type Cursor = >::Cursor; + type Row = >::Row; +} + impl Connection for PooledConnection where M: ManageConnection, @@ -150,12 +159,13 @@ where (&mut **self).execute(query) } - fn load(&mut self, source: T) -> QueryResult> + fn load<'a, T>( + &'a mut self, + source: T, + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, - T::SqlType: CompatibleType, - U: FromSqlRow, Self::Backend: QueryMetadata, { (&mut **self).load(source) From b3b4e45fefcbb0de0ca9768df52a6d99539cb826 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 18 Jun 2021 17:08:55 +0200 Subject: [PATCH 11/32] Remove experimentall code --- diesel/src/expression/array_comparison.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index afc2384a9c4b..7c8a3c8387c5 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -119,17 +119,6 @@ where } } -// impl AsInExpression for [T; N] -// where T: AsExpression, -// ST: SqlType + TypedExpressionType -// { -// type InExpression = StaticMany; - -// fn as_in_expression(self) -> Self::InExpression { -// todo!() -// } -// } - pub trait MaybeEmpty { fn is_empty(&self) -> bool; } @@ -160,9 +149,6 @@ where } } -// #[derive(Debug, Clone, ValidGrouping)] -// pub struct StaticMany([T; N]); - #[derive(Debug, Clone, ValidGrouping)] pub struct Many(Vec); From e5812852b801994ae3525d6223f3e2f5ca1ba146 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 24 Jun 2021 15:10:05 +0200 Subject: [PATCH 12/32] Fix sqlite deserialization --- diesel/src/row.rs | 8 +- diesel/src/sqlite/backend.rs | 21 +- diesel/src/sqlite/connection/functions.rs | 12 +- diesel/src/sqlite/connection/mod.rs | 4 +- diesel/src/sqlite/connection/row.rs | 77 ++++--- diesel/src/sqlite/connection/sqlite_value.rs | 196 ++++++++++++------ .../sqlite/connection/statement_iterator.rs | 39 ++-- diesel/src/sqlite/connection/stmt.rs | 83 +++++--- .../src/sqlite/types/date_and_time/chrono.rs | 99 +++++---- diesel/src/sqlite/types/date_and_time/mod.rs | 27 +-- diesel/src/sqlite/types/mod.rs | 34 ++- diesel/src/sqlite/types/numeric.rs | 2 +- ...ay_expressions_must_be_correct_type.stderr | 12 +- ...array_expressions_must_be_same_type.stderr | 24 +-- .../fail/array_only_usable_with_pg.stderr | 30 --- .../tests/fail/selectable.stderr | 24 --- diesel_tests/tests/deserialization.rs | 93 ++++++++- 17 files changed, 474 insertions(+), 311 deletions(-) diff --git a/diesel/src/row.rs b/diesel/src/row.rs index 63fcc887016a..95177bba8b31 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -34,7 +34,7 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// /// * Crates implementing custom backends should provide their own type /// meeting the required trait bounds - type Field: Field; + type Field: Field<'a, DB>; /// Return type of `PartialRow` /// @@ -63,7 +63,7 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// /// This trait allows retrieving information on the name of the colum and on the value of the /// field. -pub trait Field { +pub trait Field<'a, DB: Backend> { /// The name of the current field /// /// Returns `None` if it's an unnamed field @@ -71,7 +71,9 @@ pub trait Field { /// Get the value representing the current field in the raw representation /// as it is transmitted by the database - fn value<'a>(&'a self) -> Option>; + fn value<'b>(&'b self) -> Option> + where + 'a: 'b; /// Checks whether this field is null or not. fn is_null(&self) -> bool { diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 41a89447e051..6f630ef0a754 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -38,6 +38,25 @@ pub enum SqliteType { Long, } +impl SqliteType { + pub(super) fn from_raw_sqlite(tpe: i32) -> Option { + use libsqlite3_sys as ffi; + + match tpe { + ffi::SQLITE_TEXT => Some(SqliteType::Text), + ffi::SQLITE_INTEGER => Some(SqliteType::Long), + ffi::SQLITE_FLOAT => Some(SqliteType::Double), + ffi::SQLITE_BLOB => Some(SqliteType::Binary), + ffi::SQLITE_NULL => None, + _ => unreachable!( + "Sqlite's documentation state that this case ({}) is not reachable. \ + If you ever see this error message please open an issue at \ + https://github.com/diesel-rs/diesel." + ), + } + } +} + impl Backend for Sqlite { type QueryBuilder = SqliteQueryBuilder; type BindCollector = RawBytesBindCollector; @@ -45,7 +64,7 @@ impl Backend for Sqlite { } impl<'a> HasRawValue<'a> for Sqlite { - type RawValue = &'a SqliteValue; + type RawValue = SqliteValue<'a, 'a>; } impl TypeMetadata for Sqlite { diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index b22cdf41570b..c5d6c7e528e4 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -2,7 +2,7 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; -use super::{Sqlite, SqliteAggregateFunction, SqliteValue}; +use super::{Sqlite, SqliteAggregateFunction}; use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; use crate::row::{Field, PartialRow, Row, RowIndex}; @@ -174,7 +174,7 @@ struct FunctionArgument<'a> { p: PhantomData<&'a ()>, } -impl<'a> Field for FunctionArgument<'a> { +impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { fn field_name(&self) -> Option<&str> { None } @@ -183,7 +183,11 @@ impl<'a> Field for FunctionArgument<'a> { self.value().is_none() } - fn value<'b>(&'b self) -> Option> { - unsafe { SqliteValue::new(self.arg) } + fn value<'b>(&'b self) -> Option> + where + 'a: 'b, + { + todo!() + // unsafe { SqliteValue::new(self.arg) } } } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 2d4932d22116..43001752670f 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -56,8 +56,8 @@ impl SimpleConnection for SqliteConnection { } impl<'a> IterableConnection<'a, Sqlite> for SqliteConnection { - type Cursor = StatementIterator<'a, 'a>; - type Row = self::row::SqliteRow<'a, 'a>; + type Cursor = StatementIterator<'a>; + type Row = self::row::SqliteRow<'a>; } impl Connection for SqliteConnection { diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index 6311e513fe7f..380db103a7fd 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -9,13 +9,13 @@ use crate::sqlite::Sqlite; use crate::util::OnceCell; #[allow(missing_debug_implementations)] -pub struct SqliteRow<'a, 'b> { - pub(super) inner: Rc>>, +pub struct SqliteRow<'a> { + pub(super) inner: Rc>>, pub(super) field_count: usize, } -pub(super) enum PrivateSqliteRow<'a, 'b> { - Direct(StatementUse<'a, 'b>), +pub(super) enum PrivateSqliteRow<'a> { + Direct(StatementUse<'a>), Duplicated { values: Vec>, column_names: Rc>>, @@ -23,7 +23,7 @@ pub(super) enum PrivateSqliteRow<'a, 'b> { TemporaryEmpty, } -impl<'a, 'b> PrivateSqliteRow<'a, 'b> { +impl<'a> PrivateSqliteRow<'a> { pub(super) fn duplicate(&mut self, column_names: &mut Option>>>) -> Self { match self { PrivateSqliteRow::Direct(stmt) => { @@ -40,7 +40,7 @@ impl<'a, 'b> PrivateSqliteRow<'a, 'b> { }; PrivateSqliteRow::Duplicated { values: (0..stmt.column_count()) - .map(|idx| stmt.value(idx).map(|v| v.duplicate())) + .map(|idx| stmt.copy_value(idx)) .collect(), column_names, } @@ -60,8 +60,8 @@ impl<'a, 'b> PrivateSqliteRow<'a, 'b> { } } -impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { - type Field = SqliteField<'a, 'b>; +impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { + type Field = SqliteField<'a>; type InnerPartialRow = Self; fn field_count(&self) -> usize { @@ -88,18 +88,19 @@ impl<'a, 'b> Row<'b, Sqlite> for SqliteRow<'a, 'b> { } } -impl<'a: 'b, 'b> RowIndex for SqliteRow<'a, 'b> { - #[inline] +impl<'a> RowIndex for SqliteRow<'a> { + #[inline(always)] fn idx(&self, idx: usize) -> Option { - if idx < self.field_count { - Some(idx) - } else { - None - } + Some(idx) + // if idx < self.field_count { + // Some(idx) + // } else { + // None + // } } } -impl<'a: 'b, 'b, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b> { +impl<'a, 'd> RowIndex<&'d str> for SqliteRow<'a> { fn idx(&self, field_name: &'d str) -> Option { match &mut *self.inner.borrow_mut() { PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name), @@ -107,20 +108,27 @@ impl<'a: 'b, 'b, 'd> RowIndex<&'d str> for SqliteRow<'a, 'b> { .iter() .position(|n| n.as_ref().map(|s| s as &str) == Some(field_name)), PrivateSqliteRow::TemporaryEmpty => { - unreachable!() + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) } } } } #[allow(missing_debug_implementations)] -pub struct SqliteField<'a, 'b> { - row: SqliteRow<'a, 'b>, - col_idx: i32, +pub struct SqliteField<'a> { + pub(super) row: SqliteRow<'a>, + pub(super) col_idx: i32, field_name: OnceCell>, } -impl<'a: 'b, 'b> Field for SqliteField<'a, 'b> { +impl<'a> Field<'a, Sqlite> for SqliteField<'a> { fn field_name(&self) -> Option<&str> { self.field_name .get_or_init(|| match &mut *self.row.inner.borrow_mut() { @@ -130,7 +138,16 @@ impl<'a: 'b, 'b> Field for SqliteField<'a, 'b> { PrivateSqliteRow::Duplicated { column_names, .. } => column_names .get(self.col_idx as usize) .and_then(|n| n.clone()), - PrivateSqliteRow::TemporaryEmpty => unreachable!(), + PrivateSqliteRow::TemporaryEmpty => { + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } }) .as_ref() .map(|s| s as &str) @@ -140,16 +157,10 @@ impl<'a: 'b, 'b> Field for SqliteField<'a, 'b> { self.value().is_none() } - fn value<'d>(&'d self) -> Option> { - match &*self.row.inner.borrow() { - PrivateSqliteRow::Direct(stmt) => stmt.value(self.col_idx), - PrivateSqliteRow::Duplicated { values, .. } => { - values.get(self.col_idx as usize).and_then(|v| { - v.as_ref() - .and_then(|v| unsafe { SqliteValue::new(v.value.as_ptr()) }) - }) - } - PrivateSqliteRow::TemporaryEmpty => unreachable!(), - } + fn value<'d>(&'d self) -> Option> + where + 'a: 'd, + { + SqliteValue::new(self.row.inner.borrow(), self.col_idx) } } diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 2260b15213d4..d62bd57d4923 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -1,10 +1,13 @@ extern crate libsqlite3_sys as ffi; +use std::cell::Ref; use std::ptr::NonNull; use std::{slice, str}; use crate::sqlite::SqliteType; +use super::row::PrivateSqliteRow; + extern "C" { pub fn sqlite3_value_free(value: *mut ffi::sqlite3_value); pub fn sqlite3_value_dup(value: *const ffi::sqlite3_value) -> *mut ffi::sqlite3_value; @@ -15,9 +18,9 @@ extern "C" { /// Use existing `FromSql` implementations to convert this into /// rust values #[allow(missing_debug_implementations, missing_copy_implementations)] -#[repr(C)] -pub struct SqliteValue { - value: ffi::sqlite3_value, +pub struct SqliteValue<'a, 'b> { + row: Ref<'a, PrivateSqliteRow<'b>>, + col_idx: i32, } pub struct OwnedSqliteValue { @@ -30,90 +33,167 @@ impl Drop for OwnedSqliteValue { } } -impl SqliteValue { - pub(crate) unsafe fn new<'a>(inner: *mut ffi::sqlite3_value) -> Option<&'a Self> { - let ptr = NonNull::new(inner as *mut SqliteValue)?; - // This cast is allowed because value is the only field - // of this struct and this cast is allowed in C + we have a `#[repr(C)]` - // on this type to fore the layout to be the same - // (I(weiznich) would like to use `#[repr(transparent)]` here instead, but - // that does not work as of rust 1.48 - let value = &*ptr.as_ptr(); - // We check if the SQL value is NULL here (in the SQL meaning, not in the ptr meaning) - if value.is_null() { - None - } else { - Some(value) +impl<'a, 'b> SqliteValue<'a, 'b> { + pub(super) fn new(row: Ref<'a, PrivateSqliteRow<'b>>, col_idx: i32) -> Option { + match &*row { + PrivateSqliteRow::Direct(stmt) => { + if stmt.column_type(col_idx).is_none() { + return None; + } + } + PrivateSqliteRow::Duplicated { values, .. } => { + if values + .get(col_idx as usize) + .and_then(|v| v.as_ref()) + .is_none() + { + return None; + } + } + PrivateSqliteRow::TemporaryEmpty => todo!(), } + Some(Self { row, col_idx }) } - pub(crate) fn read_text(&self) -> &str { - unsafe { - let ptr = ffi::sqlite3_value_text(&self.value as *const _ as *mut ffi::sqlite3_value); - let len = ffi::sqlite3_value_bytes(&self.value as *const _ as *mut ffi::sqlite3_value); - let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); - // The string is guaranteed to be utf8 according to - // https://www.sqlite.org/c3ref/value_blob.html - str::from_utf8_unchecked(bytes) + pub(crate) fn parse_string(&self, f: impl FnOnce(&str) -> R) -> R { + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => f(stmt.read_column_as_str(self.col_idx)), + super::row::PrivateSqliteRow::Duplicated { values, .. } => f(values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .read_as_str()), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), } } - pub(crate) fn read_blob(&self) -> &[u8] { - unsafe { - let ptr = ffi::sqlite3_value_blob(&self.value as *const _ as *mut ffi::sqlite3_value); - let len = ffi::sqlite3_value_bytes(&self.value as *const _ as *mut ffi::sqlite3_value); - slice::from_raw_parts(ptr as *const u8, len as usize) + pub(crate) fn read_text(&self) -> String { + self.parse_string(|s| s.to_owned()) + } + + pub(crate) fn read_blob(&self) -> Vec { + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => { + stmt.read_column_as_blob(self.col_idx).to_owned() + } + super::row::PrivateSqliteRow::Duplicated { values, .. } => values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .read_as_blob() + .to_owned(), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), } } pub(crate) fn read_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(&self.value as *const _ as *mut ffi::sqlite3_value) as i32 } + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_integer(self.col_idx), + super::row::PrivateSqliteRow::Duplicated { values, .. } => values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .read_as_integer(), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), + } } pub(crate) fn read_long(&self) -> i64 { - unsafe { - ffi::sqlite3_value_int64(&self.value as *const _ as *mut ffi::sqlite3_value) as i64 + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_long(self.col_idx), + super::row::PrivateSqliteRow::Duplicated { values, .. } => values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .read_as_long(), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), } } pub(crate) fn read_double(&self) -> f64 { - unsafe { - ffi::sqlite3_value_double(&self.value as *const _ as *mut ffi::sqlite3_value) as f64 + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_double(self.col_idx), + super::row::PrivateSqliteRow::Duplicated { values, .. } => values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .read_as_double(), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), } } /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { - let tpe = - unsafe { ffi::sqlite3_value_type(&self.value as *const _ as *mut ffi::sqlite3_value) }; - match tpe { - ffi::SQLITE_TEXT => Some(SqliteType::Text), - ffi::SQLITE_INTEGER => Some(SqliteType::Long), - ffi::SQLITE_FLOAT => Some(SqliteType::Double), - ffi::SQLITE_BLOB => Some(SqliteType::Binary), - ffi::SQLITE_NULL => None, - _ => unreachable!("Sqlite docs saying this is not reachable"), + match &*self.row { + super::row::PrivateSqliteRow::Direct(stmt) => stmt.column_type(self.col_idx), + super::row::PrivateSqliteRow::Duplicated { values, .. } => values + .get(self.col_idx as usize) + .and_then(|o| o.as_ref()) + .expect("We checked that this value is not null") + .value_type(), + super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), } } +} - pub(crate) fn is_null(&self) -> bool { - self.value_type().is_none() - } +impl OwnedSqliteValue { + pub(super) fn copy_from_ptr(ptr: *mut ffi::sqlite3_value) -> Option { + let tpe = unsafe { ffi::sqlite3_value_type(ptr) }; + if SqliteType::from_raw_sqlite(tpe).is_none() { + return None; + } - pub(crate) fn duplicate(&self) -> OwnedSqliteValue { - let value = - unsafe { sqlite3_value_dup(&self.value as *const _ as *const ffi::sqlite3_value) }; - let value = NonNull::new(value) - .expect("Sqlite documentation states this returns only null if value is null or OOM"); - OwnedSqliteValue { value } + let value = unsafe { sqlite3_value_dup(ptr) }; + + Some(Self { + value: NonNull::new(value)?, + }) } -} -impl OwnedSqliteValue { - pub(crate) fn duplicate(&self) -> OwnedSqliteValue { + pub(super) fn duplicate(&self) -> OwnedSqliteValue { + // self.value is a `NonNull` ptr so this cannot be null let value = unsafe { sqlite3_value_dup(self.value.as_ptr()) }; - let value = NonNull::new(value) - .expect("Sqlite documentation states this returns only null if value is null or OOM"); + let value = NonNull::new(value).expect( + "Sqlite documentation states this returns only null if value is null \ + or OOM. If you ever see this panic message please open an issue at \ + https://github.com/diesel-rs/diesel.", + ); OwnedSqliteValue { value } } + + fn read_as_str(&self) -> &str { + unsafe { + let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); + // The string is guaranteed to be utf8 according to + // https://www.sqlite.org/c3ref/value_blob.html + str::from_utf8_unchecked(bytes) + } + } + + fn read_as_blob(&self) -> &[u8] { + unsafe { + let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + slice::from_raw_parts(ptr as *const u8, len as usize) + } + } + + fn read_as_integer(&self) -> i32 { + unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) } + } + + fn read_as_long(&self) -> i64 { + unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) } + } + + fn read_as_double(&self) -> f64 { + unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) } + } + + fn value_type(&self) -> Option { + let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; + SqliteType::from_raw_sqlite(tpe) + } } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 25462a8191f6..993af2feaf02 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -6,28 +6,30 @@ use super::stmt::StatementUse; use crate::result::QueryResult; #[allow(missing_debug_implementations)] -pub struct StatementIterator<'a: 'b, 'b> { - inner: PrivateStatementIterator<'a, 'b>, +pub struct StatementIterator<'a> { + inner: PrivateStatementIterator<'a>, column_names: Option>>>, + field_count: usize, } -enum PrivateStatementIterator<'a: 'b, 'b> { - NotStarted(StatementUse<'a, 'b>), - Started(Rc>>), +enum PrivateStatementIterator<'a> { + NotStarted(StatementUse<'a>), + Started(Rc>>), TemporaryEmpty, } -impl<'a: 'b, 'b> StatementIterator<'a, 'b> { - pub fn new(stmt: StatementUse<'a, 'b>) -> Self { +impl<'a> StatementIterator<'a> { + pub fn new(stmt: StatementUse<'a>) -> Self { Self { inner: PrivateStatementIterator::NotStarted(stmt), column_names: None, + field_count: 0, } } } -impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { - type Item = QueryResult>; +impl<'a> Iterator for StatementIterator<'a> { + type Item = QueryResult>; fn next(&mut self) -> Option { use PrivateStatementIterator::*; @@ -38,6 +40,7 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { Ok(None) => None, Ok(Some(stmt)) => { let field_count = stmt.column_count() as usize; + self.field_count = field_count; let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); self.inner = Started(inner.clone()); Some(Ok(SqliteRow { inner, field_count })) @@ -60,7 +63,7 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { Err(e) => Some(Err(e)), Ok(None) => None, Ok(Some(stmt)) => { - let field_count = stmt.column_count() as usize; + let field_count = self.field_count; (*last_row_ref.get_mut()) = PrivateSqliteRow::Direct(stmt); self.inner = Started(last_row.clone()); Some(Ok(SqliteRow { @@ -73,7 +76,12 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { // any other state than `PrivateSqliteRow::Direct` is invalid here // and should not happen. If this ever happens this is a logic error // in the code above - unreachable!() + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) } } else { // We don't own the statement. There is another existing reference, likly because @@ -87,7 +95,7 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { Err(e) => Some(Err(e)), Ok(None) => None, Ok(Some(stmt)) => { - let field_count = stmt.column_count() as usize; + let field_count = self.field_count; let last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); self.inner = Started(last_row.clone()); @@ -101,7 +109,12 @@ impl<'a: 'b, 'b> Iterator for StatementIterator<'a, 'b> { // any other state than `PrivateSqliteRow::Direct` is invalid here // and should not happen. If this ever happens this is a logic error // in the code above - unreachable!() + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) } } } diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 73b9af6e4794..5e348596eac0 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -2,11 +2,12 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; -use super::SqliteValue; +use super::sqlite_value::OwnedSqliteValue; use crate::result::Error::DatabaseError; use crate::result::*; use crate::sqlite::SqliteType; use crate::util::OnceCell; +use core::slice; use std::ffi::{CStr, CString}; use std::io::{stderr, Write}; use std::os::raw as libc; @@ -117,12 +118,12 @@ impl Drop for Statement { } #[allow(missing_debug_implementations)] -pub struct StatementUse<'a: 'b, 'b> { +pub struct StatementUse<'a> { statement: &'a mut Statement, - column_names: OnceCell>, + column_names: OnceCell>, } -impl<'a, 'b> StatementUse<'a, 'b> { +impl<'a> StatementUse<'a> { pub(in crate::sqlite::connection) fn new(statement: &'a mut Statement) -> Self { StatementUse { statement, @@ -152,10 +153,7 @@ impl<'a, 'b> StatementUse<'a, 'b> { // on the same column. // // https://sqlite.org/c3ref/column_name.html - // - // As result of this requirements: Never use that function outside of `ColumnInformation` - // and never use `ColumnInformation` outside of `StatementUse` - unsafe fn column_name(&mut self, idx: i32) -> &'b str { + unsafe fn column_name(&mut self, idx: i32) -> *const str { let name = { let column_name = ffi::sqlite3_column_name(self.statement.inner_statement.as_ptr(), idx); @@ -171,52 +169,77 @@ impl<'a, 'b> StatementUse<'a, 'b> { If you see this error message something has gone \ horribliy wrong. Please open an issue at the \ diesel repository.", - ) + ) as *const str } - pub(in crate::sqlite::connection) fn column_count(&self) -> i32 { + pub(super) fn column_count(&self) -> i32 { unsafe { ffi::sqlite3_column_count(self.statement.inner_statement.as_ptr()) } } - pub(in crate::sqlite::connection) fn index_for_column_name( - &mut self, - field_name: &str, - ) -> Option { + pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option { (0..self.column_count()) .find(|idx| self.field_name(*idx) == Some(field_name)) .map(|v| v as usize) } - pub(in crate::sqlite::connection) fn field_name<'c>(&'c mut self, idx: i32) -> Option<&'c str> - where - 'b: 'c, - { + pub(super) fn field_name<'c>(&'c mut self, idx: i32) -> Option<&'c str> { if let Some(column_names) = self.column_names.get() { - return column_names.get(idx as usize).copied(); + return column_names + .get(idx as usize) + .and_then(|c| unsafe { c.as_ref() }); } let values = (0..self.column_count()) .map(|idx| unsafe { self.column_name(idx) }) .collect::>(); let ret = values.get(idx as usize).copied(); let _ = self.column_names.set(values); - ret + ret.and_then(|p| unsafe { p.as_ref() }) + } + + pub(super) fn column_type(&self, idx: i32) -> Option { + let tpe = unsafe { ffi::sqlite3_column_type(self.statement.inner_statement.as_ptr(), idx) }; + SqliteType::from_raw_sqlite(tpe) + } + + pub(super) fn read_column_as_str(&self, idx: i32) -> &str { + unsafe { + let ptr = ffi::sqlite3_column_text(self.statement.inner_statement.as_ptr(), idx); + let len = ffi::sqlite3_column_bytes(self.statement.inner_statement.as_ptr(), idx); + let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); + // The string is guaranteed to be utf8 according to + // https://www.sqlite.org/c3ref/value_blob.html + std::str::from_utf8_unchecked(bytes) + } } - pub(in crate::sqlite::connection) fn value<'c>( - &self, - idx: i32, - ) -> Option<&'a super::SqliteValue> - where - 'b: 'c, - { + pub(super) fn read_column_as_blob(&self, idx: i32) -> &[u8] { unsafe { - let ptr = ffi::sqlite3_column_value(self.statement.inner_statement.as_ptr(), idx); - SqliteValue::new(ptr) + let ptr = ffi::sqlite3_column_blob(self.statement.inner_statement.as_ptr(), idx); + let len = ffi::sqlite3_column_bytes(self.statement.inner_statement.as_ptr(), idx); + slice::from_raw_parts(ptr as *const u8, len as usize) } } + + pub(super) fn read_column_as_integer(&self, idx: i32) -> i32 { + unsafe { ffi::sqlite3_column_int(self.statement.inner_statement.as_ptr(), idx) } + } + + pub(super) fn read_column_as_long(&self, idx: i32) -> i64 { + unsafe { ffi::sqlite3_column_int64(self.statement.inner_statement.as_ptr(), idx) } + } + + pub(super) fn read_column_as_double(&self, idx: i32) -> f64 { + unsafe { ffi::sqlite3_column_double(self.statement.inner_statement.as_ptr(), idx) } + } + + pub(super) fn copy_value(&self, idx: i32) -> Option { + let ptr = + unsafe { ffi::sqlite3_column_value(self.statement.inner_statement.as_ptr(), idx) }; + OwnedSqliteValue::copy_from_ptr(ptr) + } } -impl<'a, 'b> Drop for StatementUse<'a, 'b> { +impl<'a> Drop for StatementUse<'a> { fn drop(&mut self) { self.statement.reset(); } diff --git a/diesel/src/sqlite/types/date_and_time/chrono.rs b/diesel/src/sqlite/types/date_and_time/chrono.rs index 41136050e067..fa8867a031cc 100644 --- a/diesel/src/sqlite/types/date_and_time/chrono.rs +++ b/diesel/src/sqlite/types/date_and_time/chrono.rs @@ -13,9 +13,9 @@ const SQLITE_DATE_FORMAT: &str = "%F"; impl FromSql for NaiveDate { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - Self::parse_from_str(text, SQLITE_DATE_FORMAT).map_err(Into::into) + value + .parse_string(|s| Self::parse_from_str(s, SQLITE_DATE_FORMAT)) + .map_err(Into::into) } } @@ -28,21 +28,21 @@ impl ToSql for NaiveDate { impl FromSql for NaiveTime { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - let valid_time_formats = &[ - // Most likely - "%T%.f", // All other valid formats in order of documentation - "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", - ]; - - for format in valid_time_formats { - if let Ok(time) = Self::parse_from_str(text, format) { - return Ok(time); + value.parse_string(|text| { + let valid_time_formats = &[ + // Most likely + "%T%.f", // All other valid formats in order of documentation + "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", + ]; + + for format in valid_time_formats { + if let Ok(time) = Self::parse_from_str(text, format) { + return Ok(time); + } } - } - Err(format!("Invalid time {}", text).into()) + Err(format!("Invalid time {}", text).into()) + }) } } @@ -55,44 +55,43 @@ impl ToSql for NaiveTime { impl FromSql for NaiveDateTime { fn from_sql(value: backend::RawValue) -> deserialize::Result { - let text_ptr = <*const str as FromSql>::from_sql(value)?; - let text = unsafe { &*text_ptr }; - - let sqlite_datetime_formats = &[ - // Most likely format - "%F %T%.f", - // Other formats in order of appearance in docs - "%F %R", - "%F %RZ", - "%F %R%:z", - "%F %T%.fZ", - "%F %T%.f%:z", - "%FT%R", - "%FT%RZ", - "%FT%R%:z", - "%FT%T%.f", - "%FT%T%.fZ", - "%FT%T%.f%:z", - ]; - - for format in sqlite_datetime_formats { - if let Ok(dt) = Self::parse_from_str(text, format) { - return Ok(dt); + value.parse_string(|text| { + let sqlite_datetime_formats = &[ + // Most likely format + "%F %T%.f", + // Other formats in order of appearance in docs + "%F %R", + "%F %RZ", + "%F %R%:z", + "%F %T%.fZ", + "%F %T%.f%:z", + "%FT%R", + "%FT%RZ", + "%FT%R%:z", + "%FT%T%.f", + "%FT%T%.fZ", + "%FT%T%.f%:z", + ]; + + for format in sqlite_datetime_formats { + if let Ok(dt) = Self::parse_from_str(text, format) { + return Ok(dt); + } } - } - if let Ok(julian_days) = text.parse::() { - let epoch_in_julian_days = 2_440_587.5; - let seconds_in_day = 86400.0; - let timestamp = (julian_days - epoch_in_julian_days) * seconds_in_day; - let seconds = timestamp as i64; - let nanos = (timestamp.fract() * 1E9) as u32; - if let Some(timestamp) = Self::from_timestamp_opt(seconds, nanos) { - return Ok(timestamp); + if let Ok(julian_days) = text.parse::() { + let epoch_in_julian_days = 2_440_587.5; + let seconds_in_day = 86400.0; + let timestamp = (julian_days - epoch_in_julian_days) * seconds_in_day; + let seconds = timestamp as i64; + let nanos = (timestamp.fract() * 1E9) as u32; + if let Some(timestamp) = Self::from_timestamp_opt(seconds, nanos) { + return Ok(timestamp); + } } - } - Err(format!("Invalid datetime {}", text).into()) + Err(format!("Invalid datetime {}", text).into()) + }) } } diff --git a/diesel/src/sqlite/types/date_and_time/mod.rs b/diesel/src/sqlite/types/date_and_time/mod.rs index 6583061a3a6b..563a68ef83d6 100644 --- a/diesel/src/sqlite/types/date_and_time/mod.rs +++ b/diesel/src/sqlite/types/date_and_time/mod.rs @@ -9,13 +9,8 @@ use crate::sqlite::Sqlite; #[cfg(feature = "chrono")] mod chrono; -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -32,13 +27,8 @@ impl ToSql for String { } } -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -55,13 +45,8 @@ impl ToSql for String { } } -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { FromSql::::from_sql(value) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index 612331c5e1a8..6e8a804f075a 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -9,62 +9,52 @@ use crate::deserialize::{self, FromSql}; use crate::serialize::{self, Output, ToSql}; use crate::sql_types; -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `String`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const str { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { +impl FromSql for String { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { let text = value.read_text(); - Ok(text as *const _) + Ok(text) } } -/// The returned pointer is *only* valid for the lifetime to the argument of -/// `from_sql`. This impl is intended for uses where you want to write a new -/// impl in terms of `Vec`, but don't want to allocate. We have to return a -/// raw pointer instead of a reference with a lifetime due to the structure of -/// `FromSql` -impl FromSql for *const [u8] { - fn from_sql(bytes: &'_ SqliteValue) -> deserialize::Result { +impl FromSql for Vec { + fn from_sql(bytes: SqliteValue<'_, '_>) -> deserialize::Result { let bytes = bytes.read_blob(); - Ok(bytes as *const _) + Ok(bytes) } } impl FromSql for i16 { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer() as i16) } } impl FromSql for i32 { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer()) } } impl FromSql for bool { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_integer() != 0) } } impl FromSql for i64 { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_long()) } } impl FromSql for f32 { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_double() as f32) } } impl FromSql for f64 { - fn from_sql(value: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { Ok(value.read_double()) } } diff --git a/diesel/src/sqlite/types/numeric.rs b/diesel/src/sqlite/types/numeric.rs index f4b4a1a955e2..aaf2683544dc 100644 --- a/diesel/src/sqlite/types/numeric.rs +++ b/diesel/src/sqlite/types/numeric.rs @@ -8,7 +8,7 @@ use crate::sqlite::connection::SqliteValue; use crate::sqlite::Sqlite; impl FromSql for BigDecimal { - fn from_sql(bytes: &'_ SqliteValue) -> deserialize::Result { + fn from_sql(bytes: SqliteValue<'_, '_>) -> deserialize::Result { let x = >::from_sql(bytes)?; BigDecimal::from_f64(x).ok_or_else(|| format!("{} is not valid decimal number ", x).into()) } diff --git a/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr b/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr index 9bca051f805b..927edf363d1a 100644 --- a/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr +++ b/diesel_compile_tests/tests/fail/array_expressions_must_be_correct_type.stderr @@ -66,18 +66,18 @@ error[E0277]: the trait bound `f64: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` -error[E0277]: the trait bound `f64: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `f64: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_correct_type.rs:9:33 | 9 | select(array((1f64, 3f64))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `f64` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `f64` | - = note: required because of the requirements on the impl of `QueryFragment<_>` for `(f64, f64)` + = note: required because of the requirements on the impl of `QueryFragment` for `(f64, f64)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<(f64, f64), diesel::sql_types::Integer>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<(f64, f64), diesel::sql_types::Integer>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause>>` error[E0277]: the trait bound `f64: diesel::Expression` is not satisfied diff --git a/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr b/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr index c60d03994aeb..6a9335579ad8 100644 --- a/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr +++ b/diesel_compile_tests/tests/fail/array_expressions_must_be_same_type.stderr @@ -66,18 +66,18 @@ error[E0277]: the trait bound `f64: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` -error[E0277]: the trait bound `f64: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `f64: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_same_type.rs:11:30 | 11 | select(array((1, 3f64))).get_result::>(&mut connection).unwrap(); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `f64` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `f64` | - = note: required because of the requirements on the impl of `QueryFragment<_>` for `(diesel::expression::bound::Bound, f64)` + = note: required because of the requirements on the impl of `QueryFragment` for `(diesel::expression::bound::Bound, f64)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<(diesel::expression::bound::Bound, f64), diesel::sql_types::Integer>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<(diesel::expression::bound::Bound, f64), diesel::sql_types::Integer>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause, f64), diesel::sql_types::Integer>>>` error[E0277]: the trait bound `f64: diesel::Expression` is not satisfied @@ -192,11 +192,11 @@ error[E0277]: the trait bound `{integer}: QueryId` is not satisfied = note: required because of the requirements on the impl of `QueryId` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` -error[E0277]: the trait bound `{integer}: QueryFragment<_>` is not satisfied +error[E0277]: the trait bound `{integer}: QueryFragment` is not satisfied --> $DIR/array_expressions_must_be_same_type.rs:12:30 | 12 | select(array((1, 3f64))).get_result::>(&mut connection).unwrap(); - | ^^^^^^^^^^ the trait `QueryFragment<_>` is not implemented for `{integer}` + | ^^^^^^^^^^ the trait `QueryFragment` is not implemented for `{integer}` | = help: the following implementations were found: <&'a T as QueryFragment> @@ -204,12 +204,12 @@ error[E0277]: the trait bound `{integer}: QueryFragment<_>` is not satisfied <(A, B) as QueryFragment<__DB>> <(A, B, C) as QueryFragment<__DB>> and 223 others - = note: required because of the requirements on the impl of `QueryFragment<_>` for `({integer}, diesel::expression::bound::Bound)` + = note: required because of the requirements on the impl of `QueryFragment` for `({integer}, diesel::expression::bound::Bound)` = note: 2 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `diesel::pg::expression::array::ArrayLiteral<({integer}, diesel::expression::bound::Bound), diesel::sql_types::Double>` - = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), _>` for `diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>` + = note: required because of the requirements on the impl of `QueryFragment` for `diesel::pg::expression::array::ArrayLiteral<({integer}, diesel::expression::bound::Bound), diesel::sql_types::Double>` + = note: required because of the requirements on the impl of `SelectClauseQueryFragment<(), Pg>` for `diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>` = note: 1 redundant requirements hidden - = note: required because of the requirements on the impl of `QueryFragment<_>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` + = note: required because of the requirements on the impl of `QueryFragment` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` = note: required because of the requirements on the impl of `LoadQuery<_, Vec>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause), diesel::sql_types::Double>>>` error[E0277]: the trait bound `{integer}: diesel::Expression` is not satisfied diff --git a/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr b/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr index 2a425d396d30..309e78588330 100644 --- a/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr +++ b/diesel_compile_tests/tests/fail/array_only_usable_with_pg.stderr @@ -6,21 +6,6 @@ error[E0271]: type mismatch resolving `>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` -error[E0277]: the trait bound `Sqlite: HasSqlType>` is not satisfied - --> $DIR/array_only_usable_with_pg.rs:8:25 - | -8 | select(array((1,))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `HasSqlType>` is not implemented for `Sqlite` - | - = help: the following implementations were found: - > - > - > - > - and 8 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` - error[E0271]: type mismatch resolving `::Backend == Pg` --> $DIR/array_only_usable_with_pg.rs:11:25 | @@ -28,18 +13,3 @@ error[E0271]: type mismatch resolving `>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` - -error[E0277]: the trait bound `Mysql: HasSqlType>` is not satisfied - --> $DIR/array_only_usable_with_pg.rs:11:25 - | -11 | select(array((1,))).get_result::>(&mut connection); - | ^^^^^^^^^^ the trait `HasSqlType>` is not implemented for `Mysql` - | - = help: the following implementations were found: - > - > - > - > - and 15 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Mysql` - = note: required because of the requirements on the impl of `LoadQuery>` for `SelectStatement<(), diesel::query_builder::select_clause::SelectClause,), diesel::sql_types::Integer>>>` diff --git a/diesel_compile_tests/tests/fail/selectable.stderr b/diesel_compile_tests/tests/fail/selectable.stderr index 1b175a43590f..555672aa6411 100644 --- a/diesel_compile_tests/tests/fail/selectable.stderr +++ b/diesel_compile_tests/tests/fail/selectable.stderr @@ -603,27 +603,3 @@ error[E0271]: type mismatch resolving `` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` - -error[E0277]: the trait bound `Sqlite: HasSqlType>` is not satisfied - --> $DIR/selectable.rs:210:10 - | -210 | .load(&mut conn) - | ^^^^ the trait `HasSqlType>` is not implemented for `Sqlite` - | - = help: the following implementations were found: - > - > - > - > - and 8 others - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` - -error[E0277]: the trait bound `SelectBy: SingleValue` is not satisfied - --> $DIR/selectable.rs:210:10 - | -210 | .load(&mut conn) - | ^^^^ the trait `SingleValue` is not implemented for `SelectBy` - | - = note: required because of the requirements on the impl of `QueryMetadata>` for `Sqlite` - = note: required because of the requirements on the impl of `LoadQuery` for `SelectStatement, Grouped, diesel::expression::nullable::Nullable>>>, diesel::query_builder::select_clause::SelectClause>, diesel::query_builder::distinct_clause::NoDistinctClause, diesel::query_builder::where_clause::NoWhereClause, diesel::query_builder::order_clause::NoOrderClause, LimitOffsetClause, diesel::query_builder::group_by_clause::GroupByClause>` diff --git a/diesel_tests/tests/deserialization.rs b/diesel_tests/tests/deserialization.rs index 380c701605bd..c6004c183638 100644 --- a/diesel_tests/tests/deserialization.rs +++ b/diesel_tests/tests/deserialization.rs @@ -1,5 +1,6 @@ use crate::schema::*; -use diesel::*; +use diesel::deserialize::FromSqlRow; +use diesel::prelude::*; use std::borrow::Cow; #[derive(Queryable, PartialEq, Debug, Selectable)] @@ -27,3 +28,93 @@ fn generated_queryable_allows_lifetimes() { users.select(CowUser::as_select()).first(connection) ); } + +#[test] +fn fun_with_row_iters() { + use crate::schema::users::dsl::*; + use diesel::deserialize::FromSql; + use diesel::row::{Field, Row}; + use diesel::sql_types; + + let conn = &mut connection_with_sean_and_tess_in_users_table(); + + let query = users.select((id, name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + dbg!(); + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + dbg!(); + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().is_none()); + dbg!( + >::from_nullable_sql(first_values.0) + .unwrap() + ); //, expected[0].0); + dbg!( + >::from_nullable_sql(first_values.1) + .unwrap() + ); //, expected[0].1); + + dbg!( + >::from_nullable_sql(second_values.0) + .unwrap() + ); //, expected[1].0); + dbg!( + >::from_nullable_sql(second_values.1) + .unwrap() + ); //, expected[1].1); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + dbg!( + >::from_nullable_sql(first_values.0) + .unwrap() + ); //, expected[0].0); + dbg!( + >::from_nullable_sql(first_values.1) + .unwrap() + ); //, expected[0].1); + + panic!() +} From da4447b10e839d2b243cfd98d3a2e3717ffe2de4 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Mon, 28 Jun 2021 17:07:46 +0200 Subject: [PATCH 13/32] Fix the custom sql function implementation for Sqlite --- diesel/src/sqlite/connection/functions.rs | 91 +++++++++++--- diesel/src/sqlite/connection/raw.rs | 43 ++++--- diesel/src/sqlite/connection/row.rs | 116 ++++++++++++++++++ diesel/src/sqlite/connection/sqlite_value.rs | 1 + .../sqlite/connection/statement_iterator.rs | 20 ++- diesel/src/util.rs | 2 + diesel_tests/tests/deserialization.rs | 90 -------------- 7 files changed, 235 insertions(+), 128 deletions(-) diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index c5d6c7e528e4..025219d403a2 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -1,6 +1,7 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; +use super::row::PrivateSqliteRow; use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction}; use crate::deserialize::{FromSqlRow, StaticallySizedRow}; @@ -8,7 +9,13 @@ use crate::result::{DatabaseErrorKind, Error, QueryResult}; use crate::row::{Field, PartialRow, Row, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; +use crate::sqlite::connection::sqlite_value::OwnedSqliteValue; +use crate::sqlite::SqliteValue; +use std::cell::{Ref, RefCell}; use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::DerefMut; +use std::rc::Rc; pub fn register( conn: &RawConnection, @@ -85,7 +92,7 @@ where } pub(crate) fn build_sql_function_args( - args: &[*mut ffi::sqlite3_value], + args: &mut [*mut ffi::sqlite3_value], ) -> Result where Args: FromSqlRow, @@ -117,14 +124,67 @@ where }) } -#[derive(Clone)] struct FunctionRow<'a> { - args: &'a [*mut ffi::sqlite3_value], + // we use `ManuallyDrop` to prevent dropping the content of the internal vector + // as this buffer is owned by sqlite not by diesel + args: Rc>>>, + field_count: usize, + marker: PhantomData<&'a ffi::sqlite3_value>, +} + +impl<'a> Drop for FunctionRow<'a> { + fn drop(&mut self) { + if let Some(args) = Rc::get_mut(&mut self.args) { + if let PrivateSqliteRow::Duplicated { column_names, .. } = + DerefMut::deref_mut(RefCell::get_mut(args)) + { + if let Some(inner) = Rc::get_mut(column_names) { + // an empty Vector does not allocate according to the documentation + // so this prevents leaking memory + std::mem::drop(std::mem::replace(inner, Vec::new())); + } + } + } + } } impl<'a> FunctionRow<'a> { - fn new(args: &'a [*mut ffi::sqlite3_value]) -> Self { - Self { args } + fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self { + let lenghts = args.len(); + let args = unsafe { + Vec::from_raw_parts( + // This cast is safe because: + // * Casting from a pointer to an arry to a pointer to the first array + // element is safe + // * Casting from a raw pointer to `NonNull` is safe, + // because `NonNull` is #[repr(transparent)] + // * Casting from `NonNull` to `OwnedSqliteValue` is safe, + // as the struct is `#[repr(transparent)] + // * Casting from `NonNull` to `Option>` as the documentation + // states: "This is so that enums may use this forbidden value as a discriminant – + // Option> has the same size as *mut T" + // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)] + // guarantees the same layout as the inner type + // * It's unsafe to drop the vector (and the vector elements) + // because of this we wrap the vector (or better the Row) + // Into `ManualDrop` to prevent the dropping + args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value + as *mut Option, + lenghts, + lenghts, + ) + }; + + Self { + field_count: lenghts, + args: Rc::new(RefCell::new(ManuallyDrop::new( + PrivateSqliteRow::Duplicated { + values: args, + column_names: Rc::new(vec![None; lenghts]), + }, + ))), + marker: PhantomData, + } } } @@ -133,7 +193,7 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { - self.args.len() + self.field_count } fn get(&self, idx: I) -> Option @@ -141,10 +201,9 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { Self: crate::row::RowIndex, { let idx = self.idx(idx)?; - - self.args.get(idx).map(|arg| FunctionArgument { - arg: *arg, - p: PhantomData, + Some(FunctionArgument { + args: self.args.clone(), + col_idx: idx as i32, }) } @@ -155,7 +214,7 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { impl<'a> RowIndex for FunctionRow<'a> { fn idx(&self, idx: usize) -> Option { - if idx < self.args.len() { + if idx < self.field_count() { Some(idx) } else { None @@ -170,8 +229,8 @@ impl<'a, 'b> RowIndex<&'a str> for FunctionRow<'b> { } struct FunctionArgument<'a> { - arg: *mut ffi::sqlite3_value, - p: PhantomData<&'a ()>, + args: Rc>>>, + col_idx: i32, } impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { @@ -187,7 +246,9 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { where 'a: 'b, { - todo!() - // unsafe { SqliteValue::new(self.arg) } + SqliteValue::new( + Ref::map(self.args.borrow(), |drop| std::ops::Deref::deref(drop)), + self.col_idx, + ) } } diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs index 23f6186f33ba..b65ca1511881 100644 --- a/diesel/src/sqlite/connection/raw.rs +++ b/diesel/src/sqlite/connection/raw.rs @@ -91,7 +91,7 @@ impl RawConnection { f: F, ) -> QueryResult<()> where - F: FnMut(&Self, &[*mut ffi::sqlite3_value]) -> QueryResult + F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult + std::panic::UnwindSafe + Send + 'static, @@ -269,7 +269,7 @@ extern "C" fn run_custom_function( num_args: libc::c_int, value_ptr: *mut *mut ffi::sqlite3_value, ) where - F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult + F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult + std::panic::UnwindSafe + Send + 'static, @@ -278,7 +278,6 @@ extern "C" fn run_custom_function( static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen."; static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen."; - let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) }; let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } { // We use `ManuallyDrop` here because we do not want to run the // Drop impl of `RawConnection` as this would close the connection @@ -306,13 +305,16 @@ extern "C" fn run_custom_function( // this is sound as `F` itself and the stored string is `UnwindSafe` let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback); - let result = - std::panic::catch_unwind(move || Ok((callback.0)(&*conn, args)?)).unwrap_or_else(|p| { - Err(SqliteCallbackError::Panic( - p, - data_ptr.function_name.clone(), - )) - }); + let result = std::panic::catch_unwind(move || { + let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) }; + Ok((callback.0)(&*conn, args)?) + }) + .unwrap_or_else(|p| { + Err(SqliteCallbackError::Panic( + p, + data_ptr.function_name.clone(), + )) + }); match result { Ok(value) => value.result_of(ctx), Err(e) => { @@ -342,15 +344,16 @@ extern "C" fn run_aggregator_step_function, Sqlite: HasSqlType, { - let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) }; - let result = - std::panic::catch_unwind(move || run_aggregator_step::(ctx, args)) - .unwrap_or_else(|e| { - Err(SqliteCallbackError::Panic( - e, - format!("{}::step() paniced", std::any::type_name::()), - )) - }); + let result = std::panic::catch_unwind(move || { + let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) }; + run_aggregator_step::(ctx, args) + }) + .unwrap_or_else(|e| { + Err(SqliteCallbackError::Panic( + e, + format!("{}::step() paniced", std::any::type_name::()), + )) + }); match result { Ok(()) => {} @@ -360,7 +363,7 @@ extern "C" fn run_aggregator_step_function( ctx: *mut ffi::sqlite3_context, - args: &[*mut ffi::sqlite3_value], + args: &mut [*mut ffi::sqlite3_value], ) -> Result<(), SqliteCallbackError> where A: SqliteAggregateFunction, diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index 380db103a7fd..3badcfb1d5c1 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -164,3 +164,119 @@ impl<'a> Field<'a, Sqlite> for SqliteField<'a> { SqliteValue::new(self.row.inner.borrow(), self.col_idx) } } + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } + } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);") + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_values); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_values); + + assert!(row_iter.next().is_none()); + + let first_values = (first_fields.0.value(), first_fields.1.value()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); +} diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index d62bd57d4923..0cc6c20c9d80 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -23,6 +23,7 @@ pub struct SqliteValue<'a, 'b> { col_idx: i32, } +#[repr(transparent)] pub struct OwnedSqliteValue { pub(super) value: NonNull, } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 993af2feaf02..2241fa6570d1 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -88,9 +88,23 @@ impl<'a> Iterator for StatementIterator<'a> { // a user stored the row in some long time container before calling next another time // In this case we copy out the current values into a temporary store and advance // the statement iterator internally afterwards - if let PrivateSqliteRow::Direct(stmt) = - last_row.replace_with(|inner| inner.duplicate(&mut self.column_names)) - { + let last_row = { + let mut last_row = match last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + self.inner = Started(last_row.clone()); + return Some(Err(crate::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `SqliteValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(&mut self.column_names); + std::mem::replace(last_row, duplicated) + }; + if let PrivateSqliteRow::Direct(stmt) = last_row { match stmt.step() { Err(e) => Some(Err(e)), Ok(None) => None, diff --git a/diesel/src/util.rs b/diesel/src/util.rs index 2a0a1de3d14a..5bf0d9024540 100644 --- a/diesel/src/util.rs +++ b/diesel/src/util.rs @@ -10,6 +10,8 @@ pub trait TupleSize { const SIZE: usize; } +#[cfg(any(feature = "postgres", feature = "sqlite"))] mod once_cell; +#[cfg(any(feature = "postgres", feature = "sqlite"))] pub(crate) use self::once_cell::OnceCell; diff --git a/diesel_tests/tests/deserialization.rs b/diesel_tests/tests/deserialization.rs index c6004c183638..968222b030bc 100644 --- a/diesel_tests/tests/deserialization.rs +++ b/diesel_tests/tests/deserialization.rs @@ -28,93 +28,3 @@ fn generated_queryable_allows_lifetimes() { users.select(CowUser::as_select()).first(connection) ); } - -#[test] -fn fun_with_row_iters() { - use crate::schema::users::dsl::*; - use diesel::deserialize::FromSql; - use diesel::row::{Field, Row}; - use diesel::sql_types; - - let conn = &mut connection_with_sean_and_tess_in_users_table(); - - let query = users.select((id, name)); - - let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; - - let row_iter = conn.load(&query).unwrap(); - for (row, expected) in row_iter.zip(&expected) { - let row = row.unwrap(); - - let deserialized = <(i32, String) as FromSqlRow< - (sql_types::Integer, sql_types::Text), - _, - >>::build_from_row(&row) - .unwrap(); - - assert_eq!(&deserialized, expected); - } - - { - let collected_rows = conn.load(&query).unwrap().collect::>(); - - for (row, expected) in collected_rows.iter().zip(&expected) { - let deserialized = row - .as_ref() - .map(|row| { - <(i32, String) as FromSqlRow< - (sql_types::Integer, sql_types::Text), - _, - >>::build_from_row(row).unwrap() - }) - .unwrap(); - - assert_eq!(&deserialized, expected); - } - } - - let mut row_iter = conn.load(&query).unwrap(); - - dbg!(); - let first_row = row_iter.next().unwrap().unwrap(); - let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); - let first_values = (first_fields.0.value(), first_fields.1.value()); - - dbg!(); - let second_row = row_iter.next().unwrap().unwrap(); - let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); - let second_values = (second_fields.0.value(), second_fields.1.value()); - - assert!(row_iter.next().is_none()); - dbg!( - >::from_nullable_sql(first_values.0) - .unwrap() - ); //, expected[0].0); - dbg!( - >::from_nullable_sql(first_values.1) - .unwrap() - ); //, expected[0].1); - - dbg!( - >::from_nullable_sql(second_values.0) - .unwrap() - ); //, expected[1].0); - dbg!( - >::from_nullable_sql(second_values.1) - .unwrap() - ); //, expected[1].1); - - let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); - let first_values = (first_fields.0.value(), first_fields.1.value()); - - dbg!( - >::from_nullable_sql(first_values.0) - .unwrap() - ); //, expected[0].0); - dbg!( - >::from_nullable_sql(first_values.1) - .unwrap() - ); //, expected[0].1); - - panic!() -} From df787a2e435bc2d052775d3687294af448f1fcdc Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 29 Jun 2021 11:15:32 +0200 Subject: [PATCH 14/32] WIP --- diesel/Cargo.toml | 2 +- diesel/src/mysql/connection/stmt/iterator.rs | 7 +++++-- diesel/src/pg/connection/row.rs | 7 +++++-- diesel/src/util/once_cell.rs | 6 ++---- diesel_bench/Cargo.toml | 9 ++++++++- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 14412097fe13..3b005663ec5f 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["32-column-tables", "without-deprecated"] +default = ["mysql", "postgres", "sqlite"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 84058734dd11..7decea82df9c 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -157,7 +157,7 @@ pub struct MysqlField<'a> { _marker: PhantomData<&'a (Binds, StatementMetadata)>, } -impl<'a> Field for MysqlField<'a> { +impl<'a> Field<'a, Mysql> for MysqlField<'a> { fn field_name(&self) -> Option<&str> { self.metadata.fields()[self.idx].field_name() } @@ -166,7 +166,10 @@ impl<'a> Field for MysqlField<'a> { (*self.bind)[self.idx].is_null() } - fn value<'b>(&'b self) -> Option> { + fn value<'b>(&'b self) -> Option> + where + 'a: 'b, + { self.bind[self.idx].value() } } diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index 03ba6a03cd8a..87b923058996 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -64,12 +64,15 @@ pub struct PgField<'a> { col_idx: usize, } -impl<'a> Field for PgField<'a> { +impl<'a> Field<'a, Pg> for PgField<'a> { fn field_name(&self) -> Option<&str> { self.db_result.column_name(self.col_idx) } - fn value<'b>(&'b self) -> Option> { + fn value<'b>(&'b self) -> Option> + where + 'a: 'b, + { let raw = self.db_result.get(self.row_idx, self.col_idx)?; let type_oid = self.db_result.column_type(self.col_idx); diff --git a/diesel/src/util/once_cell.rs b/diesel/src/util/once_cell.rs index 2e1771a9470e..a7a2c42543de 100644 --- a/diesel/src/util/once_cell.rs +++ b/diesel/src/util/once_cell.rs @@ -13,9 +13,8 @@ use std::cell::UnsafeCell; /// # Examples /// /// ``` -/// #![feature(once_cell)] /// -/// use std::lazy::OnceCell; +/// use crate::lazy::OnceCell; /// /// let cell = OnceCell::new(); /// assert!(cell.get().is_none()); @@ -60,9 +59,8 @@ impl OnceCell { /// # Examples /// /// ``` - /// #![feature(once_cell)] /// - /// use std::lazy::OnceCell; + /// use crate::lazy::OnceCell; /// /// let cell = OnceCell::new(); /// assert_eq!(cell.get_or_try_init(|| Err(())), Err(())); diff --git a/diesel_bench/Cargo.toml b/diesel_bench/Cargo.toml index 9c0a252c8da5..da8b3ff609ea 100644 --- a/diesel_bench/Cargo.toml +++ b/diesel_bench/Cargo.toml @@ -8,6 +8,13 @@ autobenches = false [workspace] +[workspace.profile.bench] +opt-level = 3 +debug = true +lto = true +incremental = false +codegen-units = 1 + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] @@ -42,7 +49,7 @@ bench = true harness = false [features] -default = [] +default = ["sqlite"] postgres = ["diesel/postgres"] sqlite = ["diesel/sqlite"] mysql = ["diesel/mysql"] From 6e99fe0fb599fb9a78b2d77a619db04b9bfa49fd Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Wed, 30 Jun 2021 17:08:41 +0200 Subject: [PATCH 15/32] Multiple small fixes * Add tests for backend specific iterator behavior * Optimize the sqlite implementation * Fix the mysql implementation to work correctly for cases where someone stores the row till after the next call to Iterator::next * Add a `RunQueryDsl::load_iter` method --- diesel/src/connection/mod.rs | 2 +- diesel/src/connection/statement_cache.rs | 12 +- diesel/src/mysql/connection/bind.rs | 9 +- diesel/src/mysql/connection/mod.rs | 4 +- diesel/src/mysql/connection/stmt/iterator.rs | 260 ++++++++++++++++-- diesel/src/mysql/connection/stmt/mod.rs | 2 +- diesel/src/pg/connection/cursor.rs | 110 ++++++++ diesel/src/pg/connection/mod.rs | 2 +- diesel/src/pg/connection/row.rs | 8 +- diesel/src/query_dsl/mod.rs | 117 +++++++- diesel/src/row.rs | 47 +++- diesel/src/sqlite/backend.rs | 19 -- diesel/src/sqlite/connection/functions.rs | 24 +- diesel/src/sqlite/connection/mod.rs | 7 +- diesel/src/sqlite/connection/row.rs | 61 ++-- diesel/src/sqlite/connection/sqlite_value.rs | 187 +++++-------- .../sqlite/connection/statement_iterator.rs | 6 +- diesel/src/sqlite/connection/stmt.rs | 74 +++-- diesel/src/sqlite/types/mod.rs | 18 +- diesel/src/util/once_cell.rs | 14 +- ...ct_carries_correct_result_type_info.stderr | 3 +- ...elect_sql_still_ensures_result_type.stderr | 3 +- 22 files changed, 698 insertions(+), 291 deletions(-) diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 48ffba21c09f..383d5c55cb98 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -11,7 +11,7 @@ use crate::query_builder::{AsQuery, QueryFragment, QueryId}; use crate::result::*; #[doc(hidden)] -pub use self::statement_cache::{MaybeCached, StatementCache, StatementCacheKey}; +pub use self::statement_cache::{MaybeCached, PrepareForCache, StatementCache, StatementCacheKey}; pub use self::transaction_manager::{AnsiTransactionManager, TransactionManager}; /// Perform simple operations on a backend. diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache.rs index 06323589efe0..04970c4c69d1 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache.rs @@ -107,6 +107,12 @@ pub struct StatementCache { pub cache: HashMap, Statement>, } +#[derive(Debug, Clone, Copy)] +pub enum PrepareForCache { + Yes, + No, +} + #[allow(clippy::len_without_is_empty, clippy::new_without_default)] impl StatementCache where @@ -133,7 +139,7 @@ where ) -> QueryResult> where T: QueryFragment + QueryId, - F: FnOnce(&str) -> QueryResult, + F: FnOnce(&str, PrepareForCache) -> QueryResult, { use std::collections::hash_map::Entry::{Occupied, Vacant}; @@ -141,7 +147,7 @@ where if !source.is_safe_to_cache_prepared()? { let sql = cache_key.sql(source)?; - return prepare_fn(&sql).map(MaybeCached::CannotCache); + return prepare_fn(&sql, PrepareForCache::No).map(MaybeCached::CannotCache); } let cached_result = match self.cache.entry(cache_key) { @@ -149,7 +155,7 @@ where Vacant(entry) => { let statement = { let sql = entry.key().sql(source)?; - prepare_fn(&sql) + prepare_fn(&sql, PrepareForCache::Yes) }; entry.insert(statement?) diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index 7b2fa7e43f77..14beb1c838b4 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -11,6 +11,7 @@ use crate::mysql::types::MYSQL_TIME; use crate::mysql::{MysqlType, MysqlValue}; use crate::result::QueryResult; +#[derive(Clone)] pub struct Binds { data: Vec, } @@ -34,7 +35,7 @@ impl Binds { .iter() .zip( types - .into_iter() + .iter() .map(|o| o.as_ref()) .chain(std::iter::repeat(None)), ) @@ -78,10 +79,6 @@ impl Binds { data.update_buffer_length(); } } - - pub fn len(&self) -> usize { - self.data.len() - } } impl Index for Binds { @@ -127,7 +124,7 @@ impl From for Flags { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BindData { tpe: ffi::enum_field_types, bytes: Vec, diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index efae0ca6a6bf..48b379f7f2de 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -88,7 +88,7 @@ impl Connection for MysqlConnection { MaybeCached::Cached(stmt) => stmt, }; - let results = unsafe { stmt.results(metadata)? }; + let results = unsafe { stmt.results(&metadata)? }; Ok(results) }) } @@ -121,7 +121,7 @@ impl MysqlConnection { let cache = &mut self.statement_cache; let conn = &mut self.raw_connection; - let mut stmt = cache.cached_statement(source, &[], |sql| conn.prepare(sql))?; + let mut stmt = cache.cached_statement(source, &[], |sql, _| conn.prepare(sql))?; let mut bind_collector = RawBytesBindCollector::new(); source.collect_binds(&mut bind_collector, &mut ())?; let binds = bind_collector diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 7decea82df9c..bf64bf4815da 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,3 +1,4 @@ +use std::cell::{Ref, RefCell}; use std::marker::PhantomData; use std::rc::Rc; @@ -9,30 +10,28 @@ use crate::row::*; #[allow(missing_debug_implementations)] pub struct StatementIterator<'a> { stmt: &'a mut Statement, - output_binds: Rc, + last_row: Rc>, metadata: Rc, - types: Vec>, size: usize, fetched_rows: usize, } impl<'a> StatementIterator<'a> { #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement, types: Vec>) -> QueryResult { + pub fn new(stmt: &'a mut Statement, types: &[Option]) -> QueryResult { let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_output_types(&types, &metadata); + let mut output_binds = Binds::from_output_types(types, &metadata); stmt.execute_statement(&mut output_binds)?; let size = unsafe { stmt.result_size() }?; Ok(StatementIterator { metadata: Rc::new(metadata), - output_binds: Rc::new(output_binds), + last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))), fetched_rows: 0, size, stmt, - types, }) } } @@ -43,28 +42,66 @@ impl<'a> Iterator for StatementIterator<'a> { fn next(&mut self) -> Option { // check if we own the only instance of the bind buffer // if that's the case we can reuse the underlying allocations - // if that's not the case, allocate a new buffer - let res = if let Some(binds) = Rc::get_mut(&mut self.output_binds) { - self.stmt - .populate_row_buffers(binds) - .map(|o| o.map(|()| self.output_binds.clone())) + // if that's not the case, we need to copy the output bind buffers + // to somewhere else + let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) { + if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) { + self.stmt.populate_row_buffers(binds) + } else { + // any other state than `PrivateMysqlRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } } else { // The shared bind buffer is in use by someone else, - // we allocate a new buffer here - let mut output_binds = Binds::from_output_types(&self.types, &self.metadata); - self.stmt - .populate_row_buffers(&mut output_binds) - .map(|o| o.map(|()| Rc::new(output_binds))) + // this means we copy out the values and replace the used reference + // by the copied values. After this we can advance the statment + // another step + let mut last_row = { + let mut last_row = match self.last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + return Some(Err(crate::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(); + std::mem::replace(last_row, duplicated) + }; + let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row { + self.stmt.populate_row_buffers(binds) + } else { + // any other state than `PrivateMysqlRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + }; + self.last_row = Rc::new(RefCell::new(last_row)); + res }; match res { - Ok(Some(binds)) => { + Ok(Some(())) => { self.fetched_rows += 1; Some(Ok(MysqlRow { - col_idx: 0, - binds, metadata: self.metadata.clone(), _marker: Default::default(), + row: self.last_row.clone(), })) } Ok(None) => None, @@ -96,30 +133,45 @@ impl<'a> ExactSizeIterator for StatementIterator<'a> { #[derive(Clone)] #[allow(missing_debug_implementations)] pub struct MysqlRow<'a> { - col_idx: usize, - binds: Rc, + row: Rc>, metadata: Rc, _marker: PhantomData<&'a mut (Binds, StatementMetadata)>, } -impl<'a> Row<'a, Mysql> for MysqlRow<'a> { +enum PrivateMysqlRow { + Direct(Binds), + Copied(Binds), +} + +impl PrivateMysqlRow { + fn duplicate(&self) -> Self { + match self { + Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()), + } + } +} + +impl<'a, 'b> RowFieldHelper<'a, Mysql> for MysqlRow<'b> { type Field = MysqlField<'a>; +} + +impl<'a> Row<'a, Mysql> for MysqlRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { - self.binds.len() + self.metadata.fields().len() } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex, { let idx = self.idx(idx)?; Some(MysqlField { - bind: self.binds.clone(), + binds: self.row.borrow(), metadata: self.metadata.clone(), idx, - _marker: Default::default(), }) } @@ -151,10 +203,9 @@ impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { #[allow(missing_debug_implementations)] pub struct MysqlField<'a> { - bind: Rc, + binds: Ref<'a, PrivateMysqlRow>, metadata: Rc, idx: usize, - _marker: PhantomData<&'a (Binds, StatementMetadata)>, } impl<'a> Field<'a, Mysql> for MysqlField<'a> { @@ -163,13 +214,162 @@ impl<'a> Field<'a, Mysql> for MysqlField<'a> { } fn is_null(&self) -> bool { - (*self.bind)[self.idx].is_null() + match &*self.binds { + PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(), + } } fn value<'b>(&'b self) -> Option> where 'a: 'b, { - self.bind[self.idx].value() + match &*self.binds { + PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(), + } + } +} + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } + } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(conn) + .unwrap(); + crate::sql_query("DELETE FROM users;") + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + { + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + dbg!(); + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + assert_eq!(collected_rows.len(), 2); + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + assert_eq!(&deserialized, expected); + } } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_fields); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = ( + Row::get(&second_row, 0).unwrap(), + Row::get(&second_row, 1).unwrap(), + ); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_fields); + + assert!(row_iter.next().is_none()); + + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let second_fields = ( + Row::get(&second_row, 0).unwrap(), + Row::get(&second_row, 1).unwrap(), + ); + + let first_values = (first_fields.0.value(), first_fields.1.value()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = ( + Row::get(&first_row, 0).unwrap(), + Row::get(&first_row, 1).unwrap(), + ); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); } diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 4b9998349efb..3ad0d52ebf8d 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -82,7 +82,7 @@ impl Statement { /// be called on this statement. pub unsafe fn results<'a>( &'a mut self, - types: Vec>, + types: &[Option], ) -> QueryResult> { StatementIterator::new(self, types) } diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index 593c819379f1..e8aee68367ee 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -57,3 +57,113 @@ impl<'a> Iterator for Cursor<'a> { self.len() } } + +#[test] +fn fun_with_row_iters() { + crate::table! { + #[allow(unused_parens)] + users(id) { + id -> Integer, + name -> Text, + } + } + + use crate::deserialize::{FromSql, FromSqlRow}; + use crate::pg::Pg; + use crate::prelude::*; + use crate::row::{Field, Row}; + use crate::sql_types; + + let conn = &mut crate::test_helpers::connection(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(conn) + .unwrap(); + + crate::insert_into(users::table) + .values(vec![ + (users::id.eq(1), users::name.eq("Sean")), + (users::id.eq(2), users::name.eq("Tess")), + ]) + .execute(conn) + .unwrap(); + + let query = users::table.select((users::id, users::name)); + + let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))]; + + let row_iter = conn.load(&query).unwrap(); + for (row, expected) in row_iter.zip(&expected) { + let row = row.unwrap(); + + let deserialized = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(&row) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + + { + let collected_rows = conn.load(&query).unwrap().collect::>(); + + for (row, expected) in collected_rows.iter().zip(&expected) { + let deserialized = row + .as_ref() + .map(|row| { + <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + _, + >>::build_from_row(row).unwrap() + }) + .unwrap(); + + assert_eq!(&deserialized, expected); + } + } + + let mut row_iter = conn.load(&query).unwrap(); + + let first_row = row_iter.next().unwrap().unwrap(); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + let second_row = row_iter.next().unwrap().unwrap(); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let second_values = (second_fields.0.value(), second_fields.1.value()); + + assert!(row_iter.next().is_none()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); + + assert_eq!( + >::from_nullable_sql(second_values.0).unwrap(), + expected[1].0 + ); + assert_eq!( + >::from_nullable_sql(second_values.1).unwrap(), + expected[1].1 + ); + + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); + + assert_eq!( + >::from_nullable_sql(first_values.0).unwrap(), + expected[0].0 + ); + assert_eq!( + >::from_nullable_sql(first_values.1).unwrap(), + expected[0].1 + ); +} diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 65680dc4b29d..e4f3621483a7 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -158,7 +158,7 @@ impl PgConnection { let cache_len = self.statement_cache.len(); let cache = &mut self.statement_cache; let raw_conn = &mut self.raw_connection; - let query = cache.cached_statement(source, &metadata, |sql| { + let query = cache.cached_statement(source, &metadata, |sql, _| { let query_name = if source.is_safe_to_cache_prepared()? { Some(format!("__diesel_stmt_{}", cache_len)) } else { diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index 87b923058996..f2f670e3d628 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -16,16 +16,20 @@ impl<'a> PgRow<'a> { } } -impl<'a> Row<'a, Pg> for PgRow<'a> { +impl<'a, 'b> RowFieldHelper<'a, Pg> for PgRow<'b> { type Field = PgField<'a>; +} + +impl<'a> Row<'a, Pg> for PgRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { self.db_result.column_count() } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex, { let idx = self.idx(idx)?; diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index b1e91cf4f0b8..4bf9d50405fe 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -52,6 +52,7 @@ pub use self::join_dsl::{InternalJoinDsl, JoinOnDsl, JoinWithImplicitOnClause}; pub use self::load_dsl::CompatibleType; #[doc(hidden)] pub use self::load_dsl::LoadQuery; +use self::load_dsl::LoadQueryRet; pub use self::save_changes_dsl::{SaveChangesDsl, UpdateAndFetchResults}; /// The traits used by `QueryDsl`. @@ -1307,12 +1308,10 @@ pub trait RunQueryDsl: Sized { methods::ExecuteDsl::execute(self, conn) } - /// Executes the given query, returning a `Vec` with the returned rows. + /// Executes the given query, returning a [`Vec`] with the returned rows. /// - /// When using the query builder, - /// the return type can be - /// a tuple of the values, - /// or a struct which implements [`Queryable`]. + /// When using the query builder, the return type can be + /// a tuple of the values, or a struct which implements [`Queryable`]. /// /// When this method is called on [`sql_query`], /// the return type can only be a struct which implements [`QueryableByName`] @@ -1407,6 +1406,114 @@ pub trait RunQueryDsl: Sized { self.internal_load(conn)?.collect() } + /// Executes the given query, returning a [`Iterator`] with the returned rows. + /// + /// **You should normally prefer to use [`RunQueryDsl::load`] instead**. This method + /// is provided for situations where the result needs to be collected into a different + /// container than a [`Vec`] + /// + /// When using the query builder, the return type can be + /// a tuple of the values, or a struct which implements [`Queryable`]. + /// + /// When this method is called on [`sql_query`], + /// the return type can only be a struct which implements [`QueryableByName`] + /// + /// For insert, update, and delete operations where only a count of affected is needed, + /// [`execute`] should be used instead. + /// + /// [`Queryable`]: crate::deserialize::Queryable + /// [`QueryableByName`]: crate::deserialize::QueryableByName + /// [`execute`]: crate::query_dsl::RunQueryDsl::execute() + /// [`sql_query`]: crate::sql_query() + /// + /// # Examples + /// + /// ## Returning a single field + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users.select(name) + /// .load_iter::(connection)? + /// .collect::>>()?; + /// assert_eq!(vec!["Sean", "Tess"], data); + /// # Ok(()) + /// # } + /// ``` + /// + /// ## Returning a tuple + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users + /// .load_iter::<(i32, String)>(connection)? + /// .collect::>>()?; + /// let expected_data = vec![ + /// (1, String::from("Sean")), + /// (2, String::from("Tess")), + /// ]; + /// assert_eq!(expected_data, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// ## Returning a struct + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// #[derive(Queryable, PartialEq, Debug)] + /// struct User { + /// id: i32, + /// name: String, + /// } + /// + /// # fn main() { + /// # run_test(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use diesel::insert_into; + /// # use schema::users::dsl::*; + /// # let connection = &mut establish_connection(); + /// let data = users + /// .load_iter::(connection)? + /// .collect::>>()?; + /// let expected_data = vec![ + /// User { id: 1, name: String::from("Sean") }, + /// User { id: 2, name: String::from("Tess") }, + /// ]; + /// assert_eq!(expected_data, data); + /// # Ok(()) + /// # } + /// ``` + fn load_iter<'a, U>( + self, + conn: &'a mut Conn, + ) -> QueryResult<>::Ret> + where + Self: LoadQuery, + { + self.internal_load(conn) + } + /// Runs the command, and returns the affected row. /// /// `Err(NotFound)` will be returned if the query affected 0 rows. You can diff --git a/diesel/src/row.rs b/diesel/src/row.rs index 95177bba8b31..db0cdf3d27f3 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -21,12 +21,9 @@ pub trait RowIndex { fn idx(&self, idx: I) -> Option; } -/// Represents a single database row. -/// -/// This trait is used as an argument to [`FromSqlRow`]. -/// -/// [`FromSqlRow`]: crate::deserialize::FromSqlRow -pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Sized { +/// A helper trait to indicate the life time bound for a field returned +/// by [`Row::get`] +pub trait RowFieldHelper<'a, DB: Backend> { /// Field type returned by a `Row` implementation /// /// * Crates using existing backend should not concern themself with the @@ -35,7 +32,16 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// * Crates implementing custom backends should provide their own type /// meeting the required trait bounds type Field: Field<'a, DB>; +} +/// Represents a single database row. +/// +/// This trait is used as an argument to [`FromSqlRow`]. +/// +/// [`FromSqlRow`]: crate::deserialize::FromSqlRow +pub trait Row<'a, DB: Backend>: + RowIndex + for<'b> RowIndex<&'b str> + for<'b> RowFieldHelper<'b, DB> + Sized +{ /// Return type of `PartialRow` /// /// For all implementations, beside of the `Row` implementation on `PartialRow` itself @@ -49,10 +55,25 @@ pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Si /// Get the field with the provided index from the row. /// /// Returns `None` if there is no matching field for the given index - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex; + /// Get a deserialized value with the provided index from the row. + /// + /// Returns `None` if there is no matching field for the given index + /// Returns `Some(Err(…)` if there is an error during deserialization + /// Returns `Some(T)` if deserialization is successful + fn get_value(&self, idx: I) -> Option> + where + Self: RowIndex, + T: FromSql, + { + let field = self.get(idx)?; + Some(>::from_nullable_sql(field.value())) + } + /// Returns a wrapping row that allows only to access fields, where the index is part of /// the provided range. #[doc(hidden)] @@ -111,20 +132,28 @@ impl<'a, R> PartialRow<'a, R> { } } +impl<'a, 'b, DB, R> RowFieldHelper<'a, DB> for PartialRow<'b, R> +where + DB: Backend, + R: RowFieldHelper<'a, DB>, +{ + type Field = R::Field; +} + impl<'a, 'b, DB, R> Row<'a, DB> for PartialRow<'b, R> where DB: Backend, R: Row<'a, DB>, { - type Field = R::Field; type InnerPartialRow = R; fn field_count(&self) -> usize { self.range.len() } - fn get(&self, idx: I) -> Option + fn get<'c, I>(&'c self, idx: I) -> Option<>::Field> where + 'a: 'c, Self: RowIndex, { let idx = self.idx(idx)?; diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 6f630ef0a754..bb7451cffa28 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -38,25 +38,6 @@ pub enum SqliteType { Long, } -impl SqliteType { - pub(super) fn from_raw_sqlite(tpe: i32) -> Option { - use libsqlite3_sys as ffi; - - match tpe { - ffi::SQLITE_TEXT => Some(SqliteType::Text), - ffi::SQLITE_INTEGER => Some(SqliteType::Long), - ffi::SQLITE_FLOAT => Some(SqliteType::Double), - ffi::SQLITE_BLOB => Some(SqliteType::Binary), - ffi::SQLITE_NULL => None, - _ => unreachable!( - "Sqlite's documentation state that this case ({}) is not reachable. \ - If you ever see this error message please open an issue at \ - https://github.com/diesel-rs/diesel." - ), - } - } -} - impl Backend for Sqlite { type QueryBuilder = SqliteQueryBuilder; type BindCollector = RawBytesBindCollector; diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index 025219d403a2..d6d103f7c07c 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -6,7 +6,7 @@ use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction}; use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; -use crate::row::{Field, PartialRow, Row, RowIndex}; +use crate::row::{Field, PartialRow, Row, RowFieldHelper, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; use crate::sqlite::connection::sqlite_value::OwnedSqliteValue; @@ -139,9 +139,9 @@ impl<'a> Drop for FunctionRow<'a> { DerefMut::deref_mut(RefCell::get_mut(args)) { if let Some(inner) = Rc::get_mut(column_names) { - // an empty Vector does not allocate according to the documentation - // so this prevents leaking memory - std::mem::drop(std::mem::replace(inner, Vec::new())); + // According the https://doc.rust-lang.org/std/mem/struct.ManuallyDrop.html#method.drop + // it's fine to just drop the values here + unsafe { std::ptr::drop_in_place(inner as *mut _) } } } } @@ -180,7 +180,7 @@ impl<'a> FunctionRow<'a> { args: Rc::new(RefCell::new(ManuallyDrop::new( PrivateSqliteRow::Duplicated { values: args, - column_names: Rc::new(vec![None; lenghts]), + column_names: Rc::from(vec![None; lenghts]), }, ))), marker: PhantomData, @@ -188,21 +188,25 @@ impl<'a> FunctionRow<'a> { } } -impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { +impl<'a, 'b> RowFieldHelper<'a, Sqlite> for FunctionRow<'b> { type Field = FunctionArgument<'a>; +} + +impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { self.field_count } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: crate::row::RowIndex, { let idx = self.idx(idx)?; Some(FunctionArgument { - args: self.args.clone(), + args: self.args.borrow(), col_idx: idx as i32, }) } @@ -229,7 +233,7 @@ impl<'a, 'b> RowIndex<&'a str> for FunctionRow<'b> { } struct FunctionArgument<'a> { - args: Rc>>>, + args: Ref<'a, ManuallyDrop>>, col_idx: i32, } @@ -247,7 +251,7 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { 'a: 'b, { SqliteValue::new( - Ref::map(self.args.borrow(), |drop| std::ops::Deref::deref(drop)), + Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)), self.col_idx, ) } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 43001752670f..e41820bc1c31 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -91,6 +91,7 @@ impl Connection for SqliteConnection { Ok(self.raw_connection.rows_affected_by_last_query()) } + //#[tracing::instrument(skip(self, source))] #[doc(hidden)] fn load<'a, T>( &'a mut self, @@ -214,6 +215,7 @@ impl SqliteConnection { } } + //#[tracing::instrument(skip(self, source, f))] fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( &'a mut self, source: &'_ T, @@ -221,8 +223,9 @@ impl SqliteConnection { ) -> QueryResult { let raw_connection = &self.raw_connection; let cache = &mut self.statement_cache; - let mut statement = - cache.cached_statement(source, &[], |sql| Statement::prepare(raw_connection, sql))?; + let mut statement = cache.cached_statement(source, &[], |sql, is_cached| { + Statement::prepare(raw_connection, sql, is_cached) + })?; let mut bind_collector = RawBytesBindCollector::::new(); source.collect_binds(&mut bind_collector, &mut ())?; diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index 3badcfb1d5c1..cc2dfcb99b8d 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -1,10 +1,10 @@ -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::convert::TryFrom; use std::rc::Rc; use super::sqlite_value::{OwnedSqliteValue, SqliteValue}; use super::stmt::StatementUse; -use crate::row::{Field, PartialRow, Row, RowIndex}; +use crate::row::{Field, PartialRow, Row, RowFieldHelper, RowIndex}; use crate::sqlite::Sqlite; use crate::util::OnceCell; @@ -18,19 +18,19 @@ pub(super) enum PrivateSqliteRow<'a> { Direct(StatementUse<'a>), Duplicated { values: Vec>, - column_names: Rc>>, + column_names: Rc<[Option]>, }, TemporaryEmpty, } impl<'a> PrivateSqliteRow<'a> { - pub(super) fn duplicate(&mut self, column_names: &mut Option>>>) -> Self { + pub(super) fn duplicate(&mut self, column_names: &mut Option]>>) -> Self { match self { PrivateSqliteRow::Direct(stmt) => { let column_names = if let Some(column_names) = column_names { column_names.clone() } else { - let c = Rc::new( + let c: Rc<[Option]> = Rc::from( (0..stmt.column_count()) .map(|idx| stmt.field_name(idx).map(|s| s.to_owned())) .collect::>(), @@ -60,24 +60,25 @@ impl<'a> PrivateSqliteRow<'a> { } } -impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { +impl<'a, 'b> RowFieldHelper<'a, Sqlite> for SqliteRow<'b> { type Field = SqliteField<'a>; +} + +impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { type InnerPartialRow = Self; fn field_count(&self) -> usize { self.field_count } - fn get(&self, idx: I) -> Option + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where + 'a: 'b, Self: RowIndex, { let idx = self.idx(idx)?; Some(SqliteField { - row: SqliteRow { - inner: self.inner.clone(), - field_count: self.field_count, - }, + row: self.inner.borrow(), col_idx: i32::try_from(idx).ok()?, field_name: OnceCell::new(), }) @@ -89,14 +90,12 @@ impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { } impl<'a> RowIndex for SqliteRow<'a> { - #[inline(always)] fn idx(&self, idx: usize) -> Option { - Some(idx) - // if idx < self.field_count { - // Some(idx) - // } else { - // None - // } + if idx < self.field_count { + Some(idx) + } else { + None + } } } @@ -123,7 +122,7 @@ impl<'a, 'd> RowIndex<&'d str> for SqliteRow<'a> { #[allow(missing_debug_implementations)] pub struct SqliteField<'a> { - pub(super) row: SqliteRow<'a>, + pub(super) row: Ref<'a, PrivateSqliteRow<'a>>, pub(super) col_idx: i32, field_name: OnceCell>, } @@ -131,9 +130,20 @@ pub struct SqliteField<'a> { impl<'a> Field<'a, Sqlite> for SqliteField<'a> { fn field_name(&self) -> Option<&str> { self.field_name - .get_or_init(|| match &mut *self.row.inner.borrow_mut() { + .get_or_init(|| match &*self.row { PrivateSqliteRow::Direct(stmt) => { - stmt.field_name(self.col_idx).map(|s| s.to_owned()) + let column_name = unsafe { + // This is safe due to the fact that we + // move the column name to an allocation + // on rust side as soon as possible + // + // We cannot index into a non existing column here, because + // we checked that before even constructing the corresponding + // field + let column_name = stmt.column_name(self.col_idx); + (&*column_name).to_owned() + }; + Some(column_name) } PrivateSqliteRow::Duplicated { column_names, .. } => column_names .get(self.col_idx as usize) @@ -161,7 +171,7 @@ impl<'a> Field<'a, Sqlite> for SqliteField<'a> { where 'a: 'd, { - SqliteValue::new(self.row.inner.borrow(), self.col_idx) + SqliteValue::new(Ref::clone(&self.row), self.col_idx) } } @@ -237,6 +247,8 @@ fn fun_with_row_iters() { assert!(row_iter.next().unwrap().is_err()); std::mem::drop(first_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(first_fields); let second_row = row_iter.next().unwrap().unwrap(); let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); @@ -244,9 +256,14 @@ fn fun_with_row_iters() { assert!(row_iter.next().unwrap().is_err()); std::mem::drop(second_values); + assert!(row_iter.next().unwrap().is_err()); + std::mem::drop(second_fields); assert!(row_iter.next().is_none()); + let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap()); + let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap()); + let first_values = (first_fields.0.value(), first_fields.1.value()); let second_values = (second_fields.0.value(), second_fields.1.value()); diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 0cc6c20c9d80..076db8887768 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -19,8 +19,18 @@ extern "C" { /// rust values #[allow(missing_debug_implementations, missing_copy_implementations)] pub struct SqliteValue<'a, 'b> { - row: Ref<'a, PrivateSqliteRow<'b>>, - col_idx: i32, + // This field exists to ensure that nobody + // can modify the underlying row while we are + // holding a reference to some row value here + _row: Ref<'a, PrivateSqliteRow<'b>>, + // we extract the raw value pointer as part of the constructor + // to safe the match statements for each method + // Acconding to benchmarks this leads to a ~20-30% speedup + // + // This is sound as long as nobody calls `stmt.step()` + // while holding this value. We ensure this by including + // a reference to the row above. + value: NonNull, } #[repr(transparent)] @@ -36,116 +46,93 @@ impl Drop for OwnedSqliteValue { impl<'a, 'b> SqliteValue<'a, 'b> { pub(super) fn new(row: Ref<'a, PrivateSqliteRow<'b>>, col_idx: i32) -> Option { - match &*row { - PrivateSqliteRow::Direct(stmt) => { - if stmt.column_type(col_idx).is_none() { - return None; - } - } + let value = match &*row { + PrivateSqliteRow::Direct(stmt) => stmt.column_value(col_idx)?, PrivateSqliteRow::Duplicated { values, .. } => { - if values - .get(col_idx as usize) - .and_then(|v| v.as_ref()) - .is_none() - { - return None; - } + values.get(col_idx as usize).and_then(|v| v.as_ref())?.value + } + PrivateSqliteRow::TemporaryEmpty => { + // This cannot happen as this is only a temproray state + // used inside of `StatementIterator::next()` + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) } - PrivateSqliteRow::TemporaryEmpty => todo!(), + }; + + let ret = Self { _row: row, value }; + if ret.value_type().is_none() { + None + } else { + Some(ret) } - Some(Self { row, col_idx }) } - pub(crate) fn parse_string(&self, f: impl FnOnce(&str) -> R) -> R { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => f(stmt.read_column_as_str(self.col_idx)), - super::row::PrivateSqliteRow::Duplicated { values, .. } => f(values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .read_as_str()), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), - } + pub(crate) fn parse_string<'c, R>(&'c self, f: impl FnOnce(&'c str) -> R) -> R { + let s = unsafe { + let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); + // The string is guaranteed to be utf8 according to + // https://www.sqlite.org/c3ref/value_blob.html + str::from_utf8_unchecked(bytes) + }; + f(s) } - pub(crate) fn read_text(&self) -> String { - self.parse_string(|s| s.to_owned()) + pub(crate) fn read_text(&self) -> &str { + self.parse_string(|s| s) } - pub(crate) fn read_blob(&self) -> Vec { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => { - stmt.read_column_as_blob(self.col_idx).to_owned() - } - super::row::PrivateSqliteRow::Duplicated { values, .. } => values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .read_as_blob() - .to_owned(), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), + pub(crate) fn read_blob(&self) -> &[u8] { + unsafe { + let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); + slice::from_raw_parts(ptr as *const u8, len as usize) } } pub(crate) fn read_integer(&self) -> i32 { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_integer(self.col_idx), - super::row::PrivateSqliteRow::Duplicated { values, .. } => values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .read_as_integer(), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), - } + unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) } } pub(crate) fn read_long(&self) -> i64 { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_long(self.col_idx), - super::row::PrivateSqliteRow::Duplicated { values, .. } => values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .read_as_long(), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), - } + unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) } } pub(crate) fn read_double(&self) -> f64 { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => stmt.read_column_as_double(self.col_idx), - super::row::PrivateSqliteRow::Duplicated { values, .. } => values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .read_as_double(), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), - } + unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) } } + //#[tracing::instrument(skip(self))] /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { - match &*self.row { - super::row::PrivateSqliteRow::Direct(stmt) => stmt.column_type(self.col_idx), - super::row::PrivateSqliteRow::Duplicated { values, .. } => values - .get(self.col_idx as usize) - .and_then(|o| o.as_ref()) - .expect("We checked that this value is not null") - .value_type(), - super::row::PrivateSqliteRow::TemporaryEmpty => todo!(), + let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; + match tpe { + ffi::SQLITE_TEXT => Some(SqliteType::Text), + ffi::SQLITE_INTEGER => Some(SqliteType::Long), + ffi::SQLITE_FLOAT => Some(SqliteType::Double), + ffi::SQLITE_BLOB => Some(SqliteType::Binary), + ffi::SQLITE_NULL => None, + _ => unreachable!( + "Sqlite's documentation state that this case ({}) is not reachable. \ + If you ever see this error message please open an issue at \ + https://github.com/diesel-rs/diesel." + ), } } } impl OwnedSqliteValue { - pub(super) fn copy_from_ptr(ptr: *mut ffi::sqlite3_value) -> Option { - let tpe = unsafe { ffi::sqlite3_value_type(ptr) }; - if SqliteType::from_raw_sqlite(tpe).is_none() { + pub(super) fn copy_from_ptr(ptr: NonNull) -> Option { + let tpe = unsafe { ffi::sqlite3_value_type(ptr.as_ptr()) }; + if ffi::SQLITE_NULL == tpe { return None; } - - let value = unsafe { sqlite3_value_dup(ptr) }; - + let value = unsafe { sqlite3_value_dup(ptr.as_ptr()) }; Some(Self { value: NonNull::new(value)?, }) @@ -161,40 +148,4 @@ impl OwnedSqliteValue { ); OwnedSqliteValue { value } } - - fn read_as_str(&self) -> &str { - unsafe { - let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); - let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); - let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); - // The string is guaranteed to be utf8 according to - // https://www.sqlite.org/c3ref/value_blob.html - str::from_utf8_unchecked(bytes) - } - } - - fn read_as_blob(&self) -> &[u8] { - unsafe { - let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); - let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); - slice::from_raw_parts(ptr as *const u8, len as usize) - } - } - - fn read_as_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) } - } - - fn read_as_long(&self) -> i64 { - unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) } - } - - fn read_as_double(&self) -> f64 { - unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) } - } - - fn value_type(&self) -> Option { - let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; - SqliteType::from_raw_sqlite(tpe) - } } diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 2241fa6570d1..ae51cbe52a64 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -8,7 +8,7 @@ use crate::result::QueryResult; #[allow(missing_debug_implementations)] pub struct StatementIterator<'a> { inner: PrivateStatementIterator<'a>, - column_names: Option>>>, + column_names: Option]>>, field_count: usize, } @@ -32,7 +32,7 @@ impl<'a> Iterator for StatementIterator<'a> { type Item = QueryResult>; fn next(&mut self) -> Option { - use PrivateStatementIterator::*; + use PrivateStatementIterator::{NotStarted, Started, TemporaryEmpty}; match std::mem::replace(&mut self.inner, TemporaryEmpty) { NotStarted(stmt) => match stmt.step() { @@ -94,7 +94,7 @@ impl<'a> Iterator for StatementIterator<'a> { Err(_e) => { self.inner = Started(last_row.clone()); return Some(Err(crate::result::Error::DeserializationError( - "Failed to reborrow row. Try to release any `SqliteValue` \ + "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ that exists at this point" .into(), ))); diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 5e348596eac0..77dc3fa88b0e 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -3,11 +3,11 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; use super::sqlite_value::OwnedSqliteValue; +use crate::connection::PrepareForCache; use crate::result::Error::DatabaseError; use crate::result::*; use crate::sqlite::SqliteType; use crate::util::OnceCell; -use core::slice; use std::ffi::{CStr, CString}; use std::io::{stderr, Write}; use std::os::raw as libc; @@ -18,15 +18,37 @@ pub struct Statement { bind_index: libc::c_int, } +const SQLITE_PREPARE_PERSISTENT: libc::c_uint = 0x01; + +extern "C" { + fn sqlite3_prepare_v3( + db: *mut ffi::sqlite3, + zSql: *const libc::c_char, + nByte: libc::c_int, + flags: libc::c_uint, + ppStmt: *mut *mut ffi::sqlite3_stmt, + pzTail: *mut *const libc::c_char, + ) -> libc::c_int; +} + impl Statement { - pub fn prepare(raw_connection: &RawConnection, sql: &str) -> QueryResult { + pub fn prepare( + raw_connection: &RawConnection, + sql: &str, + is_cached: PrepareForCache, + ) -> QueryResult { let mut stmt = ptr::null_mut(); let mut unused_portion = ptr::null(); let prepare_result = unsafe { - ffi::sqlite3_prepare_v2( + sqlite3_prepare_v3( raw_connection.internal_connection.as_ptr(), CString::new(sql)?.as_ptr(), sql.len() as libc::c_int, + if matches!(is_cached, PrepareForCache::Yes) { + SQLITE_PREPARE_PERSISTENT + } else { + 0 + }, &mut stmt, &mut unused_portion, ) @@ -135,7 +157,7 @@ impl<'a> StatementUse<'a> { self.step().map(|_| ()) } - pub(in crate::sqlite::connection) fn step<'c>(self) -> QueryResult> { + pub(in crate::sqlite::connection) fn step(self) -> QueryResult> { let res = unsafe { match ffi::sqlite3_step(self.statement.inner_statement.as_ptr()) { ffi::SQLITE_DONE => Ok(None), @@ -153,7 +175,7 @@ impl<'a> StatementUse<'a> { // on the same column. // // https://sqlite.org/c3ref/column_name.html - unsafe fn column_name(&mut self, idx: i32) -> *const str { + pub(super) unsafe fn column_name(&self, idx: i32) -> *const str { let name = { let column_name = ffi::sqlite3_column_name(self.statement.inner_statement.as_ptr(), idx); @@ -182,7 +204,7 @@ impl<'a> StatementUse<'a> { .map(|v| v as usize) } - pub(super) fn field_name<'c>(&'c mut self, idx: i32) -> Option<&'c str> { + pub(super) fn field_name(&mut self, idx: i32) -> Option<&str> { if let Some(column_names) = self.column_names.get() { return column_names .get(idx as usize) @@ -196,46 +218,14 @@ impl<'a> StatementUse<'a> { ret.and_then(|p| unsafe { p.as_ref() }) } - pub(super) fn column_type(&self, idx: i32) -> Option { - let tpe = unsafe { ffi::sqlite3_column_type(self.statement.inner_statement.as_ptr(), idx) }; - SqliteType::from_raw_sqlite(tpe) - } - - pub(super) fn read_column_as_str(&self, idx: i32) -> &str { - unsafe { - let ptr = ffi::sqlite3_column_text(self.statement.inner_statement.as_ptr(), idx); - let len = ffi::sqlite3_column_bytes(self.statement.inner_statement.as_ptr(), idx); - let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); - // The string is guaranteed to be utf8 according to - // https://www.sqlite.org/c3ref/value_blob.html - std::str::from_utf8_unchecked(bytes) - } - } - - pub(super) fn read_column_as_blob(&self, idx: i32) -> &[u8] { - unsafe { - let ptr = ffi::sqlite3_column_blob(self.statement.inner_statement.as_ptr(), idx); - let len = ffi::sqlite3_column_bytes(self.statement.inner_statement.as_ptr(), idx); - slice::from_raw_parts(ptr as *const u8, len as usize) - } - } - - pub(super) fn read_column_as_integer(&self, idx: i32) -> i32 { - unsafe { ffi::sqlite3_column_int(self.statement.inner_statement.as_ptr(), idx) } - } - - pub(super) fn read_column_as_long(&self, idx: i32) -> i64 { - unsafe { ffi::sqlite3_column_int64(self.statement.inner_statement.as_ptr(), idx) } - } - - pub(super) fn read_column_as_double(&self, idx: i32) -> f64 { - unsafe { ffi::sqlite3_column_double(self.statement.inner_statement.as_ptr(), idx) } + pub(super) fn copy_value(&self, idx: i32) -> Option { + OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?) } - pub(super) fn copy_value(&self, idx: i32) -> Option { + pub(super) fn column_value(&self, idx: i32) -> Option> { let ptr = unsafe { ffi::sqlite3_column_value(self.statement.inner_statement.as_ptr(), idx) }; - OwnedSqliteValue::copy_from_ptr(ptr) + NonNull::new(ptr) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index 6e8a804f075a..3d96ae2c6206 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -9,17 +9,27 @@ use crate::deserialize::{self, FromSql}; use crate::serialize::{self, Output, ToSql}; use crate::sql_types; -impl FromSql for String { +/// The returned pointer is *only* valid for the lifetime to the argument of +/// `from_sql`. This impl is intended for uses where you want to write a new +/// impl in terms of `String`, but don't want to allocate. We have to return a +/// raw pointer instead of a reference with a lifetime due to the structure of +/// `FromSql` +impl FromSql for *const str { fn from_sql(value: SqliteValue<'_, '_>) -> deserialize::Result { let text = value.read_text(); - Ok(text) + Ok(text as *const _) } } -impl FromSql for Vec { +/// The returned pointer is *only* valid for the lifetime to the argument of +/// `from_sql`. This impl is intended for uses where you want to write a new +/// impl in terms of `Vec`, but don't want to allocate. We have to return a +/// raw pointer instead of a reference with a lifetime due to the structure of +/// `FromSql` +impl FromSql for *const [u8] { fn from_sql(bytes: SqliteValue<'_, '_>) -> deserialize::Result { let bytes = bytes.read_blob(); - Ok(bytes) + Ok(bytes as *const _) } } diff --git a/diesel/src/util/once_cell.rs b/diesel/src/util/once_cell.rs index a7a2c42543de..cb55e81e193e 100644 --- a/diesel/src/util/once_cell.rs +++ b/diesel/src/util/once_cell.rs @@ -12,9 +12,9 @@ use std::cell::UnsafeCell; /// /// # Examples /// -/// ``` +/// ```ignore /// -/// use crate::lazy::OnceCell; +/// use crate::util::OnceCell; /// /// let cell = OnceCell::new(); /// assert!(cell.get().is_none()); @@ -25,7 +25,7 @@ use std::cell::UnsafeCell; /// assert_eq!(value, "Hello, World!"); /// assert!(cell.get().is_some()); /// ``` -pub struct OnceCell { +pub(crate) struct OnceCell { // Invariant: written to at most once. inner: UnsafeCell>, } @@ -38,7 +38,7 @@ impl Default for OnceCell { impl OnceCell { /// Creates a new empty cell. - pub const fn new() -> OnceCell { + pub(crate) const fn new() -> OnceCell { OnceCell { inner: UnsafeCell::new(None), } @@ -58,9 +58,9 @@ impl OnceCell { /// /// # Examples /// - /// ``` + /// ```ignore /// - /// use crate::lazy::OnceCell; + /// use crate::util::OnceCell; /// /// let cell = OnceCell::new(); /// assert_eq!(cell.get_or_try_init(|| Err(())), Err(())); @@ -71,7 +71,7 @@ impl OnceCell { /// assert_eq!(value, Ok(&92)); /// assert_eq!(cell.get(), Some(&92)) /// ``` - pub fn get_or_init(&self, f: F) -> &T + pub(crate) fn get_or_init(&self, f: F) -> &T where F: FnOnce() -> T, { diff --git a/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr b/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr index ca17b5993493..7deae98b5423 100644 --- a/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr +++ b/diesel_compile_tests/tests/fail/select_carries_correct_result_type_info.stderr @@ -22,9 +22,8 @@ error[E0277]: the trait bound `*const str: FromSql> <*const [u8] as FromSql> - <*const str as FromSql> <*const str as FromSql> - and 3 others + <*const str as FromSql> = note: required because of the requirements on the impl of `FromSql` for `std::string::String` = note: required because of the requirements on the impl of `Queryable` for `std::string::String` = note: required because of the requirements on the impl of `FromSqlRow` for `std::string::String` diff --git a/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr b/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr index 33aa2fc9a9ed..27e3009a1ea3 100644 --- a/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr +++ b/diesel_compile_tests/tests/fail/select_sql_still_ensures_result_type.stderr @@ -7,9 +7,8 @@ error[E0277]: the trait bound `*const str: FromSql` is not satisfied = help: the following implementations were found: <*const [u8] as FromSql> <*const [u8] as FromSql> - <*const str as FromSql> <*const str as FromSql> - and 3 others + <*const str as FromSql> = note: required because of the requirements on the impl of `FromSql` for `std::string::String` = note: required because of the requirements on the impl of `Queryable` for `std::string::String` = note: required because of the requirements on the impl of `FromSqlRow` for `std::string::String` From a2d9af075519b22208dd819bdd19b2f82fe0e349 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 1 Jul 2021 21:21:28 +0200 Subject: [PATCH 16/32] Ci fixes --- diesel/Cargo.toml | 2 +- diesel/src/mysql/connection/bind.rs | 705 +++++++++--------- .../information_schema.rs | 10 +- diesel_tests/tests/deserialization.rs | 1 - 4 files changed, 354 insertions(+), 364 deletions(-) diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 3b005663ec5f..9b5a684b620e 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -44,7 +44,7 @@ ipnetwork = ">=0.12.2, <0.19.0" quickcheck = "0.9" [features] -default = ["mysql", "postgres", "sqlite"] +default = ["32-column-tables", "with-deprecated"] extras = ["chrono", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index 14beb1c838b4..fb37de35c41e 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -715,9 +715,8 @@ mod tests { ) .unwrap(); - let mut stmt = conn - .prepare_query(&crate::sql_query( - "SELECT + conn.with_prepared_query(&crate::sql_query( + "SELECT tiny_int, small_int, medium_int, int_col, big_int, unsigned_int, zero_fill_int, numeric_col, decimal_col, float_col, double_col, bit_col, @@ -727,375 +726,367 @@ mod tests { ST_AsText(polygon_col), ST_AsText(multipoint_col), ST_AsText(multilinestring_col), ST_AsText(multipolygon_col), ST_AsText(geometry_collection), json_col FROM all_mysql_types", - )) - .unwrap(); - - let metadata = stmt.metadata().unwrap(); - let mut output_binds = - Binds::from_output_types(vec![None; metadata.fields().len()], &metadata); - stmt.execute_statement(&mut output_binds).unwrap(); - stmt.populate_row_buffers(&mut output_binds).unwrap(); - - let results: Vec<(BindData, &_)> = output_binds - .data - .into_iter() - .zip(metadata.fields()) - .collect::>(); - - macro_rules! matches { - ($expression:expr, $( $pattern:pat )|+ $( if $guard: expr )?) => { - match $expression { - $( $pattern )|+ $( if $guard )? => true, - _ => false - } - } - } - - let tiny_int_col = &results[0].0; - assert_eq!(tiny_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TINY); - assert!(tiny_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!tiny_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(tiny_int_col), Ok(0))); - - let small_int_col = &results[1].0; - assert_eq!(small_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_SHORT); - assert!(small_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!small_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(small_int_col), Ok(1))); - - let medium_int_col = &results[2].0; - assert_eq!(medium_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_INT24); - assert!(medium_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!medium_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(medium_int_col), Ok(2))); - - let int_col = &results[3].0; - assert_eq!(int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); - assert!(int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(int_col), Ok(3))); - - let big_int_col = &results[4].0; - assert_eq!(big_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONGLONG); - assert!(big_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!big_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(big_int_col), Ok(-5))); - - let unsigned_int_col = &results[5].0; - assert_eq!(unsigned_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); - assert!(unsigned_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(unsigned_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!( - to_value::, u32>(unsigned_int_col), - Ok(42) - )); - - let zero_fill_int_col = &results[6].0; - assert_eq!( - zero_fill_int_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG - ); - assert!(zero_fill_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(zero_fill_int_col.flags.contains(Flags::ZEROFILL_FLAG)); - assert!(matches!(to_value::(zero_fill_int_col), Ok(1))); + ), |mut stmt, _| { + + let metadata = stmt.metadata().unwrap(); + let mut output_binds = + Binds::from_output_types(&vec![None; metadata.fields().len()], &metadata); + stmt.execute_statement(&mut output_binds).unwrap(); + stmt.populate_row_buffers(&mut output_binds).unwrap(); + + let results: Vec<(BindData, &_)> = output_binds + .data + .into_iter() + .zip(metadata.fields()) + .collect::>(); + + let tiny_int_col = &results[0].0; + assert_eq!(tiny_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TINY); + assert!(tiny_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!tiny_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(tiny_int_col), Ok(0))); + + let small_int_col = &results[1].0; + assert_eq!(small_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_SHORT); + assert!(small_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!small_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(small_int_col), Ok(1))); + + let medium_int_col = &results[2].0; + assert_eq!(medium_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_INT24); + assert!(medium_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!medium_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(medium_int_col), Ok(2))); + + let int_col = &results[3].0; + assert_eq!(int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); + assert!(int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(int_col), Ok(3))); + + let big_int_col = &results[4].0; + assert_eq!(big_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONGLONG); + assert!(big_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!big_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(big_int_col), Ok(-5))); + + let unsigned_int_col = &results[5].0; + assert_eq!(unsigned_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); + assert!(unsigned_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(unsigned_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!( + to_value::, u32>(unsigned_int_col), + Ok(42) + )); + + let zero_fill_int_col = &results[6].0; + assert_eq!( + zero_fill_int_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG + ); + assert!(zero_fill_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(zero_fill_int_col.flags.contains(Flags::ZEROFILL_FLAG)); + assert!(matches!(to_value::(zero_fill_int_col), Ok(1))); - let numeric_col = &results[7].0; - assert_eq!( - numeric_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL - ); - assert!(numeric_col.flags.contains(Flags::NUM_FLAG)); - assert!(!numeric_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!( - to_value::(numeric_col).unwrap(), - bigdecimal::BigDecimal::from_f32(-999.999).unwrap() - ); + let numeric_col = &results[7].0; + assert_eq!( + numeric_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL + ); + assert!(numeric_col.flags.contains(Flags::NUM_FLAG)); + assert!(!numeric_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!( + to_value::(numeric_col).unwrap(), + bigdecimal::BigDecimal::from_f32(-999.999).unwrap() + ); - let decimal_col = &results[8].0; - assert_eq!( - decimal_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL - ); - assert!(decimal_col.flags.contains(Flags::NUM_FLAG)); - assert!(!decimal_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!( - to_value::(decimal_col).unwrap(), - bigdecimal::BigDecimal::from_f32(3.14).unwrap() - ); + let decimal_col = &results[8].0; + assert_eq!( + decimal_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL + ); + assert!(decimal_col.flags.contains(Flags::NUM_FLAG)); + assert!(!decimal_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!( + to_value::(decimal_col).unwrap(), + bigdecimal::BigDecimal::from_f32(3.14).unwrap() + ); - let float_col = &results[9].0; - assert_eq!(float_col.tpe, ffi::enum_field_types::MYSQL_TYPE_FLOAT); - assert!(float_col.flags.contains(Flags::NUM_FLAG)); - assert!(!float_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!(to_value::(float_col).unwrap(), 1.23); - - let double_col = &results[10].0; - assert_eq!(double_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DOUBLE); - assert!(double_col.flags.contains(Flags::NUM_FLAG)); - assert!(!double_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!(to_value::(double_col).unwrap(), 4.5678); - - let bit_col = &results[11].0; - assert_eq!(bit_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BIT); - assert!(!bit_col.flags.contains(Flags::NUM_FLAG)); - assert!(bit_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(!bit_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(bit_col).unwrap(), vec![170]); - - let date_col = &results[12].0; - assert_eq!(date_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DATE); - assert!(!date_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(date_col).unwrap(), - chrono::NaiveDate::from_ymd_opt(1000, 1, 1).unwrap(), - ); + let float_col = &results[9].0; + assert_eq!(float_col.tpe, ffi::enum_field_types::MYSQL_TYPE_FLOAT); + assert!(float_col.flags.contains(Flags::NUM_FLAG)); + assert!(!float_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!(to_value::(float_col).unwrap(), 1.23); + + let double_col = &results[10].0; + assert_eq!(double_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DOUBLE); + assert!(double_col.flags.contains(Flags::NUM_FLAG)); + assert!(!double_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!(to_value::(double_col).unwrap(), 4.5678); + + let bit_col = &results[11].0; + assert_eq!(bit_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BIT); + assert!(!bit_col.flags.contains(Flags::NUM_FLAG)); + assert!(bit_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(!bit_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(bit_col).unwrap(), vec![170]); + + let date_col = &results[12].0; + assert_eq!(date_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DATE); + assert!(!date_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(date_col).unwrap(), + chrono::NaiveDate::from_ymd_opt(1000, 1, 1).unwrap(), + ); - let date_time_col = &results[13].0; - assert_eq!( - date_time_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_DATETIME - ); - assert!(!date_time_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(date_time_col).unwrap(), - chrono::NaiveDateTime::parse_from_str("9999-12-31 12:34:45", "%Y-%m-%d %H:%M:%S") - .unwrap() - ); + let date_time_col = &results[13].0; + assert_eq!( + date_time_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_DATETIME + ); + assert!(!date_time_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(date_time_col).unwrap(), + chrono::NaiveDateTime::parse_from_str("9999-12-31 12:34:45", "%Y-%m-%d %H:%M:%S") + .unwrap() + ); - let timestamp_col = &results[14].0; - assert_eq!( - timestamp_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_TIMESTAMP - ); - assert!(!timestamp_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(timestamp_col).unwrap(), - chrono::NaiveDateTime::parse_from_str("2020-01-01 10:10:10", "%Y-%m-%d %H:%M:%S") - .unwrap() - ); + let timestamp_col = &results[14].0; + assert_eq!( + timestamp_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_TIMESTAMP + ); + assert!(!timestamp_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(timestamp_col).unwrap(), + chrono::NaiveDateTime::parse_from_str("2020-01-01 10:10:10", "%Y-%m-%d %H:%M:%S") + .unwrap() + ); - let time_col = &results[15].0; - assert_eq!(time_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TIME); - assert!(!time_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(time_col).unwrap(), - chrono::NaiveTime::from_hms(23, 01, 01) - ); + let time_col = &results[15].0; + assert_eq!(time_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TIME); + assert!(!time_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(time_col).unwrap(), + chrono::NaiveTime::from_hms(23, 01, 01) + ); - let year_col = &results[16].0; - assert_eq!(year_col.tpe, ffi::enum_field_types::MYSQL_TYPE_YEAR); - assert!(year_col.flags.contains(Flags::NUM_FLAG)); - assert!(year_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(year_col), Ok(2020))); - - let char_col = &results[17].0; - assert_eq!(char_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!char_col.flags.contains(Flags::NUM_FLAG)); - assert!(!char_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!char_col.flags.contains(Flags::SET_FLAG)); - assert!(!char_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!char_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(char_col).unwrap(), "abc"); - - let varchar_col = &results[18].0; - assert_eq!( - varchar_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_VAR_STRING - ); - assert!(!varchar_col.flags.contains(Flags::NUM_FLAG)); - assert!(!varchar_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!varchar_col.flags.contains(Flags::SET_FLAG)); - assert!(!varchar_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!varchar_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(varchar_col).unwrap(), "foo"); - - let binary_col = &results[19].0; - assert_eq!(binary_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!binary_col.flags.contains(Flags::NUM_FLAG)); - assert!(!binary_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!binary_col.flags.contains(Flags::SET_FLAG)); - assert!(!binary_col.flags.contains(Flags::ENUM_FLAG)); - assert!(binary_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::>(binary_col).unwrap(), - b"a \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); + let year_col = &results[16].0; + assert_eq!(year_col.tpe, ffi::enum_field_types::MYSQL_TYPE_YEAR); + assert!(year_col.flags.contains(Flags::NUM_FLAG)); + assert!(year_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(year_col), Ok(2020))); + + let char_col = &results[17].0; + assert_eq!(char_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!char_col.flags.contains(Flags::NUM_FLAG)); + assert!(!char_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!char_col.flags.contains(Flags::SET_FLAG)); + assert!(!char_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!char_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(char_col).unwrap(), "abc"); + + let varchar_col = &results[18].0; + assert_eq!( + varchar_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_VAR_STRING + ); + assert!(!varchar_col.flags.contains(Flags::NUM_FLAG)); + assert!(!varchar_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!varchar_col.flags.contains(Flags::SET_FLAG)); + assert!(!varchar_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!varchar_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(varchar_col).unwrap(), "foo"); + + let binary_col = &results[19].0; + assert_eq!(binary_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!binary_col.flags.contains(Flags::NUM_FLAG)); + assert!(!binary_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!binary_col.flags.contains(Flags::SET_FLAG)); + assert!(!binary_col.flags.contains(Flags::ENUM_FLAG)); + assert!(binary_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::>(binary_col).unwrap(), + b"a \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + ); - let varbinary_col = &results[20].0; - assert_eq!( - varbinary_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_VAR_STRING - ); - assert!(!varbinary_col.flags.contains(Flags::NUM_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::SET_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::ENUM_FLAG)); - assert!(varbinary_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(varbinary_col).unwrap(), b"a "); - - let blob_col = &results[21].0; - assert_eq!(blob_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); - assert!(!blob_col.flags.contains(Flags::NUM_FLAG)); - assert!(blob_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!blob_col.flags.contains(Flags::SET_FLAG)); - assert!(!blob_col.flags.contains(Flags::ENUM_FLAG)); - assert!(blob_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(blob_col).unwrap(), b"binary"); - - let text_col = &results[22].0; - assert_eq!(text_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); - assert!(!text_col.flags.contains(Flags::NUM_FLAG)); - assert!(text_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!text_col.flags.contains(Flags::SET_FLAG)); - assert!(!text_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!text_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(text_col).unwrap(), - "some text whatever" - ); + let varbinary_col = &results[20].0; + assert_eq!( + varbinary_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_VAR_STRING + ); + assert!(!varbinary_col.flags.contains(Flags::NUM_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::SET_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::ENUM_FLAG)); + assert!(varbinary_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(varbinary_col).unwrap(), b"a "); + + let blob_col = &results[21].0; + assert_eq!(blob_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); + assert!(!blob_col.flags.contains(Flags::NUM_FLAG)); + assert!(blob_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!blob_col.flags.contains(Flags::SET_FLAG)); + assert!(!blob_col.flags.contains(Flags::ENUM_FLAG)); + assert!(blob_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(blob_col).unwrap(), b"binary"); + + let text_col = &results[22].0; + assert_eq!(text_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); + assert!(!text_col.flags.contains(Flags::NUM_FLAG)); + assert!(text_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!text_col.flags.contains(Flags::SET_FLAG)); + assert!(!text_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!text_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(text_col).unwrap(), + "some text whatever" + ); - let enum_col = &results[23].0; - assert_eq!(enum_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!enum_col.flags.contains(Flags::NUM_FLAG)); - assert!(!enum_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!enum_col.flags.contains(Flags::SET_FLAG)); - assert!(enum_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!enum_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(enum_col).unwrap(), "red"); - - let set_col = &results[24].0; - assert_eq!(set_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!set_col.flags.contains(Flags::NUM_FLAG)); - assert!(!set_col.flags.contains(Flags::BLOB_FLAG)); - assert!(set_col.flags.contains(Flags::SET_FLAG)); - assert!(!set_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!set_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(set_col).unwrap(), "one"); - - let geom = &results[25].0; - assert_eq!(geom.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!geom.flags.contains(Flags::NUM_FLAG)); - assert!(!geom.flags.contains(Flags::BLOB_FLAG)); - assert!(!geom.flags.contains(Flags::SET_FLAG)); - assert!(!geom.flags.contains(Flags::ENUM_FLAG)); - assert!(!geom.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(geom).unwrap(), "POINT(1 1)"); - - let point_col = &results[26].0; - assert_eq!(point_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!point_col.flags.contains(Flags::NUM_FLAG)); - assert!(!point_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!point_col.flags.contains(Flags::SET_FLAG)); - assert!(!point_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!point_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(point_col).unwrap(), "POINT(1 1)"); - - let linestring_col = &results[27].0; - assert_eq!( - linestring_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!linestring_col.flags.contains(Flags::NUM_FLAG)); - assert!(!linestring_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!linestring_col.flags.contains(Flags::SET_FLAG)); - assert!(!linestring_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!linestring_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(linestring_col).unwrap(), - "LINESTRING(0 0,1 1,2 2)" - ); + let enum_col = &results[23].0; + assert_eq!(enum_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!enum_col.flags.contains(Flags::NUM_FLAG)); + assert!(!enum_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!enum_col.flags.contains(Flags::SET_FLAG)); + assert!(enum_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!enum_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(enum_col).unwrap(), "red"); + + let set_col = &results[24].0; + assert_eq!(set_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!set_col.flags.contains(Flags::NUM_FLAG)); + assert!(!set_col.flags.contains(Flags::BLOB_FLAG)); + assert!(set_col.flags.contains(Flags::SET_FLAG)); + assert!(!set_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!set_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(set_col).unwrap(), "one"); + + let geom = &results[25].0; + assert_eq!(geom.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!geom.flags.contains(Flags::NUM_FLAG)); + assert!(!geom.flags.contains(Flags::BLOB_FLAG)); + assert!(!geom.flags.contains(Flags::SET_FLAG)); + assert!(!geom.flags.contains(Flags::ENUM_FLAG)); + assert!(!geom.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(geom).unwrap(), "POINT(1 1)"); + + let point_col = &results[26].0; + assert_eq!(point_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!point_col.flags.contains(Flags::NUM_FLAG)); + assert!(!point_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!point_col.flags.contains(Flags::SET_FLAG)); + assert!(!point_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!point_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(point_col).unwrap(), "POINT(1 1)"); + + let linestring_col = &results[27].0; + assert_eq!( + linestring_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!linestring_col.flags.contains(Flags::NUM_FLAG)); + assert!(!linestring_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!linestring_col.flags.contains(Flags::SET_FLAG)); + assert!(!linestring_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!linestring_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(linestring_col).unwrap(), + "LINESTRING(0 0,1 1,2 2)" + ); - let polygon_col = &results[28].0; - assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); - assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(polygon_col).unwrap(), - "POLYGON((0 0,10 0,10 10,0 10,0 0),(5 5,7 5,7 7,5 7,5 5))" - ); + let polygon_col = &results[28].0; + assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); + assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(polygon_col).unwrap(), + "POLYGON((0 0,10 0,10 10,0 10,0 0),(5 5,7 5,7 7,5 7,5 5))" + ); - let multipoint_col = &results[29].0; - assert_eq!( - multipoint_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!multipoint_col.flags.contains(Flags::NUM_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::SET_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::BINARY_FLAG)); - // older mysql and mariadb versions get back another encoding here - // we test for both as there seems to be no clear pattern when one or - // the other is returned - let multipoint_res = to_value::(multipoint_col).unwrap(); - assert!( - multipoint_res == "MULTIPOINT((0 0),(10 10),(10 20),(20 20))" - || multipoint_res == "MULTIPOINT(0 0,10 10,10 20,20 20)" - ); + let multipoint_col = &results[29].0; + assert_eq!( + multipoint_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!multipoint_col.flags.contains(Flags::NUM_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::SET_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::BINARY_FLAG)); + // older mysql and mariadb versions get back another encoding here + // we test for both as there seems to be no clear pattern when one or + // the other is returned + let multipoint_res = to_value::(multipoint_col).unwrap(); + assert!( + multipoint_res == "MULTIPOINT((0 0),(10 10),(10 20),(20 20))" + || multipoint_res == "MULTIPOINT(0 0,10 10,10 20,20 20)" + ); - let multilinestring_col = &results[30].0; - assert_eq!( - multilinestring_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!multilinestring_col.flags.contains(Flags::NUM_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::SET_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(multilinestring_col).unwrap(), - "MULTILINESTRING((10 48,10 21,10 0),(16 0,16 23,16 48))" - ); + let multilinestring_col = &results[30].0; + assert_eq!( + multilinestring_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!multilinestring_col.flags.contains(Flags::NUM_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::SET_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(multilinestring_col).unwrap(), + "MULTILINESTRING((10 48,10 21,10 0),(16 0,16 23,16 48))" + ); - let polygon_col = &results[31].0; - assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); - assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(polygon_col).unwrap(), - "MULTIPOLYGON(((28 26,28 0,84 0,84 42,28 26),(52 18,66 23,73 9,48 6,52 18)),((59 18,67 18,67 13,59 13,59 18)))" - ); + let polygon_col = &results[31].0; + assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); + assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(polygon_col).unwrap(), + "MULTIPOLYGON(((28 26,28 0,84 0,84 42,28 26),(52 18,66 23,73 9,48 6,52 18)),((59 18,67 18,67 13,59 13,59 18)))" + ); - let geometry_collection = &results[32].0; - assert_eq!( - geometry_collection.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!geometry_collection.flags.contains(Flags::NUM_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::BLOB_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::SET_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::ENUM_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(geometry_collection).unwrap(), - "GEOMETRYCOLLECTION(POINT(1 1),LINESTRING(0 0,1 1,2 2,3 3,4 4))" - ); + let geometry_collection = &results[32].0; + assert_eq!( + geometry_collection.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!geometry_collection.flags.contains(Flags::NUM_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::BLOB_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::SET_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::ENUM_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(geometry_collection).unwrap(), + "GEOMETRYCOLLECTION(POINT(1 1),LINESTRING(0 0,1 1,2 2,3 3,4 4))" + ); - let json_col = &results[33].0; - // mariadb >= 10.2 and mysql >=8.0 are supporting a json type - // from those mariadb >= 10.3 and mysql >= 8.0 are reporting - // json here, so we assert that we get back json - // mariadb 10.5 returns again blob - assert!( - json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_JSON - || json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_BLOB - ); - assert!(!json_col.flags.contains(Flags::NUM_FLAG)); - assert!(json_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!json_col.flags.contains(Flags::SET_FLAG)); - assert!(!json_col.flags.contains(Flags::ENUM_FLAG)); - assert!(json_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(json_col).unwrap(), - "{\"key1\": \"value1\", \"key2\": \"value2\"}" - ); + let json_col = &results[33].0; + // mariadb >= 10.2 and mysql >=8.0 are supporting a json type + // from those mariadb >= 10.3 and mysql >= 8.0 are reporting + // json here, so we assert that we get back json + // mariadb 10.5 returns again blob + assert!( + json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_JSON + || json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_BLOB + ); + assert!(!json_col.flags.contains(Flags::NUM_FLAG)); + assert!(json_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!json_col.flags.contains(Flags::SET_FLAG)); + assert!(!json_col.flags.contains(Flags::ENUM_FLAG)); + assert!(json_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(json_col).unwrap(), + "{\"key1\": \"value1\", \"key2\": \"value2\"}" + ); + Ok(()) + }).unwrap(); } fn query_single_table( diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 767104e9d187..841a269ba9db 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -203,11 +203,11 @@ where columns::column_name, >: QueryFragment, Conn::Backend: QueryMetadata<( - sql_types::Text, - sql_types::Text, - sql_types::Nullable, - sql_types::Text, - )> + 'static, + sql_types::Text, + sql_types::Text, + sql_types::Nullable, + sql_types::Text, + )> + 'static, { use self::information_schema::columns::dsl::*; diff --git a/diesel_tests/tests/deserialization.rs b/diesel_tests/tests/deserialization.rs index 968222b030bc..ecc630de3948 100644 --- a/diesel_tests/tests/deserialization.rs +++ b/diesel_tests/tests/deserialization.rs @@ -1,5 +1,4 @@ use crate::schema::*; -use diesel::deserialize::FromSqlRow; use diesel::prelude::*; use std::borrow::Cow; From 0b3eb5579a3b7879643bd04e451ed598f9bf4f6b Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 2 Jul 2021 11:08:05 +0200 Subject: [PATCH 17/32] Cleanup + use bindings from libsqlite3-sys --- diesel/Cargo.toml | 2 +- diesel/src/sqlite/connection/mod.rs | 2 -- diesel/src/sqlite/connection/sqlite_value.rs | 12 +++--------- diesel/src/sqlite/connection/stmt.rs | 17 ++--------------- 4 files changed, 6 insertions(+), 27 deletions(-) diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 9b5a684b620e..df61c69e6a8a 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -16,7 +16,7 @@ edition = "2018" byteorder = "1.0" chrono = { version = "0.4.19", optional = true, default-features = false, features = ["clock", "std"] } libc = { version = "0.2.0", optional = true } -libsqlite3-sys = { version = ">=0.8.0, <0.23.0", optional = true, features = ["min_sqlite_version_3_7_16"] } +libsqlite3-sys = { version = ">=0.8.0, <0.23.0", optional = true, features = ["bundled_bindings"] } mysqlclient-sys = { version = "0.2.0", optional = true } pq-sys = { version = "0.4.0", optional = true } quickcheck = { version = "0.9.0", optional = true } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index e41820bc1c31..edc310534bfd 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -91,7 +91,6 @@ impl Connection for SqliteConnection { Ok(self.raw_connection.rows_affected_by_last_query()) } - //#[tracing::instrument(skip(self, source))] #[doc(hidden)] fn load<'a, T>( &'a mut self, @@ -215,7 +214,6 @@ impl SqliteConnection { } } - //#[tracing::instrument(skip(self, source, f))] fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( &'a mut self, source: &'_ T, diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 076db8887768..67dc4d4eff75 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -8,11 +8,6 @@ use crate::sqlite::SqliteType; use super::row::PrivateSqliteRow; -extern "C" { - pub fn sqlite3_value_free(value: *mut ffi::sqlite3_value); - pub fn sqlite3_value_dup(value: *const ffi::sqlite3_value) -> *mut ffi::sqlite3_value; -} - /// Raw sqlite value as received from the database /// /// Use existing `FromSql` implementations to convert this into @@ -40,7 +35,7 @@ pub struct OwnedSqliteValue { impl Drop for OwnedSqliteValue { fn drop(&mut self) { - unsafe { sqlite3_value_free(self.value.as_ptr()) } + unsafe { ffi::sqlite3_value_free(self.value.as_ptr()) } } } @@ -107,7 +102,6 @@ impl<'a, 'b> SqliteValue<'a, 'b> { unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) } } - //#[tracing::instrument(skip(self))] /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; @@ -132,7 +126,7 @@ impl OwnedSqliteValue { if ffi::SQLITE_NULL == tpe { return None; } - let value = unsafe { sqlite3_value_dup(ptr.as_ptr()) }; + let value = unsafe { ffi::sqlite3_value_dup(ptr.as_ptr()) }; Some(Self { value: NonNull::new(value)?, }) @@ -140,7 +134,7 @@ impl OwnedSqliteValue { pub(super) fn duplicate(&self) -> OwnedSqliteValue { // self.value is a `NonNull` ptr so this cannot be null - let value = unsafe { sqlite3_value_dup(self.value.as_ptr()) }; + let value = unsafe { ffi::sqlite3_value_dup(self.value.as_ptr()) }; let value = NonNull::new(value).expect( "Sqlite documentation states this returns only null if value is null \ or OOM. If you ever see this panic message please open an issue at \ diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 77dc3fa88b0e..d41cf0fba221 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -18,19 +18,6 @@ pub struct Statement { bind_index: libc::c_int, } -const SQLITE_PREPARE_PERSISTENT: libc::c_uint = 0x01; - -extern "C" { - fn sqlite3_prepare_v3( - db: *mut ffi::sqlite3, - zSql: *const libc::c_char, - nByte: libc::c_int, - flags: libc::c_uint, - ppStmt: *mut *mut ffi::sqlite3_stmt, - pzTail: *mut *const libc::c_char, - ) -> libc::c_int; -} - impl Statement { pub fn prepare( raw_connection: &RawConnection, @@ -40,12 +27,12 @@ impl Statement { let mut stmt = ptr::null_mut(); let mut unused_portion = ptr::null(); let prepare_result = unsafe { - sqlite3_prepare_v3( + ffi::sqlite3_prepare_v3( raw_connection.internal_connection.as_ptr(), CString::new(sql)?.as_ptr(), sql.len() as libc::c_int, if matches!(is_cached, PrepareForCache::Yes) { - SQLITE_PREPARE_PERSISTENT + ffi::SQLITE_PREPARE_PERSISTENT as u32 } else { 0 }, From 77cd7aff4ac2ada4cf8173453b0c5828a07bad58 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Mon, 5 Jul 2021 11:03:57 +0200 Subject: [PATCH 18/32] Simplify the handling of not cached prepared statements for sqlite and mysql connections --- diesel/src/mysql/connection/bind.rs | 686 +++++++++---------- diesel/src/mysql/connection/mod.rs | 43 +- diesel/src/mysql/connection/stmt/iterator.rs | 9 +- diesel/src/mysql/connection/stmt/mod.rs | 11 - diesel/src/sqlite/connection/mod.rs | 36 +- diesel/src/sqlite/connection/stmt.rs | 6 +- 6 files changed, 378 insertions(+), 413 deletions(-) diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index fb37de35c41e..c6ae063b2b8d 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -715,7 +715,7 @@ mod tests { ) .unwrap(); - conn.with_prepared_query(&crate::sql_query( + let mut stmt = conn.prepared_query(&crate::sql_query( "SELECT tiny_int, small_int, medium_int, int_col, big_int, unsigned_int, zero_fill_int, @@ -726,367 +726,365 @@ mod tests { ST_AsText(polygon_col), ST_AsText(multipoint_col), ST_AsText(multilinestring_col), ST_AsText(multipolygon_col), ST_AsText(geometry_collection), json_col FROM all_mysql_types", - ), |mut stmt, _| { - - let metadata = stmt.metadata().unwrap(); - let mut output_binds = - Binds::from_output_types(&vec![None; metadata.fields().len()], &metadata); - stmt.execute_statement(&mut output_binds).unwrap(); - stmt.populate_row_buffers(&mut output_binds).unwrap(); - - let results: Vec<(BindData, &_)> = output_binds - .data - .into_iter() - .zip(metadata.fields()) - .collect::>(); - - let tiny_int_col = &results[0].0; - assert_eq!(tiny_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TINY); - assert!(tiny_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!tiny_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(tiny_int_col), Ok(0))); - - let small_int_col = &results[1].0; - assert_eq!(small_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_SHORT); - assert!(small_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!small_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(small_int_col), Ok(1))); - - let medium_int_col = &results[2].0; - assert_eq!(medium_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_INT24); - assert!(medium_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!medium_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(medium_int_col), Ok(2))); - - let int_col = &results[3].0; - assert_eq!(int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); - assert!(int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(int_col), Ok(3))); - - let big_int_col = &results[4].0; - assert_eq!(big_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONGLONG); - assert!(big_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(!big_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(big_int_col), Ok(-5))); - - let unsigned_int_col = &results[5].0; - assert_eq!(unsigned_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); - assert!(unsigned_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(unsigned_int_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!( - to_value::, u32>(unsigned_int_col), - Ok(42) - )); - - let zero_fill_int_col = &results[6].0; - assert_eq!( - zero_fill_int_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG - ); - assert!(zero_fill_int_col.flags.contains(Flags::NUM_FLAG)); - assert!(zero_fill_int_col.flags.contains(Flags::ZEROFILL_FLAG)); - assert!(matches!(to_value::(zero_fill_int_col), Ok(1))); + )).unwrap(); - let numeric_col = &results[7].0; - assert_eq!( - numeric_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL - ); - assert!(numeric_col.flags.contains(Flags::NUM_FLAG)); - assert!(!numeric_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!( - to_value::(numeric_col).unwrap(), - bigdecimal::BigDecimal::from_f32(-999.999).unwrap() - ); + let metadata = stmt.metadata().unwrap(); + let mut output_binds = + Binds::from_output_types(&vec![None; metadata.fields().len()], &metadata); + stmt.execute_statement(&mut output_binds).unwrap(); + stmt.populate_row_buffers(&mut output_binds).unwrap(); - let decimal_col = &results[8].0; - assert_eq!( - decimal_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL - ); - assert!(decimal_col.flags.contains(Flags::NUM_FLAG)); - assert!(!decimal_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!( - to_value::(decimal_col).unwrap(), - bigdecimal::BigDecimal::from_f32(3.14).unwrap() - ); + let results: Vec<(BindData, &_)> = output_binds + .data + .into_iter() + .zip(metadata.fields()) + .collect::>(); - let float_col = &results[9].0; - assert_eq!(float_col.tpe, ffi::enum_field_types::MYSQL_TYPE_FLOAT); - assert!(float_col.flags.contains(Flags::NUM_FLAG)); - assert!(!float_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!(to_value::(float_col).unwrap(), 1.23); - - let double_col = &results[10].0; - assert_eq!(double_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DOUBLE); - assert!(double_col.flags.contains(Flags::NUM_FLAG)); - assert!(!double_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert_eq!(to_value::(double_col).unwrap(), 4.5678); - - let bit_col = &results[11].0; - assert_eq!(bit_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BIT); - assert!(!bit_col.flags.contains(Flags::NUM_FLAG)); - assert!(bit_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(!bit_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(bit_col).unwrap(), vec![170]); - - let date_col = &results[12].0; - assert_eq!(date_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DATE); - assert!(!date_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(date_col).unwrap(), - chrono::NaiveDate::from_ymd_opt(1000, 1, 1).unwrap(), - ); + let tiny_int_col = &results[0].0; + assert_eq!(tiny_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TINY); + assert!(tiny_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!tiny_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(tiny_int_col), Ok(0))); + + let small_int_col = &results[1].0; + assert_eq!(small_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_SHORT); + assert!(small_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!small_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(small_int_col), Ok(1))); + + let medium_int_col = &results[2].0; + assert_eq!(medium_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_INT24); + assert!(medium_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!medium_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(medium_int_col), Ok(2))); + + let int_col = &results[3].0; + assert_eq!(int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); + assert!(int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(int_col), Ok(3))); + + let big_int_col = &results[4].0; + assert_eq!(big_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONGLONG); + assert!(big_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(!big_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(big_int_col), Ok(-5))); + + let unsigned_int_col = &results[5].0; + assert_eq!(unsigned_int_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG); + assert!(unsigned_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(unsigned_int_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!( + to_value::, u32>(unsigned_int_col), + Ok(42) + )); + + let zero_fill_int_col = &results[6].0; + assert_eq!( + zero_fill_int_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG + ); + assert!(zero_fill_int_col.flags.contains(Flags::NUM_FLAG)); + assert!(zero_fill_int_col.flags.contains(Flags::ZEROFILL_FLAG)); + assert!(matches!(to_value::(zero_fill_int_col), Ok(1))); - let date_time_col = &results[13].0; - assert_eq!( - date_time_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_DATETIME - ); - assert!(!date_time_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(date_time_col).unwrap(), - chrono::NaiveDateTime::parse_from_str("9999-12-31 12:34:45", "%Y-%m-%d %H:%M:%S") - .unwrap() - ); + let numeric_col = &results[7].0; + assert_eq!( + numeric_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL + ); + assert!(numeric_col.flags.contains(Flags::NUM_FLAG)); + assert!(!numeric_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!( + to_value::(numeric_col).unwrap(), + bigdecimal::BigDecimal::from_f32(-999.999).unwrap() + ); - let timestamp_col = &results[14].0; - assert_eq!( - timestamp_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_TIMESTAMP - ); - assert!(!timestamp_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(timestamp_col).unwrap(), - chrono::NaiveDateTime::parse_from_str("2020-01-01 10:10:10", "%Y-%m-%d %H:%M:%S") - .unwrap() - ); + let decimal_col = &results[8].0; + assert_eq!( + decimal_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_NEWDECIMAL + ); + assert!(decimal_col.flags.contains(Flags::NUM_FLAG)); + assert!(!decimal_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!( + to_value::(decimal_col).unwrap(), + bigdecimal::BigDecimal::from_f32(3.14).unwrap() + ); - let time_col = &results[15].0; - assert_eq!(time_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TIME); - assert!(!time_col.flags.contains(Flags::NUM_FLAG)); - assert_eq!( - to_value::(time_col).unwrap(), - chrono::NaiveTime::from_hms(23, 01, 01) - ); + let float_col = &results[9].0; + assert_eq!(float_col.tpe, ffi::enum_field_types::MYSQL_TYPE_FLOAT); + assert!(float_col.flags.contains(Flags::NUM_FLAG)); + assert!(!float_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!(to_value::(float_col).unwrap(), 1.23); + + let double_col = &results[10].0; + assert_eq!(double_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DOUBLE); + assert!(double_col.flags.contains(Flags::NUM_FLAG)); + assert!(!double_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert_eq!(to_value::(double_col).unwrap(), 4.5678); + + let bit_col = &results[11].0; + assert_eq!(bit_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BIT); + assert!(!bit_col.flags.contains(Flags::NUM_FLAG)); + assert!(bit_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(!bit_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(bit_col).unwrap(), vec![170]); + + let date_col = &results[12].0; + assert_eq!(date_col.tpe, ffi::enum_field_types::MYSQL_TYPE_DATE); + assert!(!date_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(date_col).unwrap(), + chrono::NaiveDate::from_ymd_opt(1000, 1, 1).unwrap(), + ); - let year_col = &results[16].0; - assert_eq!(year_col.tpe, ffi::enum_field_types::MYSQL_TYPE_YEAR); - assert!(year_col.flags.contains(Flags::NUM_FLAG)); - assert!(year_col.flags.contains(Flags::UNSIGNED_FLAG)); - assert!(matches!(to_value::(year_col), Ok(2020))); - - let char_col = &results[17].0; - assert_eq!(char_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!char_col.flags.contains(Flags::NUM_FLAG)); - assert!(!char_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!char_col.flags.contains(Flags::SET_FLAG)); - assert!(!char_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!char_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(char_col).unwrap(), "abc"); - - let varchar_col = &results[18].0; - assert_eq!( - varchar_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_VAR_STRING - ); - assert!(!varchar_col.flags.contains(Flags::NUM_FLAG)); - assert!(!varchar_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!varchar_col.flags.contains(Flags::SET_FLAG)); - assert!(!varchar_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!varchar_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(varchar_col).unwrap(), "foo"); - - let binary_col = &results[19].0; - assert_eq!(binary_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!binary_col.flags.contains(Flags::NUM_FLAG)); - assert!(!binary_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!binary_col.flags.contains(Flags::SET_FLAG)); - assert!(!binary_col.flags.contains(Flags::ENUM_FLAG)); - assert!(binary_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::>(binary_col).unwrap(), - b"a \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - ); + let date_time_col = &results[13].0; + assert_eq!( + date_time_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_DATETIME + ); + assert!(!date_time_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(date_time_col).unwrap(), + chrono::NaiveDateTime::parse_from_str("9999-12-31 12:34:45", "%Y-%m-%d %H:%M:%S") + .unwrap() + ); - let varbinary_col = &results[20].0; - assert_eq!( - varbinary_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_VAR_STRING - ); - assert!(!varbinary_col.flags.contains(Flags::NUM_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::SET_FLAG)); - assert!(!varbinary_col.flags.contains(Flags::ENUM_FLAG)); - assert!(varbinary_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(varbinary_col).unwrap(), b"a "); - - let blob_col = &results[21].0; - assert_eq!(blob_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); - assert!(!blob_col.flags.contains(Flags::NUM_FLAG)); - assert!(blob_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!blob_col.flags.contains(Flags::SET_FLAG)); - assert!(!blob_col.flags.contains(Flags::ENUM_FLAG)); - assert!(blob_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::>(blob_col).unwrap(), b"binary"); - - let text_col = &results[22].0; - assert_eq!(text_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); - assert!(!text_col.flags.contains(Flags::NUM_FLAG)); - assert!(text_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!text_col.flags.contains(Flags::SET_FLAG)); - assert!(!text_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!text_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(text_col).unwrap(), - "some text whatever" - ); + let timestamp_col = &results[14].0; + assert_eq!( + timestamp_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_TIMESTAMP + ); + assert!(!timestamp_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(timestamp_col).unwrap(), + chrono::NaiveDateTime::parse_from_str("2020-01-01 10:10:10", "%Y-%m-%d %H:%M:%S") + .unwrap() + ); - let enum_col = &results[23].0; - assert_eq!(enum_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!enum_col.flags.contains(Flags::NUM_FLAG)); - assert!(!enum_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!enum_col.flags.contains(Flags::SET_FLAG)); - assert!(enum_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!enum_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(enum_col).unwrap(), "red"); - - let set_col = &results[24].0; - assert_eq!(set_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); - assert!(!set_col.flags.contains(Flags::NUM_FLAG)); - assert!(!set_col.flags.contains(Flags::BLOB_FLAG)); - assert!(set_col.flags.contains(Flags::SET_FLAG)); - assert!(!set_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!set_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(set_col).unwrap(), "one"); - - let geom = &results[25].0; - assert_eq!(geom.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!geom.flags.contains(Flags::NUM_FLAG)); - assert!(!geom.flags.contains(Flags::BLOB_FLAG)); - assert!(!geom.flags.contains(Flags::SET_FLAG)); - assert!(!geom.flags.contains(Flags::ENUM_FLAG)); - assert!(!geom.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(geom).unwrap(), "POINT(1 1)"); - - let point_col = &results[26].0; - assert_eq!(point_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!point_col.flags.contains(Flags::NUM_FLAG)); - assert!(!point_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!point_col.flags.contains(Flags::SET_FLAG)); - assert!(!point_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!point_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!(to_value::(point_col).unwrap(), "POINT(1 1)"); - - let linestring_col = &results[27].0; - assert_eq!( - linestring_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!linestring_col.flags.contains(Flags::NUM_FLAG)); - assert!(!linestring_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!linestring_col.flags.contains(Flags::SET_FLAG)); - assert!(!linestring_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!linestring_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(linestring_col).unwrap(), - "LINESTRING(0 0,1 1,2 2)" - ); + let time_col = &results[15].0; + assert_eq!(time_col.tpe, ffi::enum_field_types::MYSQL_TYPE_TIME); + assert!(!time_col.flags.contains(Flags::NUM_FLAG)); + assert_eq!( + to_value::(time_col).unwrap(), + chrono::NaiveTime::from_hms(23, 01, 01) + ); - let polygon_col = &results[28].0; - assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); - assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(polygon_col).unwrap(), - "POLYGON((0 0,10 0,10 10,0 10,0 0),(5 5,7 5,7 7,5 7,5 5))" - ); + let year_col = &results[16].0; + assert_eq!(year_col.tpe, ffi::enum_field_types::MYSQL_TYPE_YEAR); + assert!(year_col.flags.contains(Flags::NUM_FLAG)); + assert!(year_col.flags.contains(Flags::UNSIGNED_FLAG)); + assert!(matches!(to_value::(year_col), Ok(2020))); + + let char_col = &results[17].0; + assert_eq!(char_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!char_col.flags.contains(Flags::NUM_FLAG)); + assert!(!char_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!char_col.flags.contains(Flags::SET_FLAG)); + assert!(!char_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!char_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(char_col).unwrap(), "abc"); + + let varchar_col = &results[18].0; + assert_eq!( + varchar_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_VAR_STRING + ); + assert!(!varchar_col.flags.contains(Flags::NUM_FLAG)); + assert!(!varchar_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!varchar_col.flags.contains(Flags::SET_FLAG)); + assert!(!varchar_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!varchar_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(varchar_col).unwrap(), "foo"); + + let binary_col = &results[19].0; + assert_eq!(binary_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!binary_col.flags.contains(Flags::NUM_FLAG)); + assert!(!binary_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!binary_col.flags.contains(Flags::SET_FLAG)); + assert!(!binary_col.flags.contains(Flags::ENUM_FLAG)); + assert!(binary_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::>(binary_col).unwrap(), + b"a \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + ); - let multipoint_col = &results[29].0; - assert_eq!( - multipoint_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!multipoint_col.flags.contains(Flags::NUM_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::SET_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!multipoint_col.flags.contains(Flags::BINARY_FLAG)); - // older mysql and mariadb versions get back another encoding here - // we test for both as there seems to be no clear pattern when one or - // the other is returned - let multipoint_res = to_value::(multipoint_col).unwrap(); - assert!( - multipoint_res == "MULTIPOINT((0 0),(10 10),(10 20),(20 20))" - || multipoint_res == "MULTIPOINT(0 0,10 10,10 20,20 20)" - ); + let varbinary_col = &results[20].0; + assert_eq!( + varbinary_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_VAR_STRING + ); + assert!(!varbinary_col.flags.contains(Flags::NUM_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::SET_FLAG)); + assert!(!varbinary_col.flags.contains(Flags::ENUM_FLAG)); + assert!(varbinary_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(varbinary_col).unwrap(), b"a "); + + let blob_col = &results[21].0; + assert_eq!(blob_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); + assert!(!blob_col.flags.contains(Flags::NUM_FLAG)); + assert!(blob_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!blob_col.flags.contains(Flags::SET_FLAG)); + assert!(!blob_col.flags.contains(Flags::ENUM_FLAG)); + assert!(blob_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::>(blob_col).unwrap(), b"binary"); + + let text_col = &results[22].0; + assert_eq!(text_col.tpe, ffi::enum_field_types::MYSQL_TYPE_BLOB); + assert!(!text_col.flags.contains(Flags::NUM_FLAG)); + assert!(text_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!text_col.flags.contains(Flags::SET_FLAG)); + assert!(!text_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!text_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(text_col).unwrap(), + "some text whatever" + ); - let multilinestring_col = &results[30].0; - assert_eq!( - multilinestring_col.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!multilinestring_col.flags.contains(Flags::NUM_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::SET_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!multilinestring_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(multilinestring_col).unwrap(), - "MULTILINESTRING((10 48,10 21,10 0),(16 0,16 23,16 48))" - ); + let enum_col = &results[23].0; + assert_eq!(enum_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!enum_col.flags.contains(Flags::NUM_FLAG)); + assert!(!enum_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!enum_col.flags.contains(Flags::SET_FLAG)); + assert!(enum_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!enum_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(enum_col).unwrap(), "red"); + + let set_col = &results[24].0; + assert_eq!(set_col.tpe, ffi::enum_field_types::MYSQL_TYPE_STRING); + assert!(!set_col.flags.contains(Flags::NUM_FLAG)); + assert!(!set_col.flags.contains(Flags::BLOB_FLAG)); + assert!(set_col.flags.contains(Flags::SET_FLAG)); + assert!(!set_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!set_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(set_col).unwrap(), "one"); + + let geom = &results[25].0; + assert_eq!(geom.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!geom.flags.contains(Flags::NUM_FLAG)); + assert!(!geom.flags.contains(Flags::BLOB_FLAG)); + assert!(!geom.flags.contains(Flags::SET_FLAG)); + assert!(!geom.flags.contains(Flags::ENUM_FLAG)); + assert!(!geom.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(geom).unwrap(), "POINT(1 1)"); + + let point_col = &results[26].0; + assert_eq!(point_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!point_col.flags.contains(Flags::NUM_FLAG)); + assert!(!point_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!point_col.flags.contains(Flags::SET_FLAG)); + assert!(!point_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!point_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!(to_value::(point_col).unwrap(), "POINT(1 1)"); + + let linestring_col = &results[27].0; + assert_eq!( + linestring_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!linestring_col.flags.contains(Flags::NUM_FLAG)); + assert!(!linestring_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!linestring_col.flags.contains(Flags::SET_FLAG)); + assert!(!linestring_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!linestring_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(linestring_col).unwrap(), + "LINESTRING(0 0,1 1,2 2)" + ); - let polygon_col = &results[31].0; - assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); - assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); - assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); - assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( + let polygon_col = &results[28].0; + assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); + assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(polygon_col).unwrap(), + "POLYGON((0 0,10 0,10 10,0 10,0 0),(5 5,7 5,7 7,5 7,5 5))" + ); + + let multipoint_col = &results[29].0; + assert_eq!( + multipoint_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!multipoint_col.flags.contains(Flags::NUM_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::SET_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!multipoint_col.flags.contains(Flags::BINARY_FLAG)); + // older mysql and mariadb versions get back another encoding here + // we test for both as there seems to be no clear pattern when one or + // the other is returned + let multipoint_res = to_value::(multipoint_col).unwrap(); + assert!( + multipoint_res == "MULTIPOINT((0 0),(10 10),(10 20),(20 20))" + || multipoint_res == "MULTIPOINT(0 0,10 10,10 20,20 20)" + ); + + let multilinestring_col = &results[30].0; + assert_eq!( + multilinestring_col.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!multilinestring_col.flags.contains(Flags::NUM_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::SET_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!multilinestring_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(multilinestring_col).unwrap(), + "MULTILINESTRING((10 48,10 21,10 0),(16 0,16 23,16 48))" + ); + + let polygon_col = &results[31].0; + assert_eq!(polygon_col.tpe, ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB); + assert!(!polygon_col.flags.contains(Flags::NUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!polygon_col.flags.contains(Flags::SET_FLAG)); + assert!(!polygon_col.flags.contains(Flags::ENUM_FLAG)); + assert!(!polygon_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( to_value::(polygon_col).unwrap(), "MULTIPOLYGON(((28 26,28 0,84 0,84 42,28 26),(52 18,66 23,73 9,48 6,52 18)),((59 18,67 18,67 13,59 13,59 18)))" ); - let geometry_collection = &results[32].0; - assert_eq!( - geometry_collection.tpe, - ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB - ); - assert!(!geometry_collection.flags.contains(Flags::NUM_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::BLOB_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::SET_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::ENUM_FLAG)); - assert!(!geometry_collection.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(geometry_collection).unwrap(), - "GEOMETRYCOLLECTION(POINT(1 1),LINESTRING(0 0,1 1,2 2,3 3,4 4))" - ); + let geometry_collection = &results[32].0; + assert_eq!( + geometry_collection.tpe, + ffi::enum_field_types::MYSQL_TYPE_LONG_BLOB + ); + assert!(!geometry_collection.flags.contains(Flags::NUM_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::BLOB_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::SET_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::ENUM_FLAG)); + assert!(!geometry_collection.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(geometry_collection).unwrap(), + "GEOMETRYCOLLECTION(POINT(1 1),LINESTRING(0 0,1 1,2 2,3 3,4 4))" + ); - let json_col = &results[33].0; - // mariadb >= 10.2 and mysql >=8.0 are supporting a json type - // from those mariadb >= 10.3 and mysql >= 8.0 are reporting - // json here, so we assert that we get back json - // mariadb 10.5 returns again blob - assert!( - json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_JSON - || json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_BLOB - ); - assert!(!json_col.flags.contains(Flags::NUM_FLAG)); - assert!(json_col.flags.contains(Flags::BLOB_FLAG)); - assert!(!json_col.flags.contains(Flags::SET_FLAG)); - assert!(!json_col.flags.contains(Flags::ENUM_FLAG)); - assert!(json_col.flags.contains(Flags::BINARY_FLAG)); - assert_eq!( - to_value::(json_col).unwrap(), - "{\"key1\": \"value1\", \"key2\": \"value2\"}" - ); - Ok(()) - }).unwrap(); + let json_col = &results[33].0; + // mariadb >= 10.2 and mysql >=8.0 are supporting a json type + // from those mariadb >= 10.3 and mysql >= 8.0 are reporting + // json here, so we assert that we get back json + // mariadb 10.5 returns again blob + assert!( + json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_JSON + || json_col.tpe == ffi::enum_field_types::MYSQL_TYPE_BLOB + ); + assert!(!json_col.flags.contains(Flags::NUM_FLAG)); + assert!(json_col.flags.contains(Flags::BLOB_FLAG)); + assert!(!json_col.flags.contains(Flags::SET_FLAG)); + assert!(!json_col.flags.contains(Flags::ENUM_FLAG)); + assert!(json_col.flags.contains(Flags::BINARY_FLAG)); + assert_eq!( + to_value::(json_col).unwrap(), + "{\"key1\": \"value1\", \"key2\": \"value2\"}" + ); } fn query_single_table( diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 48b379f7f2de..1fbb07cbb1d4 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -4,6 +4,7 @@ mod stmt; mod url; use self::raw::RawConnection; +use self::stmt::iterator::StatementIterator; use self::stmt::Statement; use self::url::ConnectionOptions; use super::backend::Mysql; @@ -20,7 +21,6 @@ pub struct MysqlConnection { raw_connection: RawConnection, transaction_state: AnsiTransactionManager, statement_cache: StatementCache, - current_statement: Option, } unsafe impl Send for MysqlConnection {} @@ -51,7 +51,6 @@ impl Connection for MysqlConnection { raw_connection, transaction_state: AnsiTransactionManager::default(), statement_cache: StatementCache::new(), - current_statement: None, }; conn.set_config_options() .map_err(CouldntSetupConfiguration)?; @@ -75,22 +74,12 @@ impl Connection for MysqlConnection { T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, { - self.with_prepared_query(&source.as_query(), |stmt, current_statement| { - let mut metadata = Vec::new(); - Mysql::row_metadata(&mut (), &mut metadata); - let stmt = match stmt { - MaybeCached::CannotCache(stmt) => { - *current_statement = Some(stmt); - current_statement - .as_mut() - .expect("We set it literally above") - } - MaybeCached::Cached(stmt) => stmt, - }; - - let results = unsafe { stmt.results(&metadata)? }; - Ok(results) - }) + let stmt = self.prepared_query(&source.as_query())?; + + let mut metadata = Vec::new(); + Mysql::row_metadata(&mut (), &mut metadata); + + StatementIterator::from_stmt(stmt, &metadata) } #[doc(hidden)] @@ -98,12 +87,11 @@ impl Connection for MysqlConnection { where T: QueryFragment + QueryId, { - self.with_prepared_query(source, |stmt, _| { - unsafe { - stmt.execute()?; - } - Ok(stmt.affected_rows()) - }) + let stmt = self.prepared_query(source)?; + unsafe { + stmt.execute()?; + } + Ok(stmt.affected_rows()) } #[doc(hidden)] @@ -113,11 +101,10 @@ impl Connection for MysqlConnection { } impl MysqlConnection { - fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + fn prepared_query<'a, T: QueryFragment + QueryId>( &'a mut self, source: &'_ T, - f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option) -> QueryResult, - ) -> QueryResult { + ) -> QueryResult> { let cache = &mut self.statement_cache; let conn = &mut self.raw_connection; @@ -129,7 +116,7 @@ impl MysqlConnection { .into_iter() .zip(bind_collector.binds); stmt.bind(binds)?; - f(stmt, &mut self.current_statement) + Ok(stmt) } fn set_config_options(&mut self) -> QueryResult<()> { diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index bf64bf4815da..ff666ce96458 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -3,13 +3,14 @@ use std::marker::PhantomData; use std::rc::Rc; use super::{Binds, Statement, StatementMetadata}; +use crate::connection::MaybeCached; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; #[allow(missing_debug_implementations)] pub struct StatementIterator<'a> { - stmt: &'a mut Statement, + stmt: MaybeCached<'a, Statement>, last_row: Rc>, metadata: Rc, size: usize, @@ -17,8 +18,10 @@ pub struct StatementIterator<'a> { } impl<'a> StatementIterator<'a> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement, types: &[Option]) -> QueryResult { + pub fn from_stmt( + mut stmt: MaybeCached<'a, Statement>, + types: &[Option], + ) -> QueryResult { let metadata = stmt.metadata()?; let mut output_binds = Binds::from_output_types(types, &metadata); diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 3ad0d52ebf8d..a5b5fab4b61b 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -8,7 +8,6 @@ use std::ffi::CStr; use std::os::raw as libc; use std::ptr::NonNull; -use self::iterator::*; use super::bind::Binds; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, Error, QueryResult}; @@ -77,16 +76,6 @@ impl Statement { affected_rows as usize } - /// This function should be called instead of `execute` for queries which - /// have a return value. After calling this function, `execute` can never - /// be called on this statement. - pub unsafe fn results<'a>( - &'a mut self, - types: &[Option], - ) -> QueryResult> { - StatementIterator::new(self, types) - } - /// This function should be called after `execute` only /// otherwise it's not guranteed to return a valid result pub(in crate::mysql::connection) unsafe fn result_size(&mut self) -> QueryResult { diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index edc310534bfd..0b319006a0f2 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -35,11 +35,10 @@ use crate::sqlite::Sqlite; /// - Special identifiers (`:memory:`) #[allow(missing_debug_implementations)] pub struct SqliteConnection { - // Both statement_cache and current_statement needs to be before raw_connection + // statement_cache needs to be before raw_connection // otherwise we will get errors about open statements before closing the // connection itself statement_cache: StatementCache, - current_statement: Option, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, } @@ -78,7 +77,6 @@ impl Connection for SqliteConnection { statement_cache: StatementCache::new(), raw_connection, transaction_state: AnsiTransactionManager::default(), - current_statement: None, }; conn.register_diesel_sql_functions() .map_err(CouldntSetupConfiguration)?; @@ -101,19 +99,10 @@ impl Connection for SqliteConnection { T::Query: QueryFragment + QueryId, Self::Backend: QueryMetadata, { - self.with_prepared_query(&source.as_query(), |stmt, current_statement| { - let statement = match stmt { - MaybeCached::CannotCache(stmt) => { - *current_statement = Some(stmt); - current_statement - .as_mut() - .expect("We set it literally above") - } - MaybeCached::Cached(stmt) => stmt, - }; - let statement_use = StatementUse::new(statement); - Ok(StatementIterator::new(statement_use)) - }) + let stmt = self.prepared_query(&source.as_query())?; + + let statement_use = StatementUse::new(stmt); + Ok(StatementIterator::new(statement_use)) } #[doc(hidden)] @@ -121,10 +110,10 @@ impl Connection for SqliteConnection { where T: QueryFragment + QueryId, { - self.with_prepared_query(source, |mut stmt, _| { - let statement_use = StatementUse::new(&mut stmt); - statement_use.run() - })?; + let stmt = self.prepared_query(source)?; + + let statement_use = StatementUse::new(stmt); + statement_use.run()?; Ok(self.raw_connection.rows_affected_by_last_query()) } @@ -214,11 +203,10 @@ impl SqliteConnection { } } - fn with_prepared_query<'a, T: QueryFragment + QueryId, R>( + fn prepared_query<'a, T: QueryFragment + QueryId>( &'a mut self, source: &'_ T, - f: impl FnOnce(MaybeCached<'a, Statement>, &'a mut Option) -> QueryResult, - ) -> QueryResult { + ) -> QueryResult> { let raw_connection = &self.raw_connection; let cache = &mut self.statement_cache; let mut statement = cache.cached_statement(source, &[], |sql, is_cached| { @@ -233,7 +221,7 @@ impl SqliteConnection { statement.bind(tpe, value)?; } - f(statement, &mut self.current_statement) + Ok(statement) } #[doc(hidden)] diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index d41cf0fba221..b3d284d39894 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -3,7 +3,7 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; use super::sqlite_value::OwnedSqliteValue; -use crate::connection::PrepareForCache; +use crate::connection::{MaybeCached, PrepareForCache}; use crate::result::Error::DatabaseError; use crate::result::*; use crate::sqlite::SqliteType; @@ -128,12 +128,12 @@ impl Drop for Statement { #[allow(missing_debug_implementations)] pub struct StatementUse<'a> { - statement: &'a mut Statement, + statement: MaybeCached<'a, Statement>, column_names: OnceCell>, } impl<'a> StatementUse<'a> { - pub(in crate::sqlite::connection) fn new(statement: &'a mut Statement) -> Self { + pub(in crate::sqlite::connection) fn new(statement: MaybeCached<'a, Statement>) -> Self { StatementUse { statement, column_names: OnceCell::new(), From b644647040979dcfd11b99a383157158fafb0d63 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Mon, 5 Jul 2021 14:24:31 +0200 Subject: [PATCH 19/32] Minor cleanups --- diesel/src/lib.rs | 2 +- diesel/src/mysql/connection/stmt/iterator.rs | 1 - diesel/src/pg/connection/cursor.rs | 3 +-- diesel_bench/Cargo.toml | 10 +--------- 4 files changed, 3 insertions(+), 13 deletions(-) diff --git a/diesel/src/lib.rs b/diesel/src/lib.rs index 59a15bcf280d..96211a2569a5 100644 --- a/diesel/src/lib.rs +++ b/diesel/src/lib.rs @@ -95,7 +95,7 @@ // For the `specialization` feature. #![cfg_attr(feature = "unstable", allow(incomplete_features))] // Built-in Lints -//#![deny(warnings)] +#![deny(warnings)] #![warn( missing_debug_implementations, missing_copy_implementations, diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index ff666ce96458..f88dc951ea5e 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -284,7 +284,6 @@ fn fun_with_row_iters() { assert_eq!(&deserialized, expected); } } - dbg!(); { let collected_rows = conn.load(&query).unwrap().collect::>(); diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index e8aee68367ee..d688ee8535ab 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -3,8 +3,7 @@ use std::rc::Rc; use super::result::PgResult; use super::row::PgRow; -/// The type returned by various [`Conn -/// ection`] methods. +/// The type returned by various [`Connection`] methods. /// Acts as an iterator over `T`. #[allow(missing_debug_implementations)] pub struct Cursor<'a> { diff --git a/diesel_bench/Cargo.toml b/diesel_bench/Cargo.toml index da8b3ff609ea..5497918fcbba 100644 --- a/diesel_bench/Cargo.toml +++ b/diesel_bench/Cargo.toml @@ -7,14 +7,6 @@ build = "build.rs" autobenches = false [workspace] - -[workspace.profile.bench] -opt-level = 3 -debug = true -lto = true -incremental = false -codegen-units = 1 - # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] @@ -49,7 +41,7 @@ bench = true harness = false [features] -default = ["sqlite"] +default = [] postgres = ["diesel/postgres"] sqlite = ["diesel/sqlite"] mysql = ["diesel/mysql"] From 2bfe6df39c3e0fcfcc8e6fa658f16f13e9e80444 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Mon, 5 Jul 2021 14:24:40 +0200 Subject: [PATCH 20/32] Skip an unnessesary clone in postgres hot path --- diesel/src/pg/connection/row.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index f2f670e3d628..c33d14ac08ca 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -34,7 +34,7 @@ impl<'a> Row<'a, Pg> for PgRow<'a> { { let idx = self.idx(idx)?; Some(PgField { - db_result: self.db_result.clone(), + db_result: &self.db_result, row_idx: self.row_idx, col_idx: idx, }) @@ -63,7 +63,7 @@ impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { #[allow(missing_debug_implementations)] pub struct PgField<'a> { - db_result: Rc>, + db_result: &'a PgResult<'a>, row_idx: usize, col_idx: usize, } From 1a1b071ea36442f519b2ff964c04a01f076312f8 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 8 Jul 2021 17:51:19 +0200 Subject: [PATCH 21/32] Address Review comments * Remove unneeded life times from `PgResult` and `MysqlRow` * Add tests to show that it is possible to have more than one iterator for `PgConnection`, but not for `SqliteConnection` and `MysqlConnection` * Fix various broken life time bounds for unsafe code on `PgResult` to be bound to the result itself * Copy `MysqlType` by value instead of using a reference --- diesel/src/connection/mod.rs | 2 +- diesel/src/mysql/connection/bind.rs | 11 ++---- diesel/src/mysql/connection/mod.rs | 2 +- diesel/src/mysql/connection/stmt/iterator.rs | 15 ++++---- diesel/src/pg/connection/cursor.rs | 34 +++++++++++++++---- diesel/src/pg/connection/mod.rs | 4 +-- diesel/src/pg/connection/result.rs | 11 +++--- diesel/src/pg/connection/row.rs | 19 ++++++----- diesel/src/pg/connection/stmt/mod.rs | 2 +- ..._mysql_don_not_allow_multiple_iterators.rs | 32 +++++++++++++++++ ...ql_don_not_allow_multiple_iterators.stderr | 21 ++++++++++++ 11 files changed, 109 insertions(+), 44 deletions(-) create mode 100644 diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs create mode 100644 diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 383d5c55cb98..c3bad6a3ec32 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -28,7 +28,7 @@ pub trait SimpleConnection { /// This trait describes which cursor type is used by a given connection /// implementation. This trait is only useful in combination with [`Connection`]. /// -/// Implementation wise this is a workaround for GAT types +/// Implementation wise this is a workaround for GAT's pub trait IterableConnection<'a, DB: Backend> { /// The cursor type returned by [`Connection::load`] /// diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index c6ae063b2b8d..a038c7d1b45e 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -33,12 +33,7 @@ impl Binds { let data = metadata .fields() .iter() - .zip( - types - .iter() - .map(|o| o.as_ref()) - .chain(std::iter::repeat(None)), - ) + .zip(types.iter().copied().chain(std::iter::repeat(None))) .map(|(field, tpe)| BindData::for_output(tpe, field)) .collect(); @@ -150,7 +145,7 @@ impl BindData { } } - fn for_output(tpe: Option<&MysqlType>, metadata: &MysqlFieldMetadata) -> Self { + fn for_output(tpe: Option, metadata: &MysqlFieldMetadata) -> Self { let (tpe, flags) = if let Some(tpe) = tpe { match (tpe, metadata.field_type()) { // Those are types where we handle the conversion in diesel itself @@ -277,7 +272,7 @@ impl BindData { (metadata.field_type(), metadata.flags()) } - (tpe, _) => (*tpe).into(), + (tpe, _) => tpe.into(), } } else { (metadata.field_type(), metadata.flags()) diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 1fbb07cbb1d4..1a283a4a78ad 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -34,7 +34,7 @@ impl SimpleConnection for MysqlConnection { impl<'a> IterableConnection<'a, Mysql> for MysqlConnection { type Cursor = self::stmt::iterator::StatementIterator<'a>; - type Row = self::stmt::iterator::MysqlRow<'a>; + type Row = self::stmt::iterator::MysqlRow; } impl Connection for MysqlConnection { diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index f88dc951ea5e..bf860603955f 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,5 +1,4 @@ use std::cell::{Ref, RefCell}; -use std::marker::PhantomData; use std::rc::Rc; use super::{Binds, Statement, StatementMetadata}; @@ -40,7 +39,7 @@ impl<'a> StatementIterator<'a> { } impl<'a> Iterator for StatementIterator<'a> { - type Item = QueryResult>; + type Item = QueryResult; fn next(&mut self) -> Option { // check if we own the only instance of the bind buffer @@ -103,7 +102,6 @@ impl<'a> Iterator for StatementIterator<'a> { self.fetched_rows += 1; Some(Ok(MysqlRow { metadata: self.metadata.clone(), - _marker: Default::default(), row: self.last_row.clone(), })) } @@ -135,10 +133,9 @@ impl<'a> ExactSizeIterator for StatementIterator<'a> { #[derive(Clone)] #[allow(missing_debug_implementations)] -pub struct MysqlRow<'a> { +pub struct MysqlRow { row: Rc>, metadata: Rc, - _marker: PhantomData<&'a mut (Binds, StatementMetadata)>, } enum PrivateMysqlRow { @@ -154,11 +151,11 @@ impl PrivateMysqlRow { } } -impl<'a, 'b> RowFieldHelper<'a, Mysql> for MysqlRow<'b> { +impl<'a> RowFieldHelper<'a, Mysql> for MysqlRow { type Field = MysqlField<'a>; } -impl<'a> Row<'a, Mysql> for MysqlRow<'a> { +impl<'a> Row<'a, Mysql> for MysqlRow { type InnerPartialRow = Self; fn field_count(&self) -> usize { @@ -183,7 +180,7 @@ impl<'a> Row<'a, Mysql> for MysqlRow<'a> { } } -impl<'a> RowIndex for MysqlRow<'a> { +impl RowIndex for MysqlRow { fn idx(&self, idx: usize) -> Option { if idx < self.field_count() { Some(idx) @@ -193,7 +190,7 @@ impl<'a> RowIndex for MysqlRow<'a> { } } -impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { +impl<'a> RowIndex<&'a str> for MysqlRow { fn idx(&self, idx: &'a str) -> Option { self.metadata .fields() diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index d688ee8535ab..c7c1329d829f 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -6,13 +6,13 @@ use super::row::PgRow; /// The type returned by various [`Connection`] methods. /// Acts as an iterator over `T`. #[allow(missing_debug_implementations)] -pub struct Cursor<'a> { +pub struct Cursor { current_row: usize, - db_result: Rc>, + db_result: Rc, } -impl<'a> Cursor<'a> { - pub(super) fn new(db_result: PgResult<'a>) -> Self { +impl Cursor { + pub(super) fn new(db_result: PgResult) -> Self { Cursor { current_row: 0, db_result: Rc::new(db_result), @@ -20,14 +20,14 @@ impl<'a> Cursor<'a> { } } -impl<'a> ExactSizeIterator for Cursor<'a> { +impl ExactSizeIterator for Cursor { fn len(&self) -> usize { self.db_result.num_rows() - self.current_row } } -impl<'a> Iterator for Cursor<'a> { - type Item = crate::QueryResult>; +impl Iterator for Cursor { + type Item = crate::QueryResult; fn next(&mut self) -> Option { if self.current_row < self.db_result.num_rows() { @@ -165,4 +165,24 @@ fn fun_with_row_iters() { >::from_nullable_sql(first_values.1).unwrap(), expected[0].1 ); + + let row_iter1 = conn.load(&query).unwrap(); + let row_iter2 = conn.load(&query).unwrap(); + + for ((row1, row2), (expected_id, expected_name)) in row_iter1.zip(row_iter2).zip(expected) { + let (id1, name1) = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + Pg, + >>::build_from_row(&row1.unwrap()) + .unwrap(); + let (id2, name2) = <(i32, String) as FromSqlRow< + (sql_types::Integer, sql_types::Text), + Pg, + >>::build_from_row(&row2.unwrap()) + .unwrap(); + assert_eq!(id1, expected_id); + assert_eq!(id2, expected_id); + assert_eq!(name1, expected_name); + assert_eq!(name2, expected_name); + } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index e4f3621483a7..e43e025aa98d 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -44,8 +44,8 @@ impl SimpleConnection for PgConnection { } impl<'a> IterableConnection<'a, Pg> for PgConnection { - type Cursor = Cursor<'a>; - type Row = self::row::PgRow<'a>; + type Cursor = Cursor; + type Row = self::row::PgRow; } impl Connection for PgConnection { diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 379b9ad3b81f..80e5994a8051 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -2,7 +2,6 @@ extern crate pq_sys; use self::pq_sys::*; use std::ffi::CStr; -use std::marker::PhantomData; use std::num::NonZeroU32; use std::os::raw as libc; use std::rc::Rc; @@ -18,13 +17,14 @@ use crate::util::OnceCell; const CLOSED_CONNECTION_MSG: &str = "server closed the connection unexpectedly\n\t\ This probably means the server terminated abnormally\n\tbefore or while processing the request.\n"; -pub(crate) struct PgResult<'a> { +pub(crate) struct PgResult { internal_result: RawResult, column_count: usize, row_count: usize, + column_name_map: OnceCell>>, } -impl<'a> PgResult<'a> { +impl PgResult { #[allow(clippy::new_ret_no_self)] pub fn new(internal_result: RawResult) -> QueryResult { let result_status = unsafe { PQresultStatus(internal_result.as_ptr()) }; @@ -37,7 +37,6 @@ impl<'a> PgResult<'a> { column_count, row_count, column_name_map: OnceCell::new(), - _marker: PhantomData, }) } ExecStatusType::PGRES_EMPTY_QUERY => { @@ -95,11 +94,11 @@ impl<'a> PgResult<'a> { self.row_count } - pub fn get_row(self: Rc, idx: usize) -> PgRow<'a> { + pub fn get_row(self: Rc, idx: usize) -> PgRow { PgRow::new(self, idx) } - pub fn get(&self, row_idx: usize, col_idx: usize) -> Option<&'a [u8]> { + pub fn get(&self, row_idx: usize, col_idx: usize) -> Option<&[u8]> { if self.is_null(row_idx, col_idx) { None } else { diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index c33d14ac08ca..a24fcb7a1c10 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -1,26 +1,27 @@ use super::result::PgResult; +use crate::pg::value::TypeOidLookup; use crate::pg::{Pg, PgValue}; use crate::row::*; use std::rc::Rc; #[derive(Clone)] #[allow(missing_debug_implementations)] -pub struct PgRow<'a> { - db_result: Rc>, +pub struct PgRow { + db_result: Rc, row_idx: usize, } -impl<'a> PgRow<'a> { - pub(crate) fn new(db_result: Rc>, row_idx: usize) -> Self { +impl PgRow { + pub(crate) fn new(db_result: Rc, row_idx: usize) -> Self { PgRow { db_result, row_idx } } } -impl<'a, 'b> RowFieldHelper<'a, Pg> for PgRow<'b> { +impl<'a> RowFieldHelper<'a, Pg> for PgRow { type Field = PgField<'a>; } -impl<'a> Row<'a, Pg> for PgRow<'a> { +impl<'a> Row<'a, Pg> for PgRow { type InnerPartialRow = Self; fn field_count(&self) -> usize { @@ -45,7 +46,7 @@ impl<'a> Row<'a, Pg> for PgRow<'a> { } } -impl<'a> RowIndex for PgRow<'a> { +impl RowIndex for PgRow { fn idx(&self, idx: usize) -> Option { if idx < self.field_count() { Some(idx) @@ -55,7 +56,7 @@ impl<'a> RowIndex for PgRow<'a> { } } -impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { +impl<'a> RowIndex<&'a str> for PgRow { fn idx(&self, field_name: &'a str) -> Option { (0..self.field_count()).find(|idx| self.db_result.column_name(*idx) == Some(field_name)) } @@ -63,7 +64,7 @@ impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { #[allow(missing_debug_implementations)] pub struct PgField<'a> { - db_result: &'a PgResult<'a>, + db_result: &'a PgResult, row_idx: usize, col_idx: usize, } diff --git a/diesel/src/pg/connection/stmt/mod.rs b/diesel/src/pg/connection/stmt/mod.rs index ce6f59426626..5fee3879be26 100644 --- a/diesel/src/pg/connection/stmt/mod.rs +++ b/diesel/src/pg/connection/stmt/mod.rs @@ -21,7 +21,7 @@ impl Statement { &'_ self, raw_connection: &'a mut RawConnection, param_data: &'_ Vec>>, - ) -> QueryResult> { + ) -> QueryResult { let params_pointer = param_data .iter() .map(|data| { diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs b/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs new file mode 100644 index 000000000000..e30f92909815 --- /dev/null +++ b/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs @@ -0,0 +1,32 @@ +extern crate diesel; + +use diesel::prelude::*; +use diesel::sql_query; + +fn main() { + let conn = &mut SqliteConnection::establish("foo").unwrap(); + // For sqlite the returned iterator is coupled to + // a statement, which is coupled to the connection itself + // so we cannot have more than one iterator + // for the same connection + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + + let conn = &mut MysqlConnection::establish("foo").unwrap(); + // The same argument applies to mysql + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + + let conn = &mut PgConnection::establish("foo").unwrap(); + // It works for PgConnection as the result is not related to the + // connection in any way + let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + + let _ = row_iter1.zip(row_iter2); + +} diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr b/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr new file mode 100644 index 000000000000..a1fa2d21eaf3 --- /dev/null +++ b/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr @@ -0,0 +1,21 @@ +error[E0499]: cannot borrow `*conn` as mutable more than once at a time + --> $DIR/sqlite_and_mysql_don_not_allow_multiple_iterators.rs:13:21 + | +12 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + | ---- first mutable borrow occurs here +13 | let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + | ^^^^ second mutable borrow occurs here +14 | +15 | let _ = row_iter1.zip(row_iter2); + | --------- first borrow later used here + +error[E0499]: cannot borrow `*conn` as mutable more than once at a time + --> $DIR/sqlite_and_mysql_don_not_allow_multiple_iterators.rs:20:21 + | +19 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); + | ---- first mutable borrow occurs here +20 | let row_iter2 = conn.load(&sql_query("bar")).unwrap(); + | ^^^^ second mutable borrow occurs here +21 | +22 | let _ = row_iter1.zip(row_iter2); + | --------- first borrow later used here From 0c67c4e603d48010db3681bab3b0459ae85c3d51 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 8 Jul 2021 17:53:58 +0200 Subject: [PATCH 22/32] Only receive the column type oid if requested Our normal deserialization workflow does not need the type OID to work, so we can change the code in such a way that the OID is only received from the result if the users actually requests it. This should speedup various benchmarks for `PgConnection` measurably. --- diesel/src/pg/connection/row.rs | 9 ++++++-- diesel/src/pg/types/array.rs | 2 +- diesel/src/pg/types/ranges.rs | 4 ++-- diesel/src/pg/types/record.rs | 2 +- diesel/src/pg/value.rs | 39 ++++++++++++++++++++++++++++----- 5 files changed, 45 insertions(+), 11 deletions(-) diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index a24fcb7a1c10..b28d5187f403 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -79,8 +79,13 @@ impl<'a> Field<'a, Pg> for PgField<'a> { 'a: 'b, { let raw = self.db_result.get(self.row_idx, self.col_idx)?; - let type_oid = self.db_result.column_type(self.col_idx); - Some(PgValue::new(raw, type_oid)) + Some(PgValue::new(raw, self)) + } +} + +impl<'a> TypeOidLookup for PgField<'a> { + fn lookup(&self) -> std::num::NonZeroU32 { + self.db_result.column_type(self.col_idx) } } diff --git a/diesel/src/pg/types/array.rs b/diesel/src/pg/types/array.rs index 54098d3c7ee4..7b6b44b3fe50 100644 --- a/diesel/src/pg/types/array.rs +++ b/diesel/src/pg/types/array.rs @@ -48,7 +48,7 @@ where } else { let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - T::from_sql(PgValue::new(elem_bytes, value.get_oid())) + T::from_sql(PgValue::new(elem_bytes, &value)) } }) .collect() diff --git a/diesel/src/pg/types/ranges.rs b/diesel/src/pg/types/ranges.rs index 02187c1a4cf3..bb965b342da9 100644 --- a/diesel/src/pg/types/ranges.rs +++ b/diesel/src/pg/types/ranges.rs @@ -69,7 +69,7 @@ where let elem_size = bytes.read_i32::()?; let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - let value = T::from_sql(PgValue::new(elem_bytes, value.get_oid()))?; + let value = T::from_sql(PgValue::new(elem_bytes, &value))?; lower_bound = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) @@ -80,7 +80,7 @@ where if !flags.contains(RangeFlags::UB_INF) { let _size = bytes.read_i32::()?; - let value = T::from_sql(PgValue::new(bytes, value.get_oid()))?; + let value = T::from_sql(PgValue::new(bytes, &value))?; upper_bound = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) diff --git a/diesel/src/pg/types/record.rs b/diesel/src/pg/types/record.rs index 34c923e96913..e6328994a702 100644 --- a/diesel/src/pg/types/record.rs +++ b/diesel/src/pg/types/record.rs @@ -54,7 +54,7 @@ macro_rules! tuple_impls { bytes = new_bytes; $T::from_sql(PgValue::new( elem_bytes, - oid, + &oid, ))? } },)+); diff --git a/diesel/src/pg/value.rs b/diesel/src/pg/value.rs index 24ae07c65392..63b42110602c 100644 --- a/diesel/src/pg/value.rs +++ b/diesel/src/pg/value.rs @@ -8,7 +8,32 @@ use std::ops::Range; #[allow(missing_debug_implementations)] pub struct PgValue<'a> { raw_value: &'a [u8], - type_oid: NonZeroU32, + type_oid_lookup: &'a dyn TypeOidLookup, +} + +pub(crate) trait TypeOidLookup { + fn lookup(&self) -> NonZeroU32; +} + +impl TypeOidLookup for F +where + F: Fn() -> NonZeroU32, +{ + fn lookup(&self) -> NonZeroU32 { + (self)() + } +} + +impl<'a> TypeOidLookup for PgValue<'a> { + fn lookup(&self) -> NonZeroU32 { + self.type_oid_lookup.lookup() + } +} + +impl TypeOidLookup for NonZeroU32 { + fn lookup(&self) -> NonZeroU32 { + *self + } } impl<'a> BinaryRawValue<'a> for Pg { @@ -20,16 +45,20 @@ impl<'a> BinaryRawValue<'a> for Pg { impl<'a> PgValue<'a> { #[cfg(test)] pub(crate) fn for_test(raw_value: &'a [u8]) -> Self { + static FAKE_OID: NonZeroU32 = unsafe { + // 42 != 0, so this is actually safe + NonZeroU32::new_unchecked(42) + }; Self { raw_value, - type_oid: NonZeroU32::new(42).unwrap(), + type_oid_lookup: &FAKE_OID, } } - pub(crate) fn new(raw_value: &'a [u8], type_oid: NonZeroU32) -> Self { + pub(crate) fn new(raw_value: &'a [u8], type_oid_lookup: &'a dyn TypeOidLookup) -> Self { Self { raw_value, - type_oid, + type_oid_lookup, } } @@ -40,7 +69,7 @@ impl<'a> PgValue<'a> { /// Get the type oid of this value pub fn get_oid(&self) -> NonZeroU32 { - self.type_oid + self.type_oid_lookup.lookup() } pub(crate) fn subslice(&self, range: Range) -> Self { From 65410cecaa67defb76cbd2d63154d958647e5fa7 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 10:14:45 +0200 Subject: [PATCH 23/32] Address review comments --- diesel/src/connection/mod.rs | 6 +- diesel/src/mysql/connection/bind.rs | 79 +++++++++++++------ diesel/src/mysql/connection/mod.rs | 4 +- diesel/src/mysql/connection/stmt/iterator.rs | 29 +++---- diesel/src/mysql/connection/stmt/mod.rs | 15 ++-- diesel/src/mysql/types/mod.rs | 2 +- diesel/src/pg/connection/mod.rs | 4 +- diesel/src/pg/connection/result.rs | 27 ++++--- diesel/src/pg/connection/row.rs | 9 +-- diesel/src/pg/connection/stmt/mod.rs | 10 +-- diesel/src/pg/expression/array.rs | 1 + diesel/src/pg/expression/array_comparison.rs | 5 +- .../src/pg/expression/expression_methods.rs | 5 +- diesel/src/pg/types/array.rs | 5 +- diesel/src/pg/types/mod.rs | 6 +- diesel/src/pg/types/ranges.rs | 9 ++- diesel/src/pg/types/record.rs | 1 + diesel/src/query_dsl/load_dsl.rs | 14 ++-- diesel/src/query_dsl/mod.rs | 11 ++- diesel/src/result.rs | 9 --- diesel/src/row.rs | 16 ++-- diesel/src/sql_types/mod.rs | 2 +- diesel/src/sqlite/connection/functions.rs | 11 +-- diesel/src/sqlite/connection/mod.rs | 4 +- diesel/src/sqlite/connection/row.rs | 11 +-- ..._mysql_do_not_allow_multiple_iterators.rs} | 0 ...ql_do_not_allow_multiple_iterators.stderr} | 0 diesel_tests/tests/types.rs | 2 +- diesel_tests/tests/types_roundtrip.rs | 2 +- 29 files changed, 163 insertions(+), 136 deletions(-) rename diesel_compile_tests/tests/fail/{sqlite_and_mysql_don_not_allow_multiple_iterators.rs => sqlite_and_mysql_do_not_allow_multiple_iterators.rs} (100%) rename diesel_compile_tests/tests/fail/{sqlite_and_mysql_don_not_allow_multiple_iterators.stderr => sqlite_and_mysql_do_not_allow_multiple_iterators.stderr} (100%) diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index c3bad6a3ec32..cbd0516f3de5 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -29,7 +29,7 @@ pub trait SimpleConnection { /// implementation. This trait is only useful in combination with [`Connection`]. /// /// Implementation wise this is a workaround for GAT's -pub trait IterableConnection<'a, DB: Backend> { +pub trait ConnectionGatWorkaround<'a, DB: Backend> { /// The cursor type returned by [`Connection::load`] /// /// Users should handle this as opaque type that implements [`Iterator`] @@ -42,7 +42,7 @@ pub trait IterableConnection<'a, DB: Backend> { /// A connection to a database pub trait Connection: SimpleConnection + Sized + Send where - Self: for<'a> IterableConnection<'a, ::Backend>, + Self: for<'a> ConnectionGatWorkaround<'a, ::Backend>, { /// The backend this type connects to type Backend: Backend; @@ -195,7 +195,7 @@ where fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index a038c7d1b45e..efc4a4b6bc6c 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -11,12 +11,23 @@ use crate::mysql::types::MYSQL_TIME; use crate::mysql::{MysqlType, MysqlValue}; use crate::result::QueryResult; -#[derive(Clone)] -pub struct Binds { +pub struct PreparedStatementBinds(Binds); + +pub struct OutputBinds(Binds); + +impl Clone for OutputBinds { + fn clone(&self) -> Self { + Self(Binds { + data: self.0.data.clone(), + }) + } +} + +struct Binds { data: Vec, } -impl Binds { +impl PreparedStatementBinds { pub fn from_input_data(input: Iter) -> QueryResult where Iter: IntoIterator>)>, @@ -26,9 +37,18 @@ impl Binds { .map(BindData::for_input) .collect::>(); - Ok(Binds { data }) + Ok(Self(Binds { data })) } + pub fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + self.0.with_mysql_binds(f) + } +} + +impl OutputBinds { pub fn from_output_types(types: &[Option], metadata: &StatementMetadata) -> Self { let data = metadata .fields() @@ -37,23 +57,11 @@ impl Binds { .map(|(field, tpe)| BindData::for_output(tpe, field)) .collect(); - Binds { data } - } - - pub fn with_mysql_binds(&mut self, f: F) -> T - where - F: FnOnce(*mut ffi::MYSQL_BIND) -> T, - { - let mut binds = self - .data - .iter_mut() - .map(|x| unsafe { x.mysql_bind() }) - .collect::>(); - f(binds.as_mut_ptr()) + Self(Binds { data }) } pub fn populate_dynamic_buffers(&mut self, stmt: &Statement) -> QueryResult<()> { - for (i, data) in self.data.iter_mut().enumerate() { + for (i, data) in self.0.data.iter_mut().enumerate() { data.did_numeric_overflow_occur()?; // This is safe because we are re-binding the invalidated buffers // at the end of this function @@ -70,16 +78,37 @@ impl Binds { } pub fn update_buffer_lengths(&mut self) { - for data in &mut self.data { + for data in &mut self.0.data { data.update_buffer_length(); } } + + pub fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + self.0.with_mysql_binds(f) + } +} + +impl Binds { + fn with_mysql_binds(&mut self, f: F) -> T + where + F: FnOnce(*mut ffi::MYSQL_BIND) -> T, + { + let mut binds = self + .data + .iter_mut() + .map(|x| unsafe { x.mysql_bind() }) + .collect::>(); + f(binds.as_mut_ptr()) + } } -impl Index for Binds { +impl Index for OutputBinds { type Output = BindData; fn index(&self, index: usize) -> &Self::Output { - &self.data[index] + &self.0.data[index] } } @@ -1091,12 +1120,12 @@ mod tests { let bind = BindData::for_test_output(bind_tpe.into()); - let mut binds = Binds { data: vec![bind] }; + let mut binds = OutputBinds(Binds { data: vec![bind] }); stmt.execute_statement(&mut binds).unwrap(); stmt.populate_row_buffers(&mut binds).unwrap(); - binds.data.remove(0) + binds.0.data.remove(0) } fn input_bind( @@ -1130,9 +1159,9 @@ mod tests { is_truncated: None, }; - let binds = Binds { + let binds = PreparedStatementBinds(Binds { data: vec![id_bind, field_bind], - }; + }); stmt.input_bind(binds).unwrap(); stmt.did_an_error_occur().unwrap(); unsafe { diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 1a283a4a78ad..3040b30bc55f 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -32,7 +32,7 @@ impl SimpleConnection for MysqlConnection { } } -impl<'a> IterableConnection<'a, Mysql> for MysqlConnection { +impl<'a> ConnectionGatWorkaround<'a, Mysql> for MysqlConnection { type Cursor = self::stmt::iterator::StatementIterator<'a>; type Row = self::stmt::iterator::MysqlRow; } @@ -68,7 +68,7 @@ impl Connection for MysqlConnection { fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index bf860603955f..d832f17e3dcf 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,7 +1,7 @@ use std::cell::{Ref, RefCell}; use std::rc::Rc; -use super::{Binds, Statement, StatementMetadata}; +use super::{OutputBinds, Statement, StatementMetadata}; use crate::connection::MaybeCached; use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; @@ -12,8 +12,7 @@ pub struct StatementIterator<'a> { stmt: MaybeCached<'a, Statement>, last_row: Rc>, metadata: Rc, - size: usize, - fetched_rows: usize, + len: usize, } impl<'a> StatementIterator<'a> { @@ -23,7 +22,7 @@ impl<'a> StatementIterator<'a> { ) -> QueryResult { let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_output_types(types, &metadata); + let mut output_binds = OutputBinds::from_output_types(types, &metadata); stmt.execute_statement(&mut output_binds)?; let size = unsafe { stmt.result_size() }?; @@ -31,8 +30,7 @@ impl<'a> StatementIterator<'a> { Ok(StatementIterator { metadata: Rc::new(metadata), last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))), - fetched_rows: 0, - size, + len: size, stmt, }) } @@ -99,7 +97,7 @@ impl<'a> Iterator for StatementIterator<'a> { match res { Ok(Some(())) => { - self.fetched_rows += 1; + self.len = self.len.saturating_sub(1); Some(Ok(MysqlRow { metadata: self.metadata.clone(), row: self.last_row.clone(), @@ -107,7 +105,7 @@ impl<'a> Iterator for StatementIterator<'a> { } Ok(None) => None, Err(e) => { - self.fetched_rows += 1; + self.len = self.len.saturating_sub(1); Some(Err(e)) } } @@ -127,7 +125,7 @@ impl<'a> Iterator for StatementIterator<'a> { impl<'a> ExactSizeIterator for StatementIterator<'a> { fn len(&self) -> usize { - self.size - self.fetched_rows + self.len } } @@ -139,8 +137,8 @@ pub struct MysqlRow { } enum PrivateMysqlRow { - Direct(Binds), - Copied(Binds), + Direct(OutputBinds), + Copied(OutputBinds), } impl PrivateMysqlRow { @@ -151,7 +149,7 @@ impl PrivateMysqlRow { } } -impl<'a> RowFieldHelper<'a, Mysql> for MysqlRow { +impl<'a> RowGatWorkaround<'a, Mysql> for MysqlRow { type Field = MysqlField<'a>; } @@ -162,7 +160,7 @@ impl<'a> Row<'a, Mysql> for MysqlRow { self.metadata.fields().len() } - fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where 'a: 'b, Self: RowIndex, @@ -219,10 +217,7 @@ impl<'a> Field<'a, Mysql> for MysqlField<'a> { } } - fn value<'b>(&'b self) -> Option> - where - 'a: 'b, - { + fn value(&self) -> Option> { match &*self.binds { PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(), } diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index a5b5fab4b61b..de774ee3db91 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -8,7 +8,7 @@ use std::ffi::CStr; use std::os::raw as libc; use std::ptr::NonNull; -use super::bind::Binds; +use super::bind::{OutputBinds, PreparedStatementBinds}; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, Error, QueryResult}; @@ -18,7 +18,7 @@ pub use self::metadata::{MysqlFieldMetadata, StatementMetadata}; // https://github.com/rust-lang/rust/issues/81658 pub struct Statement { stmt: NonNull, - input_binds: Option, + input_binds: Option, } impl Statement { @@ -44,11 +44,14 @@ impl Statement { where Iter: IntoIterator>)>, { - let input_binds = Binds::from_input_data(binds)?; + let input_binds = PreparedStatementBinds::from_input_data(binds)?; self.input_bind(input_binds) } - pub(super) fn input_bind(&mut self, mut input_binds: Binds) -> QueryResult<()> { + pub(super) fn input_bind( + &mut self, + mut input_binds: PreparedStatementBinds, + ) -> QueryResult<()> { input_binds.with_mysql_binds(|bind_ptr| { // This relies on the invariant that the current value of `self.input_binds` // will not change without this function being called @@ -150,7 +153,7 @@ impl Statement { } } - pub(super) fn execute_statement(&mut self, binds: &mut Binds) -> QueryResult<()> { + pub(super) fn execute_statement(&mut self, binds: &mut OutputBinds) -> QueryResult<()> { unsafe { binds.with_mysql_binds(|bind_ptr| self.bind_result(bind_ptr))?; self.execute()?; @@ -158,7 +161,7 @@ impl Statement { Ok(()) } - pub(super) fn populate_row_buffers(&self, binds: &mut Binds) -> QueryResult> { + pub(super) fn populate_row_buffers(&self, binds: &mut OutputBinds) -> QueryResult> { let next_row_result = unsafe { ffi::mysql_stmt_fetch(self.stmt.as_ptr()) }; match next_row_result as libc::c_uint { ffi::MYSQL_NO_DATA => Ok(None), diff --git a/diesel/src/mysql/types/mod.rs b/diesel/src/mysql/types/mod.rs index 76a646f40b30..fd7e4158adb0 100644 --- a/diesel/src/mysql/types/mod.rs +++ b/diesel/src/mysql/types/mod.rs @@ -55,7 +55,7 @@ impl FromSql for i8 { /// Represents the MySQL unsigned type. #[derive(Debug, Clone, Copy, Default, SqlType, QueryId)] -pub struct Unsigned(ST); +pub struct Unsigned(ST); impl Add for Unsigned where diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index e43e025aa98d..09f863e01700 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -43,7 +43,7 @@ impl SimpleConnection for PgConnection { } } -impl<'a> IterableConnection<'a, Pg> for PgConnection { +impl<'a> ConnectionGatWorkaround<'a, Pg> for PgConnection { type Cursor = Cursor; type Row = self::row::PgRow; } @@ -75,7 +75,7 @@ impl Connection for PgConnection { fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 80e5994a8051..d18d512a6b0f 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -21,7 +21,10 @@ pub(crate) struct PgResult { internal_result: RawResult, column_count: usize, row_count: usize, - column_name_map: OnceCell>>, + // We store field names as pointer + // as we cannot put a correct lifetime here + // The value is valid as long as we haven't freed `RawResult` + column_name_map: OnceCell>>, } impl PgResult { @@ -136,25 +139,29 @@ impl PgResult { .get_or_init(|| { (0..self.column_count) .map(|idx| unsafe { + // https://www.postgresql.org/docs/13/libpq-exec.html#LIBPQ-PQFNAME + // states that the returned ptr is valid till the underlying result is freed + // That means we can couple the lifetime to self let ptr = PQfname(self.internal_result.as_ptr(), idx as libc::c_int); if ptr.is_null() { None } else { - Some( - CStr::from_ptr(ptr) - .to_str() - .expect( - "Expect postgres field names to be UTF-8, because we \ + Some(CStr::from_ptr(ptr).to_str().expect( + "Expect postgres field names to be UTF-8, because we \ requested UTF-8 encoding on connection setup", - ) - .to_owned(), - ) + ) as *const str) } }) .collect() }) .get(col_idx) - .and_then(|n| n.as_ref().map(|n| n as &str)) + .and_then(|n| { + n.map(|n: *const str| unsafe { + // The pointer is valid for the same lifetime as &self + // so we can dereference it without any check + &*n + }) + }) } pub fn column_count(&self) -> usize { diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index b28d5187f403..f689fa21e3fd 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -17,7 +17,7 @@ impl PgRow { } } -impl<'a> RowFieldHelper<'a, Pg> for PgRow { +impl<'a> RowGatWorkaround<'a, Pg> for PgRow { type Field = PgField<'a>; } @@ -28,7 +28,7 @@ impl<'a> Row<'a, Pg> for PgRow { self.db_result.column_count() } - fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where 'a: 'b, Self: RowIndex, @@ -74,10 +74,7 @@ impl<'a> Field<'a, Pg> for PgField<'a> { self.db_result.column_name(self.col_idx) } - fn value<'b>(&'b self) -> Option> - where - 'a: 'b, - { + fn value(&self) -> Option> { let raw = self.db_result.get(self.row_idx, self.col_idx)?; Some(PgValue::new(raw, self)) diff --git a/diesel/src/pg/connection/stmt/mod.rs b/diesel/src/pg/connection/stmt/mod.rs index 5fee3879be26..a7790ed1facd 100644 --- a/diesel/src/pg/connection/stmt/mod.rs +++ b/diesel/src/pg/connection/stmt/mod.rs @@ -16,11 +16,10 @@ pub(crate) struct Statement { } impl Statement { - #[allow(clippy::ptr_arg)] - pub fn execute<'a>( - &'_ self, - raw_connection: &'a mut RawConnection, - param_data: &'_ Vec>>, + pub fn execute( + &self, + raw_connection: &mut RawConnection, + param_data: &[Option>], ) -> QueryResult { let params_pointer = param_data .iter() @@ -48,7 +47,6 @@ impl Statement { PgResult::new(internal_res?) } - #[allow(clippy::ptr_arg)] pub fn prepare( raw_connection: &mut RawConnection, sql: &str, diff --git a/diesel/src/pg/expression/array.rs b/diesel/src/pg/expression/array.rs index 8636d9b7dbe1..c135d9f01825 100644 --- a/diesel/src/pg/expression/array.rs +++ b/diesel/src/pg/expression/array.rs @@ -58,6 +58,7 @@ where impl Expression for ArrayLiteral where + ST: 'static, T: Expression, { type SqlType = sql_types::Array; diff --git a/diesel/src/pg/expression/array_comparison.rs b/diesel/src/pg/expression/array_comparison.rs index 4d3c636a2a1c..b37ad4627283 100644 --- a/diesel/src/pg/expression/array_comparison.rs +++ b/diesel/src/pg/expression/array_comparison.rs @@ -128,7 +128,7 @@ where impl_selectable_expression!(All); -pub trait AsArrayExpression { +pub trait AsArrayExpression { type Expression: Expression>; // This method is part of the public API @@ -139,6 +139,7 @@ pub trait AsArrayExpression { impl AsArrayExpression for T where + ST: 'static, T: AsExpression>, { type Expression = >>::Expression; @@ -151,6 +152,7 @@ where impl AsArrayExpression for SelectStatement where + ST: 'static, Self: SelectQuery, { type Expression = Subselect>; @@ -162,6 +164,7 @@ where impl<'a, ST, QS, DB, GB> AsArrayExpression for BoxedSelectStatement<'a, ST, QS, DB, GB> where + ST: 'static, Self: SelectQuery, { type Expression = Subselect>; diff --git a/diesel/src/pg/expression/expression_methods.rs b/diesel/src/pg/expression/expression_methods.rs index a40712aa1601..4a56a08d4185 100644 --- a/diesel/src/pg/expression/expression_methods.rs +++ b/diesel/src/pg/expression/expression_methods.rs @@ -611,7 +611,10 @@ pub trait RangeHelper: SqlType { type Inner; } -impl RangeHelper for Range { +impl RangeHelper for Range +where + Self: 'static, +{ type Inner = ST; } diff --git a/diesel/src/pg/types/array.rs b/diesel/src/pg/types/array.rs index 7b6b44b3fe50..b6315eb60a1e 100644 --- a/diesel/src/pg/types/array.rs +++ b/diesel/src/pg/types/array.rs @@ -60,7 +60,7 @@ use crate::expression::AsExpression; macro_rules! array_as_expression { ($ty:ty, $sql_type:ty) => { - impl<'a, 'b, ST, T> AsExpression<$sql_type> for $ty { + impl<'a, 'b, ST: 'static, T> AsExpression<$sql_type> for $ty { type Expression = Bound<$sql_type, Self>; fn as_expression(self) -> Self::Expression { @@ -126,6 +126,7 @@ where impl ToSql>, Pg> for [T] where [T]: ToSql, Pg>, + ST: 'static, { fn to_sql(&self, out: &mut Output) -> serialize::Result { ToSql::, Pg>::to_sql(self, out) @@ -134,6 +135,7 @@ where impl ToSql, Pg> for Vec where + ST: 'static, [T]: ToSql, Pg>, T: fmt::Debug, { @@ -144,6 +146,7 @@ where impl ToSql>, Pg> for Vec where + ST: 'static, Vec: ToSql, Pg>, { fn to_sql(&self, out: &mut Output) -> serialize::Result { diff --git a/diesel/src/pg/types/mod.rs b/diesel/src/pg/types/mod.rs index 81e843de4a15..d232b50d0015 100644 --- a/diesel/src/pg/types/mod.rs +++ b/diesel/src/pg/types/mod.rs @@ -99,7 +99,7 @@ pub mod sql_types { /// [Vec]: std::vec::Vec /// [slice]: https://doc.rust-lang.org/nightly/std/primitive.slice.html #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] - pub struct Array(ST); + pub struct Array(ST); /// The `Range` SQL type. /// @@ -117,7 +117,7 @@ pub mod sql_types { /// [`FromSql`]: crate::deserialize::FromSql /// [bound]: std::collections::Bound #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] - pub struct Range(ST); + pub struct Range(ST); #[doc(hidden)] pub type Int4range = Range; @@ -171,7 +171,7 @@ pub mod sql_types { /// [`WriteTuple`]: super::super::super::serialize::WriteTuple #[derive(Debug, Clone, Copy, Default, QueryId, SqlType)] #[postgres(oid = "2249", array_oid = "2287")] - pub struct Record(ST); + pub struct Record(ST); /// Alias for `SmallInt` pub type SmallSerial = crate::sql_types::SmallInt; diff --git a/diesel/src/pg/types/ranges.rs b/diesel/src/pg/types/ranges.rs index bb965b342da9..98b443ed8d48 100644 --- a/diesel/src/pg/types/ranges.rs +++ b/diesel/src/pg/types/ranges.rs @@ -23,7 +23,7 @@ bitflags! { } } -impl AsExpression> for (Bound, Bound) { +impl AsExpression> for (Bound, Bound) { type Expression = SqlBound, Self>; fn as_expression(self) -> Self::Expression { @@ -31,7 +31,7 @@ impl AsExpression> for (Bound, Bound) { } } -impl<'a, ST, T> AsExpression> for &'a (Bound, Bound) { +impl<'a, ST: 'static, T> AsExpression> for &'a (Bound, Bound) { type Expression = SqlBound, Self>; fn as_expression(self) -> Self::Expression { @@ -39,7 +39,7 @@ impl<'a, ST, T> AsExpression> for &'a (Bound, Bound) { } } -impl AsExpression>> for (Bound, Bound) { +impl AsExpression>> for (Bound, Bound) { type Expression = SqlBound>, Self>; fn as_expression(self) -> Self::Expression { @@ -47,7 +47,7 @@ impl AsExpression>> for (Bound, Bound) { } } -impl<'a, ST, T> AsExpression>> for &'a (Bound, Bound) { +impl<'a, ST: 'static, T> AsExpression>> for &'a (Bound, Bound) { type Expression = SqlBound>, Self>; fn as_expression(self) -> Self::Expression { @@ -158,6 +158,7 @@ where impl ToSql>, Pg> for (Bound, Bound) where + ST: 'static, (Bound, Bound): ToSql, Pg>, { fn to_sql(&self, out: &mut Output) -> serialize::Result { diff --git a/diesel/src/pg/types/record.rs b/diesel/src/pg/types/record.rs index e6328994a702..0b5e9186e7c5 100644 --- a/diesel/src/pg/types/record.rs +++ b/diesel/src/pg/types/record.rs @@ -147,6 +147,7 @@ where impl Expression for PgTuple where T: Expression, + T::SqlType: 'static, { type SqlType = Record; } diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index dc85d6cc3822..247908e62218 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -1,6 +1,6 @@ use super::RunQueryDsl; use crate::backend::Backend; -use crate::connection::{Connection, IterableConnection}; +use crate::connection::{Connection, ConnectionGatWorkaround}; use crate::deserialize::FromSqlRow; use crate::expression::{select_by::SelectBy, Expression, QueryMetadata, Selectable}; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; @@ -15,16 +15,16 @@ use crate::result::QueryResult; /// [`RunQueryDsl`]: crate::RunQueryDsl pub trait LoadQuery: RunQueryDsl where - for<'a> Self: LoadQueryRet<'a, Conn, U>, + for<'a> Self: LoadQueryGatWorkaround<'a, Conn, U>, { /// Load this query fn internal_load<'a>( self, conn: &'a mut Conn, - ) -> QueryResult<>::Ret>; + ) -> QueryResult<>::Ret>; } -pub trait LoadQueryRet<'a, Conn, U> { +pub trait LoadQueryGatWorkaround<'a, Conn, U> { type Ret: Iterator>; } @@ -69,7 +69,7 @@ pub struct LoadIter<'a, U, C, ST, DB> { _marker: std::marker::PhantomData<&'a (ST, U, DB)>, } -impl<'a, Conn, T, U, DB> LoadQueryRet<'a, Conn, U> for T +impl<'a, Conn, T, U, DB> LoadQueryGatWorkaround<'a, Conn, U> for T where Conn: Connection, T: AsQuery + RunQueryDsl, @@ -82,7 +82,7 @@ where type Ret = LoadIter< 'a, U, - >::Cursor, + >::Cursor, >::SqlType, DB, >; @@ -101,7 +101,7 @@ where fn internal_load<'a>( self, conn: &'a mut Conn, - ) -> QueryResult<>::Ret> { + ) -> QueryResult<>::Ret> { Ok(LoadIter { cursor: conn.load(self)?, _marker: Default::default(), diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index 4bf9d50405fe..6a9deb291385 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -18,7 +18,7 @@ use crate::expression::Expression; use crate::helper_types::*; use crate::query_builder::locking_clause as lock; use crate::query_source::{joins, Table}; -use crate::result::{first_or_not_found, QueryResult}; +use crate::result::QueryResult; mod belonging_to_dsl; #[doc(hidden)] @@ -52,7 +52,7 @@ pub use self::join_dsl::{InternalJoinDsl, JoinOnDsl, JoinWithImplicitOnClause}; pub use self::load_dsl::CompatibleType; #[doc(hidden)] pub use self::load_dsl::LoadQuery; -use self::load_dsl::LoadQueryRet; +use self::load_dsl::LoadQueryGatWorkaround; pub use self::save_changes_dsl::{SaveChangesDsl, UpdateAndFetchResults}; /// The traits used by `QueryDsl`. @@ -1507,7 +1507,7 @@ pub trait RunQueryDsl: Sized { fn load_iter<'a, U>( self, conn: &'a mut Conn, - ) -> QueryResult<>::Ret> + ) -> QueryResult<>::Ret> where Self: LoadQuery, { @@ -1563,7 +1563,10 @@ pub trait RunQueryDsl: Sized { where Self: LoadQuery, { - first_or_not_found(self.internal_load(conn)) + match self.internal_load(conn)?.next() { + Some(v) => v, + None => Err(crate::result::Error::NotFound), + } } /// Runs the command, returning an `Vec` with the affected rows. diff --git a/diesel/src/result.rs b/diesel/src/result.rs index cbb8cec6cafa..de8b7d0dedc7 100644 --- a/diesel/src/result.rs +++ b/diesel/src/result.rs @@ -360,15 +360,6 @@ fn error_impls_send() { let x: &Send = &err; } -pub(crate) fn first_or_not_found( - records: QueryResult>>, -) -> QueryResult { - match records?.next() { - Some(r) => r, - None => Err(Error::NotFound), - } -} - /// An unexpected `NULL` was encountered during deserialization #[derive(Debug, Clone, Copy)] pub struct UnexpectedNullError; diff --git a/diesel/src/row.rs b/diesel/src/row.rs index db0cdf3d27f3..f085aa82befa 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -23,7 +23,7 @@ pub trait RowIndex { /// A helper trait to indicate the life time bound for a field returned /// by [`Row::get`] -pub trait RowFieldHelper<'a, DB: Backend> { +pub trait RowGatWorkaround<'a, DB: Backend> { /// Field type returned by a `Row` implementation /// /// * Crates using existing backend should not concern themself with the @@ -40,7 +40,7 @@ pub trait RowFieldHelper<'a, DB: Backend> { /// /// [`FromSqlRow`]: crate::deserialize::FromSqlRow pub trait Row<'a, DB: Backend>: - RowIndex + for<'b> RowIndex<&'b str> + for<'b> RowFieldHelper<'b, DB> + Sized + RowIndex + for<'b> RowIndex<&'b str> + for<'b> RowGatWorkaround<'b, DB> + Sized { /// Return type of `PartialRow` /// @@ -55,7 +55,7 @@ pub trait Row<'a, DB: Backend>: /// Get the field with the provided index from the row. /// /// Returns `None` if there is no matching field for the given index - fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where 'a: 'b, Self: RowIndex; @@ -92,9 +92,7 @@ pub trait Field<'a, DB: Backend> { /// Get the value representing the current field in the raw representation /// as it is transmitted by the database - fn value<'b>(&'b self) -> Option> - where - 'a: 'b; + fn value(&self) -> Option>; /// Checks whether this field is null or not. fn is_null(&self) -> bool { @@ -132,10 +130,10 @@ impl<'a, R> PartialRow<'a, R> { } } -impl<'a, 'b, DB, R> RowFieldHelper<'a, DB> for PartialRow<'b, R> +impl<'a, 'b, DB, R> RowGatWorkaround<'a, DB> for PartialRow<'b, R> where DB: Backend, - R: RowFieldHelper<'a, DB>, + R: RowGatWorkaround<'a, DB>, { type Field = R::Field; } @@ -151,7 +149,7 @@ where self.range.len() } - fn get<'c, I>(&'c self, idx: I) -> Option<>::Field> + fn get<'c, I>(&'c self, idx: I) -> Option<>::Field> where 'a: 'c, Self: RowIndex, diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index b929afa57e86..1868d5eb0c49 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -509,7 +509,7 @@ pub use diesel_derives::SqlType; /// This trait is automatically implemented by [`#[derive(SqlType)]`](derive@SqlType) /// which sets `IsNull` to [`is_nullable::NotNull`] /// -pub trait SqlType { +pub trait SqlType: 'static { /// Is this type nullable? /// /// This type should always be one of the structs in the ['is_nullable`] diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index d6d103f7c07c..063c07b9d6a3 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -6,7 +6,7 @@ use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction}; use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; -use crate::row::{Field, PartialRow, Row, RowFieldHelper, RowIndex}; +use crate::row::{Field, PartialRow, Row, RowGatWorkaround, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; use crate::sqlite::connection::sqlite_value::OwnedSqliteValue; @@ -188,7 +188,7 @@ impl<'a> FunctionRow<'a> { } } -impl<'a, 'b> RowFieldHelper<'a, Sqlite> for FunctionRow<'b> { +impl<'a, 'b> RowGatWorkaround<'a, Sqlite> for FunctionRow<'b> { type Field = FunctionArgument<'a>; } @@ -199,7 +199,7 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { self.field_count } - fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where 'a: 'b, Self: crate::row::RowIndex, @@ -246,10 +246,7 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { self.value().is_none() } - fn value<'b>(&'b self) -> Option> - where - 'a: 'b, - { + fn value(&self) -> Option> { SqliteValue::new( Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)), self.col_idx, diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 0b319006a0f2..07961478244b 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -54,7 +54,7 @@ impl SimpleConnection for SqliteConnection { } } -impl<'a> IterableConnection<'a, Sqlite> for SqliteConnection { +impl<'a> ConnectionGatWorkaround<'a, Sqlite> for SqliteConnection { type Cursor = StatementIterator<'a>; type Row = self::row::SqliteRow<'a>; } @@ -93,7 +93,7 @@ impl Connection for SqliteConnection { fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/sqlite/connection/row.rs b/diesel/src/sqlite/connection/row.rs index cc2dfcb99b8d..52f2928e5d47 100644 --- a/diesel/src/sqlite/connection/row.rs +++ b/diesel/src/sqlite/connection/row.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use super::sqlite_value::{OwnedSqliteValue, SqliteValue}; use super::stmt::StatementUse; -use crate::row::{Field, PartialRow, Row, RowFieldHelper, RowIndex}; +use crate::row::{Field, PartialRow, Row, RowGatWorkaround, RowIndex}; use crate::sqlite::Sqlite; use crate::util::OnceCell; @@ -60,7 +60,7 @@ impl<'a> PrivateSqliteRow<'a> { } } -impl<'a, 'b> RowFieldHelper<'a, Sqlite> for SqliteRow<'b> { +impl<'a, 'b> RowGatWorkaround<'a, Sqlite> for SqliteRow<'b> { type Field = SqliteField<'a>; } @@ -71,7 +71,7 @@ impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { self.field_count } - fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> + fn get<'b, I>(&'b self, idx: I) -> Option<>::Field> where 'a: 'b, Self: RowIndex, @@ -167,10 +167,7 @@ impl<'a> Field<'a, Sqlite> for SqliteField<'a> { self.value().is_none() } - fn value<'d>(&'d self) -> Option> - where - 'a: 'd, - { + fn value(&self) -> Option> { SqliteValue::new(Ref::clone(&self.row), self.col_idx) } } diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.rs similarity index 100% rename from diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.rs rename to diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.rs diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr similarity index 100% rename from diesel_compile_tests/tests/fail/sqlite_and_mysql_don_not_allow_multiple_iterators.stderr rename to diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index 0d07cab6cd47..524bd1fdb726 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -1245,7 +1245,7 @@ fn query_single_value(sql_str: &str) -> U where U: FromSqlRow + 'static, TestBackend: HasSqlType, - T: QueryId + SingleValue + SqlType + 'static, + T: QueryId + SingleValue + SqlType, { use diesel::dsl::sql; let connection = &mut connection(); diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index b4c5995c72e1..4296e65db6eb 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -20,7 +20,7 @@ use std::collections::Bound; pub fn test_type_round_trips(value: T) -> bool where - ST: QueryId + SqlType + TypedExpressionType + SingleValue + 'static, + ST: QueryId + SqlType + TypedExpressionType + SingleValue, ::Backend: HasSqlType, T: AsExpression + FromSqlRow::Backend> From 83eee55a4b99441972e53d9a8e260a5d63e36028 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 10:21:34 +0200 Subject: [PATCH 24/32] Post rebase format fixes --- diesel/src/pg/connection/result.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index d18d512a6b0f..7fdc401a33a7 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -12,7 +12,6 @@ use super::row::PgRow; use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryResult}; use crate::util::OnceCell; - // Message after a database connection has been unexpectedly closed. const CLOSED_CONNECTION_MSG: &str = "server closed the connection unexpectedly\n\t\ This probably means the server terminated abnormally\n\tbefore or while processing the request.\n"; From ac7609dd07ce603509d4dd10204a50c3cae5ab71 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 10:34:20 +0200 Subject: [PATCH 25/32] Fix r2d2 --- diesel/src/r2d2.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index a8bc7e5861ed..127a17f4a026 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -17,7 +17,7 @@ use std::fmt; use std::marker::PhantomData; use crate::backend::Backend; -use crate::connection::{IterableConnection, SimpleConnection, TransactionManager}; +use crate::connection::{ConnectionGatWorkaround, SimpleConnection, TransactionManager}; use crate::expression::QueryMetadata; use crate::prelude::*; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; @@ -130,14 +130,14 @@ where } } -impl<'a, DB, M> IterableConnection<'a, DB> for PooledConnection +impl<'a, DB, M> ConnectionGatWorkaround<'a, DB> for PooledConnection where M: ManageConnection, M::Connection: Connection, DB: Backend, { - type Cursor = >::Cursor; - type Row = >::Row; + type Cursor = >::Cursor; + type Row = >::Row; } impl Connection for PooledConnection @@ -162,7 +162,7 @@ where fn load<'a, T>( &'a mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, From d7ae4bce205da9c5f250f9d76658bee4088f1822 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 14:26:31 +0200 Subject: [PATCH 26/32] Rustdoc + clippy fixes --- diesel/src/connection/mod.rs | 8 ++++---- diesel/src/mysql/connection/mod.rs | 6 +++--- diesel/src/pg/connection/mod.rs | 5 +---- diesel/src/query_dsl/load_dsl.rs | 12 ++++++------ diesel/src/query_dsl/mod.rs | 6 +++--- diesel/src/r2d2.rs | 6 +++--- diesel/src/sqlite/connection/mod.rs | 6 +++--- 7 files changed, 23 insertions(+), 26 deletions(-) diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index cbd0516f3de5..86ee9cc760c7 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -35,7 +35,7 @@ pub trait ConnectionGatWorkaround<'a, DB: Backend> { /// Users should handle this as opaque type that implements [`Iterator`] type Cursor: Iterator>; /// The row type used as [`Iterator::Item`] for the iterator implementation - /// of [`IterableConnection::Cursor`] + /// of [`ConnectionGatWorkaround::Cursor`] type Row: crate::row::Row<'a, DB>; } @@ -192,10 +192,10 @@ where fn execute(&mut self, query: &str) -> QueryResult; #[doc(hidden)] - fn load<'a, T>( - &'a mut self, + fn load( + &mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 3040b30bc55f..8ba6eb950a32 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -65,10 +65,10 @@ impl Connection for MysqlConnection { } #[doc(hidden)] - fn load<'a, T>( - &'a mut self, + fn load( + &mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 09f863e01700..36b1e35994b7 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -72,10 +72,7 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn load<'a, T>( - &'a mut self, - source: T, - ) -> QueryResult<>::Cursor> + fn load(&mut self, source: T) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index 247908e62218..35a1ee464ab3 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -18,10 +18,10 @@ where for<'a> Self: LoadQueryGatWorkaround<'a, Conn, U>, { /// Load this query - fn internal_load<'a>( + fn internal_load( self, - conn: &'a mut Conn, - ) -> QueryResult<>::Ret>; + conn: &mut Conn, + ) -> QueryResult<>::Ret>; } pub trait LoadQueryGatWorkaround<'a, Conn, U> { @@ -98,10 +98,10 @@ where U: FromSqlRow<>::SqlType, DB> + 'static, >::SqlType: 'static, { - fn internal_load<'a>( + fn internal_load( self, - conn: &'a mut Conn, - ) -> QueryResult<>::Ret> { + conn: &mut Conn, + ) -> QueryResult<>::Ret> { Ok(LoadIter { cursor: conn.load(self)?, _marker: Default::default(), diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index 6a9deb291385..234cff123cd5 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -1504,10 +1504,10 @@ pub trait RunQueryDsl: Sized { /// # Ok(()) /// # } /// ``` - fn load_iter<'a, U>( + fn load_iter( self, - conn: &'a mut Conn, - ) -> QueryResult<>::Ret> + conn: &mut Conn, + ) -> QueryResult<>::Ret> where Self: LoadQuery, { diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 127a17f4a026..5d24c3b7d2a9 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -159,10 +159,10 @@ where (&mut **self).execute(query) } - fn load<'a, T>( - &'a mut self, + fn load( + &mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 07961478244b..195de331eeec 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -90,10 +90,10 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn load<'a, T>( - &'a mut self, + fn load( + &mut self, source: T, - ) -> QueryResult<>::Cursor> + ) -> QueryResult<>::Cursor> where T: AsQuery, T::Query: QueryFragment + QueryId, From 594b84f22914d646c3635cba988c05db0f5fd018 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 14:35:26 +0200 Subject: [PATCH 27/32] Apply review commend regarding to `Row::get_value` --- diesel/src/row.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/diesel/src/row.rs b/diesel/src/row.rs index f085aa82befa..d2ce59667179 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -62,16 +62,19 @@ pub trait Row<'a, DB: Backend>: /// Get a deserialized value with the provided index from the row. /// - /// Returns `None` if there is no matching field for the given index - /// Returns `Some(Err(…)` if there is an error during deserialization - /// Returns `Some(T)` if deserialization is successful - fn get_value(&self, idx: I) -> Option> + /// * Returns `Ok(T)` if deserialization is successful + /// * Returns `Err(Error::DeserializationError)` if there is an error during deserialization + /// * Returns `Err(Error::NotFound)` if the row does not contain a value at this position + /// use [`result.optional()`](crate::result::OptionalExtension) to convert the result to + /// `QueryResult>` for explicit access to the `None` values + fn get_value(&self, idx: I) -> crate::result::QueryResult where Self: RowIndex, T: FromSql, { - let field = self.get(idx)?; - Some(>::from_nullable_sql(field.value())) + let field = self.get(idx).ok_or(crate::result::Error::NotFound)?; + >::from_nullable_sql(field.value()) + .map_err(crate::result::Error::DeserializationError) } /// Returns a wrapping row that allows only to access fields, where the index is part of From 065816a143b12b20af84a61d493952eb419efd94 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 14:37:01 +0200 Subject: [PATCH 28/32] Fix trybuild errors --- .../sqlite_and_mysql_do_not_allow_multiple_iterators.stderr | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr index a1fa2d21eaf3..c3fa7217fb21 100644 --- a/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr +++ b/diesel_compile_tests/tests/fail/sqlite_and_mysql_do_not_allow_multiple_iterators.stderr @@ -1,5 +1,5 @@ error[E0499]: cannot borrow `*conn` as mutable more than once at a time - --> $DIR/sqlite_and_mysql_don_not_allow_multiple_iterators.rs:13:21 + --> $DIR/sqlite_and_mysql_do_not_allow_multiple_iterators.rs:13:21 | 12 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); | ---- first mutable borrow occurs here @@ -10,7 +10,7 @@ error[E0499]: cannot borrow `*conn` as mutable more than once at a time | --------- first borrow later used here error[E0499]: cannot borrow `*conn` as mutable more than once at a time - --> $DIR/sqlite_and_mysql_don_not_allow_multiple_iterators.rs:20:21 + --> $DIR/sqlite_and_mysql_do_not_allow_multiple_iterators.rs:20:21 | 19 | let row_iter1 = conn.load(&sql_query("bar")).unwrap(); | ---- first mutable borrow occurs here From e1e7c0a50a879922b615c9e2913c815bb56ad5cc Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 26 Aug 2021 15:30:08 +0200 Subject: [PATCH 29/32] More ci fixes --- diesel/src/mysql/connection/bind.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index efc4a4b6bc6c..14994cf84c39 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -754,11 +754,12 @@ mod tests { let metadata = stmt.metadata().unwrap(); let mut output_binds = - Binds::from_output_types(&vec![None; metadata.fields().len()], &metadata); + OutputBinds::from_output_types(&vec![None; metadata.fields().len()], &metadata); stmt.execute_statement(&mut output_binds).unwrap(); stmt.populate_row_buffers(&mut output_binds).unwrap(); let results: Vec<(BindData, &_)> = output_binds + .0 .data .into_iter() .zip(metadata.fields()) From 2dedaa4e510cfc70b6c87c4f700c637790bd9528 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 27 Aug 2021 14:05:38 +0200 Subject: [PATCH 30/32] Address comment about `Row::get_value` --- diesel/src/row.rs | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/diesel/src/row.rs b/diesel/src/row.rs index d2ce59667179..6f7635d5b0d0 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -61,20 +61,13 @@ pub trait Row<'a, DB: Backend>: Self: RowIndex; /// Get a deserialized value with the provided index from the row. - /// - /// * Returns `Ok(T)` if deserialization is successful - /// * Returns `Err(Error::DeserializationError)` if there is an error during deserialization - /// * Returns `Err(Error::NotFound)` if the row does not contain a value at this position - /// use [`result.optional()`](crate::result::OptionalExtension) to convert the result to - /// `QueryResult>` for explicit access to the `None` values - fn get_value(&self, idx: I) -> crate::result::QueryResult + fn get_value(&self, idx: I) -> crate::deserialize::Result where Self: RowIndex, T: FromSql, { - let field = self.get(idx).ok_or(crate::result::Error::NotFound)?; + let field = self.get(idx).ok_or(crate::result::UnexpectedEndOfRow)?; >::from_nullable_sql(field.value()) - .map_err(crate::result::Error::DeserializationError) } /// Returns a wrapping row that allows only to access fields, where the index is part of From fa913fb08b1db0404eee1caace9aa8745034db85 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Wed, 15 Sep 2021 08:45:29 +0000 Subject: [PATCH 31/32] Apply suggestions from code review --- diesel/src/sqlite/connection/functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index 063c07b9d6a3..500fc52286c9 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -154,7 +154,7 @@ impl<'a> FunctionRow<'a> { let args = unsafe { Vec::from_raw_parts( // This cast is safe because: - // * Casting from a pointer to an arry to a pointer to the first array + // * Casting from a pointer to an array to a pointer to the first array // element is safe // * Casting from a raw pointer to `NonNull` is safe, // because `NonNull` is #[repr(transparent)] From 69876a1892644f1dddbf4889763fc741f1f3faca Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 17 Sep 2021 11:38:31 +0200 Subject: [PATCH 32/32] Return a boxed iterator from `RunQueryDsl::load_iter` This makes it much more understandable what's the type of the return type, while having a minimal performance cost. I've opted into applying this change only for this method and not for `LoadDsl::internal_load` and `Connection::load` as both are not really user facing. (I would count them as advanced methods most users never see and never use directly) --- diesel/src/query_dsl/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/diesel/src/query_dsl/mod.rs b/diesel/src/query_dsl/mod.rs index 234cff123cd5..fd8696d6378e 100644 --- a/diesel/src/query_dsl/mod.rs +++ b/diesel/src/query_dsl/mod.rs @@ -52,7 +52,6 @@ pub use self::join_dsl::{InternalJoinDsl, JoinOnDsl, JoinWithImplicitOnClause}; pub use self::load_dsl::CompatibleType; #[doc(hidden)] pub use self::load_dsl::LoadQuery; -use self::load_dsl::LoadQueryGatWorkaround; pub use self::save_changes_dsl::{SaveChangesDsl, UpdateAndFetchResults}; /// The traits used by `QueryDsl`. @@ -1504,14 +1503,15 @@ pub trait RunQueryDsl: Sized { /// # Ok(()) /// # } /// ``` - fn load_iter( + fn load_iter<'a, U>( self, - conn: &mut Conn, - ) -> QueryResult<>::Ret> + conn: &'a mut Conn, + ) -> QueryResult> + 'a>> where - Self: LoadQuery, + U: 'a, + Self: LoadQuery + 'a, { - self.internal_load(conn) + self.internal_load(conn).map(|i| Box::new(i) as Box<_>) } /// Runs the command, and returns the affected row.