Skip to content

Commit

Permalink
Extension bindings (#266)
Browse files Browse the repository at this point in the history
* Introduce to_variant trait function to LogicalNode and create Explain LogicalNode bindings

* Cargo fmt

* bindings for Extension LogicalNode

* Add missing classes to list of exports so test_imports will pass

* Update to point to proper repo

* Update pytest to adhere to aggregate calls being wrapped in projections

* Address linter change which causes a pytest to fail
  • Loading branch information
jdye64 authored Mar 13, 2023
1 parent 2172d3f commit 9b2c3da
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 134 deletions.
268 changes: 151 additions & 117 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ default = ["mimalloc"]
tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.8"
pyo3 = { version = "0.18.0", features = ["extension-module", "abi3", "abi3-py37"] }
datafusion = { version = "19.0.0", features = ["pyarrow", "avro"] }
datafusion-expr = "19.0.0"
datafusion-optimizer = "19.0.0"
datafusion-common = { version = "19.0.0", features = ["pyarrow"] }
datafusion-sql = "19.0.0"
datafusion-substrait = "19.0.0"
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab", features = ["pyarrow", "avro"]}
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab", features = ["pyarrow"]}
datafusion-sql = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
datafusion-substrait = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
uuid = { version = "1.2", features = ["v4"] }
mimalloc = { version = "*", optional = true, default-features = false }
async-trait = "0.1"
Expand Down
2 changes: 2 additions & 0 deletions datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
TryCast,
Between,
Explain,
Extension,
)

__version__ = importlib_metadata.version(__name__)
Expand Down Expand Up @@ -129,6 +130,7 @@
"TryCast",
"Between",
"Explain",
"Extension",
]


Expand Down
12 changes: 4 additions & 8 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,13 @@ def test_logical_plan(aggregate_df):
def test_optimized_logical_plan(aggregate_df):
plan = aggregate_df.optimized_logical_plan()

expected = "Projection: test.c1, SUM(test.c2)"
expected = "Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]"

assert expected == plan.display()

expected = (
"Projection: test.c1, SUM(test.c2)\n"
" Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n"
" TableScan: test projection=[c1, c2]"
"Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n"
" TableScan: test projection=[c1, c2]"
)

assert expected == plan.display_indent()
Expand All @@ -366,9 +365,7 @@ def test_optimized_logical_plan(aggregate_df):
def test_execution_plan(aggregate_df):
plan = aggregate_df.execution_plan()

expected = (
"ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n"
)
expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[SUM(test.c2)]\n" # noqa: E501

assert expected == plan.display()

Expand All @@ -382,7 +379,6 @@ def test_execution_plan(aggregate_df):

# indent plan will be different for everyone due to absolute path
# to filename, so we just check for some expected content
assert "ProjectionExec:" in indent
assert "AggregateExec:" in indent
assert "CoalesceBatchesExec:" in indent
assert "RepartitionExec:" in indent
Expand Down
4 changes: 4 additions & 0 deletions datafusion/tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
Cast,
TryCast,
Between,
Explain,
Extension,
)


Expand Down Expand Up @@ -143,6 +145,8 @@ def test_class_module_is_datafusion():
Cast,
TryCast,
Between,
Explain,
Extension,
]:
assert klass.__module__ == "datafusion.expr"

Expand Down
8 changes: 5 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_common::config::Extensions;
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -698,19 +699,20 @@ impl PySessionContext {
part: usize,
py: Python,
) -> PyResult<PyRecordBatchStream> {
let ctx = Arc::new(TaskContext::new(
let ctx = TaskContext::try_new(
"task_id".to_string(),
"session_id".to_string(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Arc::new(RuntimeEnv::default()),
));
Extensions::default(),
);
// create a Tokio runtime to run the async code
let rt = Runtime::new().unwrap();
let plan = plan.plan.clone();
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, ctx) });
rt.spawn(async move { plan.execute(part, Arc::new(ctx?)) });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
}
Expand Down
2 changes: 2 additions & 0 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod cross_join;
pub mod empty_relation;
pub mod exists;
pub mod explain;
pub mod extension;
pub mod filter;
pub mod grouping_set;
pub mod in_list;
Expand Down Expand Up @@ -272,6 +273,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<join::PyJoinConstraint>()?;
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
m.add_class::<extension::PyExtension>()?;
m.add_class::<filter::PyFilter>()?;
m.add_class::<projection::PyProjection>()?;
m.add_class::<table_scan::PyTableScan>()?;
Expand Down
52 changes: 52 additions & 0 deletions src/expr/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::Extension;
use pyo3::prelude::*;

use crate::sql::logical::PyLogicalPlan;

use super::logical_node::LogicalNode;

#[pyclass(name = "Extension", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyExtension {
pub node: Extension,
}

impl From<Extension> for PyExtension {
fn from(node: Extension) -> PyExtension {
PyExtension { node }
}
}

#[pymethods]
impl PyExtension {
fn name(&self) -> PyResult<String> {
Ok(self.node.node.name().to_string())
}
}

impl LogicalNode for PyExtension {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![]
}

fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Ok(self.clone().into_py(py))
}
}
2 changes: 2 additions & 0 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::expr::aggregate::PyAggregate;
use crate::expr::analyze::PyAnalyze;
use crate::expr::empty_relation::PyEmptyRelation;
use crate::expr::explain::PyExplain;
use crate::expr::extension::PyExtension;
use crate::expr::filter::PyFilter;
use crate::expr::limit::PyLimit;
use crate::expr::projection::PyProjection;
Expand Down Expand Up @@ -60,6 +61,7 @@ impl PyLogicalPlan {
LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),
Expand Down

0 comments on commit 9b2c3da

Please sign in to comment.