diff --git a/datafusion/tests/test_expr.py b/datafusion/tests/test_expr.py new file mode 100644 index 000000000..4a7db879a --- /dev/null +++ b/datafusion/tests/test_expr.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import SessionContext +from datafusion.expr import ( + Projection, + Filter, + Aggregate, + Limit, + Sort, + TableScan, +) +import pytest + + +@pytest.fixture +def test_ctx(): + ctx = SessionContext() + ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv") + return ctx + + +def test_projection(test_ctx): + df = test_ctx.sql("select c1, 123, c1 < 123 from test") + plan = df.logical_plan() + + plan = plan.to_variant() + assert isinstance(plan, Projection) + + plan = plan.input().to_variant() + assert isinstance(plan, TableScan) + + +def test_filter(test_ctx): + df = test_ctx.sql("select c1 from test WHERE c1 > 5") + plan = df.logical_plan() + + plan = plan.to_variant() + assert isinstance(plan, Projection) + + plan = plan.input().to_variant() + assert isinstance(plan, Filter) + + +def test_limit(test_ctx): + df = test_ctx.sql("select c1 from test LIMIT 10") + plan = df.logical_plan() + + plan = plan.to_variant() + assert isinstance(plan, Limit) + + +def test_aggregate(test_ctx): + df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1") + plan = df.logical_plan() + + plan = plan.to_variant() + assert isinstance(plan, Projection) + + plan = plan.input().to_variant() + assert isinstance(plan, Aggregate) + + +def test_sort(test_ctx): + df = test_ctx.sql("select c1 from test order by c1") + plan = df.logical_plan() + + plan = plan.to_variant() + assert isinstance(plan, Sort) diff --git a/src/sql/logical.rs b/src/sql/logical.rs index dcd7baa58..08d19619d 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -17,6 +17,13 @@ use std::sync::Arc; +use crate::errors::py_runtime_err; +use crate::expr::aggregate::PyAggregate; +use crate::expr::filter::PyFilter; +use crate::expr::limit::PyLimit; +use crate::expr::projection::PyProjection; +use crate::expr::sort::PySort; +use crate::expr::table_scan::PyTableScan; use datafusion_expr::LogicalPlan; use pyo3::prelude::*; @@ -37,6 +44,22 @@ impl PyLogicalPlan { #[pymethods] impl PyLogicalPlan { + /// Return the specific logical operator + fn to_variant(&self, py: Python) -> PyResult { + Python::with_gil(|_| match self.plan.as_ref() { + LogicalPlan::Projection(plan) => Ok(PyProjection::from(plan.clone()).into_py(py)), + LogicalPlan::TableScan(plan) => Ok(PyTableScan::from(plan.clone()).into_py(py)), + LogicalPlan::Filter(plan) => Ok(PyFilter::from(plan.clone()).into_py(py)), + LogicalPlan::Limit(plan) => Ok(PyLimit::from(plan.clone()).into_py(py)), + LogicalPlan::Sort(plan) => Ok(PySort::from(plan.clone()).into_py(py)), + LogicalPlan::Aggregate(plan) => Ok(PyAggregate::from(plan.clone()).into_py(py)), + other => Err(py_runtime_err(format!( + "Cannot convert this plan to a LogicalNode: {:?}", + other + ))), + }) + } + /// Get the inputs to this plan pub fn inputs(&self) -> Vec { let mut inputs = vec![];