Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CREATE EXPERIMENT, expand support for WITH kwargs #796

Merged
merged 12 commits into from
Oct 7, 2022
389 changes: 356 additions & 33 deletions dask_planner/src/parser.rs

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use std::sync::Arc;
use crate::dialect::DaskDialect;
use crate::parser::{DaskParser, DaskStatement};
use crate::sql::logical::analyze_table::AnalyzeTablePlanNode;
use crate::sql::logical::create_experiment::CreateExperimentPlanNode;
use crate::sql::logical::create_model::CreateModelPlanNode;
use crate::sql::logical::create_table::CreateTablePlanNode;
use crate::sql::logical::create_view::CreateViewPlanNode;
Expand Down Expand Up @@ -427,6 +428,17 @@ impl DaskSQLContext {
with_options: create_model.with_options,
}),
})),
DaskStatement::CreateExperiment(create_experiment) => {
Ok(LogicalPlan::Extension(Extension {
node: Arc::new(CreateExperimentPlanNode {
experiment_name: create_experiment.name,
input: self._logical_relational_algebra(create_experiment.select)?,
if_not_exists: create_experiment.if_not_exists,
or_replace: create_experiment.or_replace,
with_options: create_experiment.with_options,
}),
}))
}
DaskStatement::PredictModel(predict_model) => Ok(LogicalPlan::Extension(Extension {
node: Arc::new(PredictModelPlanNode {
model_schema: predict_model.schema_name,
Expand Down
9 changes: 9 additions & 0 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::sql::types::rel_data_type_field::RelDataTypeField;
pub mod aggregate;
pub mod analyze_table;
pub mod create_catalog_schema;
pub mod create_experiment;
pub mod create_memory_table;
pub mod create_model;
pub mod create_table;
Expand Down Expand Up @@ -39,6 +40,7 @@ use pyo3::prelude::*;

use self::analyze_table::AnalyzeTablePlanNode;
use self::create_catalog_schema::CreateCatalogSchemaPlanNode;
use self::create_experiment::CreateExperimentPlanNode;
use self::create_model::CreateModelPlanNode;
use self::create_table::CreateTablePlanNode;
use self::create_view::CreateViewPlanNode;
Expand Down Expand Up @@ -149,6 +151,11 @@ impl PyLogicalPlan {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::CreateExperiment as PyCreateExperiment
pub fn create_experiment(&self) -> PyResult<create_experiment::PyCreateExperiment> {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::DropTable as DropTable
pub fn drop_table(&self) -> PyResult<drop_table::PyDropTable> {
to_py_plan(self.current_node.as_ref())
Expand Down Expand Up @@ -295,6 +302,8 @@ impl PyLogicalPlan {
let node = extension.node.as_any();
if node.downcast_ref::<CreateModelPlanNode>().is_some() {
"CreateModel"
} else if node.downcast_ref::<CreateExperimentPlanNode>().is_some() {
"CreateExperiment"
} else if node.downcast_ref::<CreateCatalogSchemaPlanNode>().is_some() {
"CreateCatalogSchema"
} else if node.downcast_ref::<CreateTablePlanNode>().is_some() {
Expand Down
130 changes: 130 additions & 0 deletions dask_planner/src/sql/logical/create_experiment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use crate::parser::PySqlArg;
use crate::sql::exceptions::py_type_err;
use crate::sql::logical;
use pyo3::prelude::*;

use datafusion_expr::logical_plan::UserDefinedLogicalNode;
use datafusion_expr::{Expr, LogicalPlan};

use fmt::Debug;
use std::{any::Any, fmt, sync::Arc};

use datafusion_common::DFSchemaRef;

#[derive(Clone)]
pub struct CreateExperimentPlanNode {
pub experiment_name: String,
pub input: LogicalPlan,
pub if_not_exists: bool,
pub or_replace: bool,
pub with_options: Vec<(String, PySqlArg)>,
}

impl Debug for CreateExperimentPlanNode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.fmt_for_explain(f)
}
}

impl UserDefinedLogicalNode for CreateExperimentPlanNode {
fn as_any(&self) -> &dyn Any {
self
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
self.input.schema()
}

fn expressions(&self) -> Vec<Expr> {
// there is no need to expose any expressions here since DataFusion would
// not be able to do anything with expressions that are specific to
// CREATE EXPERIMENT
vec![]
}

fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"CreateExperiment: experiment_name={}",
self.experiment_name
)
}

fn from_template(
&self,
_exprs: &[Expr],
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode> {
assert_eq!(inputs.len(), 1, "input size inconsistent");
Arc::new(CreateExperimentPlanNode {
experiment_name: self.experiment_name.clone(),
input: inputs[0].clone(),
if_not_exists: self.if_not_exists,
or_replace: self.or_replace,
with_options: self.with_options.clone(),
})
}
}

#[pyclass(name = "CreateExperiment", module = "dask_planner", subclass)]
pub struct PyCreateExperiment {
pub(crate) create_experiment: CreateExperimentPlanNode,
}

#[pymethods]
impl PyCreateExperiment {
/// Creating an experiment requires that a subquery be passed to the CREATE EXPERIMENT
/// statement to be used to gather the dataset which should be used for the
/// experiment. This function returns that portion of the statement.
#[pyo3(name = "getSelectQuery")]
fn get_select_query(&self) -> PyResult<logical::PyLogicalPlan> {
Ok(self.create_experiment.input.clone().into())
}

#[pyo3(name = "getExperimentName")]
fn get_experiment_name(&self) -> PyResult<String> {
Ok(self.create_experiment.experiment_name.clone())
}

#[pyo3(name = "getIfNotExists")]
fn get_if_not_exists(&self) -> PyResult<bool> {
Ok(self.create_experiment.if_not_exists)
}

#[pyo3(name = "getOrReplace")]
pub fn get_or_replace(&self) -> PyResult<bool> {
Ok(self.create_experiment.or_replace)
}

#[pyo3(name = "getSQLWithOptions")]
fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {
Ok(self.create_experiment.with_options.clone())
}
}

impl TryFrom<logical::LogicalPlan> for PyCreateExperiment {
type Error = PyErr;

fn try_from(logical_plan: logical::LogicalPlan) -> Result<Self, Self::Error> {
match logical_plan {
logical::LogicalPlan::Extension(extension) => {
if let Some(ext) = extension
.node
.as_any()
.downcast_ref::<CreateExperimentPlanNode>()
{
Ok(PyCreateExperiment {
create_experiment: ext.clone(),
})
} else {
Err(py_type_err("unexpected plan"))
}
}
_ => Err(py_type_err("unexpected plan")),
}
}
}
25 changes: 4 additions & 21 deletions dask_planner/src/sql/logical/create_model.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::sql::exceptions::py_type_err;
use crate::sql::logical;
use crate::sql::parser_utils::DaskParserUtils;
use pyo3::prelude::*;

use datafusion_expr::logical_plan::UserDefinedLogicalNode;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr;

use fmt::Debug;
use std::collections::HashMap;
use std::{any::Any, fmt, sync::Arc};

use crate::parser::PySqlArg;
use datafusion_common::DFSchemaRef;

#[derive(Clone)]
Expand All @@ -19,7 +17,7 @@ pub struct CreateModelPlanNode {
pub input: LogicalPlan,
pub if_not_exists: bool,
pub or_replace: bool,
pub with_options: Vec<SqlParserExpr>,
pub with_options: Vec<(String, PySqlArg)>,
}

impl Debug for CreateModelPlanNode {
Expand Down Expand Up @@ -99,23 +97,8 @@ impl PyCreateModel {
}

#[pyo3(name = "getSQLWithOptions")]
fn sql_with_options(&self) -> PyResult<HashMap<String, String>> {
let mut options: HashMap<String, String> = HashMap::new();
for elem in &self.create_model.with_options {
match elem {
SqlParserExpr::BinaryOp { left, op: _, right } => {
options.insert(
DaskParserUtils::str_from_expr(*left.clone()),
DaskParserUtils::str_from_expr(*right.clone()),
);
}
_ => {
return Err(py_type_err(
"Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types"));
}
}
}
Ok(options)
fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {
Ok(self.create_model.with_options.clone())
}
}

Expand Down
45 changes: 5 additions & 40 deletions dask_planner/src/sql/logical/create_table.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::parser::PySqlArg;
use crate::sql::exceptions::py_type_err;
use crate::sql::logical;

use pyo3::prelude::*;

use datafusion_expr::logical_plan::UserDefinedLogicalNode;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_sql::sqlparser::ast::{Expr as SqlParserExpr, Value};

use fmt::Debug;
use std::collections::HashMap;
use std::{any::Any, fmt, sync::Arc};

use datafusion_common::{DFSchema, DFSchemaRef};
Expand All @@ -19,7 +19,7 @@ pub struct CreateTablePlanNode {
pub table_name: String,
pub if_not_exists: bool,
pub or_replace: bool,
pub with_options: Vec<SqlParserExpr>,
pub with_options: Vec<(String, PySqlArg)>,
}

impl Debug for CreateTablePlanNode {
Expand Down Expand Up @@ -91,43 +91,8 @@ impl PyCreateTable {
}

#[pyo3(name = "getSQLWithOptions")]
fn sql_with_options(&self) -> PyResult<HashMap<String, String>> {
let mut options: HashMap<String, String> = HashMap::new();
for elem in &self.create_table.with_options {
if let SqlParserExpr::BinaryOp { left, op: _, right } = elem {
let key: Result<String, PyErr> = match *left.clone() {
SqlParserExpr::Identifier(ident) => Ok(ident.value),
_ => Err(py_type_err(format!(
"unexpected `left` Value type encountered: {:?}",
left
))),
};
let val: Result<String, PyErr> = match *right.clone() {
SqlParserExpr::Value(value) => match value {
Value::SingleQuotedString(e) => Ok(e.replace('\'', "")),
Value::DoubleQuotedString(e) => Ok(e.replace('\"', "")),
Value::Boolean(e) => {
if e {
Ok("True".to_string())
} else {
Ok("False".to_string())
}
}
Value::Number(e, ..) => Ok(e),
_ => Err(py_type_err(format!(
"unexpected Value type encountered: {:?}",
value
))),
},
_ => Err(py_type_err(format!(
"encountered unexpected Expr type: {:?}",
right
))),
};
options.insert(key?, val?);
}
}
Ok(options)
fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {
Ok(self.create_table.with_options.clone())
}
}

Expand Down
25 changes: 4 additions & 21 deletions dask_planner/src/sql/logical/export_model.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use crate::parser::PySqlArg;
use crate::sql::exceptions::py_type_err;
use crate::sql::logical;
use crate::sql::parser_utils::DaskParserUtils;
use pyo3::prelude::*;

use datafusion_expr::logical_plan::UserDefinedLogicalNode;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr;

use fmt::Debug;
use std::collections::HashMap;
use std::{any::Any, fmt, sync::Arc};

use datafusion_common::{DFSchema, DFSchemaRef};
Expand All @@ -17,7 +15,7 @@ use datafusion_common::{DFSchema, DFSchemaRef};
pub struct ExportModelPlanNode {
pub schema: DFSchemaRef,
pub model_name: String,
pub with_options: Vec<SqlParserExpr>,
pub with_options: Vec<(String, PySqlArg)>,
}

impl Debug for ExportModelPlanNode {
Expand Down Expand Up @@ -77,23 +75,8 @@ impl PyExportModel {
}

#[pyo3(name = "getSQLWithOptions")]
fn sql_with_options(&self) -> PyResult<HashMap<String, String>> {
let mut options: HashMap<String, String> = HashMap::new();
for elem in &self.export_model.with_options {
match elem {
SqlParserExpr::BinaryOp { left, op: _, right } => {
options.insert(
DaskParserUtils::str_from_expr(*left.clone()),
DaskParserUtils::str_from_expr(*right.clone()),
);
}
_ => {
return Err(py_type_err(
"Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types"));
}
}
}
Ok(options)
fn sql_with_options(&self) -> PyResult<Vec<(String, PySqlArg)>> {
Ok(self.export_model.with_options.clone())
}
}

Expand Down
Loading