diff --git a/Cargo.toml b/Cargo.toml index 65dbf60d..5caa5f22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ pkg-config = "0.3.24" polars = "0.35.4" polars-core = "0.35.4" pretty_assertions = "1.4.0" +prettyplease = "0.2.20" proc-macro2 = "1.0.56" quote = "1.0.21" r2d2 = "0.8.9" diff --git a/crates/duckdb/Cargo.toml b/crates/duckdb/Cargo.toml index d2783176..b6cca439 100644 --- a/crates/duckdb/Cargo.toml +++ b/crates/duckdb/Cargo.toml @@ -19,6 +19,8 @@ name = "duckdb" [features] default = [] +# todo: remove this requirement on buildtime_bindgen going forward +loadable_extension = ["libduckdb-sys/loadable_extension", "buildtime_bindgen"] bundled = ["libduckdb-sys/bundled"] json = ["libduckdb-sys/json", "bundled"] parquet = ["libduckdb-sys/parquet", "bundled"] diff --git a/crates/libduckdb-sys/Cargo.toml b/crates/libduckdb-sys/Cargo.toml index d02c0ea0..4e5e933b 100644 --- a/crates/libduckdb-sys/Cargo.toml +++ b/crates/libduckdb-sys/Cargo.toml @@ -22,6 +22,8 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"] json = ["bundled"] parquet = ["bundled"] extensions-full = ["json", "parquet"] +loadable_extension = ["prettyplease", "quote", "syn"] + [dependencies] @@ -35,6 +37,9 @@ vcpkg = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tar = { workspace = true } +syn = { workspace = true, optional = true } +quote = { workspace = true, optional = true } +prettyplease = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["ffi"] } diff --git a/crates/libduckdb-sys/build.rs b/crates/libduckdb-sys/build.rs index 8a279e27..5596cc9a 100644 --- a/crates/libduckdb-sys/build.rs +++ b/crates/libduckdb-sys/build.rs @@ -94,7 +94,7 @@ mod build_bundled { #[cfg(feature = "buildtime_bindgen")] { use super::{bindings, HeaderLocation}; - let header = HeaderLocation::FromPath(format!("{out_dir}/{lib_name}/src/include/duckdb.h")); + let header = HeaderLocation::FromPath(format!("{out_dir}/{lib_name}/src/include")); bindings::write_to_out_dir(header, out_path); } #[cfg(not(feature = "buildtime_bindgen"))] @@ -178,11 +178,28 @@ impl From for String { let prefix = env_prefix(); let mut header = env::var(format!("{prefix}_INCLUDE_DIR")) .unwrap_or_else(|_| env::var(format!("{}_LIB_DIR", env_prefix())).unwrap()); - header.push_str("/duckdb.h"); + header.push_str(if cfg!(feature = "loadable_extension") { + "/duckdb_extension.h" + } else { + "/duckdb.h" + }); header } - HeaderLocation::Wrapper => "wrapper.h".into(), - HeaderLocation::FromPath(path) => path, + HeaderLocation::Wrapper => if cfg!(feature = "loadable_extension") { + "wrapper_ext.h" + } else { + "wrapper.h" + } + .into(), + HeaderLocation::FromPath(path) => format!( + "{}/{}", + path, + if cfg!(feature = "loadable_extension") { + "duckdb_extension.h" + } else { + "duckdb.h" + } + ), } } } @@ -266,7 +283,6 @@ mod build_linked { match pkg_config::Config::new().print_system_libs(false).probe(link_lib) { Ok(mut lib) => { if let Some(mut header) = lib.include_paths.pop() { - header.push("duckdb.h"); HeaderLocation::FromPath(header.to_string_lossy().into()) } else { HeaderLocation::Wrapper @@ -288,7 +304,6 @@ mod build_linked { // See if vcpkg can find it. if let Ok(mut lib) = vcpkg::Config::new().probe(lib_name()) { if let Some(mut header) = lib.include_paths.pop() { - header.push("duckdb.h"); return Some(HeaderLocation::FromPath(header.to_string_lossy().into())); } } @@ -308,15 +323,28 @@ mod bindings { pub fn write_to_out_dir(header: HeaderLocation, out_path: &Path) { let header: String = header.into(); let mut output = Vec::new(); - bindgen::builder() + let mut bindings = bindgen::builder() .trust_clang_mangling(false) .header(header.clone()) - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())); + + if cfg!(feature = "loadable_extension") { + bindings = bindings.ignore_functions(); // see generate_functions + } + + bindings .generate() .unwrap_or_else(|_| panic!("could not run bindgen on header {header}")) .write(Box::new(&mut output)) .expect("could not write output of bindgen"); - let output = String::from_utf8(output).expect("bindgen output was not UTF-8?!"); + + let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!"); + + #[cfg(feature = "loadable_extension")] + { + super::loadable_extension::generate_functions(&mut output); + } + let mut file = OpenOptions::new() .write(true) .truncate(true) @@ -328,3 +356,109 @@ mod bindings { .unwrap_or_else(|_| panic!("Could not write to {out_path:?}")); } } + +#[cfg(all(feature = "buildtime_bindgen", feature = "loadable_extension"))] +mod loadable_extension { + fn get_duckdb_api_routines(ast: syn::File) -> syn::ItemStruct { + ast.items + .into_iter() + .find_map(|i| { + if let syn::Item::Struct(s) = i { + if s.ident == "duckdb_ext_api_v0" { + Some(s) + } else { + None + } + } else { + None + } + }) + .expect("could not find duckdb_ext_api_v0") + } + + pub fn generate_functions(output: &mut String) { + let ast: syn::File = syn::parse_str(output).expect("could not parse bindgen output"); + let duckdb_api_routines = get_duckdb_api_routines(ast); + let duckdb_api_routines_ident = duckdb_api_routines.ident; + let p_api = quote::format_ident!("p_api"); + let mut stores = Vec::new(); + + for field in duckdb_api_routines.fields { + let ident = field.ident.expect("unnamed field"); + let span = ident.span(); + let name = ident.to_string(); + let ptr_name = syn::Ident::new(format!("__{}", name.to_uppercase()).as_ref(), span); + let duckdb_fn_name = syn::Ident::new(&name, span); + let method = extract_method(&field.ty).unwrap_or_else(|| panic!("unexpected type for {name}")); + let arg_names: syn::punctuated::Punctuated<&syn::Ident, syn::token::Comma> = + method.inputs.iter().map(|i| &i.name.as_ref().unwrap().0).collect(); + let args = &method.inputs; + let varargs = &method.variadic; + if varargs.is_some() { + continue; + } + + let ty = &method.output; + let tokens = { + quote::quote! { + static #ptr_name: ::std::sync::atomic::AtomicPtr<()> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()); + pub unsafe fn #duckdb_fn_name(#args) #ty { + let fun = { + let ptr = #ptr_name.load(::std::sync::atomic::Ordering::Acquire); + assert!(!ptr.is_null(), "DuckDB API not initialized or DuckDB feature omitted"); + let fun: unsafe extern "C" fn(#args #varargs) #ty = ::std::mem::transmute(ptr); + fun + }; + + (fun)(#arg_names) + } + } + }; + + output.push_str(&prettyplease::unparse( + &syn::parse2(tokens).expect("could not parse quote output"), + )); + + output.push('\n'); + let _ = &mut stores.push(quote::quote! { + if let Some(fun) = (*#p_api).#ident { + #ptr_name.store( + fun as usize as *mut (), + ::std::sync::atomic::Ordering::Release, + ); + } + }); + } + + // todo: check version number + let tokens = quote::quote! { + pub unsafe fn duckdb_extension_init(#p_api: *mut #duckdb_api_routines_ident) -> ::std::result::Result<(), ()> { + #(#stores)* + Ok(()) + } + }; + output.push_str(&prettyplease::unparse( + &syn::parse2(tokens).expect("could not parse quote output"), + )); + output.push('\n'); + } + + fn extract_method(ty: &syn::Type) -> Option<&syn::TypeBareFn> { + match ty { + syn::Type::Path(tp) => tp.path.segments.last(), + _ => None, + } + .map(|seg| match &seg.arguments { + syn::PathArguments::AngleBracketed(args) => args.args.first(), + _ => None, + })? + .map(|arg| match arg { + syn::GenericArgument::Type(t) => Some(t), + _ => None, + })? + .map(|ty| match ty { + syn::Type::BareFn(r) => Some(r), + _ => None, + })? + } +} diff --git a/crates/libduckdb-sys/wrapper_ext.h b/crates/libduckdb-sys/wrapper_ext.h new file mode 100644 index 00000000..525858d9 --- /dev/null +++ b/crates/libduckdb-sys/wrapper_ext.h @@ -0,0 +1 @@ +#include "duckdb/duckdb_extension.h" \ No newline at end of file