diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index dcab86a49..18946888f 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -388,6 +388,15 @@ def test_execution_plan(aggregate_df): assert "RepartitionExec:" in indent assert "CsvExec:" in indent + ctx = SessionContext() + stream = ctx.execute(plan, 0) + # get the one and only batch + batch = stream.next() + assert batch is not None + # there should be no more batches + batch = stream.next() + assert batch is None + def test_repartition(df): df.repartition(2) diff --git a/src/context.rs b/src/context.rs index 8dcd1d6ff..1acf5f289 100644 --- a/src/context.rs +++ b/src/context.rs @@ -28,7 +28,9 @@ use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::DataFusionError; +use crate::errors::{py_datafusion_err, DataFusionError}; +use crate::physical_plan::PyExecutionPlan; +use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; @@ -39,14 +41,17 @@ use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::datasource::TableProvider; use datafusion::datasource::MemTable; -use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext, TaskContext}; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_common::ScalarValue; +use tokio::runtime::Runtime; +use tokio::task::JoinHandle; #[pyclass(name = "SessionConfig", module = "datafusion", subclass, unsendable)] #[derive(Clone, Default)] @@ -579,6 +584,30 @@ impl PySessionContext { Err(err) => Ok(format!("Error: {:?}", err.to_string())), } } + + /// Execute a partition of an execution plan and return a stream of record batches + pub fn execute( + &self, + plan: PyExecutionPlan, + part: usize, + py: Python, + ) -> PyResult { + let ctx = Arc::new(TaskContext::new( + "task_id".to_string(), + "session_id".to_string(), + HashMap::new(), + HashMap::new(), + HashMap::new(), + Arc::new(RuntimeEnv::default()), + )); + // create a Tokio runtime to run the async code + let rt = Runtime::new().unwrap(); + let plan = plan.plan.clone(); + let fut: JoinHandle> = + rt.spawn(async move { plan.execute(part, ctx) }); + let stream = wait_for_future(py, fut).map_err(|e| py_datafusion_err(e))?; + Ok(PyRecordBatchStream::new(stream?)) + } } impl PySessionContext { diff --git a/src/lib.rs b/src/lib.rs index f6d404efd..d9898db61 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,7 @@ mod expr; mod functions; pub mod physical_plan; mod pyarrow_filter_expression; +mod record_batch; pub mod sql; pub mod store; pub mod substrait; diff --git a/src/record_batch.rs b/src/record_batch.rs new file mode 100644 index 000000000..15b70e8ce --- /dev/null +++ b/src/record_batch.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::wait_for_future; +use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::physical_plan::SendableRecordBatchStream; +use futures::StreamExt; +use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; + +#[pyclass(name = "RecordBatch", module = "datafusion", subclass)] +pub struct PyRecordBatch { + batch: RecordBatch, +} + +#[pymethods] +impl PyRecordBatch { + fn to_pyarrow(&self, py: Python) -> PyResult { + self.batch.to_pyarrow(py) + } +} + +impl From for PyRecordBatch { + fn from(batch: RecordBatch) -> Self { + Self { batch } + } +} + +#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] +pub struct PyRecordBatchStream { + stream: SendableRecordBatchStream, +} + +impl PyRecordBatchStream { + pub fn new(stream: SendableRecordBatchStream) -> Self { + Self { stream } + } +} + +#[pymethods] +impl PyRecordBatchStream { + fn next(&mut self, py: Python) -> PyResult> { + let result = self.stream.next(); + match wait_for_future(py, result) { + None => Ok(None), + Some(Ok(b)) => Ok(Some(b.into())), + Some(Err(e)) => Err(e.into()), + } + } +}