Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support User Defined Table Function
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <[email protected]>
Veeupup committed Nov 25, 2023
1 parent 393e48f commit 32ecce4
Showing 6 changed files with 381 additions and 20 deletions.
224 changes: 224 additions & 0 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::csv::reader::Format;
use arrow::csv::ReaderBuilder;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::streaming::StreamingTable;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};
use std::fs::File;
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;

// To define your own table function, you only need to do the following 3 things:
// 1. Implement your own TableProvider
// 2. Implement your own TableFunctionImpl and return your TableProvider
// 3. Register the function using ctx.register_udtf

/// This example demonstrates how to register a TableFunction
#[tokio::main]
async fn main() -> Result<()> {
// create local execution context
let ctx = SessionContext::new();

ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));
ctx.register_udtf("read_csv_stream", Arc::new(LocalStreamCsvTable {}));

let testdata = datafusion::test_util::arrow_test_data();
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");

// run it with println now()
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}', now());").as_str())
.await?;
df.show().await?;

// just run
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
.await?;
df.show().await?;

// stream csv table
let df2 = ctx
.sql(format!("SELECT * FROM read_csv_stream('{csv_file}');").as_str())
.await?;
df2.show().await?;

Ok(())
}

// Option1: (full implmentation of a TableProvider)
struct LocalCsvTable {
schema: SchemaRef,
exprs: Vec<Expr>,
batches: Vec<RecordBatch>,
}

#[async_trait]
impl TableProvider for LocalCsvTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
if !self.exprs.is_empty() {
self.interpreter_expr(state).await?;
}
Ok(Arc::new(MemoryExec::try_new(
&[self.batches.clone()],
TableProvider::schema(self),
projection.cloned(),
)?))
}
}

impl LocalCsvTable {
// TODO(veeupup): maybe we can make interpreter Expr this more simpler for users
// TODO(veeupup): maybe we can support more type of exprs
async fn interpreter_expr(&self, state: &SessionState) -> Result<()> {
use datafusion::logical_expr::expr_rewriter::normalize_col;
use datafusion::logical_expr::utils::columnize_expr;
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(DFSchema::empty()),
});
let logical_plan = Projection::try_new(
vec![columnize_expr(
normalize_col(self.exprs[0].clone(), &plan)?,
plan.schema(),
)],
Arc::new(plan),
)
.map(LogicalPlan::Projection)?;
let rbs = collect(
state.create_physical_plan(&logical_plan).await?,
Arc::new(TaskContext::from(state)),
)
.await?;
println!("time now: {:?}", rbs[0].column(0));
Ok(())
}
}

struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let mut new_exprs = vec![];
let mut filepath = String::new();
for expr in exprs {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => {
filepath = path.clone()
}
expr => new_exprs.push(expr.clone()),
}
}
let (schema, batches) = read_csv_batches(filepath)?;
let table = LocalCsvTable {
schema,
exprs: new_exprs.clone(),
batches,
};
Ok(Arc::new(table))
}
}

// Option2: (use StreamingTable to make it simpler)
// Implement PartitionStream and Use StreamTable to return streaming table
impl PartitionStream for LocalCsvTable {
fn schema(&self) -> &SchemaRef {
&self.schema
}

fn execute(
&self,
_ctx: Arc<datafusion::execution::TaskContext>,
) -> datafusion::physical_plan::SendableRecordBatchStream {
Box::pin(RecordBatchStreamAdapter::new(
self.schema.clone(),
// You can even read data from network or else anywhere, using async is also ok
// In Fact, you can even implement your own SendableRecordBatchStream
// by implementing Stream<Item = ArrowResult<RecordBatch>> + Send + Sync + 'static
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
))
}
}

struct LocalStreamCsvTable {}

impl TableFunctionImpl for LocalStreamCsvTable {
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let filepath = match args[0] {
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => path.clone(),
_ => unimplemented!(),
};
let (schema, batches) = read_csv_batches(filepath)?;
let stream = LocalCsvTable {
schema: schema.clone(),
batches,
exprs: vec![],
};
let table = StreamingTable::try_new(schema, vec![Arc::new(stream)])?;
Ok(Arc::new(table))
}
}

fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
let mut file = File::open(csv_path)?;
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
file.rewind()?;

let reader = ReaderBuilder::new(Arc::new(schema.clone()))
.with_header(true)
.build(file)?;
let mut batches = vec![];
for bacth in reader {
batches.push(bacth?);
}
let schema = Arc::new(schema);
Ok((schema, batches))
}
56 changes: 56 additions & 0 deletions datafusion/core/src/datasource/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! A table that uses a function to generate data
use super::TableProvider;

use datafusion_common::Result;
use datafusion_expr::Expr;

use std::sync::Arc;

/// A trait for table function implementations
pub trait TableFunctionImpl: Sync + Send {
/// Create a table provider
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
}

