Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyo3 refactorings #740

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 65 additions & 81 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
use pyo3::types::{PyDict, PyList, PyTuple};
use tokio::task::JoinHandle;

/// Configuration options for a SessionContext
Expand Down Expand Up @@ -291,24 +291,17 @@ impl PySessionContext {
pub fn register_object_store(
&mut self,
scheme: &str,
store: &Bound<'_, PyAny>,
store: StorageContexts,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
match StorageContexts::extract_bound(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
StorageContexts::MicrosoftAzure(azure) => {
Ok((azure.inner, azure.container_name))
}
StorageContexts::LocalFileSystem(local) => Ok((local.inner, "".to_string())),
},
Err(_e) => Err(PyValueError::new_err("Invalid object store")),
};

// for most stores the "host" is the bucket name and can be inferred from the store
let (store, upstream_host) = res?;
let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
};

// let users override the host to match the api signature from upstream
let derived_host = if let Some(host) = host {
host
Expand Down Expand Up @@ -434,105 +427,96 @@ impl PySessionContext {
}

/// Construct datafusion dataframe from Python list
#[allow(clippy::wrong_self_convention)]
pub fn from_pylist(
&mut self,
data: PyObject,
data: Bound<'_, PyList>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Acquire GIL Token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from Python dictionary
#[allow(clippy::wrong_self_convention)]
pub fn from_pydict(
&mut self,
data: PyObject,
data: Bound<'_, PyDict>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Acquire GIL Token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from Arrow Table
#[allow(clippy::wrong_self_convention)]
pub fn from_arrow_table(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0(py, "to_batches")?;

let schema = data.getattr(py, "schema")?;
let schema = schema.extract::<PyArrowType<Schema>>(py)?;

// Cast PyObject to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, Some(schema), py)
})
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0("to_batches")?;

let schema = data.getattr("schema")?;
let schema = schema.extract::<PyArrowType<Schema>>()?;

// Cast PyAny to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>()?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, Some(schema), py)
}

/// Construct datafusion dataframe from pandas
#[allow(clippy::wrong_self_convention)]
pub fn from_pandas(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Obtain GIL token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from polars
#[allow(clippy::wrong_self_convention)]
pub fn from_polars(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Convert Polars dataframe to Arrow Table
let table = data.call_method0(py, "to_arrow")?;
// Convert Polars dataframe to Arrow Table
let table = data.call_method0("to_arrow")?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, data.py())?;
Ok(df)
}

pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
Expand Down
55 changes: 22 additions & 33 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,17 +423,15 @@ impl PyDataFrame {

/// Convert to Arrow Table
/// Collect the batches and pass to Arrow Table
fn to_arrow_table(&self, py: Python) -> PyResult<PyObject> {
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
let batches = self.collect(py)?.to_object(py);
let schema: PyObject = self.schema().into_py(py);

Python::with_gil(|py| {
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
})
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
}

fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
Expand Down Expand Up @@ -464,51 +462,42 @@ impl PyDataFrame {

/// Convert to pandas dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
let result = table.call_method0(py, "to_pandas")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
let result = table.call_method0(py, "to_pandas")?;
Ok(result)
}

/// Convert to Python list using pyarrow
/// Each list item represents one row encoded as dictionary
fn to_pylist(&self, py: Python) -> PyResult<PyObject> {
fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
let result = table.call_method0(py, "to_pylist")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
let result = table.call_method0(py, "to_pylist")?;
Ok(result)
}

/// Convert to Python dictionary using pyarrow
/// Each dictionary key is a column and the dictionary value represents the column values
fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
let result = table.call_method0(py, "to_pydict")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
let result = table.call_method0(py, "to_pydict")?;
Ok(result)
}

/// Convert to polars dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
fn to_polars(&self, py: Python) -> PyResult<PyObject> {
fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
})
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
}

// Executes this DataFrame to get the total number of rows.
Expand Down
4 changes: 2 additions & 2 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl PyLogicalPlan {
impl PyLogicalPlan {
/// Return the specific logical operator
pub fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Python::with_gil(|_| match self.plan.as_ref() {
match self.plan.as_ref() {
LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py),
LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
LogicalPlan::CrossJoin(plan) => PyCrossJoin::from(plan.clone()).to_variant(py),
Expand All @@ -85,7 +85,7 @@ impl PyLogicalPlan {
"Cannot convert this plan to a LogicalNode: {:?}",
other
))),
})
}
}

/// Get the inputs to this plan
Expand Down
Loading