Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewgapp committed Dec 1, 2024
1 parent c30bdae commit d48fbe0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
22 changes: 11 additions & 11 deletions crates/duckdb/src/vtab/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ impl From<duckdb_function_info> for ScalarFunctionInfo {

impl ScalarFunctionInfo {
pub unsafe fn get_scalar_extra_info<T>(&self) -> &T {
&*(duckdb_scalar_function_get_extra_info(self.0) as *const T)
&*(duckdb_scalar_function_get_extra_info(self.0).cast())
}

pub unsafe fn set_error(&self, error: &str) {
Expand Down Expand Up @@ -471,6 +471,16 @@ impl Drop for ScalarFunction {
use libduckdb_sys as ffi;

impl ScalarFunction {
/// Creates a new empty scalar function.
pub fn new(name: impl Into<String>) -> Result<Self, Error> {
let name: String = name.into();
let f_ptr = unsafe { duckdb_create_scalar_function() };
let c_name = CString::new(name).expect("name should contain valid utf-8");
unsafe { duckdb_scalar_function_set_name(f_ptr, c_name.as_ptr()) };

Ok(Self { ptr: f_ptr })
}

/// Adds a parameter to the scalar function.
///
/// # Arguments
Expand Down Expand Up @@ -514,16 +524,6 @@ impl ScalarFunction {
self
}

/// Creates a new empty scalar function.
pub fn new(name: impl Into<String>) -> Result<Self, Error> {
let name: String = name.into();
let f_ptr = unsafe { duckdb_create_scalar_function() };
let c_name = CString::new(name).expect("name should contain valid utf-8");
unsafe { duckdb_scalar_function_set_name(f_ptr, c_name.as_ptr()) };

Ok(Self { ptr: f_ptr })
}

/// Assigns extra information to the scalar function that can be fetched during binding, etc.
///
/// # Arguments
Expand Down
23 changes: 19 additions & 4 deletions crates/duckdb/src/vtab/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl Connection {
let scalar_function = ScalarFunction::new(name)?;
signature.register_with_scalar(&scalar_function);
scalar_function.set_function(Some(scalar_func::<S>));
// scalar_function.set_extra_info::<S::State>();
scalar_function.set_extra_info::<S::State>();
set.add_function(scalar_function)?;
}
self.db.borrow_mut().register_scalar_function_set(set)
Expand Down Expand Up @@ -584,7 +584,9 @@ mod test {
impl ArrowScalar for ArrowOverloaded {
type State = MockState;

fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
fn invoke(s: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
assert_eq!("some meta", s.info);

let a = input.column(0);
let b = input.column(1);

Expand Down Expand Up @@ -637,16 +639,29 @@ mod test {
}
}

#[derive(Debug)]
struct TestState {
#[allow(dead_code)]
inner: i32,
}

impl Default for TestState {
fn default() -> Self {
TestState { inner: 42 }
}
}

struct EchoScalar {}

impl VScalar for EchoScalar {
type State = ();
type State = TestState;

unsafe fn invoke(
_: &Self::State,
s: &Self::State,
input: &mut DataChunkHandle,
output: &mut dyn WritableVector,
) -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(s.inner, 42);
let values = input.flat_vector(0);
let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
let strings = values
Expand Down

0 comments on commit d48fbe0

Please sign in to comment.