From 2df66d8e9074c84ca8b0a70a0cbfc2710daba17e Mon Sep 17 00:00:00 2001 From: Jubilee <46493976+workingjubilee@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:17:25 -0800 Subject: [PATCH] Make `#[derive(PostgresType)]` impl its own FromDatum (#1381) Mainly, this removes a source of persistent and confounding type errors because of the generic blanket impl of FromDatum for all T that fulfill so-and-so bounds. These may mislead one, if one is writing code generic over FromDatum, to imagine that one needs a Serialize and Deserialize impl or bound for a given case, even when those are _not_ required. By moving these requirements onto the type that derives, this moves any confusion to the specific cases it actually applies to. This has a regrettable effect that now PostgresType _requires_ a Serialize and Deserialize impl in order to work, _unless_ one uses the hacky `#[bikeshed_postgres_type_manually_impl_from_into_datum]` attribute, which I intend to rename or otherwise fix up before pgrx reaches its 0.12.0 release. --- pgrx-examples/custom_types/src/fixed_size.rs | 1 + pgrx-macros/src/lib.rs | 76 ++++++++++++++-- pgrx-sql-entity-graph/src/lib.rs | 2 +- .../src/postgres_type/mod.rs | 27 +++--- pgrx-tests/src/tests/postgres_type_tests.rs | 4 +- pgrx/src/datum/varlena.rs | 86 ++----------------- 6 files changed, 93 insertions(+), 103 deletions(-) diff --git a/pgrx-examples/custom_types/src/fixed_size.rs b/pgrx-examples/custom_types/src/fixed_size.rs index 678411dde..a78c3dc7e 100644 --- a/pgrx-examples/custom_types/src/fixed_size.rs +++ b/pgrx-examples/custom_types/src/fixed_size.rs @@ -13,6 +13,7 @@ use pgrx::{opname, pg_operator, PgVarlena, PgVarlenaInOutFuncs, StringInfo}; use std::str::FromStr; #[derive(Copy, Clone, PostgresType)] +#[bikeshed_postgres_type_manually_impl_from_into_datum] #[pgvarlena_inoutfuncs] pub struct FixedF32Array { array: [f32; 91], diff --git a/pgrx-macros/src/lib.rs b/pgrx-macros/src/lib.rs index 8e2ccb244..1e5975db2 100644 --- a/pgrx-macros/src/lib.rs +++ b/pgrx-macros/src/lib.rs @@ -18,9 +18,10 @@ use syn::spanned::Spanned; use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl}; use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord}; -use pgrx_sql_entity_graph::{ +use pgrx_sql_entity_graph as sql_gen; +use sql_gen::{ parse_extern_attributes, CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs, - PgAggregate, PgExtern, PostgresEnum, PostgresType, Schema, + PgAggregate, PgExtern, PostgresEnum, Schema, }; use crate::rewriter::PgGuardRewriter; @@ -709,7 +710,16 @@ Optionally accepts the following attributes: * `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type. * `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx). */ -#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgrx))] +#[proc_macro_derive( + PostgresType, + attributes( + inoutfuncs, + pgvarlena_inoutfuncs, + bikeshed_postgres_type_manually_impl_from_into_datum, + requires, + pgrx + ) +)] pub fn postgres_type(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -752,10 +762,60 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result }; // all #[derive(PostgresType)] need to implement that trait + // and also the FromDatum and IntoDatum stream.extend(quote! { - impl #generics ::pgrx::PostgresType for #name #generics { } + impl #generics ::pgrx::datum::PostgresType for #name #generics { } }); + if !args.contains(&PostgresTypeAttribute::ManualFromIntoDatum) { + stream.extend( + quote! { + impl #generics ::pgrx::datum::IntoDatum for #name #generics { + fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> { + #[allow(deprecated)] + Some(unsafe { ::pgrx::cbor_encode(&self) }.into()) + } + + fn type_oid() -> ::pgrx::pg_sys::Oid { + ::pgrx::wrappers::rust_regtypein::() + } + } + + impl #generics ::pgrx::datum::FromDatum for #name #generics { + unsafe fn from_polymorphic_datum( + datum: ::pgrx::pg_sys::Datum, + is_null: bool, + _typoid: ::pgrx::pg_sys::Oid, + ) -> Option { + if is_null { + None + } else { + #[allow(deprecated)] + ::pgrx::cbor_decode(datum.cast_mut_ptr()) + } + } + + unsafe fn from_datum_in_memory_context( + mut memory_context: ::pgrx::memcxt::PgMemoryContexts, + datum: ::pgrx::pg_sys::Datum, + is_null: bool, + _typoid: ::pgrx::pg_sys::Oid, + ) -> Option { + if is_null { + None + } else { + memory_context.switch_to(|_| { + // this gets the varlena Datum copied into this memory context + let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr()); + Self::from_datum(varlena.into(), is_null) + }) + } + } + } + } + ) + } + // and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait // which implements _in and _out #[pg_extern] functions that just return the type itself if args.contains(&PostgresTypeAttribute::Default) { @@ -834,7 +894,7 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result }); } - let sql_graph_entity_item = PostgresType::from_derive_input(ast)?; + let sql_graph_entity_item = sql_gen::PostgresTypeDerive::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); Ok(stream) @@ -933,6 +993,7 @@ enum PostgresTypeAttribute { InOutFuncs, PgVarlenaInOutFuncs, Default, + ManualFromIntoDatum, } fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet { @@ -945,11 +1006,12 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet { categorized_attributes.insert(PostgresTypeAttribute::InOutFuncs); } - "pgvarlena_inoutfuncs" => { categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs); } - + "bikeshed_postgres_type_manually_impl_from_into_datum" => { + categorized_attributes.insert(PostgresTypeAttribute::ManualFromIntoDatum); + } _ => { // we can just ignore attributes we don't understand } diff --git a/pgrx-sql-entity-graph/src/lib.rs b/pgrx-sql-entity-graph/src/lib.rs index 69a4d29fd..b30f288ce 100644 --- a/pgrx-sql-entity-graph/src/lib.rs +++ b/pgrx-sql-entity-graph/src/lib.rs @@ -42,7 +42,7 @@ pub use postgres_hash::PostgresHash; pub use postgres_ord::entity::PostgresOrdEntity; pub use postgres_ord::PostgresOrd; pub use postgres_type::entity::PostgresTypeEntity; -pub use postgres_type::PostgresType; +pub use postgres_type::PostgresTypeDerive; pub use schema::entity::SchemaEntity; pub use schema::Schema; pub use to_sql::entity::ToSqlConfigEntity; diff --git a/pgrx-sql-entity-graph/src/postgres_type/mod.rs b/pgrx-sql-entity-graph/src/postgres_type/mod.rs index d982e2b33..da3ad8765 100644 --- a/pgrx-sql-entity-graph/src/postgres_type/mod.rs +++ b/pgrx-sql-entity-graph/src/postgres_type/mod.rs @@ -34,11 +34,11 @@ use crate::{CodeEnrichment, ToSqlConfig}; /// ```rust /// use syn::{Macro, parse::Parse, parse_quote, parse}; /// use quote::{quote, ToTokens}; -/// use pgrx_sql_entity_graph::PostgresType; +/// use pgrx_sql_entity_graph::PostgresTypeDerive; /// /// # fn main() -> eyre::Result<()> { /// use pgrx_sql_entity_graph::CodeEnrichment; -/// let parsed: CodeEnrichment = parse_quote! { +/// let parsed: CodeEnrichment = parse_quote! { /// #[derive(PostgresType)] /// struct Example<'a> { /// demo: &'a str, @@ -49,7 +49,7 @@ use crate::{CodeEnrichment, ToSqlConfig}; /// # } /// ``` #[derive(Debug, Clone)] -pub struct PostgresType { +pub struct PostgresTypeDerive { name: Ident, generics: Generics, in_fn: Ident, @@ -57,7 +57,7 @@ pub struct PostgresType { to_sql_config: ToSqlConfig, } -impl PostgresType { +impl PostgresTypeDerive { pub fn new( name: Ident, generics: Generics, @@ -100,7 +100,7 @@ impl PostgresType { } } -impl ToEntityGraphTokens for PostgresType { +impl ToEntityGraphTokens for PostgresTypeDerive { fn to_entity_graph_tokens(&self) -> TokenStream2 { let name = &self.name; let mut static_generics = self.generics.clone(); @@ -211,17 +211,14 @@ impl ToEntityGraphTokens for PostgresType { } } -impl ToRustCodeTokens for PostgresType {} +impl ToRustCodeTokens for PostgresTypeDerive {} -impl Parse for CodeEnrichment { +impl Parse for CodeEnrichment { fn parse(input: ParseStream) -> Result { - let parsed: ItemStruct = input.parse()?; - let to_sql_config = - ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default(); - let funcname_in = - Ident::new(&format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span()); - let funcname_out = - Ident::new(&format!("{}_out", parsed.ident).to_lowercase(), parsed.ident.span()); - PostgresType::new(parsed.ident, parsed.generics, funcname_in, funcname_out, to_sql_config) + let ItemStruct { attrs, ident, generics, .. } = input.parse()?; + let to_sql_config = ToSqlConfig::from_attributes(attrs.as_slice())?.unwrap_or_default(); + let in_fn = Ident::new(&format!("{}_in", ident).to_lowercase(), ident.span()); + let out_fn = Ident::new(&format!("{}_out", ident).to_lowercase(), ident.span()); + PostgresTypeDerive::new(ident, generics, in_fn, out_fn, to_sql_config) } } diff --git a/pgrx-tests/src/tests/postgres_type_tests.rs b/pgrx-tests/src/tests/postgres_type_tests.rs index 56262a5cd..46d25e4c7 100644 --- a/pgrx-tests/src/tests/postgres_type_tests.rs +++ b/pgrx-tests/src/tests/postgres_type_tests.rs @@ -13,7 +13,7 @@ use pgrx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo}; use serde::{Deserialize, Serialize}; use std::str::FromStr; -#[derive(Copy, Clone, PostgresType)] +#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)] #[pgvarlena_inoutfuncs] pub struct VarlenaType { a: f32, @@ -38,7 +38,7 @@ impl PgVarlenaInOutFuncs for VarlenaType { } } -#[derive(Copy, Clone, PostgresType)] +#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)] #[pgvarlena_inoutfuncs] pub enum VarlenaEnumType { A, diff --git a/pgrx/src/datum/varlena.rs b/pgrx/src/datum/varlena.rs index a50f8401e..6af244da8 100644 --- a/pgrx/src/datum/varlena.rs +++ b/pgrx/src/datum/varlena.rs @@ -11,8 +11,7 @@ use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT}; use crate::{ pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any, - varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType, - StringInfo, + varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, StringInfo, }; use pgrx_sql_entity_graph::metadata::{ ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable, @@ -60,8 +59,9 @@ impl Clone for PallocdVarlena { /// use std::str::FromStr; /// /// use pgrx::prelude::*; +/// use serde::{Serialize, Deserialize}; /// -/// #[derive(Copy, Clone, PostgresType)] +/// #[derive(Copy, Clone, PostgresType, Serialize, Deserialize)] /// #[pgvarlena_inoutfuncs] /// struct MyType { /// a: f32, @@ -378,50 +378,8 @@ where } } -impl IntoDatum for T -where - T: PostgresType + Serialize, -{ - fn into_datum(self) -> Option { - Some(cbor_encode(&self).into()) - } - - fn type_oid() -> pg_sys::Oid { - crate::rust_regtypein::() - } -} - -impl<'de, T> FromDatum for T -where - T: PostgresType + Deserialize<'de>, -{ - unsafe fn from_polymorphic_datum( - datum: pg_sys::Datum, - is_null: bool, - _typoid: pg_sys::Oid, - ) -> Option { - if is_null { - None - } else { - cbor_decode(datum.cast_mut_ptr()) - } - } - - unsafe fn from_datum_in_memory_context( - memory_context: PgMemoryContexts, - datum: pg_sys::Datum, - is_null: bool, - _typoid: pg_sys::Oid, - ) -> Option { - if is_null { - None - } else { - cbor_decode_into_context(memory_context, datum.cast_mut_ptr()) - } - } -} - -fn cbor_encode(input: T) -> *const pg_sys::varlena +#[doc(hidden)] +pub unsafe fn cbor_encode(input: T) -> *const pg_sys::varlena where T: Serialize, { @@ -439,6 +397,7 @@ where varlena as *const pg_sys::varlena } +#[doc(hidden)] pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T where T: Deserialize<'de>, @@ -450,6 +409,8 @@ where serde_cbor::from_slice(slice).expect("failed to decode CBOR") } +#[doc(hidden)] +#[deprecated(since = "0.12.0", note = "just use the FromDatum impl")] pub unsafe fn cbor_decode_into_context<'de, T>( mut memory_context: PgMemoryContexts, input: *mut pg_sys::varlena, @@ -464,37 +425,6 @@ where }) } -#[allow(dead_code)] -fn json_encode(input: T) -> *const pg_sys::varlena -where - T: Serialize, -{ - let mut serialized = StringInfo::new(); - - serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header - serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON"); - - let size = serialized.len(); - let varlena = serialized.into_char_ptr(); - unsafe { - set_varsize(varlena as *mut pg_sys::varlena, size as i32); - } - - varlena as *const pg_sys::varlena -} - -#[allow(dead_code)] -unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T -where - T: Deserialize<'de>, -{ - let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena); - let len = varsize_any_exhdr(varlena); - let data = vardata_any(varlena); - let slice = std::slice::from_raw_parts(data as *const u8, len); - serde_json::from_slice(slice).expect("failed to decode JSON") -} - unsafe impl SqlTranslatable for PgVarlena where T: SqlTranslatable + Copy,