diff --git a/Cargo.lock b/Cargo.lock index a58554a5f..deb0dbf03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -610,6 +610,26 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "enum-map" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "988f0d17a0fa38291e5f41f71ea8d46a5d5497b9054d5a759fae2cbb819f2356" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a4da76b3b6116d758c7ba93f7ec6a35d2e2cf24feda76c6e38a375f4d5c59f2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_proxy" version = "0.4.1" @@ -1407,6 +1427,7 @@ dependencies = [ "atomic-traits", "bitflags", "bitvec", + "enum-map", "heapless", "libc", "once_cell", diff --git a/pgx/Cargo.toml b/pgx/Cargo.toml index e9dd61be7..67210cfae 100644 --- a/pgx/Cargo.toml +++ b/pgx/Cargo.toml @@ -42,6 +42,7 @@ pgx-sql-entity-graph = { path = "../pgx-sql-entity-graph", version = "=0.7.4" } once_cell = "1.17.1" # polyfill until std::lazy::OnceCell stabilizes seq-macro = "0.3" # impls loops in macros uuid = { version = "1.3.0", features = [ "v4" ] } # PgLwLock and shmem +enum-map = "2.4.2" # error handling and logging thiserror = "1.0" diff --git a/pgx/src/callbacks.rs b/pgx/src/callbacks.rs index be923230c..171db1464 100644 --- a/pgx/src/callbacks.rs +++ b/pgx/src/callbacks.rs @@ -12,12 +12,13 @@ Use of this source code is governed by the MIT license that can be found in the use crate as pgx; // for #[pg_guard] support from within ourself use crate::pg_sys; use crate::prelude::*; +use enum_map::{Enum, EnumMap}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; /// Postgres Transaction (Xact) Callback Events -#[derive(Hash, Eq, PartialEq, Clone, Debug)] +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Enum)] pub enum PgXactCallbackEvent { /// Fired when a transaction is aborted. It is mutually exclusive with `PgXactCallbackEvent::Commit` /// @@ -105,7 +106,8 @@ struct XactCallbackWrapper( ); /// Shorthand for the type representing the map of callbacks -type CallbackMap = HashMap>>>>; +type CallbackMap = + EnumMap>>>>>; /// Register a closure to be called during one of the `PgXactCallbackEvent` events. Multiple /// closures can be registered per event (one at a time), and they are called in the order in which @@ -176,12 +178,12 @@ where | PgXactCallbackEvent::Abort | PgXactCallbackEvent::ParallelCommit | PgXactCallbackEvent::ParallelAbort => XACT_HOOKS - .replace(HashMap::new()) - .expect("XACT_HOOKS was None during Commit/Abort") - .remove(&which_event), + .replace(CallbackMap::default()) + .expect("XACT_HOOKS was None during Commit/Abort")[which_event] + .take(), // not in a transaction-end event, so just borrow our map - _ => XACT_HOOKS.as_mut().expect("XACT_HOOKS was None").remove(&which_event), + _ => XACT_HOOKS.as_mut().expect("XACT_HOOKS was None")[which_event].take(), }; // if we have a vec of hooks for this event they're consumed here and executed @@ -206,7 +208,7 @@ where // if this is our first time here since the Postgres backend started, XACT_HOOKS will be None if XACT_HOOKS.is_none() { // so lets swap it out with a new HashMap, which will live for the duration of the backend - XACT_HOOKS.replace(HashMap::new()); + XACT_HOOKS.replace(Default::default()); // and register our single callback function (internally defined above) pg_sys::RegisterXactCallback(Some(callback), std::ptr::null_mut()); @@ -224,7 +226,7 @@ where let wrapped_func = Rc::new(RefCell::new(Some(XactCallbackWrapper(Box::new(f))))); // find (or create) the map Entry for the specified event and add our wrapped hook to it - let entry = hooks.entry(which_event).or_default(); + let entry = hooks[which_event].get_or_insert_with(Default::default); entry.push(Rc::clone(&wrapped_func)); // give the user the ability to unregister diff --git a/pgx/src/spi.rs b/pgx/src/spi.rs index 1f1f4f31c..1d9c5a3ba 100644 --- a/pgx/src/spi.rs +++ b/pgx/src/spi.rs @@ -15,7 +15,6 @@ use crate::{ }; use core::fmt::Formatter; use pgx_pg_sys::panic::ErrorReportable; -use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::fmt::Debug; use std::marker::PhantomData; @@ -413,7 +412,8 @@ pub struct SpiHeapTupleDataEntry { /// Represents the set of `pg_sys::Datum`s in a `pg_sys::HeapTuple` pub struct SpiHeapTupleData { tupdesc: NonNull, - entries: HashMap, + // offset by 1! + entries: Vec, } impl Spi { @@ -1247,16 +1247,17 @@ impl SpiHeapTupleData { htup: *mut pg_sys::HeapTupleData, ) -> Result> { let tupdesc = NonNull::new(tupdesc).ok_or(Error::NoTupleTable)?; - let mut data = SpiHeapTupleData { tupdesc, entries: HashMap::default() }; + let mut data = SpiHeapTupleData { tupdesc, entries: Vec::new() }; let tupdesc = tupdesc.as_ptr(); unsafe { // SAFETY: we know tupdesc is not null - for i in 1..=tupdesc.as_ref().unwrap().natts { + let natts = (*tupdesc).natts; + data.entries.reserve(usize::try_from(natts as usize).unwrap_or_default()); + for i in 1..=natts { let mut is_null = false; let datum = pg_sys::SPI_getbinval(htup, tupdesc, i, &mut is_null); - - data.entries.entry(i as usize).or_insert_with(|| SpiHeapTupleDataEntry { + data.entries.push(SpiHeapTupleDataEntry { datum: if is_null { None } else { Some(datum) }, type_oid: pg_sys::SPI_gettypeid(tupdesc, i), }); @@ -1302,7 +1303,9 @@ impl SpiHeapTupleData { &self, ordinal: usize, ) -> std::result::Result<&SpiHeapTupleDataEntry, Error> { - self.entries.get(&ordinal).ok_or_else(|| Error::SpiError(SpiErrorCodes::NoAttribute)) + // Wrapping because `self.entries.get(...)` will bounds check. + let index = ordinal.wrapping_sub(1); + self.entries.get(index).ok_or_else(|| Error::SpiError(SpiErrorCodes::NoAttribute)) } /// Get a raw Datum from this HeapTuple by its field name. @@ -1341,10 +1344,8 @@ impl SpiHeapTupleData { datum: T, ) -> std::result::Result<(), Error> { self.check_ordinal_bounds(ordinal)?; - self.entries.insert( - ordinal, - SpiHeapTupleDataEntry { datum: datum.into_datum(), type_oid: T::type_oid() }, - ); + self.entries[ordinal - 1] = + SpiHeapTupleDataEntry { datum: datum.into_datum(), type_oid: T::type_oid() }; Ok(()) }