Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make #[derive(PostgresType)] impl its own FromDatum #1381

1 change: 1 addition & 0 deletions pgrx-examples/custom_types/src/fixed_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pgrx::{opname, pg_operator, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[bikeshed_postgres_type_manually_impl_from_into_datum]
#[pgvarlena_inoutfuncs]
pub struct FixedF32Array {
array: [f32; 91],
Expand Down
76 changes: 69 additions & 7 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use syn::spanned::Spanned;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl};

use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord};
use pgrx_sql_entity_graph::{
use pgrx_sql_entity_graph as sql_gen;
use sql_gen::{
parse_extern_attributes, CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs,
PgAggregate, PgExtern, PostgresEnum, PostgresType, Schema,
PgAggregate, PgExtern, PostgresEnum, Schema,
};

use crate::rewriter::PgGuardRewriter;
Expand Down Expand Up @@ -709,7 +710,16 @@ Optionally accepts the following attributes:
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgrx))]
#[proc_macro_derive(
PostgresType,
attributes(
inoutfuncs,
pgvarlena_inoutfuncs,
bikeshed_postgres_type_manually_impl_from_into_datum,
requires,
pgrx
)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -752,10 +762,60 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
};

// all #[derive(PostgresType)] need to implement that trait
// and also the FromDatum and IntoDatum
stream.extend(quote! {
impl #generics ::pgrx::PostgresType for #name #generics { }
impl #generics ::pgrx::datum::PostgresType for #name #generics { }
});

if !args.contains(&PostgresTypeAttribute::ManualFromIntoDatum) {
stream.extend(
quote! {
impl #generics ::pgrx::datum::IntoDatum for #name #generics {
fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
#[allow(deprecated)]
Some(unsafe { ::pgrx::cbor_encode(&self) }.into())
}

fn type_oid() -> ::pgrx::pg_sys::Oid {
::pgrx::wrappers::rust_regtypein::<Self>()
}
}

impl #generics ::pgrx::datum::FromDatum for #name #generics {
unsafe fn from_polymorphic_datum(
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
#[allow(deprecated)]
::pgrx::cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
mut memory_context: ::pgrx::memcxt::PgMemoryContexts,
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
// this gets the varlena Datum copied into this memory context
let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
Self::from_datum(varlena.into(), is_null)
})
}
}
}
}
)
}

// and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait
// which implements _in and _out #[pg_extern] functions that just return the type itself
if args.contains(&PostgresTypeAttribute::Default) {
Expand Down Expand Up @@ -834,7 +894,7 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
});
}

let sql_graph_entity_item = PostgresType::from_derive_input(ast)?;
let sql_graph_entity_item = sql_gen::PostgresTypeDerive::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

Ok(stream)
Expand Down Expand Up @@ -933,6 +993,7 @@ enum PostgresTypeAttribute {
InOutFuncs,
PgVarlenaInOutFuncs,
Default,
ManualFromIntoDatum,
}

fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAttribute> {
Expand All @@ -945,11 +1006,12 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAtt
"inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::InOutFuncs);
}

"pgvarlena_inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs);
}

