Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loadable extensions #372

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pkg-config = "0.3.24"
polars = "0.35.4"
polars-core = "0.35.4"
pretty_assertions = "1.4.0"
prettyplease = "0.2.20"
proc-macro2 = "1.0.56"
quote = "1.0.21"
r2d2 = "0.8.9"
Expand Down
2 changes: 2 additions & 0 deletions crates/duckdb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ name = "duckdb"

[features]
default = []
# todo: remove this requirement on buildtime_bindgen going forward
loadable_extension = ["libduckdb-sys/loadable_extension", "buildtime_bindgen"]
bundled = ["libduckdb-sys/bundled"]
json = ["libduckdb-sys/json", "bundled"]
parquet = ["libduckdb-sys/parquet", "bundled"]
Expand Down
5 changes: 5 additions & 0 deletions crates/libduckdb-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"]
json = ["bundled"]
parquet = ["bundled"]
extensions-full = ["json", "parquet"]
loadable_extension = ["prettyplease", "quote", "syn"]


[dependencies]

Expand All @@ -35,6 +37,9 @@ vcpkg = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tar = { workspace = true }
syn = { workspace = true, optional = true }
quote = { workspace = true, optional = true }
prettyplease = { workspace = true, optional = true }

[dev-dependencies]
arrow = { workspace = true, features = ["ffi"] }
152 changes: 143 additions & 9 deletions crates/libduckdb-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ mod build_bundled {
#[cfg(feature = "buildtime_bindgen")]
{
use super::{bindings, HeaderLocation};
let header = HeaderLocation::FromPath(format!("{out_dir}/{lib_name}/src/include/duckdb.h"));
let header = HeaderLocation::FromPath(format!("{out_dir}/{lib_name}/src/include"));
bindings::write_to_out_dir(header, out_path);
}
#[cfg(not(feature = "buildtime_bindgen"))]
Expand Down Expand Up @@ -178,11 +178,28 @@ impl From<HeaderLocation> for String {
let prefix = env_prefix();
let mut header = env::var(format!("{prefix}_INCLUDE_DIR"))
.unwrap_or_else(|_| env::var(format!("{}_LIB_DIR", env_prefix())).unwrap());
header.push_str("/duckdb.h");
header.push_str(if cfg!(feature = "loadable_extension") {
"/duckdb_extension.h"
} else {
"/duckdb.h"
});
header
}
HeaderLocation::Wrapper => "wrapper.h".into(),
HeaderLocation::FromPath(path) => path,
HeaderLocation::Wrapper => if cfg!(feature = "loadable_extension") {
"wrapper_ext.h"
} else {
"wrapper.h"
}
.into(),
HeaderLocation::FromPath(path) => format!(
"{}/{}",
path,
if cfg!(feature = "loadable_extension") {
"duckdb_extension.h"
} else {
"duckdb.h"
}
),
}
}
}
Expand Down Expand Up @@ -266,7 +283,6 @@ mod build_linked {
match pkg_config::Config::new().print_system_libs(false).probe(link_lib) {
Ok(mut lib) => {
if let Some(mut header) = lib.include_paths.pop() {
header.push("duckdb.h");
HeaderLocation::FromPath(header.to_string_lossy().into())
} else {
HeaderLocation::Wrapper
Expand All @@ -288,7 +304,6 @@ mod build_linked {
// See if vcpkg can find it.
if let Ok(mut lib) = vcpkg::Config::new().probe(lib_name()) {
if let Some(mut header) = lib.include_paths.pop() {
header.push("duckdb.h");
return Some(HeaderLocation::FromPath(header.to_string_lossy().into()));
}
}
Expand All @@ -308,15 +323,28 @@ mod bindings {
pub fn write_to_out_dir(header: HeaderLocation, out_path: &Path) {
let header: String = header.into();
let mut output = Vec::new();
bindgen::builder()
let mut bindings = bindgen::builder()
.trust_clang_mangling(false)
.header(header.clone())
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()));

if cfg!(feature = "loadable_extension") {
bindings = bindings.ignore_functions(); // see generate_functions
}

bindings
.generate()
.unwrap_or_else(|_| panic!("could not run bindgen on header {header}"))
.write(Box::new(&mut output))
.expect("could not write output of bindgen");
let output = String::from_utf8(output).expect("bindgen output was not UTF-8?!");

let mut output = String::from_utf8(output).expect("bindgen output was not UTF-8?!");

#[cfg(feature = "loadable_extension")]
{
super::loadable_extension::generate_functions(&mut output);
}

let mut file = OpenOptions::new()
.write(true)
.truncate(true)
Expand All @@ -328,3 +356,109 @@ mod bindings {
.unwrap_or_else(|_| panic!("Could not write to {out_path:?}"));
}
}

#[cfg(all(feature = "buildtime_bindgen", feature = "loadable_extension"))]
mod loadable_extension {
fn get_duckdb_api_routines(ast: syn::File) -> syn::ItemStruct {
ast.items
.into_iter()
.find_map(|i| {
if let syn::Item::Struct(s) = i {
if s.ident == "duckdb_ext_api_v0" {
Some(s)
} else {
None
}
} else {
None
}
})
.expect("could not find duckdb_ext_api_v0")
}

pub fn generate_functions(output: &mut String) {
let ast: syn::File = syn::parse_str(output).expect("could not parse bindgen output");
let duckdb_api_routines = get_duckdb_api_routines(ast);
let duckdb_api_routines_ident = duckdb_api_routines.ident;
let p_api = quote::format_ident!("p_api");
let mut stores = Vec::new();

for field in duckdb_api_routines.fields {
let ident = field.ident.expect("unnamed field");
let span = ident.span();
let name = ident.to_string();
let ptr_name = syn::Ident::new(format!("__{}", name.to_uppercase()).as_ref(), span);
let duckdb_fn_name = syn::Ident::new(&name, span);
let method = extract_method(&field.ty).unwrap_or_else(|| panic!("unexpected type for {name}"));
let arg_names: syn::punctuated::Punctuated<&syn::Ident, syn::token::Comma> =
method.inputs.iter().map(|i| &i.name.as_ref().unwrap().0).collect();
let args = &method.inputs;
let varargs = &method.variadic;
if varargs.is_some() {
continue;
}

let ty = &method.output;
let tokens = {
quote::quote! {
static #ptr_name: ::std::sync::atomic::AtomicPtr<()> = ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut());
pub unsafe fn #duckdb_fn_name(#args) #ty {
let fun = {
let ptr = #ptr_name.load(::std::sync::atomic::Ordering::Acquire);
assert!(!ptr.is_null(), "DuckDB API not initialized or DuckDB feature omitted");
let fun: unsafe extern "C" fn(#args #varargs) #ty = ::std::mem::transmute(ptr);
fun
};

(fun)(#arg_names)
}
}
};

output.push_str(&prettyplease::unparse(
&syn::parse2(tokens).expect("could not parse quote output"),
));

output.push('\n');
let _ = &mut stores.push(quote::quote! {
if let Some(fun) = (*#p_api).#ident {
#ptr_name.store(
fun as usize as *mut (),
::std::sync::atomic::Ordering::Release,
);
}
});
}

// todo: check version number
let tokens = quote::quote! {
pub unsafe fn duckdb_extension_init(#p_api: *mut #duckdb_api_routines_ident) -> ::std::result::Result<(), ()> {
#(#stores)*
Ok(())
}
};
output.push_str(&prettyplease::unparse(
&syn::parse2(tokens).expect("could not parse quote output"),
));
output.push('\n');
}

fn extract_method(ty: &syn::Type) -> Option<&syn::TypeBareFn> {
match ty {
syn::Type::Path(tp) => tp.path.segments.last(),
_ => None,
}
.map(|seg| match &seg.arguments {
syn::PathArguments::AngleBracketed(args) => args.args.first(),
_ => None,
})?
.map(|arg| match arg {
syn::GenericArgument::Type(t) => Some(t),
_ => None,
})?
.map(|ty| match ty {
syn::Type::BareFn(r) => Some(r),
_ => None,
})?
}
}
1 change: 1 addition & 0 deletions crates/libduckdb-sys/wrapper_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "duckdb/duckdb_extension.h"
Loading