From 4d8aa98fe058d395b26144da8e67ef31e31698a7 Mon Sep 17 00:00:00 2001 From: Yoh Deadfall Date: Fri, 11 Oct 2024 12:31:02 +0300 Subject: [PATCH] Added connect_mut for data changing SPI operations --- pgrx-examples/schemas/src/lib.rs | 2 +- pgrx-tests/src/tests/bgworker_tests.rs | 4 +- pgrx-tests/src/tests/guc_tests.rs | 2 +- pgrx-tests/src/tests/pg_cast_tests.rs | 2 +- pgrx-tests/src/tests/spi_tests.rs | 23 +++--- pgrx-tests/src/tests/srf_tests.rs | 4 +- pgrx-tests/src/tests/struct_type_tests.rs | 2 +- .../escaping-spiclient-1209-cursor.stderr | 4 +- .../escaping-spiclient-1209-prep-stmt.stderr | 2 +- pgrx/src/spi.rs | 75 ++++++++++++++----- pgrx/src/spi/client.rs | 39 +++------- pgrx/src/spi/cursor.rs | 6 +- 12 files changed, 94 insertions(+), 71 deletions(-) diff --git a/pgrx-examples/schemas/src/lib.rs b/pgrx-examples/schemas/src/lib.rs index 66064c7660..bdf59ef39d 100644 --- a/pgrx-examples/schemas/src/lib.rs +++ b/pgrx-examples/schemas/src/lib.rs @@ -101,7 +101,7 @@ mod tests { #[pg_test] fn test_my_some_schema_type() -> Result<(), spi::Error> { - Spi::connect(|mut c| { + Spi::connect_mut(|c| { // "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable c.update("SET search_path TO some_schema,public", None, &[])?; assert_eq!( diff --git a/pgrx-tests/src/tests/bgworker_tests.rs b/pgrx-tests/src/tests/bgworker_tests.rs index 044a560a60..92117bcf52 100644 --- a/pgrx-tests/src/tests/bgworker_tests.rs +++ b/pgrx-tests/src/tests/bgworker_tests.rs @@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) { if arg > 0 { BackgroundWorker::transaction(|| { Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?; - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client .update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()]) .map(|_| ()) @@ -66,7 +66,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) { }; while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {} BackgroundWorker::transaction(|| { - Spi::connect(|mut c| { + Spi::connect_mut(|c| { c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()]) .map(|_| ()) }) diff --git a/pgrx-tests/src/tests/guc_tests.rs b/pgrx-tests/src/tests/guc_tests.rs index 08bc355df5..58a53562ad 100644 --- a/pgrx-tests/src/tests/guc_tests.rs +++ b/pgrx-tests/src/tests/guc_tests.rs @@ -202,7 +202,7 @@ mod tests { Spi::run("SET test.no_show TO false;").expect("SPI failed"); Spi::run("SET test.no_reset_all TO false;").expect("SPI failed"); assert_eq!(GUC_NO_RESET_ALL.get(), false); - Spi::connect(|mut client| { + Spi::connect_mut(|client| { let r = client.update("SHOW ALL", None, &[]).expect("SPI failed"); let mut no_reset_guc_in_show_all = false; diff --git a/pgrx-tests/src/tests/pg_cast_tests.rs b/pgrx-tests/src/tests/pg_cast_tests.rs index 856a7399c0..3a8ed1f722 100644 --- a/pgrx-tests/src/tests/pg_cast_tests.rs +++ b/pgrx-tests/src/tests/pg_cast_tests.rs @@ -57,7 +57,7 @@ mod tests { #[pg_test] fn test_pg_cast_assignment_type_cast() { - let _ = Spi::connect(|mut client| { + let _ = Spi::connect_mut(|client| { client.update("CREATE TABLE test_table(value int4);", None, &[])?; client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?; diff --git a/pgrx-tests/src/tests/spi_tests.rs b/pgrx-tests/src/tests/spi_tests.rs index 05b16ee616..567b0eaf47 100644 --- a/pgrx-tests/src/tests/spi_tests.rs +++ b/pgrx-tests/src/tests/spi_tests.rs @@ -165,7 +165,7 @@ mod tests { #[pg_test] fn test_inserting_null() -> Result<(), pgrx::spi::Error> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update("CREATE TABLE tests.null_test (id uuid)", None, &[]).map(|_| ()) })?; assert_eq!( @@ -188,7 +188,7 @@ mod tests { #[pg_test] fn test_cursor() -> Result<(), spi::Error> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?; client.update( "INSERT INTO tests.cursor_table (id) \ @@ -208,7 +208,7 @@ mod tests { #[pg_test] fn test_cursor_prepared_statement() -> Result<(), pgrx::spi::Error> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?; client.update( "INSERT INTO tests.cursor_table (id) \ @@ -245,7 +245,7 @@ mod tests { fn test_cursor_prepared_statement_panics_impl( args: &[DatumWithOid], ) -> Result<(), pgrx::spi::Error> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?; client.update( "INSERT INTO tests.cursor_table (id) \ @@ -264,7 +264,7 @@ mod tests { #[pg_test] fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> { - let cursor_name = Spi::connect(|mut client| { + let cursor_name = Spi::connect_mut(|client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?; client.update( "INSERT INTO tests.cursor_table (id) \ @@ -318,7 +318,7 @@ mod tests { Ok::<_, spi::Error>(()) })?; - Spi::connect(|mut client| { + Spi::connect_mut(|client| { let res = client.update("SET TIME ZONE 'PST8PDT'", None, &[])?; assert_eq!(Err(spi::Error::NoTupleTable), res.columns()); @@ -334,9 +334,8 @@ mod tests { #[pg_test] fn test_spi_non_mut() -> Result<(), pgrx::spi::Error> { - // Ensures update and cursor APIs do not need mutable reference to SpiClient - Spi::connect(|mut client| { - client.update("SELECT 1", None, &[]).expect("SPI failed"); + // Ensures cursor APIs do not need mutable reference to SpiClient + Spi::connect(|client| { let cursor = client.open_cursor("SELECT 1", &[]).detach_into_name(); client.find_cursor(&cursor).map(|_| ()) }) @@ -428,7 +427,7 @@ mod tests { #[pg_test] fn test_readwrite_in_select_readwrite() -> Result<(), spi::Error> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { // This is supposed to switch connection to read-write and run it there client.update("CREATE TABLE a (id INT)", None, &[])?; // This is supposed to run in read-write @@ -459,7 +458,7 @@ mod tests { #[pg_test] fn test_spi_select_sees_update() -> spi::Result<()> { - let with_select = Spi::connect(|mut client| { + let with_select = Spi::connect_mut(|client| { client.update("CREATE TABLE asd(id int)", None, &[])?; client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?; client.select("SELECT COUNT(*) FROM asd", None, &[])?.first().get_one::() @@ -485,7 +484,7 @@ mod tests { #[pg_test] fn test_spi_select_sees_update_in_other_session() -> spi::Result<()> { - Spi::connect::, _>(|mut client| { + Spi::connect_mut::, _>(|client| { client.update("CREATE TABLE asd(id int)", None, &[])?; client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?; Ok(()) diff --git a/pgrx-tests/src/tests/srf_tests.rs b/pgrx-tests/src/tests/srf_tests.rs index 3fc19ea9a3..bcf9b8b32f 100644 --- a/pgrx-tests/src/tests/srf_tests.rs +++ b/pgrx-tests/src/tests/srf_tests.rs @@ -243,7 +243,7 @@ mod tests { #[pg_test] fn test_srf_setof_datum_detoasting_with_borrow() { - let cnt = Spi::connect(|mut client| { + let cnt = Spi::connect_mut(|client| { // build up a table with one large column that Postgres will be forced to TOAST client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?; @@ -261,7 +261,7 @@ mod tests { #[pg_test] fn test_srf_table_datum_detoasting_with_borrow() { - let cnt = Spi::connect(|mut client| { + let cnt = Spi::connect_mut(|client| { // build up a table with one large column that Postgres will be forced to TOAST client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?; diff --git a/pgrx-tests/src/tests/struct_type_tests.rs b/pgrx-tests/src/tests/struct_type_tests.rs index 5dcf8ce08c..5b464ef49a 100644 --- a/pgrx-tests/src/tests/struct_type_tests.rs +++ b/pgrx-tests/src/tests/struct_type_tests.rs @@ -57,7 +57,7 @@ mod tests { #[pg_test] fn test_complex_storage_and_retrieval() -> Result<(), pgrx::spi::Error> { - let complex = Spi::connect(|mut client| { + let complex = Spi::connect_mut(|client| { client.update( "CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\ SELECT value FROM complex_test ORDER BY id;", None, &[])?.first().get_one::>() diff --git a/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr b/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr index f4ef0709b0..4d03ef900e 100644 --- a/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr +++ b/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr @@ -4,7 +4,7 @@ error: lifetime may not live long enough 8 | let mut res = Spi::connect(|c| { | -- return type of closure is SpiTupleTable<'2> | | - | has type `SpiClient<'1>` + | has type `&SpiClient<'1>` 9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", &[]) 10 | | .fetch(1000) 11 | | .unwrap() @@ -31,7 +31,7 @@ error: lifetime may not live long enough | -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2` | || | |return type of closure is SpiTupleTable<'2> - | has type `SpiClient<'1>` + | has type `&SpiClient<'1>` error[E0515]: cannot return value referencing temporary value --> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26 diff --git a/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-prep-stmt.stderr b/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-prep-stmt.stderr index 434890dfe7..0cdcc09515 100644 --- a/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-prep-stmt.stderr +++ b/pgrx-tests/tests/compile-fail/escaping-spiclient-1209-prep-stmt.stderr @@ -5,4 +5,4 @@ error: lifetime may not live long enough | -- ^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2` | || | |return type of closure is std::result::Result, pgrx::spi::SpiError> - | has type `SpiClient<'1>` + | has type `&SpiClient<'1>` diff --git a/pgrx/src/spi.rs b/pgrx/src/spi.rs index cebbde7b4a..2d39929beb 100644 --- a/pgrx/src/spi.rs +++ b/pgrx/src/spi.rs @@ -21,7 +21,6 @@ mod cursor; mod query; mod tuple; pub use client::SpiClient; -use client::SpiConnection; pub use cursor::SpiCursor; pub use query::{OwnedPreparedStatement, PreparedStatement, Query}; pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable}; @@ -237,13 +236,13 @@ impl Spi { } pub fn get_one(query: &str) -> Result> { - Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_one()) + Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_one()) } pub fn get_two( query: &str, ) -> Result<(Option, Option)> { - Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_two::()) + Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_two::()) } pub fn get_three< @@ -253,7 +252,7 @@ impl Spi { >( query: &str, ) -> Result<(Option, Option, Option)> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update(query, Some(1), &[])?.first().get_three::() }) } @@ -262,14 +261,14 @@ impl Spi { query: &str, args: &[DatumWithOid<'mcx>], ) -> Result> { - Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_one()) + Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_one()) } pub fn get_two_with_args<'mcx, A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>( query: &str, args: &[DatumWithOid<'mcx>], ) -> Result<(Option, Option)> { - Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_two::()) + Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_two::()) } pub fn get_three_with_args< @@ -281,12 +280,12 @@ impl Spi { query: &str, args: &[DatumWithOid<'mcx>], ) -> Result<(Option, Option, Option)> { - Spi::connect(|mut client| { + Spi::connect_mut(|client| { client.update(query, Some(1), args)?.first().get_three::() }) } - /// just run an arbitrary SQL statement. + /// Just run an arbitrary SQL statement. /// /// ## Safety /// @@ -304,7 +303,7 @@ impl Spi { query: &str, args: &[DatumWithOid<'mcx>], ) -> std::result::Result<(), Error> { - Spi::connect(|mut client| client.update(query, None, args).map(|_| ())) + Spi::connect_mut(|client| client.update(query, None, args).map(|_| ())) } /// explain a query, returning its result in json form @@ -314,7 +313,7 @@ impl Spi { /// explain a query with args, returning its result in json form pub fn explain_with_args<'mcx>(query: &str, args: &[DatumWithOid<'mcx>]) -> Result { - Ok(Spi::connect(|mut client| { + Ok(Spi::connect_mut(|client| { client .update(&format!("EXPLAIN (format json) {query}"), None, args)? .first() @@ -323,7 +322,7 @@ impl Spi { .unwrap()) } - /// Execute SPI commands via the provided `SpiClient`. + /// Execute SPI read-only commands via the provided `SpiClient`. /// /// While inside the provided closure, code executes under a short-lived "SPI Memory Context", /// and Postgres will completely free that context when this function is finished. @@ -360,10 +359,51 @@ impl Spi { /// ([`pg_sys::SPI_connect()`]) **always** returns a successful response. pub fn connect(f: F) -> R where - F: FnOnce(SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes: - - 'conn ~= CurrentMemoryContext after connection - - 'ret ~= SPI_palloc's context - */ + F: FnOnce(&SpiClient<'_>) -> R, + { + Self::connect_mut(|client| f(client)) + } + + /// Execute SPI mutating commands via the provided `SpiClient`. + /// + /// While inside the provided closure, code executes under a short-lived "SPI Memory Context", + /// and Postgres will completely free that context when this function is finished. + /// + /// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are + /// automatically copied into the into the `CurrentMemoryContext` at the time of this + /// function call. + /// + /// # Examples + /// + /// ```rust,no_run + /// use pgrx::prelude::*; + /// # fn foo() -> spi::Result<()> { + /// Spi::connect_mut(|client| { + /// client.update("INSERT INTO users VALUES ('Bob')", None, &[])?; + /// Ok(()) + /// }) + /// # } + /// ``` + /// + /// Note that `SpiClient` is scoped to the connection lifetime and cannot be returned. The + /// following code will not compile: + /// + /// ```rust,compile_fail + /// use pgrx::prelude::*; + /// let cant_return_client = Spi::connect(|client| client); + /// ``` + /// + /// # Panics + /// + /// This function will panic if for some reason it's unable to "connect" to Postgres' SPI + /// system. At the time of this writing, that's actually impossible as the underlying function + /// ([`pg_sys::SPI_connect()`]) **always** returns a successful response. + pub fn connect_mut(f: F) -> R + where + F: FnOnce(&mut SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes: + - 'conn ~= CurrentMemoryContext after connection + - 'ret ~= SPI_palloc's context + */ { // connect to SPI // @@ -379,14 +419,13 @@ impl Spi { // otherwise this function would need to return a `Result` and that's a // fucking nightmare for users to deal with. There's ample discussion around coming to // this decision at https://github.com/pgcentralfoundation/pgrx/pull/977 - let connection = - SpiConnection::connect().expect("SPI_connect indicated an unexpected failure"); + let mut client = SpiClient::connect().expect("SPI_connect indicated an unexpected failure"); // run the provided closure within the memory context that SPI_connect() // just put us un. We'll disconnect from SPI when the closure is finished. // If there's a panic or elog(ERROR), we don't care about also disconnecting from // SPI b/c Postgres will do that for us automatically - f(connection.client()) + f(&mut client) } #[track_caller] diff --git a/pgrx/src/spi/client.rs b/pgrx/src/spi/client.rs index 557aaa50cf..e073105be6 100644 --- a/pgrx/src/spi/client.rs +++ b/pgrx/src/spi/client.rs @@ -9,10 +9,18 @@ use super::query::PreparableQuery; // TODO: should `'conn` be invariant? pub struct SpiClient<'conn> { - __marker: PhantomData<&'conn SpiConnection>, + __marker: PhantomData<&'conn ()>, } impl<'conn> SpiClient<'conn> { + /// Connect to Postgres' SPI system + pub(super) fn connect() -> SpiResult { + // SPI_connect() is documented as being able to return SPI_ERROR_CONNECT, so we have to + // assume it could. The truth seems to be that it never actually does. + Spi::check_status(unsafe { pg_sys::SPI_connect() })?; + Ok(SpiClient { __marker: PhantomData }) + } + /// Prepares a statement that is valid for the lifetime of the client pub fn prepare>( &self, @@ -156,35 +164,12 @@ impl<'conn> SpiClient<'conn> { } } -/// a struct to manage our SPI connection lifetime -pub(super) struct SpiConnection(PhantomData<*mut ()>); - -impl SpiConnection { - /// Connect to Postgres' SPI system - pub(super) fn connect() -> SpiResult { - // connect to SPI - // - // SPI_connect() is documented as being able to return SPI_ERROR_CONNECT, so we have to - // assume it could. The truth seems to be that it never actually does. The one user - // of SpiConnection::connect() returns `spi::Result` anyways, so it's no big deal - Spi::check_status(unsafe { pg_sys::SPI_connect() })?; - Ok(SpiConnection(PhantomData)) - } -} - -impl Drop for SpiConnection { - /// when SpiConnection is dropped, we make sure to disconnect from SPI +impl Drop for SpiClient<'_> { + /// When `SpiClient` is dropped, we make sure to disconnect from SPI fn drop(&mut self) { - // best efforts to disconnect from SPI + // Best efforts to disconnect from SPI // SPI_finish() would only complain if we hadn't previously called SPI_connect() and // SpiConnection should prevent that from happening (assuming users don't go unsafe{}) Spi::check_status(unsafe { pg_sys::SPI_finish() }).ok(); } } - -impl SpiConnection { - /// Return a client that with a lifetime scoped to this connection. - pub(super) fn client(&self) -> SpiClient<'_> { - SpiClient { __marker: PhantomData } - } -} diff --git a/pgrx/src/spi/cursor.rs b/pgrx/src/spi/cursor.rs index 40f13c9683..8aaa547577 100644 --- a/pgrx/src/spi/cursor.rs +++ b/pgrx/src/spi/cursor.rs @@ -32,7 +32,7 @@ type CursorName = String; /// ```rust,no_run /// use pgrx::prelude::*; /// # fn foo() -> spi::Result<()> { -/// Spi::connect(|mut client| { +/// Spi::connect_mut(|client| { /// let mut cursor = client.open_cursor("SELECT * FROM generate_series(1, 5)", &[]); /// assert_eq!(Some(1), cursor.fetch(1)?.get_one::()?); /// assert_eq!(Some(2), cursor.fetch(2)?.get_one::()?); @@ -47,13 +47,13 @@ type CursorName = String; /// ```rust,no_run /// use pgrx::prelude::*; /// # fn foo() -> spi::Result<()> { -/// let cursor_name = Spi::connect(|mut client| { +/// let cursor_name = Spi::connect_mut(|client| { /// let mut cursor = client.open_cursor("SELECT * FROM generate_series(1, 5)", &[]); /// assert_eq!(Ok(Some(1)), cursor.fetch(1)?.get_one::()); /// Ok::<_, spi::Error>(cursor.detach_into_name()) // <-- cursor gets dropped here /// // <--- first SpiTupleTable gets freed by Spi::connect at this point /// })?; -/// Spi::connect(|mut client| { +/// Spi::connect_mut(|client| { /// let mut cursor = client.find_cursor(&cursor_name)?; /// assert_eq!(Ok(Some(2)), cursor.fetch(1)?.get_one::()); /// drop(cursor); // <-- cursor gets dropped here