"bikeshed_postgres_type_manually_impl_from_into_datum" => {
categorized_attributes.insert(PostgresTypeAttribute::ManualFromIntoDatum);
}
_ => {
// we can just ignore attributes we don't understand
}
Expand Down
2 changes: 1 addition & 1 deletion pgrx-sql-entity-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use postgres_hash::PostgresHash;
pub use postgres_ord::entity::PostgresOrdEntity;
pub use postgres_ord::PostgresOrd;
pub use postgres_type::entity::PostgresTypeEntity;
pub use postgres_type::PostgresType;
pub use postgres_type::PostgresTypeDerive;
pub use schema::entity::SchemaEntity;
pub use schema::Schema;
pub use to_sql::entity::ToSqlConfigEntity;
Expand Down
27 changes: 12 additions & 15 deletions pgrx-sql-entity-graph/src/postgres_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ use crate::{CodeEnrichment, ToSqlConfig};
/// ```rust
/// use syn::{Macro, parse::Parse, parse_quote, parse};
/// use quote::{quote, ToTokens};
/// use pgrx_sql_entity_graph::PostgresType;
/// use pgrx_sql_entity_graph::PostgresTypeDerive;
///
/// # fn main() -> eyre::Result<()> {
/// use pgrx_sql_entity_graph::CodeEnrichment;
/// let parsed: CodeEnrichment<PostgresType> = parse_quote! {
/// let parsed: CodeEnrichment<PostgresTypeDerive> = parse_quote! {
/// #[derive(PostgresType)]
/// struct Example<'a> {
/// demo: &'a str,
Expand All @@ -49,15 +49,15 @@ use crate::{CodeEnrichment, ToSqlConfig};
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct PostgresType {
pub struct PostgresTypeDerive {
name: Ident,
generics: Generics,
in_fn: Ident,
out_fn: Ident,
to_sql_config: ToSqlConfig,
}

impl PostgresType {
impl PostgresTypeDerive {
pub fn new(
name: Ident,
generics: Generics,
Expand Down Expand Up @@ -100,7 +100,7 @@ impl PostgresType {
}
}

impl ToEntityGraphTokens for PostgresType {
impl ToEntityGraphTokens for PostgresTypeDerive {
fn to_entity_graph_tokens(&self) -> TokenStream2 {
let name = &self.name;
let mut static_generics = self.generics.clone();
Expand Down Expand Up @@ -211,17 +211,14 @@ impl ToEntityGraphTokens for PostgresType {
}
}

impl ToRustCodeTokens for PostgresType {}
impl ToRustCodeTokens for PostgresTypeDerive {}

impl Parse for CodeEnrichment<PostgresType> {
impl Parse for CodeEnrichment<PostgresTypeDerive> {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
let parsed: ItemStruct = input.parse()?;
let to_sql_config =
ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default();
let funcname_in =
Ident::new(&format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span());
let funcname_out =
Ident::new(&format!("{}_out", parsed.ident).to_lowercase(), parsed.ident.span());
PostgresType::new(parsed.ident, parsed.generics, funcname_in, funcname_out, to_sql_config)
let ItemStruct { attrs, ident, generics, .. } = input.parse()?;
let to_sql_config = ToSqlConfig::from_attributes(attrs.as_slice())?.unwrap_or_default();
let in_fn = Ident::new(&format!("{}_in", ident).to_lowercase(), ident.span());
let out_fn = Ident::new(&format!("{}_out", ident).to_lowercase(), ident.span());
PostgresTypeDerive::new(ident, generics, in_fn, out_fn, to_sql_config)
}
}
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/postgres_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use pgrx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub struct VarlenaType {
a: f32,
Expand All @@ -38,7 +38,7 @@ impl PgVarlenaInOutFuncs for VarlenaType {
}
}

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub enum VarlenaEnumType {
A,
Expand Down
86 changes: 8 additions & 78 deletions pgrx/src/datum/varlena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT};
use crate::{
pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType,
StringInfo,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, StringInfo,
};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
Expand Down Expand Up @@ -60,8 +59,9 @@ impl Clone for PallocdVarlena {
/// use std::str::FromStr;
///
/// use pgrx::prelude::*;
/// use serde::{Serialize, Deserialize};
///
/// #[derive(Copy, Clone, PostgresType)]
/// #[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
/// #[pgvarlena_inoutfuncs]
/// struct MyType {
/// a: f32,
Expand Down Expand Up @@ -378,50 +378,8 @@ where
}
}

impl<T> IntoDatum for T
where
T: PostgresType + Serialize,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(cbor_encode(&self).into())
}

fn type_oid() -> pg_sys::Oid {
crate::rust_regtypein::<T>()
}
}

impl<'de, T> FromDatum for T
where
T: PostgresType + Deserialize<'de>,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode_into_context(memory_context, datum.cast_mut_ptr())
}
}
}

fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
#[doc(hidden)]
pub unsafe fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
Expand All @@ -439,6 +397,7 @@ where
varlena as *const pg_sys::varlena
}

#[doc(hidden)]
pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
Expand All @@ -450,6 +409,8 @@ where
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}

#[doc(hidden)]
#[deprecated(since = "0.12.0", note = "just use the FromDatum impl")]
pub unsafe fn cbor_decode_into_context<'de, T>(
mut memory_context: PgMemoryContexts,
input: *mut pg_sys::varlena,
Expand All @@ -464,37 +425,6 @@ where
})
}

#[allow(dead_code)]
fn json_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();

serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header
serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON");

let size = serialized.len();
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}

varlena as *const pg_sys::varlena
}

#[allow(dead_code)]
unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_json::from_slice(slice).expect("failed to decode JSON")
}

workingjubilee marked this conversation as resolved.
Show resolved Hide resolved
unsafe impl<T> SqlTranslatable for PgVarlena<T>
where
T: SqlTranslatable + Copy,
Expand Down