/// A table that uses a function to generate data
pub struct TableFunction {
/// Name of the table function
name: String,
/// Function implementation
fun: Arc<dyn TableFunctionImpl>,
}

impl TableFunction {
/// Create a new table function
pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
Self { name, fun }
}

/// Get the name of the table function
pub fn name(&self) -> String {
self.name.clone()
}

/// Get the function implementation and generate a table
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
self.fun.call(args)
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/mod.rs
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ pub mod avro_to_arrow;
pub mod default_table_source;
pub mod empty;
pub mod file_format;
pub mod function;
pub mod listing;
pub mod listing_table_factory;
pub mod memory;
30 changes: 29 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ mod parquet;
use crate::{
catalog::{CatalogList, MemoryCatalogList},
datasource::{
function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable},
provider::TableProviderFactory,
},
@@ -42,7 +43,7 @@ use datafusion_common::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -795,6 +796,14 @@ impl SessionContext {
.add_var_provider(variable_type, provider);
}

/// Register a table UDF with this context
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
self.state.write().table_functions.insert(
name.to_owned(),
Arc::new(TableFunction::new(name.to_owned(), fun)),
);
}

/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
@@ -1224,6 +1233,8 @@ pub struct SessionState {
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
/// Collection of catalogs containing schemas and ultimately TableProviders
catalog_list: Arc<dyn CatalogList>,
/// Table Functions
table_functions: HashMap<String, Arc<TableFunction>>,
/// Scalar functions that are registered with the context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
@@ -1322,6 +1333,7 @@ impl SessionState {
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
table_functions: HashMap::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
@@ -1860,6 +1872,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
}

fn get_table_function_source(
&self,
name: &str,
args: Vec<Expr>,
) -> Result<Arc<dyn TableSource>> {
let tbl_func = self
.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let provider = tbl_func.create_table_provider(&args)?;

Ok(provider_as_source(provider))
}

fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
9 changes: 9 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
@@ -51,6 +51,15 @@ pub trait ContextProvider {
}
/// Getter for a datasource
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>>;
/// Getter for a table function
fn get_table_function_source(
&self,
_name: &str,
_args: Vec<Expr>,
) -> Result<Arc<dyn TableSource>> {
unimplemented!()
}

/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
81 changes: 62 additions & 19 deletions datafusion/sql/src/relation/mod.rs
Original file line number Diff line number Diff line change
@@ -16,9 +16,11 @@
// under the License.

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{not_impl_err, DataFusionError, Result};
use datafusion_common::{
not_impl_err, DFSchema, DataFusionError, Result, TableReference,
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::TableFactor;
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};

mod join;

@@ -30,24 +32,65 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let (plan, alias) = match relation {
TableFactor::Table { name, alias, .. } => {
// normalize name and alias
let table_ref = self.object_name_to_table_reference(name)?;
let table_name = table_ref.to_string();
let cte = planner_context.get_cte(&table_name);
(
match (
cte,
self.context_provider.get_table_source(table_ref.clone()),
) {
(Some(cte_plan), _) => Ok(cte_plan.clone()),
(_, Ok(provider)) => {
LogicalPlanBuilder::scan(table_ref, provider, None)?.build()
TableFactor::Table {
name, alias, args, ..
} => {
// this maybe a little diffcult to resolve others tables' schema, so we only supprt value and scalar functions now
if let Some(func_args) = args {
let tbl_func_name = name.0.get(0).unwrap().value.to_string();
let mut args = vec![];
for arg in func_args {
match arg {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => {
let expr = self.sql_expr_to_logical_expr(
expr,
// TODO(veeupup): for now, maybe it's little diffcult to resolve tables' schema before create provider
// maybe we can put all relations schema in
&DFSchema::empty(),
planner_context,
)?;
args.push(expr);
}
arg => {
unimplemented!(
"Unsupported function argument type: {:?}",
arg
)
}
}
(None, Err(e)) => Err(e),
}?,
alias,
)
}
let provider = self
.context_provider
.get_table_function_source(&tbl_func_name, args)?;
let plan = LogicalPlanBuilder::scan(
TableReference::Bare {
table: std::borrow::Cow::Borrowed("tmp_table"),
},
provider,
None,
)?
.build()?;
(plan, alias)
} else {
// normalize name and alias
let table_ref = self.object_name_to_table_reference(name)?;
let table_name = table_ref.to_string();
let cte = planner_context.get_cte(&table_name);
(
match (
cte,
self.context_provider.get_table_source(table_ref.clone()),
) {
(Some(cte_plan), _) => Ok(cte_plan.clone()),
(_, Ok(provider)) => {
LogicalPlanBuilder::scan(table_ref, provider, None)?
.build()
}
(None, Err(e)) => Err(e),
}?,
alias,
)
}
}
TableFactor::Derived {
subquery, alias, ..

0 comments on commit 32ecce4

Please sign in to comment.