diff --git a/pgx-macros/src/lib.rs b/pgx-macros/src/lib.rs index c5e7b3e891..92179d9073 100644 --- a/pgx-macros/src/lib.rs +++ b/pgx-macros/src/lib.rs @@ -402,6 +402,7 @@ Optionally accepts the following attributes: * `parallel_unsafe`: Corresponds to [`PARALLEL UNSAFE`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `parallel_restricted`: Corresponds to [`PARALLEL RESTRICTED`](https://www.postgresql.org/docs/current/sql-createfunction.html). * `no_guard`: Do not use `#[pg_guard]` with the function. +* `sql`: Same arguments as `#[pgx(sql = ..)]` Functions can accept and return any type which `pgx` supports. `pgx` supports many PostgreSQL types by default. New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`]. @@ -919,7 +920,9 @@ enum DogNames { #[proc_macro_derive(PostgresEq, attributes(pgx))] pub fn postgres_eq(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - impl_postgres_eq(ast).into() + impl_postgres_eq(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -943,16 +946,9 @@ enum DogNames { #[proc_macro_derive(PostgresOrd, attributes(pgx))] pub fn postgres_ord(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - let to_sql_config = match sql_entity_graph::ToSqlConfig::from_attributes(ast.attrs.as_slice()) { - Err(e) => { - let msg = e.to_string(); - return TokenStream::from(quote! { - compile_error!(#msg); - }); - } - Ok(maybe_conf) => maybe_conf.unwrap_or_default(), - }; - impl_postgres_ord(ast, to_sql_config).into() + impl_postgres_ord(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** @@ -973,16 +969,9 @@ enum DogNames { #[proc_macro_derive(PostgresHash, attributes(pgx))] pub fn postgres_hash(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); - let to_sql_config = match sql_entity_graph::ToSqlConfig::from_attributes(ast.attrs.as_slice()) { - Err(e) => { - let msg = e.to_string(); - return TokenStream::from(quote! { - compile_error!(#msg); - }); - } - Ok(maybe_conf) => maybe_conf.unwrap_or_default(), - }; - impl_postgres_hash(ast, to_sql_config).into() + impl_postgres_hash(ast) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } /** diff --git a/pgx-macros/src/operators.rs b/pgx-macros/src/operators.rs index a46f67942a..fb17c44474 100644 --- a/pgx-macros/src/operators.rs +++ b/pgx-macros/src/operators.rs @@ -1,20 +1,18 @@ use pgx_utils::{operator_common::*, sql_entity_graph}; + use quote::ToTokens; use syn::DeriveInput; -pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(eq(&ast.ident)); stream.extend(ne(&ast.ident)); - stream + Ok(stream) } -pub(crate) fn impl_postgres_ord( - ast: DeriveInput, - to_sql_config: sql_entity_graph::ToSqlConfig, -) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(lt(&ast.ident)); @@ -23,24 +21,19 @@ pub(crate) fn impl_postgres_ord( stream.extend(ge(&ast.ident)); stream.extend(cmp(&ast.ident)); - let sql_graph_entity_item = - sql_entity_graph::PostgresOrd::new(ast.ident.clone(), to_sql_config); + let sql_graph_entity_item = sql_entity_graph::PostgresOrd::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } -pub(crate) fn impl_postgres_hash( - ast: DeriveInput, - to_sql_config: sql_entity_graph::ToSqlConfig, -) -> proc_macro2::TokenStream { +pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result { let mut stream = proc_macro2::TokenStream::new(); stream.extend(hash(&ast.ident)); - let sql_graph_entity_item = - sql_entity_graph::PostgresHash::new(ast.ident.clone(), to_sql_config); + let sql_graph_entity_item = sql_entity_graph::PostgresHash::from_derive_input(ast)?; sql_graph_entity_item.to_tokens(&mut stream); - stream + Ok(stream) } diff --git a/pgx-tests/src/tests/schema_tests.rs b/pgx-tests/src/tests/schema_tests.rs index 592963fe06..e345f86ba4 100644 --- a/pgx-tests/src/tests/schema_tests.rs +++ b/pgx-tests/src/tests/schema_tests.rs @@ -11,12 +11,10 @@ mod test_schema { #[pg_extern] fn func_in_diff_schema() {} - #[pg_extern] - #[pgx(sql = false)] + #[pg_extern(sql = false)] fn func_elided_from_schema() {} - #[pg_extern] - #[pgx(sql = "generate_function")] + #[pg_extern(sql = "generate_function")] fn func_generated_with_custom_sql() {} #[derive(Debug, PostgresType, Serialize, Deserialize)] diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs index 5e932acfe5..ada2c21cd2 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/attribute.rs @@ -1,4 +1,4 @@ -use crate::sql_entity_graph::PositioningRef; +use crate::sql_entity_graph::{PositioningRef, ToSqlConfig}; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::{ @@ -22,7 +22,17 @@ impl Parse for PgxAttributes { impl ToTokens for PgxAttributes { fn to_tokens(&self, tokens: &mut TokenStream2) { - let attrs = &self.attrs; + let attrs = self + .attrs + .iter() + .filter(|a| { + if let Attribute::Sql(_) = a { + false + } else { + true + } + }) + .collect::>(); let quoted = quote! { vec![#attrs] }; @@ -46,6 +56,7 @@ pub enum Attribute { Name(syn::LitStr), Cost(syn::Expr), Requires(Punctuated), + Sql(ToSqlConfig), } impl ToTokens for Attribute { @@ -85,6 +96,10 @@ impl ToTokens for Attribute { .collect::>(); quote! { pgx::datum::sql_entity_graph::ExternArgs::Requires(vec![#(#items_iter),*],) } } + // This attribute is handled separately + Attribute::Sql(_) => { + return; + } }; tokens.append_all(quoted); } @@ -129,6 +144,15 @@ impl Parse for Attribute { let _bracket = syn::bracketed!(content in input); Self::Requires(content.parse_terminated(PositioningRef::parse)?) } + "sql" => { + let _eq: Token![=] = input.parse()?; + if let Ok(b) = input.parse::() { + Self::Sql(ToSqlConfig::from(b.value)) + } else { + let sql = input.parse::()?; + Self::Sql(ToSqlConfig::from(sql)) + } + } _ => return Err(syn::Error::new(Span::call_site(), "Invalid option")), }; Ok(found) diff --git a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs index d4a1ade04f..5f30995820 100644 --- a/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs +++ b/pgx-utils/src/sql_entity_graph/pg_extern/mod.rs @@ -197,9 +197,18 @@ impl PgExtern { pub fn new(attr: TokenStream2, item: TokenStream2) -> Result { let attrs = syn::parse2::(attr.clone()).ok(); + let to_sql_config = attrs + .as_ref() + .and_then(|pgx_attrs| { + for a in pgx_attrs.attrs.iter() { + if let Attribute::Sql(config) = a { + return Some(config.clone()); + } + } + None + }) + .unwrap_or_default(); let func = syn::parse2::(item)?; - let to_sql_config = - ToSqlConfig::from_attributes(func.attrs.as_slice())?.unwrap_or_default(); Ok(Self { attrs, attr_tokens: attr, diff --git a/pgx-utils/src/sql_entity_graph/to_sql.rs b/pgx-utils/src/sql_entity_graph/to_sql.rs index d413bd816e..a8afa43d70 100644 --- a/pgx-utils/src/sql_entity_graph/to_sql.rs +++ b/pgx-utils/src/sql_entity_graph/to_sql.rs @@ -1,14 +1,42 @@ +use std::hash::Hash; + use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens, TokenStreamExt}; use syn::spanned::Spanned; use syn::{AttrStyle, Attribute, Lit, Meta, MetaList, MetaNameValue, NestedMeta}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ToSqlConfig { pub enabled: bool, pub callback: Option, pub content: Option, } +impl From for ToSqlConfig { + fn from(enabled: bool) -> Self { + Self { + enabled, + callback: None, + content: None, + } + } +} +impl From for ToSqlConfig { + fn from(content: syn::LitStr) -> Self { + if let Ok(path) = content.parse::() { + return Self { + enabled: true, + callback: Some(path), + content: None, + }; + } else { + return Self { + enabled: true, + callback: None, + content: Some(content), + }; + } + } +} impl Default for ToSqlConfig { fn default() -> Self { Self {