Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate tuple udf #686

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions pgx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn pg_guard(_attr: TokenStream, item: TokenStream) -> TokenStream {
// process top-level functions
// these functions get wrapped as public extern "C" functions with #[no_mangle] so they
// can also be called from C code
Item::Fn(func) => rewriter.item_fn(func, None, false, false, false).0.into(),
Item::Fn(func) => rewriter.item_fn(func, None, false, false, false).into(),
_ => {
panic!("#[pg_guard] can only be applied to extern \"C\" blocks and top-level functions")
}
Expand Down Expand Up @@ -589,30 +589,23 @@ fn rewrite_item_fn(
// make the function 'extern "C"' because this is for the #[pg_extern[ macro
func.sig.abi = Some(syn::parse_str("extern \"C\"").unwrap());
let func_span = func.span();
let (rewritten_func, need_wrapper) = rewriter.item_fn(
let rewritten_func = rewriter.item_fn(
func,
Some(sql_graph_entity_submission),
true,
is_raw,
no_guard,
);

if need_wrapper {
quote_spanned! {func_span=>
#[no_mangle]
#[doc(hidden)]
pub extern "C" fn #finfo_name() -> &'static pg_sys::Pg_finfo_record {
const V1_API: pg_sys::Pg_finfo_record = pg_sys::Pg_finfo_record { api_version: 1 };
&V1_API
}

#rewritten_func
quote_spanned! {func_span=>
#[no_mangle]
#[doc(hidden)]
pub extern "C" fn #finfo_name() -> &'static pg_sys::Pg_finfo_record {
const V1_API: pg_sys::Pg_finfo_record = pg_sys::Pg_finfo_record { api_version: 1 };
&V1_API
}
} else {
quote_spanned! {func_span=>

#rewritten_func
}
#rewritten_func
}
}

Expand Down
138 changes: 51 additions & 87 deletions pgx-utils/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,11 @@ impl PgGuardRewriter {
rewrite_args: bool,
is_raw: bool,
no_guard: bool,
) -> (proc_macro2::TokenStream, bool) {
) -> proc_macro2::TokenStream {
if rewrite_args {
self.item_fn_with_rewrite(func, entity_submission, is_raw, no_guard)
} else {
(
self.item_fn_without_rewrite(func, entity_submission, no_guard),
true,
)
self.item_fn_without_rewrite(func, entity_submission, no_guard)
}
}

Expand All @@ -63,7 +60,7 @@ impl PgGuardRewriter {
entity_submission: Option<&PgExtern>,
is_raw: bool,
no_guard: bool,
) -> (proc_macro2::TokenStream, bool) {
) -> proc_macro2::TokenStream {
// remember the original visibility and signature classifications as we want
// to use those for the outer function
let vis = func.vis.clone();
Expand Down Expand Up @@ -92,12 +89,26 @@ impl PgGuardRewriter {
Ident::new("result", Span::call_site())
};

let func_call = quote! {
let #result_var_name = {
#rewritten_args
let return_type_kind = categorize_return_type(&func);

// When returning a single tuple, rewrite the function to actually return a single-item iterator, so that
// the general-purpose "return a table" function can be used.
let func_call = if let CategorizedType::Tuple(ref _types) = return_type_kind {
quote! {
let #result_var_name = {
#rewritten_args

Some(#func_name(#arg_list)).into_iter()
};
}
} else {
quote! {
let #result_var_name = {
#rewritten_args

#func_name(#arg_list)
};
#func_name(#arg_list)
};
}
};

let prolog = quote! {
Expand All @@ -106,28 +117,20 @@ impl PgGuardRewriter {
#[doc(hidden)]
#[allow(unused_variables)]
};
match categorize_return_type(&func) {
CategorizedType::Default => (
PgGuardRewriter::impl_standard_udf(
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
rewritten_return_type,
entity_submission,
no_guard,
),
true,
),

CategorizedType::Tuple(_types) => (
PgGuardRewriter::impl_tuple_udf(func, entity_submission.clone()),
false,
match return_type_kind {
CategorizedType::Default => PgGuardRewriter::impl_standard_udf(
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
rewritten_return_type,
entity_submission,
no_guard,
),

CategorizedType::Iterator(types) if types.len() == 1 => (
CategorizedType::Iterator(types) if types.len() == 1 => {
PgGuardRewriter::impl_setof_srf(
types,
func_span,
Expand All @@ -138,11 +141,10 @@ impl PgGuardRewriter {
func_call,
entity_submission,
false,
),
true,
),
)
}

CategorizedType::OptionalIterator(types) if types.len() == 1 => (
CategorizedType::OptionalIterator(types) if types.len() == 1 => {
PgGuardRewriter::impl_setof_srf(
types,
func_span,
Expand All @@ -153,11 +155,10 @@ impl PgGuardRewriter {
func_call,
entity_submission,
true,
),
true,
),
)
}

CategorizedType::Iterator(types) => (
CategorizedType::Tuple(types) | CategorizedType::Iterator(types) => {
PgGuardRewriter::impl_table_srf(
types,
func_span,
Expand All @@ -168,22 +169,18 @@ impl PgGuardRewriter {
func_call,
entity_submission,
false,
),
true,
),
)
}

CategorizedType::OptionalIterator(types) => (
PgGuardRewriter::impl_table_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
entity_submission,
true,
),
CategorizedType::OptionalIterator(types) => PgGuardRewriter::impl_table_srf(
types,
func_span,
prolog,
vis,
func_name_wrapper,
generics,
func_call,
entity_submission,
true,
),
}
Expand Down Expand Up @@ -222,39 +219,6 @@ impl PgGuardRewriter {
}
}

fn impl_tuple_udf(
mut func: ItemFn,
entity_submission: Option<&PgExtern>,
) -> proc_macro2::TokenStream {
let func_span = func.span();
let return_type = func.sig.output;
let return_type = format!("{}", quote! {#return_type});
let return_type =
proc_macro2::TokenStream::from_str(return_type.trim_start_matches("->")).unwrap();
let return_type = quote! {impl std::iter::Iterator<Item = #return_type>};
let attrs = entity_submission
.unwrap()
.extern_attrs()
.iter()
.collect::<Punctuated<_, Token![,]>>();

func.sig.output = ReturnType::Default;
let sig = func.sig;
let body = func.block;

// We do **not** put an entity submission here as there still exists a `pg_extern` attribute.
//
// This is because we quietely rewrite the function signature to `Iterator<Item = T>` and
// rely on #[pg_extern] being called again during compilation. It is important that we
// include the original #[pg_extern(<attributes>)] in the generated code.
quote_spanned! {func_span=>
#[pg_extern(#attrs)]
#sig -> #return_type {
Some(#body).into_iter()
}
}
}

fn impl_setof_srf(
types: Vec<String>,
func_span: Span,
Expand Down