From fb8b0880b97605e31d255eb41a922e793c9e7486 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 22 Sep 2022 07:25:11 -0700 Subject: [PATCH 01/10] Initial commit --- dask_planner/src/parser.rs | 318 +++++++++++++++++- dask_planner/src/sql.rs | 12 + dask_planner/src/sql/logical.rs | 9 + .../src/sql/logical/create_experiment.rs | 147 ++++++++ dask_planner/src/sql/logical/create_model.rs | 22 +- dask_planner/src/sql/parser_utils.rs | 7 + .../physical/rel/custom/create_experiment.py | 26 +- dask_sql/utils.py | 20 +- tests/integration/test_model.py | 8 - 9 files changed, 522 insertions(+), 47 deletions(-) create mode 100644 dask_planner/src/sql/logical/create_experiment.rs diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 26ea8b49e..38fa844f6 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -2,8 +2,12 @@ //! //! Declares a SQL parser based on sqlparser that handles custom formats that we need. +use crate::sql::exceptions::py_type_err; +use pyo3::prelude::*; + use crate::dialect::DaskDialect; use crate::sql::parser_utils::DaskParserUtils; +use datafusion_sql::sqlparser::ast::Ident; use datafusion_sql::sqlparser::{ ast::{Expr, SelectItem, Statement as SQLStatement}, dialect::{keywords::Keyword, Dialect}, @@ -18,6 +22,114 @@ macro_rules! parser_err { }; } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CustomExpr { + Map(Vec), + Multiset(Vec), + Nested(Vec), +} + +#[pyclass(name = "SqlArg", module = "datafusion")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PySqlArg { + expr: Option, + custom: Option, +} + +impl PySqlArg { + pub fn new(expr: Option, custom: Option) -> Self { + Self { expr, custom } + } +} + +#[pymethods] +impl PySqlArg { + #[pyo3(name = "isCollection")] + pub fn is_collection(&self) -> PyResult { + match &self.custom { + Some(CustomExpr::Nested(_)) => Ok(false), + Some(_) => Ok(true), + None => match &self.expr { + Some(expr) => match expr { + Expr::Array(_) => Ok(true), + _ => Ok(false), + }, + None => Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )), + }, + } + } + + #[pyo3(name = "isKwargs")] + pub fn is_kwargs(&self) -> PyResult { + match &self.custom { + Some(CustomExpr::Nested(_)) => Ok(true), + Some(_) => Ok(false), + None => Ok(false), + } + } + + #[pyo3(name = "getOperator")] + pub fn get_operator(&self) -> PyResult { + match &self.custom { + Some(custom_expr) => match custom_expr { + CustomExpr::Map(_) => Ok("MAP".to_string()), + CustomExpr::Multiset(_) => Ok("MULTISET".to_string()), + CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), + }, + None => match &self.expr { + Some(expr) => match expr { + Expr::Array(_) => Ok("ARRAY".to_string()), + other => Err(py_type_err(format!("Expected Array, got {:?}", other))), + }, + None => Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )), + }, + } + } + + #[pyo3(name = "getOperandList")] + pub fn get_operand_list(&self) -> PyResult> { + match &self.custom { + Some(custom_expr) => match custom_expr { + CustomExpr::Map(exprs) | CustomExpr::Multiset(exprs) => Ok(exprs + .iter() + .map(|e| PySqlArg::new(Some(e.clone()), None)) + .collect()), + CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), + }, + None => match &self.expr { + Some(expr) => match expr { + Expr::Array(array) => Ok(array + .elem + .iter() + .map(|e| PySqlArg::new(Some(e.clone()), None)) + .collect()), + _ => Ok(vec![]), + }, + None => Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )), + }, + } + } +} + +#[pyclass(name = "SqlKwarg", module = "datafusion")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PySqlKwarg { + pub key: Ident, + pub value: PySqlArg, +} + +impl PySqlKwarg { + pub fn new(key: Ident, value: PySqlArg) -> Self { + Self { key, value } + } +} + /// Dask-SQL extension DDL for `CREATE MODEL` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateModel { @@ -30,6 +142,21 @@ pub struct CreateModel { /// To replace the model or not pub or_replace: bool, /// with options + pub with_options: Vec, +} + +/// Dask-SQL extension DDL for `CREATE EXPERIMENT` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreateExperiment { + /// experiment name + pub name: String, + /// input Query + pub select: DaskStatement, + /// IF NOT EXISTS + pub if_not_exists: bool, + /// To replace the model or not + pub or_replace: bool, + /// with options pub with_options: Vec, } @@ -167,6 +294,8 @@ pub enum DaskStatement { Statement(Box), /// Extension: `CREATE MODEL` CreateModel(Box), + /// Extension: `CREATE EXPERIMENT` + CreateExperiment(Box), /// Extension: `CREATE SCHEMA` CreateCatalogSchema(Box), /// Extension: `CREATE TABLE` @@ -384,6 +513,19 @@ impl<'a> DaskParser<'a> { // use custom parsing self.parse_create_model(if_not_exists, or_replace) } + "experiment" => { + // move one token forward + self.parser.next_token(); + + let if_not_exists = self.parser.parse_keywords(&[ + Keyword::IF, + Keyword::NOT, + Keyword::EXISTS, + ]); + + // use custom parsing + self.parse_create_experiment(if_not_exists, or_replace) + } "schema" => { // move one token forward self.parser.next_token(); @@ -658,13 +800,14 @@ impl<'a> DaskParser<'a> { ])?; self.parser.prev_token(); - let sql_statement = self.parse_statement()?; + let select = self.parse_statement()?; + self.parser.expect_token(&Token::RParen)?; let predict = PredictModel { schema_name: mdl_schema, name: mdl_name, - select: sql_statement, + select, }; Ok(DaskStatement::PredictModel(Box::new(predict))) } @@ -677,6 +820,106 @@ impl<'a> DaskParser<'a> { ) -> Result { let model_name = self.parser.parse_object_name()?; self.parser.expect_keyword(Keyword::WITH)?; + self.parser.expect_token(&Token::LParen)?; + + let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; + + self.parser.expect_token(&Token::RParen)?; + + // Parse the nested query statement + self.parser.expect_keyword(Keyword::AS)?; + self.parser.expect_token(&Token::LParen)?; + + // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements + // TODO: find a more sophisticated way to allow any statement that would return a table + self.parser.expect_one_of_keywords(&[ + Keyword::SELECT, + Keyword::DESCRIBE, + Keyword::SHOW, + Keyword::ANALYZE, + ])?; + self.parser.prev_token(); + + let select = self.parse_statement()?; + + self.parser.expect_token(&Token::RParen)?; + + let create = CreateModel { + name: model_name.to_string(), + select, + if_not_exists, + or_replace, + with_options, + }; + Ok(DaskStatement::CreateModel(Box::new(create))) + } + + // copied from sqlparser crate and adapted to work with DaskParser + fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> + where + F: FnMut(&mut DaskParser<'a>) -> Result, + { + let mut values = vec![]; + loop { + values.push(f(self)?); + if !self.parser.consume_token(&Token::Comma) { + break; + } + } + Ok(values) + } + + fn parse_key_value_pair(&mut self) -> Result { + let key = self.parser.parse_identifier()?; + self.parser.expect_token(&Token::Eq)?; + match self.parser.next_token() { + Token::LParen => { + let key_value_pairs = + self.parse_comma_separated(DaskParser::parse_key_value_pair)?; + self.parser.expect_token(&Token::RParen)?; + Ok(PySqlKwarg::new( + key, + PySqlArg::new(None, Some(CustomExpr::Nested(key_value_pairs))), + )) + } + Token::Word(w) if w.value.to_lowercase().as_str() == "map" => { + // TODO this does not support map or multiset expressions within the map + self.parser.expect_token(&Token::LBracket)?; + let values = self.parser.parse_comma_separated(Parser::parse_expr)?; + self.parser.expect_token(&Token::RBracket)?; + Ok(PySqlKwarg::new( + key, + PySqlArg::new(None, Some(CustomExpr::Map(values))), + )) + } + Token::Word(w) if w.value.to_lowercase().as_str() == "multiset" => { + // TODO this does not support map or multiset expressions within the multiset + self.parser.expect_token(&Token::LBracket)?; + let values = self.parser.parse_comma_separated(Parser::parse_expr)?; + self.parser.expect_token(&Token::RBracket)?; + Ok(PySqlKwarg::new( + key, + PySqlArg::new(None, Some(CustomExpr::Multiset(values))), + )) + } + _ => { + self.parser.prev_token(); + Ok(PySqlKwarg::new( + key, + PySqlArg::new(Some(self.parser.parse_expr()?), None), + )) + } + } + } + + /// Parse Dask-SQL CREATE EXPERIMENT statement + fn parse_create_experiment( + &mut self, + if_not_exists: bool, + or_replace: bool, + ) -> Result { + let experiment_name = self.parser.parse_object_name()?; + self.parser.expect_keyword(Keyword::WITH)?; // `table_name` has been parsed at this point but is needed in `parse_table_factor`, reset consumption self.parser.prev_token(); @@ -703,14 +946,14 @@ impl<'a> DaskParser<'a> { self.parser.expect_token(&Token::RParen)?; - let create = CreateModel { - name: model_name.to_string(), + let create = CreateExperiment { + name: experiment_name.to_string(), select, if_not_exists, or_replace, with_options, }; - Ok(DaskStatement::CreateModel(Box::new(create))) + Ok(DaskStatement::CreateExperiment(Box::new(create))) } /// Parse Dask-SQL CREATE {IF NOT EXISTS | OR REPLACE} SCHEMA ... statement @@ -957,3 +1200,68 @@ impl<'a> DaskParser<'a> { }))) } } + +#[cfg(test)] +mod test { + use crate::parser::{DaskParser, DaskStatement}; + + #[test] + fn create_model() { + let sql = r#"CREATE MODEL my_model WITH ( + model_class = 'mock.MagicMock', + target_column = 'target', + fit_kwargs = ( + first_arg = 3, + second_arg = ARRAY [ 1, 2 ], + third_arg = MAP [ 'a', 1 ], + forth_arg = MULTISET [ 1, 1, 2, 3 ] + ) + ) AS ( + SELECT x, y, x*y > 0 AS target + FROM timeseries + LIMIT 100 + )"#; + let statements = DaskParser::parse_sql(sql).unwrap(); + assert_eq!(1, statements.len()); + + match &statements[0] { + DaskStatement::CreateModel(create_model) => { + // test Debug + let expected = "[\ + PySqlKwarg { key: Ident { value: \"model_class\", quote_style: None }, value: Expr(Value(SingleQuotedString(\"mock.MagicMock\"))) }, \ + PySqlKwarg { key: Ident { value: \"target_column\", quote_style: None }, value: Expr(Value(SingleQuotedString(\"target\"))) }, \ + PySqlKwarg { key: Ident { value: \"fit_kwargs\", quote_style: None }, value: Nested([\ + PySqlKwarg { key: Ident { value: \"first_arg\", quote_style: None }, value: Expr(Value(Number(\"3\", false))) }, \ + PySqlKwarg { key: Ident { value: \"second_arg\", quote_style: None }, value: Expr(Array(Array { elem: [Value(Number(\"1\", false)), Value(Number(\"2\", false))], named: true })) }, \ + PySqlKwarg { key: Ident { value: \"third_arg\", quote_style: None }, value: Map([Value(SingleQuotedString(\"a\")), Value(Number(\"1\", false))]) }, \ + PySqlKwarg { key: Ident { value: \"forth_arg\", quote_style: None }, value: Multiset([Value(Number(\"1\", false)), Value(Number(\"1\", false)), Value(Number(\"2\", false)), Value(Number(\"3\", false))]) }\ + ]) }\ + ]"; + assert_eq!(expected, &format!("{:?}", create_model.with_options)); + + // test Display + let expected = "model_class = 'mock.MagicMock', \ + target_column = 'target', \ + fit_kwargs = (\ + first_arg = '3', \ + second_arg = ARRAY[1, 2], \ + third_arg = MAP [ 'a', 1 ], \ + forth_arg = MULTISET [ 1, 1, 2, 3 ]\ + )"; + assert_eq!( + expected, + format!( + "{}", + create_model + .with_options + .iter() + .map(|pair| format!("{}", pair)) + .collect::>() + .join(", ") + ) + ) + } + _ => panic!(), + } + } +} diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 6ae956867..6512c54d5 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -31,6 +31,7 @@ use std::sync::Arc; use crate::dialect::DaskDialect; use crate::parser::{DaskParser, DaskStatement}; use crate::sql::logical::analyze_table::AnalyzeTablePlanNode; +use crate::sql::logical::create_experiment::CreateExperimentPlanNode; use crate::sql::logical::create_model::CreateModelPlanNode; use crate::sql::logical::create_table::CreateTablePlanNode; use crate::sql::logical::create_view::CreateViewPlanNode; @@ -427,6 +428,17 @@ impl DaskSQLContext { with_options: create_model.with_options, }), })), + DaskStatement::CreateExperiment(create_experiment) => { + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(CreateExperimentPlanNode { + experiment_name: create_experiment.name, + input: self._logical_relational_algebra(create_experiment.select)?, + if_not_exists: create_experiment.if_not_exists, + or_replace: create_experiment.or_replace, + with_options: create_experiment.with_options, + }), + })) + } DaskStatement::PredictModel(predict_model) => Ok(LogicalPlan::Extension(Extension { node: Arc::new(PredictModelPlanNode { model_schema: predict_model.schema_name, diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 1a7dc865b..fb50e49e7 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -5,6 +5,7 @@ use crate::sql::types::rel_data_type_field::RelDataTypeField; pub mod aggregate; pub mod analyze_table; pub mod create_catalog_schema; +pub mod create_experiment; pub mod create_memory_table; pub mod create_model; pub mod create_table; @@ -39,6 +40,7 @@ use pyo3::prelude::*; use self::analyze_table::AnalyzeTablePlanNode; use self::create_catalog_schema::CreateCatalogSchemaPlanNode; +use self::create_experiment::CreateExperimentPlanNode; use self::create_model::CreateModelPlanNode; use self::create_table::CreateTablePlanNode; use self::create_view::CreateViewPlanNode; @@ -149,6 +151,11 @@ impl PyLogicalPlan { to_py_plan(self.current_node.as_ref()) } + /// LogicalPlan::CreateExperiment as PyCreateExperiment + pub fn create_experiment(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + /// LogicalPlan::DropTable as DropTable pub fn drop_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) @@ -295,6 +302,8 @@ impl PyLogicalPlan { let node = extension.node.as_any(); if node.downcast_ref::().is_some() { "CreateModel" + } else if node.downcast_ref::().is_some() { + "CreateExperiment" } else if node.downcast_ref::().is_some() { "CreateCatalogSchema" } else if node.downcast_ref::().is_some() { diff --git a/dask_planner/src/sql/logical/create_experiment.rs b/dask_planner/src/sql/logical/create_experiment.rs new file mode 100644 index 000000000..c40673948 --- /dev/null +++ b/dask_planner/src/sql/logical/create_experiment.rs @@ -0,0 +1,147 @@ +use crate::sql::exceptions::py_type_err; +use crate::sql::logical; +use crate::sql::parser_utils::DaskParserUtils; +use pyo3::prelude::*; + +use datafusion_expr::logical_plan::UserDefinedLogicalNode; +use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr; + +use fmt::Debug; +use std::collections::HashMap; +use std::{any::Any, fmt, sync::Arc}; + +use datafusion_common::DFSchemaRef; + +#[derive(Clone)] +pub struct CreateExperimentPlanNode { + pub experiment_name: String, + pub input: LogicalPlan, + pub if_not_exists: bool, + pub or_replace: bool, + pub with_options: Vec, +} + +impl Debug for CreateExperimentPlanNode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt_for_explain(f) + } +} + +impl UserDefinedLogicalNode for CreateExperimentPlanNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + // there is no need to expose any expressions here since DataFusion would + // not be able to do anything with expressions that are specific to + // CREATE EXPERIMENT + vec![] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "CreateExperiment: experiment_name={}", + self.experiment_name + ) + } + + fn from_template( + &self, + _exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Arc::new(CreateExperimentPlanNode { + experiment_name: self.experiment_name.clone(), + input: inputs[0].clone(), + if_not_exists: self.if_not_exists, + or_replace: self.or_replace, + with_options: self.with_options.clone(), + }) + } +} + +#[pyclass(name = "CreateExperiment", module = "dask_planner", subclass)] +pub struct PyCreateExperiment { + pub(crate) create_experiment: CreateExperimentPlanNode, +} + +#[pymethods] +impl PyCreateExperiment { + /// Creating an experiment requires that a subquery be passed to the CREATE EXPERIMENT + /// statement to be used to gather the dataset which should be used for the + /// experiment. This function returns that portion of the statement. + #[pyo3(name = "getSelectQuery")] + fn get_select_query(&self) -> PyResult { + Ok(self.create_experiment.input.clone().into()) + } + + #[pyo3(name = "getExperimentName")] + fn get_experiment_name(&self) -> PyResult { + Ok(self.create_experiment.experiment_name.clone()) + } + + #[pyo3(name = "getIfNotExists")] + fn get_if_not_exists(&self) -> PyResult { + Ok(self.create_experiment.if_not_exists) + } + + #[pyo3(name = "getOrReplace")] + pub fn get_or_replace(&self) -> PyResult { + Ok(self.create_experiment.or_replace) + } + + #[pyo3(name = "getSQLWithOptions")] + fn sql_with_options(&self) -> PyResult> { + let mut options: HashMap = HashMap::new(); + for elem in &self.create_experiment.with_options { + match elem { + SqlParserExpr::BinaryOp { left, op: _, right } => { + options.insert( + DaskParserUtils::str_from_expr(*left.clone()), + DaskParserUtils::str_from_expr(*right.clone()), + ); + } + _ => { + return Err(py_type_err( + "Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types")); + } + } + } + Ok(options) + } +} + +impl TryFrom for PyCreateExperiment { + type Error = PyErr; + + fn try_from(logical_plan: logical::LogicalPlan) -> Result { + match logical_plan { + logical::LogicalPlan::Extension(extension) => { + if let Some(ext) = extension + .node + .as_any() + .downcast_ref::() + { + Ok(PyCreateExperiment { + create_experiment: ext.clone(), + }) + } else { + Err(py_type_err("unexpected plan")) + } + } + _ => Err(py_type_err("unexpected plan")), + } + } +} diff --git a/dask_planner/src/sql/logical/create_model.rs b/dask_planner/src/sql/logical/create_model.rs index 53a65e7e0..d465a5d33 100644 --- a/dask_planner/src/sql/logical/create_model.rs +++ b/dask_planner/src/sql/logical/create_model.rs @@ -1,16 +1,15 @@ use crate::sql::exceptions::py_type_err; use crate::sql::logical; -use crate::sql::parser_utils::DaskParserUtils; use pyo3::prelude::*; use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::{Expr, LogicalPlan}; -use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr; use fmt::Debug; use std::collections::HashMap; use std::{any::Any, fmt, sync::Arc}; +use crate::parser::{PySqlArg, PySqlKwarg}; use datafusion_common::DFSchemaRef; #[derive(Clone)] @@ -19,7 +18,7 @@ pub struct CreateModelPlanNode { pub input: LogicalPlan, pub if_not_exists: bool, pub or_replace: bool, - pub with_options: Vec, + pub with_options: Vec, } impl Debug for CreateModelPlanNode { @@ -99,21 +98,10 @@ impl PyCreateModel { } #[pyo3(name = "getSQLWithOptions")] - fn sql_with_options(&self) -> PyResult> { - let mut options: HashMap = HashMap::new(); + fn sql_with_options(&self) -> PyResult> { + let mut options: HashMap = HashMap::new(); for elem in &self.create_model.with_options { - match elem { - SqlParserExpr::BinaryOp { left, op: _, right } => { - options.insert( - DaskParserUtils::str_from_expr(*left.clone()), - DaskParserUtils::str_from_expr(*right.clone()), - ); - } - _ => { - return Err(py_type_err( - "Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types")); - } - } + options.insert(elem.key.value.clone(), elem.value.clone()); } Ok(options) } diff --git a/dask_planner/src/sql/parser_utils.rs b/dask_planner/src/sql/parser_utils.rs index 040ff5943..1f8c96011 100644 --- a/dask_planner/src/sql/parser_utils.rs +++ b/dask_planner/src/sql/parser_utils.rs @@ -66,6 +66,13 @@ impl DaskParserUtils { _ => unimplemented!("Unimplemented Value type: {:?}", value), }, SqlParserExpr::Nested(nested_expr) => Self::str_from_expr(*nested_expr), + SqlParserExpr::BinaryOp { left, op, right } => format!( + "{} {} {}", + Self::str_from_expr(*left), + op, + Self::str_from_expr(*right) + ), + SqlParserExpr::Array(e) => e.to_string(), _ => unimplemented!("Unimplemented SqlParserExpr type: {:?}", expression), } } diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 1bfd27a89..2f564272a 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: import dask_sql - from dask_sql.java import org + from dask_sql.rust import LogicalPlan logger = logging.getLogger(__name__) @@ -95,19 +95,22 @@ class CreateExperimentPlugin(BaseRelPlugin): """ - class_name = "com.dask.sql.parser.SqlCreateExperiment" + class_name = "CreateExperiment" - def convert( - self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context" - ) -> DataContainer: - select = sql.getSelect() - schema_name, experiment_name = context.fqn(sql.getExperimentName()) - kwargs = convert_sql_kwargs(sql.getKwargs()) + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: + create_experiment = rel.create_experiment() + select = create_experiment.getSelectQuery() + + schema_name, experiment_name = ( + context.schema_name, + create_experiment.getExperimentName(), + ) + kwargs = convert_sql_kwargs(create_experiment.getSQLWithOptions()) if experiment_name in context.schema[schema_name].experiments: - if sql.getIfNotExists(): + if create_experiment.getIfNotExists(): return - elif not sql.getReplace(): + elif not create_experiment.getReplace(): raise RuntimeError( f"A experiment with the name {experiment_name} is already present." ) @@ -139,8 +142,7 @@ def convert( automl_kwargs = kwargs.pop("automl_kwargs", {}) logger.info(parameters) - select_query = context._to_sql_string(select) - training_df = context.sql(select_query) + training_df = context.sql(select) if not target_column: raise ValueError( "Unsupervised Algorithm cannot be tuned Automatically," diff --git a/dask_sql/utils.py b/dask_sql/utils.py index 31e2bc2eb..e8fbff045 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -136,11 +136,21 @@ def convert_sql_kwargs( Convert the Rust Vec of key/value pairs into a Dict containing the keys and values """ - def convert_literal(value: str): - if value.lower() == "true": - return True - elif value.lower() == "false": - return False + def convert_literal(value): + if value.isCollection(): + operator_mapping = { + "ARRAY": list, + "MAP": lambda x: dict(zip(x[::2], x[1::2])), + "MULTISET": set, + "ROW": tuple, + } + + operator = operator_mapping[str(value.getOperator())] + operands = [convert_literal(o) for o in value.getOperandList()] + + return operator(operands) + elif value.isKwargs(): + return convert_sql_kwargs(value.getMap()) else: return value diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 044a56fcc..c319a2a99 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -205,9 +205,6 @@ def test_create_model_with_prediction(c, training_df): # TODO - many ML tests fail on clusters without sklearn - can we avoid this? -@pytest.mark.skip( - reason="WIP DataFusion - fails to parse ARRAY in KV pairs in WITH clause, WITH clause was previsouly ignored" -) @skip_if_external_scheduler def test_iterative_and_prediction(c, training_df): c.sql( @@ -229,7 +226,6 @@ def test_iterative_and_prediction(c, training_df): # TODO - many ML tests fail on clusters without sklearn - can we avoid this? -@pytest.mark.skip(reason="WIP DataFusion") @skip_if_external_scheduler def test_show_models(c, training_df): c.sql( @@ -315,7 +311,6 @@ def test_wrong_training_or_prediction(c, training_df): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_correct_argument_passing(c, training_df): c.sql( """ @@ -674,7 +669,6 @@ def test_mlflow_export_lightgbm(c, training_df, tmpdir): # TODO - many ML tests fail on clusters without sklearn - can we avoid this? -@pytest.mark.skip(reason="WIP DataFusion") @skip_if_external_scheduler def test_ml_experiment(c, client, training_df): @@ -869,7 +863,6 @@ def test_ml_experiment(c, client, training_df): # TODO - many ML tests fail on clusters without sklearn - can we avoid this? -@pytest.mark.skip(reason="WIP DataFusion") @skip_if_external_scheduler def test_experiment_automl_classifier(c, client, training_df): tpot = pytest.importorskip("tpot", reason="tpot not installed") @@ -895,7 +888,6 @@ def test_experiment_automl_classifier(c, client, training_df): # TODO - many ML tests fail on clusters without sklearn - can we avoid this? -@pytest.mark.skip(reason="WIP DataFusion") @skip_if_external_scheduler def test_experiment_automl_regressor(c, client, training_df): tpot = pytest.importorskip("tpot", reason="tpot not installed") From cfadbcb44ce71ea91241f0576a1ece1b1f41a7f6 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 22 Sep 2022 09:31:32 -0700 Subject: [PATCH 02/10] Unblock test_correct_argument_passing --- dask_planner/src/parser.rs | 114 ++++++++++++------- dask_planner/src/sql/logical/create_model.rs | 13 +-- dask_sql/utils.py | 29 +++-- tests/integration/test_model.py | 4 +- 4 files changed, 96 insertions(+), 64 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 38fa844f6..443421eb3 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -3,13 +3,13 @@ //! Declares a SQL parser based on sqlparser that handles custom formats that we need. use crate::sql::exceptions::py_type_err; +use crate::sql::types::SqlTypeName; use pyo3::prelude::*; use crate::dialect::DaskDialect; use crate::sql::parser_utils::DaskParserUtils; -use datafusion_sql::sqlparser::ast::Ident; use datafusion_sql::sqlparser::{ - ast::{Expr, SelectItem, Statement as SQLStatement}, + ast::{Expr, SelectItem, Statement as SQLStatement, Value}, dialect::{keywords::Keyword, Dialect}, parser::{Parser, ParserError}, tokenizer::{Token, Tokenizer}, @@ -26,7 +26,7 @@ macro_rules! parser_err { pub enum CustomExpr { Map(Vec), Multiset(Vec), - Nested(Vec), + Nested(Vec<(String, PySqlArg)>), } #[pyclass(name = "SqlArg", module = "datafusion")] @@ -70,26 +70,6 @@ impl PySqlArg { } } - #[pyo3(name = "getOperator")] - pub fn get_operator(&self) -> PyResult { - match &self.custom { - Some(custom_expr) => match custom_expr { - CustomExpr::Map(_) => Ok("MAP".to_string()), - CustomExpr::Multiset(_) => Ok("MULTISET".to_string()), - CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), - }, - None => match &self.expr { - Some(expr) => match expr { - Expr::Array(_) => Ok("ARRAY".to_string()), - other => Err(py_type_err(format!("Expected Array, got {:?}", other))), - }, - None => Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )), - }, - } - } - #[pyo3(name = "getOperandList")] pub fn get_operand_list(&self) -> PyResult> { match &self.custom { @@ -115,18 +95,66 @@ impl PySqlArg { }, } } -} -#[pyclass(name = "SqlKwarg", module = "datafusion")] -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PySqlKwarg { - pub key: Ident, - pub value: PySqlArg, -} + #[pyo3(name = "getKwargs")] + pub fn get_kwargs(&self) -> PyResult> { + match &self.custom { + Some(CustomExpr::Nested(kwargs)) => Ok(kwargs.clone()), + _ => Ok(vec![]), + } + } -impl PySqlKwarg { - pub fn new(key: Ident, value: PySqlArg) -> Self { - Self { key, value } + #[pyo3(name = "getSqlType")] + pub fn get_sql_type(&self) -> PyResult { + match &self.custom { + Some(custom_expr) => match custom_expr { + CustomExpr::Map(_) => Ok(SqlTypeName::MAP), + CustomExpr::Multiset(_) => Ok(SqlTypeName::MULTISET), + CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), + }, + None => match &self.expr { + Some(Expr::Array(_)) => Ok(SqlTypeName::ARRAY), + Some(Expr::Value(scalar)) => match scalar { + Value::SingleQuotedString(_) => Ok(SqlTypeName::VARCHAR), + Value::Number(_, false) => Ok(SqlTypeName::BIGINT), + unexpected => Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))), + }, + Some(unexpected) => Err(py_type_err(format!( + "Expected array or scalar, got {:?}", + unexpected + ))), + None => Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )), + }, + } + } + + #[pyo3(name = "getSqlValue")] + pub fn get_value(&self) -> PyResult { + match &self.custom { + None => match &self.expr { + Some(Expr::Value(scalar)) => match scalar { + Value::SingleQuotedString(string) => Ok(string.clone()), + Value::Number(value, false) => Ok(value.to_string()), + unexpected => Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))), + }, + unexpected => Err(py_type_err(format!( + "Expected scalar value, got {:?}", + unexpected + ))), + }, + unexpected => Err(py_type_err(format!( + "Expected scalar value, got {:?}", + unexpected + ))), + } } } @@ -142,7 +170,7 @@ pub struct CreateModel { /// To replace the model or not pub or_replace: bool, /// with options - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `CREATE EXPERIMENT` @@ -869,7 +897,7 @@ impl<'a> DaskParser<'a> { Ok(values) } - fn parse_key_value_pair(&mut self) -> Result { + fn parse_key_value_pair(&mut self) -> Result<(String, PySqlArg), ParserError> { let key = self.parser.parse_identifier()?; self.parser.expect_token(&Token::Eq)?; match self.parser.next_token() { @@ -877,8 +905,8 @@ impl<'a> DaskParser<'a> { let key_value_pairs = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; self.parser.expect_token(&Token::RParen)?; - Ok(PySqlKwarg::new( - key, + Ok(( + key.value, PySqlArg::new(None, Some(CustomExpr::Nested(key_value_pairs))), )) } @@ -887,8 +915,8 @@ impl<'a> DaskParser<'a> { self.parser.expect_token(&Token::LBracket)?; let values = self.parser.parse_comma_separated(Parser::parse_expr)?; self.parser.expect_token(&Token::RBracket)?; - Ok(PySqlKwarg::new( - key, + Ok(( + key.value, PySqlArg::new(None, Some(CustomExpr::Map(values))), )) } @@ -897,15 +925,15 @@ impl<'a> DaskParser<'a> { self.parser.expect_token(&Token::LBracket)?; let values = self.parser.parse_comma_separated(Parser::parse_expr)?; self.parser.expect_token(&Token::RBracket)?; - Ok(PySqlKwarg::new( - key, + Ok(( + key.value, PySqlArg::new(None, Some(CustomExpr::Multiset(values))), )) } _ => { self.parser.prev_token(); - Ok(PySqlKwarg::new( - key, + Ok(( + key.value, PySqlArg::new(Some(self.parser.parse_expr()?), None), )) } diff --git a/dask_planner/src/sql/logical/create_model.rs b/dask_planner/src/sql/logical/create_model.rs index d465a5d33..7a3f12bf2 100644 --- a/dask_planner/src/sql/logical/create_model.rs +++ b/dask_planner/src/sql/logical/create_model.rs @@ -6,10 +6,9 @@ use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::{Expr, LogicalPlan}; use fmt::Debug; -use std::collections::HashMap; use std::{any::Any, fmt, sync::Arc}; -use crate::parser::{PySqlArg, PySqlKwarg}; +use crate::parser::PySqlArg; use datafusion_common::DFSchemaRef; #[derive(Clone)] @@ -18,7 +17,7 @@ pub struct CreateModelPlanNode { pub input: LogicalPlan, pub if_not_exists: bool, pub or_replace: bool, - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateModelPlanNode { @@ -98,12 +97,8 @@ impl PyCreateModel { } #[pyo3(name = "getSQLWithOptions")] - fn sql_with_options(&self) -> PyResult> { - let mut options: HashMap = HashMap::new(); - for elem in &self.create_model.with_options { - options.insert(elem.key.value.clone(), elem.value.clone()); - } - Ok(options) + fn sql_with_options(&self) -> PyResult> { + Ok(self.create_model.with_options.clone()) } } diff --git a/dask_sql/utils.py b/dask_sql/utils.py index e8fbff045..da50c1ae5 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -11,7 +11,9 @@ import numpy as np import pandas as pd +from dask_planner.rust import SqlTypeName from dask_sql.datacontainer import DataContainer +from dask_sql.mappings import sql_to_python_value logger = logging.getLogger(__name__) @@ -139,24 +141,31 @@ def convert_sql_kwargs( def convert_literal(value): if value.isCollection(): operator_mapping = { - "ARRAY": list, - "MAP": lambda x: dict(zip(x[::2], x[1::2])), - "MULTISET": set, - "ROW": tuple, + "SqlTypeName.ARRAY": list, + "SqlTypeName.MAP": lambda x: dict(zip(x[::2], x[1::2])), + "SqlTypeName.MULTISET": set, + "SqlTypeName.ROW": tuple, } - operator = operator_mapping[str(value.getOperator())] + operator = operator_mapping[str(value.getSqlType())] operands = [convert_literal(o) for o in value.getOperandList()] return operator(operands) elif value.isKwargs(): - return convert_sql_kwargs(value.getMap()) + return convert_sql_kwargs(value.getKwargs()) else: - return value + literal_type = value.getSqlType() + literal_value = value.getSqlValue() - return { - str(key): convert_literal(str(value)) for key, value in dict(sql_kwargs).items() - } + if literal_type == SqlTypeName.VARCHAR: + return value.getSqlValue() + elif literal_type == SqlTypeName.BIGINT and "." in literal_value: + literal_type = SqlTypeName.DOUBLE + + python_value = sql_to_python_value(literal_type, literal_value) + return python_value + + return {key: convert_literal(value) for key, value in dict(sql_kwargs).items()} def import_class(name: str) -> type: diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index c319a2a99..65db0a38d 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -321,7 +321,7 @@ def test_correct_argument_passing(c, training_df): first_arg = 3, second_arg = ARRAY [ 1, 2 ], third_arg = MAP [ 'a', 1 ], - forth_arg = MULTISET [ 1, 1, 2, 3 ] + fourth_arg = MULTISET [ 1, 1, 2, 3 ] ) ) AS ( SELECT x, y, x*y > 0 AS target @@ -339,7 +339,7 @@ def test_correct_argument_passing(c, training_df): fit_function.assert_called_once() call_kwargs = fit_function.call_args.kwargs assert call_kwargs == dict( - first_arg=3, second_arg=[1, 2], third_arg={"a": 1}, forth_arg=set([1, 2, 3]) + first_arg=3, second_arg=[1, 2], third_arg={"a": 1}, fourth_arg=set([1, 2, 3]) ) From 13a0819b4ed95dcde7d99c370605ff3b4067baae Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 22 Sep 2022 12:21:04 -0700 Subject: [PATCH 03/10] Resolve remaining tests in test_model.py --- dask_planner/src/parser.rs | 81 ++++++++++++------- .../src/sql/logical/create_experiment.rs | 25 +----- dask_planner/src/sql/logical/create_table.rs | 45 ++--------- dask_planner/src/sql/logical/export_model.rs | 25 +----- dask_planner/src/sql/parser_utils.rs | 59 ++++---------- .../physical/rel/custom/create_experiment.py | 2 +- tests/integration/test_model.py | 28 +++++-- 7 files changed, 106 insertions(+), 159 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 443421eb3..faee7ba23 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -9,7 +9,7 @@ use pyo3::prelude::*; use crate::dialect::DaskDialect; use crate::sql::parser_utils::DaskParserUtils; use datafusion_sql::sqlparser::{ - ast::{Expr, SelectItem, Statement as SQLStatement, Value}, + ast::{Expr, Ident, SelectItem, Statement as SQLStatement, UnaryOperator, Value}, dialect::{keywords::Keyword, Dialect}, parser::{Parser, ParserError}, tokenizer::{Token, Tokenizer}, @@ -114,9 +114,21 @@ impl PySqlArg { }, None => match &self.expr { Some(Expr::Array(_)) => Ok(SqlTypeName::ARRAY), + Some(Expr::Identifier(Ident { .. })) => Ok(SqlTypeName::VARCHAR), Some(Expr::Value(scalar)) => match scalar { - Value::SingleQuotedString(_) => Ok(SqlTypeName::VARCHAR), + Value::Boolean(_) => Ok(SqlTypeName::BOOLEAN), Value::Number(_, false) => Ok(SqlTypeName::BIGINT), + Value::SingleQuotedString(_) => Ok(SqlTypeName::VARCHAR), + unexpected => Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))), + }, + Some(Expr::UnaryOp { + op: UnaryOperator::Minus, + expr, + }) => match &**expr { + Expr::Value(Value::Number(_, false)) => Ok(SqlTypeName::BIGINT), unexpected => Err(py_type_err(format!( "Expected string, got {:?}", unexpected @@ -134,17 +146,29 @@ impl PySqlArg { } #[pyo3(name = "getSqlValue")] - pub fn get_value(&self) -> PyResult { + pub fn get_sql_value(&self) -> PyResult { match &self.custom { None => match &self.expr { + Some(Expr::Identifier(Ident { value, .. })) => Ok(value.to_string()), Some(Expr::Value(scalar)) => match scalar { - Value::SingleQuotedString(string) => Ok(string.clone()), + Value::Boolean(value) => Ok(value.to_string()), + Value::SingleQuotedString(string) => Ok(string.to_string()), Value::Number(value, false) => Ok(value.to_string()), unexpected => Err(py_type_err(format!( "Expected string, got {:?}", unexpected ))), }, + Some(Expr::UnaryOp { + op: UnaryOperator::Minus, + expr, + }) => match &**expr { + Expr::Value(Value::Number(value, false)) => Ok(format!("-{}", value)), + unexpected => Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))), + }, unexpected => Err(py_type_err(format!( "Expected scalar value, got {:?}", unexpected @@ -185,7 +209,7 @@ pub struct CreateExperiment { /// To replace the model or not pub or_replace: bool, /// with options - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `PREDICT` @@ -222,7 +246,7 @@ pub struct CreateTable { /// or replace pub or_replace: bool, /// with options - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `CREATE VIEW` @@ -253,7 +277,7 @@ pub struct ExportModel { /// model name pub name: String, /// with options - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } /// Dask-SQL extension DDL for `DESCRIBE MODEL` @@ -847,11 +871,11 @@ impl<'a> DaskParser<'a> { or_replace: bool, ) -> Result { let model_name = self.parser.parse_object_name()?; + + // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; self.parser.expect_token(&Token::LParen)?; - let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; - self.parser.expect_token(&Token::RParen)?; // Parse the nested query statement @@ -947,14 +971,12 @@ impl<'a> DaskParser<'a> { or_replace: bool, ) -> Result { let experiment_name = self.parser.parse_object_name()?; - self.parser.expect_keyword(Keyword::WITH)?; - // `table_name` has been parsed at this point but is needed in `parse_table_factor`, reset consumption - self.parser.prev_token(); - self.parser.prev_token(); - - let table_factor = self.parser.parse_table_factor()?; - let with_options = DaskParserUtils::options_from_tablefactor(&table_factor); + // Parse WITH options + self.parser.expect_keyword(Keyword::WITH)?; + self.parser.expect_token(&Token::LParen)?; + let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; + self.parser.expect_token(&Token::RParen)?; // Parse the nested query statement self.parser.expect_keyword(Keyword::AS)?; @@ -1044,13 +1066,20 @@ impl<'a> DaskParser<'a> { } } "with" => { - // `table_name` has been parsed at this point but is needed in `parse_table_factor`, reset consumption + // `table_name` has been parsed at this point but is needed, reset consumption self.parser.prev_token(); - let table_factor = self.parser.parse_table_factor()?; + // Parse schema and table name + let obj_name = self.parser.parse_object_name()?; let (tbl_schema, tbl_name) = - DaskParserUtils::elements_from_tablefactor(&table_factor)?; - let with_options = DaskParserUtils::options_from_tablefactor(&table_factor); + DaskParserUtils::elements_from_objectname(&obj_name)?; + + // Parse WITH options + self.parser.expect_keyword(Keyword::WITH)?; + self.parser.expect_token(&Token::LParen)?; + let with_options = + self.parse_comma_separated(DaskParser::parse_key_value_pair)?; + self.parser.expect_token(&Token::RParen)?; let create = CreateTable { table_schema: tbl_schema, @@ -1093,14 +1122,12 @@ impl<'a> DaskParser<'a> { } let model_name = self.parser.parse_object_name()?; - self.parser.expect_keyword(Keyword::WITH)?; - // `table_name` has been parsed at this point but is needed in `parse_table_factor`, reset consumption - self.parser.prev_token(); - self.parser.prev_token(); - - let table_factor = self.parser.parse_table_factor()?; - let with_options = DaskParserUtils::options_from_tablefactor(&table_factor); + // Parse WITH options + self.parser.expect_keyword(Keyword::WITH)?; + self.parser.expect_token(&Token::LParen)?; + let with_options = self.parse_comma_separated(DaskParser::parse_key_value_pair)?; + self.parser.expect_token(&Token::RParen)?; let export = ExportModel { name: model_name.to_string(), diff --git a/dask_planner/src/sql/logical/create_experiment.rs b/dask_planner/src/sql/logical/create_experiment.rs index c40673948..14f2843da 100644 --- a/dask_planner/src/sql/logical/create_experiment.rs +++ b/dask_planner/src/sql/logical/create_experiment.rs @@ -1,14 +1,12 @@ +use crate::parser::PySqlArg; use crate::sql::exceptions::py_type_err; use crate::sql::logical; -use crate::sql::parser_utils::DaskParserUtils; use pyo3::prelude::*; use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::{Expr, LogicalPlan}; -use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr; use fmt::Debug; -use std::collections::HashMap; use std::{any::Any, fmt, sync::Arc}; use datafusion_common::DFSchemaRef; @@ -19,7 +17,7 @@ pub struct CreateExperimentPlanNode { pub input: LogicalPlan, pub if_not_exists: bool, pub or_replace: bool, - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateExperimentPlanNode { @@ -103,23 +101,8 @@ impl PyCreateExperiment { } #[pyo3(name = "getSQLWithOptions")] - fn sql_with_options(&self) -> PyResult> { - let mut options: HashMap = HashMap::new(); - for elem in &self.create_experiment.with_options { - match elem { - SqlParserExpr::BinaryOp { left, op: _, right } => { - options.insert( - DaskParserUtils::str_from_expr(*left.clone()), - DaskParserUtils::str_from_expr(*right.clone()), - ); - } - _ => { - return Err(py_type_err( - "Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types")); - } - } - } - Ok(options) + fn sql_with_options(&self) -> PyResult> { + Ok(self.create_experiment.with_options.clone()) } } diff --git a/dask_planner/src/sql/logical/create_table.rs b/dask_planner/src/sql/logical/create_table.rs index 97bd793ba..914792310 100644 --- a/dask_planner/src/sql/logical/create_table.rs +++ b/dask_planner/src/sql/logical/create_table.rs @@ -1,13 +1,13 @@ +use crate::parser::PySqlArg; use crate::sql::exceptions::py_type_err; use crate::sql::logical; + use pyo3::prelude::*; use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::{Expr, LogicalPlan}; -use datafusion_sql::sqlparser::ast::{Expr as SqlParserExpr, Value}; use fmt::Debug; -use std::collections::HashMap; use std::{any::Any, fmt, sync::Arc}; use datafusion_common::{DFSchema, DFSchemaRef}; @@ -19,7 +19,7 @@ pub struct CreateTablePlanNode { pub table_name: String, pub if_not_exists: bool, pub or_replace: bool, - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } impl Debug for CreateTablePlanNode { @@ -91,43 +91,8 @@ impl PyCreateTable { } #[pyo3(name = "getSQLWithOptions")] - fn sql_with_options(&self) -> PyResult> { - let mut options: HashMap = HashMap::new(); - for elem in &self.create_table.with_options { - if let SqlParserExpr::BinaryOp { left, op: _, right } = elem { - let key: Result = match *left.clone() { - SqlParserExpr::Identifier(ident) => Ok(ident.value), - _ => Err(py_type_err(format!( - "unexpected `left` Value type encountered: {:?}", - left - ))), - }; - let val: Result = match *right.clone() { - SqlParserExpr::Value(value) => match value { - Value::SingleQuotedString(e) => Ok(e.replace('\'', "")), - Value::DoubleQuotedString(e) => Ok(e.replace('\"', "")), - Value::Boolean(e) => { - if e { - Ok("True".to_string()) - } else { - Ok("False".to_string()) - } - } - Value::Number(e, ..) => Ok(e), - _ => Err(py_type_err(format!( - "unexpected Value type encountered: {:?}", - value - ))), - }, - _ => Err(py_type_err(format!( - "encountered unexpected Expr type: {:?}", - right - ))), - }; - options.insert(key?, val?); - } - } - Ok(options) + fn sql_with_options(&self) -> PyResult> { + Ok(self.create_table.with_options.clone()) } } diff --git a/dask_planner/src/sql/logical/export_model.rs b/dask_planner/src/sql/logical/export_model.rs index e4ff3c16d..e5e702db9 100644 --- a/dask_planner/src/sql/logical/export_model.rs +++ b/dask_planner/src/sql/logical/export_model.rs @@ -1,14 +1,12 @@ +use crate::parser::PySqlArg; use crate::sql::exceptions::py_type_err; use crate::sql::logical; -use crate::sql::parser_utils::DaskParserUtils; use pyo3::prelude::*; use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::{Expr, LogicalPlan}; -use datafusion_sql::sqlparser::ast::Expr as SqlParserExpr; use fmt::Debug; -use std::collections::HashMap; use std::{any::Any, fmt, sync::Arc}; use datafusion_common::{DFSchema, DFSchemaRef}; @@ -17,7 +15,7 @@ use datafusion_common::{DFSchema, DFSchemaRef}; pub struct ExportModelPlanNode { pub schema: DFSchemaRef, pub model_name: String, - pub with_options: Vec, + pub with_options: Vec<(String, PySqlArg)>, } impl Debug for ExportModelPlanNode { @@ -77,23 +75,8 @@ impl PyExportModel { } #[pyo3(name = "getSQLWithOptions")] - fn sql_with_options(&self) -> PyResult> { - let mut options: HashMap = HashMap::new(); - for elem in &self.export_model.with_options { - match elem { - SqlParserExpr::BinaryOp { left, op: _, right } => { - options.insert( - DaskParserUtils::str_from_expr(*left.clone()), - DaskParserUtils::str_from_expr(*right.clone()), - ); - } - _ => { - return Err(py_type_err( - "Encountered non SqlParserExpr::BinaryOp expression, with arguments can only be of Key/Value pair types")); - } - } - } - Ok(options) + fn sql_with_options(&self) -> PyResult> { + Ok(self.export_model.with_options.clone()) } } diff --git a/dask_planner/src/sql/parser_utils.rs b/dask_planner/src/sql/parser_utils.rs index 1f8c96011..f44d7c440 100644 --- a/dask_planner/src/sql/parser_utils.rs +++ b/dask_planner/src/sql/parser_utils.rs @@ -1,9 +1,24 @@ -use datafusion_sql::sqlparser::ast::{Expr as SqlParserExpr, TableFactor, Value}; +use datafusion_sql::sqlparser::ast::{ObjectName, TableFactor}; use datafusion_sql::sqlparser::parser::ParserError; pub struct DaskParserUtils; impl DaskParserUtils { + /// Retrieves the schema and object name from a `ObjectName` instance + pub fn elements_from_objectname( + obj_name: &ObjectName, + ) -> Result<(String, String), ParserError> { + let identities: Vec = obj_name.0.iter().map(|f| f.value.clone()).collect(); + + match identities.len() { + 1 => Ok(("".to_string(), identities[0].clone())), + 2 => Ok((identities[0].clone(), identities[1].clone())), + _ => Err(ParserError::ParserError( + "TableFactor name only supports 1 or 2 elements".to_string(), + )), + } + } + /// Retrieves the table_schema and table_name from a `TableFactor` instance pub fn elements_from_tablefactor( tbl_factor: &TableFactor, @@ -34,46 +49,4 @@ impl DaskParserUtils { }, } } - - /// Gets the with options from the `TableFactor` instance - pub fn options_from_tablefactor(tbl_factor: &TableFactor) -> Vec { - match tbl_factor { - TableFactor::Table { with_hints, .. } => with_hints.clone(), - TableFactor::Derived { .. } - | TableFactor::NestedJoin { .. } - | TableFactor::TableFunction { .. } - | TableFactor::UNNEST { .. } => { - vec![] - } - } - } - - /// Given a SqlParserExpr instance retrieve the String value from it - pub fn str_from_expr(expression: SqlParserExpr) -> String { - match expression { - SqlParserExpr::Identifier(ident) => ident.value, - SqlParserExpr::Value(value) => match value { - Value::SingleQuotedString(e) => e.replace('\'', ""), - Value::DoubleQuotedString(e) => e.replace('\"', ""), - Value::Boolean(e) => { - if e { - "True".to_string() - } else { - "False".to_string() - } - } - Value::Number(e, ..) => e, - _ => unimplemented!("Unimplemented Value type: {:?}", value), - }, - SqlParserExpr::Nested(nested_expr) => Self::str_from_expr(*nested_expr), - SqlParserExpr::BinaryOp { left, op, right } => format!( - "{} {} {}", - Self::str_from_expr(*left), - op, - Self::str_from_expr(*right) - ), - SqlParserExpr::Array(e) => e.to_string(), - _ => unimplemented!("Unimplemented SqlParserExpr type: {:?}", expression), - } - } } diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 2f564272a..642456937 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -110,7 +110,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai if experiment_name in context.schema[schema_name].experiments: if create_experiment.getIfNotExists(): return - elif not create_experiment.getReplace(): + elif not create_experiment.getOrReplace(): raise RuntimeError( f"A experiment with the name {experiment_name} is already present." ) diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 65db0a38d..f21c14bc3 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -318,10 +318,13 @@ def test_correct_argument_passing(c, training_df): model_class = 'mock.MagicMock', target_column = 'target', fit_kwargs = ( - first_arg = 3, - second_arg = ARRAY [ 1, 2 ], - third_arg = MAP [ 'a', 1 ], - fourth_arg = MULTISET [ 1, 1, 2, 3 ] + single_quoted_string = 'hello', + double_quoted_string = "hi", + integer = -300, + float = 23.45, + array = ARRAY [ 1, 2 ], + dict = MAP [ 'a', 1 ], + set = MULTISET [ 1, 1, 2, 3 ] ) ) AS ( SELECT x, y, x*y > 0 AS target @@ -339,7 +342,13 @@ def test_correct_argument_passing(c, training_df): fit_function.assert_called_once() call_kwargs = fit_function.call_args.kwargs assert call_kwargs == dict( - first_arg=3, second_arg=[1, 2], third_arg={"a": 1}, fourth_arg=set([1, 2, 3]) + single_quoted_string="hello", + double_quoted_string="hi", + integer=-300, + float=23.45, + array=[1, 2], + dict={"a": 1}, + set=set([1, 2, 3]), ) @@ -762,7 +771,14 @@ def test_ml_experiment(c, client, training_df): """ CREATE EXPERIMENT my_exp64 WITH ( automl_class = 'that.is.not.a.python.class', - automl_kwargs = (population_size = 2 ,generations=2,cv=2,n_jobs=-1,use_dask=True,max_eval_time_mins=1), + automl_kwargs = ( + population_size = 2, + generations = 2, + cv = 2, + n_jobs = -1, + use_dask = True, + max_eval_time_mins = 1 + ), target_column = 'target' ) AS ( SELECT x, y, x*y > 0 AS target From a5df363703bf8e7b6b02f32a79b852504b32ae90 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 22 Sep 2022 14:38:20 -0700 Subject: [PATCH 04/10] Add special handling for boolean scalars --- dask_planner/src/parser.rs | 6 +++++- tests/integration/test_model.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index faee7ba23..e4706682a 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -151,7 +151,11 @@ impl PySqlArg { None => match &self.expr { Some(Expr::Identifier(Ident { value, .. })) => Ok(value.to_string()), Some(Expr::Value(scalar)) => match scalar { - Value::Boolean(value) => Ok(value.to_string()), + Value::Boolean(value) => Ok(if *value { + "1".to_string() + } else { + "".to_string() + }), Value::SingleQuotedString(string) => Ok(string.to_string()), Value::Number(value, false) => Ok(value.to_string()), unexpected => Err(py_type_err(format!( diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index f21c14bc3..3c1bd1a69 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -322,6 +322,7 @@ def test_correct_argument_passing(c, training_df): double_quoted_string = "hi", integer = -300, float = 23.45, + boolean = False, array = ARRAY [ 1, 2 ], dict = MAP [ 'a', 1 ], set = MULTISET [ 1, 1, 2, 3 ] @@ -346,6 +347,7 @@ def test_correct_argument_passing(c, training_df): double_quoted_string="hi", integer=-300, float=23.45, + boolean=False, array=[1, 2], dict={"a": 1}, set=set([1, 2, 3]), From ccda2523eccdf555971dcd1ada71da508803f021 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Fri, 23 Sep 2022 07:50:20 -0700 Subject: [PATCH 05/10] Update CreateModel test --- dask_planner/src/parser.rs | 55 ++++++++++++++------------------------ 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index e4706682a..87202c4f1 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -1270,10 +1270,14 @@ mod test { model_class = 'mock.MagicMock', target_column = 'target', fit_kwargs = ( - first_arg = 3, - second_arg = ARRAY [ 1, 2 ], - third_arg = MAP [ 'a', 1 ], - forth_arg = MULTISET [ 1, 1, 2, 3 ] + single_quoted_string = 'hello', + double_quoted_string = "hi", + integer = -300, + float = 23.45, + boolean = False, + array = ARRAY [ 1, 2 ], + dict = MAP [ 'a', 1 ], + set = MULTISET [ 1, 1, 2, 3 ] ) ) AS ( SELECT x, y, x*y > 0 AS target @@ -1285,40 +1289,21 @@ mod test { match &statements[0] { DaskStatement::CreateModel(create_model) => { - // test Debug let expected = "[\ - PySqlKwarg { key: Ident { value: \"model_class\", quote_style: None }, value: Expr(Value(SingleQuotedString(\"mock.MagicMock\"))) }, \ - PySqlKwarg { key: Ident { value: \"target_column\", quote_style: None }, value: Expr(Value(SingleQuotedString(\"target\"))) }, \ - PySqlKwarg { key: Ident { value: \"fit_kwargs\", quote_style: None }, value: Nested([\ - PySqlKwarg { key: Ident { value: \"first_arg\", quote_style: None }, value: Expr(Value(Number(\"3\", false))) }, \ - PySqlKwarg { key: Ident { value: \"second_arg\", quote_style: None }, value: Expr(Array(Array { elem: [Value(Number(\"1\", false)), Value(Number(\"2\", false))], named: true })) }, \ - PySqlKwarg { key: Ident { value: \"third_arg\", quote_style: None }, value: Map([Value(SingleQuotedString(\"a\")), Value(Number(\"1\", false))]) }, \ - PySqlKwarg { key: Ident { value: \"forth_arg\", quote_style: None }, value: Multiset([Value(Number(\"1\", false)), Value(Number(\"1\", false)), Value(Number(\"2\", false)), Value(Number(\"3\", false))]) }\ - ]) }\ + (\"model_class\", PySqlArg { expr: Some(Value(SingleQuotedString(\"mock.MagicMock\"))), custom: None }), \ + (\"target_column\", PySqlArg { expr: Some(Value(SingleQuotedString(\"target\"))), custom: None }), \ + (\"fit_kwargs\", PySqlArg { expr: None, custom: Some(Nested([\ + (\"single_quoted_string\", PySqlArg { expr: Some(Value(SingleQuotedString(\"hello\"))), custom: None }), \ + (\"double_quoted_string\", PySqlArg { expr: Some(Identifier(Ident { value: \"hi\", quote_style: Some('\"') })), custom: None }), \ + (\"integer\", PySqlArg { expr: Some(UnaryOp { op: Minus, expr: Value(Number(\"300\", false)) }), custom: None }), \ + (\"float\", PySqlArg { expr: Some(Value(Number(\"23.45\", false))), custom: None }), \ + (\"boolean\", PySqlArg { expr: Some(Value(Boolean(false))), custom: None }), \ + (\"array\", PySqlArg { expr: Some(Array(Array { elem: [Value(Number(\"1\", false)), Value(Number(\"2\", false))], named: true })), custom: None }), \ + (\"dict\", PySqlArg { expr: None, custom: Some(Map([Value(SingleQuotedString(\"a\")), Value(Number(\"1\", false))])) }), \ + (\"set\", PySqlArg { expr: None, custom: Some(Multiset([Value(Number(\"1\", false)), Value(Number(\"1\", false)), Value(Number(\"2\", false)), Value(Number(\"3\", false))])) })\ + ])) })\ ]"; assert_eq!(expected, &format!("{:?}", create_model.with_options)); - - // test Display - let expected = "model_class = 'mock.MagicMock', \ - target_column = 'target', \ - fit_kwargs = (\ - first_arg = '3', \ - second_arg = ARRAY[1, 2], \ - third_arg = MAP [ 'a', 1 ], \ - forth_arg = MULTISET [ 1, 1, 2, 3 ]\ - )"; - assert_eq!( - expected, - format!( - "{}", - create_model - .with_options - .iter() - .map(|pair| format!("{}", pair)) - .collect::>() - .join(", ") - ) - ) } _ => panic!(), } From 644d2e9511d52ac1161166681c01bb4051ee0dc0 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Fri, 23 Sep 2022 09:45:13 -0700 Subject: [PATCH 06/10] Refactor match statements --- dask_planner/src/parser.rs | 185 ++++++++++++++++++++----------------- 1 file changed, 102 insertions(+), 83 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 87202c4f1..7d9e8a365 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -46,143 +46,162 @@ impl PySqlArg { impl PySqlArg { #[pyo3(name = "isCollection")] pub fn is_collection(&self) -> PyResult { - match &self.custom { - Some(CustomExpr::Nested(_)) => Ok(false), - Some(_) => Ok(true), + Ok(match &self.custom { + Some(CustomExpr::Nested(_)) => false, + Some(_) => true, None => match &self.expr { - Some(expr) => match expr { - Expr::Array(_) => Ok(true), - _ => Ok(false), - }, - None => Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )), + Some(expr) => matches!(expr, Expr::Array(_)), + None => { + return Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )) + } }, - } + }) } #[pyo3(name = "isKwargs")] pub fn is_kwargs(&self) -> PyResult { - match &self.custom { - Some(CustomExpr::Nested(_)) => Ok(true), - Some(_) => Ok(false), - None => Ok(false), - } + Ok(matches!(&self.custom, Some(CustomExpr::Nested(_)))) } #[pyo3(name = "getOperandList")] pub fn get_operand_list(&self) -> PyResult> { - match &self.custom { + Ok(match &self.custom { Some(custom_expr) => match custom_expr { - CustomExpr::Map(exprs) | CustomExpr::Multiset(exprs) => Ok(exprs + CustomExpr::Map(exprs) | CustomExpr::Multiset(exprs) => exprs .iter() .map(|e| PySqlArg::new(Some(e.clone()), None)) - .collect()), - CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), + .collect(), + CustomExpr::Nested(_) => { + return Err(py_type_err("Expected Map or Multiset, got Nested")) + } }, None => match &self.expr { Some(expr) => match expr { - Expr::Array(array) => Ok(array + Expr::Array(array) => array .elem .iter() .map(|e| PySqlArg::new(Some(e.clone()), None)) - .collect()), - _ => Ok(vec![]), + .collect(), + _ => vec![], }, - None => Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )), + None => { + return Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )) + } }, - } + }) } #[pyo3(name = "getKwargs")] pub fn get_kwargs(&self) -> PyResult> { - match &self.custom { - Some(CustomExpr::Nested(kwargs)) => Ok(kwargs.clone()), - _ => Ok(vec![]), - } + Ok(match &self.custom { + Some(CustomExpr::Nested(kwargs)) => kwargs.clone(), + _ => vec![], + }) } #[pyo3(name = "getSqlType")] pub fn get_sql_type(&self) -> PyResult { - match &self.custom { + Ok(match &self.custom { Some(custom_expr) => match custom_expr { - CustomExpr::Map(_) => Ok(SqlTypeName::MAP), - CustomExpr::Multiset(_) => Ok(SqlTypeName::MULTISET), - CustomExpr::Nested(_) => Err(py_type_err("Expected Map or Multiset, got Nested")), + CustomExpr::Map(_) => SqlTypeName::MAP, + CustomExpr::Multiset(_) => SqlTypeName::MULTISET, + CustomExpr::Nested(_) => { + return Err(py_type_err("Expected Map or Multiset, got Nested")) + } }, None => match &self.expr { - Some(Expr::Array(_)) => Ok(SqlTypeName::ARRAY), - Some(Expr::Identifier(Ident { .. })) => Ok(SqlTypeName::VARCHAR), + Some(Expr::Array(_)) => SqlTypeName::ARRAY, + Some(Expr::Identifier(Ident { .. })) => SqlTypeName::VARCHAR, Some(Expr::Value(scalar)) => match scalar { - Value::Boolean(_) => Ok(SqlTypeName::BOOLEAN), - Value::Number(_, false) => Ok(SqlTypeName::BIGINT), - Value::SingleQuotedString(_) => Ok(SqlTypeName::VARCHAR), - unexpected => Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))), + Value::Boolean(_) => SqlTypeName::BOOLEAN, + Value::Number(_, false) => SqlTypeName::BIGINT, + Value::SingleQuotedString(_) => SqlTypeName::VARCHAR, + unexpected => { + return Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))) + } }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { - Expr::Value(Value::Number(_, false)) => Ok(SqlTypeName::BIGINT), - unexpected => Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))), + Expr::Value(Value::Number(_, false)) => SqlTypeName::BIGINT, + unexpected => { + return Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))) + } }, - Some(unexpected) => Err(py_type_err(format!( - "Expected array or scalar, got {:?}", - unexpected - ))), - None => Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )), + Some(unexpected) => { + return Err(py_type_err(format!( + "Expected array or scalar, got {:?}", + unexpected + ))) + } + None => { + return Err(py_type_err( + "PySqlArg must contain either a standard or custom AST expression", + )) + } }, - } + }) } #[pyo3(name = "getSqlValue")] pub fn get_sql_value(&self) -> PyResult { - match &self.custom { + Ok(match &self.custom { None => match &self.expr { - Some(Expr::Identifier(Ident { value, .. })) => Ok(value.to_string()), + Some(Expr::Identifier(Ident { value, .. })) => value.to_string(), Some(Expr::Value(scalar)) => match scalar { - Value::Boolean(value) => Ok(if *value { - "1".to_string() - } else { - "".to_string() - }), - Value::SingleQuotedString(string) => Ok(string.to_string()), - Value::Number(value, false) => Ok(value.to_string()), - unexpected => Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))), + Value::Boolean(value) => { + if *value { + "1".to_string() + } else { + "".to_string() + } + } + Value::SingleQuotedString(string) => string.to_string(), + Value::Number(value, false) => value.to_string(), + unexpected => { + return Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))) + } }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { - Expr::Value(Value::Number(value, false)) => Ok(format!("-{}", value)), - unexpected => Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))), + Expr::Value(Value::Number(value, false)) => format!("-{}", value), + unexpected => { + return Err(py_type_err(format!( + "Expected string, got {:?}", + unexpected + ))) + } }, - unexpected => Err(py_type_err(format!( + unexpected => { + return Err(py_type_err(format!( + "Expected scalar value, got {:?}", + unexpected + ))) + } + }, + unexpected => { + return Err(py_type_err(format!( "Expected scalar value, got {:?}", unexpected - ))), - }, - unexpected => Err(py_type_err(format!( - "Expected scalar value, got {:?}", - unexpected - ))), - } + ))) + } + }) } } From b720ad3f2e9c680cb054e8b2a57a9a4c9e6a7b96 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 27 Sep 2022 09:48:07 -0400 Subject: [PATCH 07/10] Make boolean value handling more concise Co-authored-by: Andy Grove --- dask_planner/src/parser.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 7d9e8a365..5a83d7ade 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -160,13 +160,9 @@ impl PySqlArg { None => match &self.expr { Some(Expr::Identifier(Ident { value, .. })) => value.to_string(), Some(Expr::Value(scalar)) => match scalar { - Value::Boolean(value) => { - if *value { - "1".to_string() - } else { - "".to_string() - } - } + Value::Boolean(true) => "1".to_string(), + Value::Boolean(false) => "".to_string(), + Value::SingleQuotedString(string) => string.to_string(), Value::Number(value, false) => value.to_string(), unexpected => { From a322c6e8cdc9f40cebaef913fbb3970e067131b5 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 27 Sep 2022 06:49:04 -0700 Subject: [PATCH 08/10] Condense some is_collection code --- dask_planner/src/parser.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 5a83d7ade..0830e8ed7 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -47,8 +47,7 @@ impl PySqlArg { #[pyo3(name = "isCollection")] pub fn is_collection(&self) -> PyResult { Ok(match &self.custom { - Some(CustomExpr::Nested(_)) => false, - Some(_) => true, + Some(custom_expr) => !matches!(custom_expr, CustomExpr::Nested(_)), None => match &self.expr { Some(expr) => matches!(expr, Expr::Array(_)), None => { From f57264a9d8a2262bbd924a59095c3b7748dc1a3e Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 27 Sep 2022 08:26:12 -0700 Subject: [PATCH 09/10] Use helper function to condense error messages --- dask_planner/src/parser.rs | 88 +++++++++++--------------------------- 1 file changed, 24 insertions(+), 64 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 0830e8ed7..71e0bfef6 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -40,6 +40,18 @@ impl PySqlArg { pub fn new(expr: Option, custom: Option) -> Self { Self { expr, custom } } + + fn expected(&self, expected: &str) -> PyResult { + Err(match &self.custom { + Some(custom_expr) => { + py_type_err(format!("Expected {}, found: {:?}", expected, custom_expr)) + } + None => match &self.expr { + Some(expr) => py_type_err(format!("Expected {}, found: {:?}", expected, expr)), + None => py_type_err("PySqlArg must be either a standard or custom AST expression"), + }, + }) + } } #[pymethods] @@ -50,11 +62,7 @@ impl PySqlArg { Some(custom_expr) => !matches!(custom_expr, CustomExpr::Nested(_)), None => match &self.expr { Some(expr) => matches!(expr, Expr::Array(_)), - None => { - return Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )) - } + None => return self.expected(""), }, }) } @@ -72,9 +80,7 @@ impl PySqlArg { .iter() .map(|e| PySqlArg::new(Some(e.clone()), None)) .collect(), - CustomExpr::Nested(_) => { - return Err(py_type_err("Expected Map or Multiset, got Nested")) - } + _ => vec![], }, None => match &self.expr { Some(expr) => match expr { @@ -85,11 +91,7 @@ impl PySqlArg { .collect(), _ => vec![], }, - None => { - return Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )) - } + None => return self.expected(""), }, }) } @@ -108,9 +110,7 @@ impl PySqlArg { Some(custom_expr) => match custom_expr { CustomExpr::Map(_) => SqlTypeName::MAP, CustomExpr::Multiset(_) => SqlTypeName::MULTISET, - CustomExpr::Nested(_) => { - return Err(py_type_err("Expected Map or Multiset, got Nested")) - } + _ => return self.expected("Map or multiset"), }, None => match &self.expr { Some(Expr::Array(_)) => SqlTypeName::ARRAY, @@ -119,36 +119,17 @@ impl PySqlArg { Value::Boolean(_) => SqlTypeName::BOOLEAN, Value::Number(_, false) => SqlTypeName::BIGINT, Value::SingleQuotedString(_) => SqlTypeName::VARCHAR, - unexpected => { - return Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))) - } + _ => return self.expected("Boolean, integer, float, or single-quoted string"), }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { Expr::Value(Value::Number(_, false)) => SqlTypeName::BIGINT, - unexpected => { - return Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))) - } + _ => return self.expected("Integer or float"), }, - Some(unexpected) => { - return Err(py_type_err(format!( - "Expected array or scalar, got {:?}", - unexpected - ))) - } - None => { - return Err(py_type_err( - "PySqlArg must contain either a standard or custom AST expression", - )) - } + Some(_) => return self.expected("Array, identifier, or scalar"), + None => return self.expected(""), }, }) } @@ -161,41 +142,20 @@ impl PySqlArg { Some(Expr::Value(scalar)) => match scalar { Value::Boolean(true) => "1".to_string(), Value::Boolean(false) => "".to_string(), - Value::SingleQuotedString(string) => string.to_string(), Value::Number(value, false) => value.to_string(), - unexpected => { - return Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))) - } + _ => return self.expected("Boolean, integer, float, or single-quoted string"), }, Some(Expr::UnaryOp { op: UnaryOperator::Minus, expr, }) => match &**expr { Expr::Value(Value::Number(value, false)) => format!("-{}", value), - unexpected => { - return Err(py_type_err(format!( - "Expected string, got {:?}", - unexpected - ))) - } + _ => return self.expected("Integer or float"), }, - unexpected => { - return Err(py_type_err(format!( - "Expected scalar value, got {:?}", - unexpected - ))) - } + _ => return self.expected("Array, identifier, or scalar"), }, - unexpected => { - return Err(py_type_err(format!( - "Expected scalar value, got {:?}", - unexpected - ))) - } + _ => return self.expected("Standard sqlparser AST expression"), }) } } From e1a060e786358e4502e6ba69876f9c2ab52cb3e8 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 6 Oct 2022 11:05:59 -0700 Subject: [PATCH 10/10] Replace instances of `elements_from_tablefactor` --- dask_planner/src/parser.rs | 19 +++++++-------- dask_planner/src/sql/parser_utils.rs | 35 ++-------------------------- 2 files changed, 10 insertions(+), 44 deletions(-) diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index 71e0bfef6..9b85a8512 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -817,7 +817,7 @@ impl<'a> DaskParser<'a> { } let (mdl_schema, mdl_name) = - DaskParserUtils::elements_from_tablefactor(&self.parser.parse_table_factor()?)?; + DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; self.parser.expect_token(&Token::Comma)?; // Limit our input to ANALYZE, DESCRIBE, SELECT, SHOW statements @@ -1048,9 +1048,9 @@ impl<'a> DaskParser<'a> { self.parser.prev_token(); // Parse schema and table name - let obj_name = self.parser.parse_object_name()?; - let (tbl_schema, tbl_name) = - DaskParserUtils::elements_from_objectname(&obj_name)?; + let (tbl_schema, tbl_name) = DaskParserUtils::elements_from_object_name( + &self.parser.parse_object_name()?, + )?; // Parse WITH options self.parser.expect_keyword(Keyword::WITH)?; @@ -1173,8 +1173,8 @@ impl<'a> DaskParser<'a> { /// Parse Dask-SQL SHOW COLUMNS FROM fn parse_show_columns(&mut self) -> Result { self.parser.expect_keyword(Keyword::FROM)?; - let table_factor = self.parser.parse_table_factor()?; - let (tbl_schema, tbl_name) = DaskParserUtils::elements_from_tablefactor(&table_factor)?; + let (tbl_schema, tbl_name) = + DaskParserUtils::elements_from_object_name(&self.parser.parse_object_name()?)?; Ok(DaskStatement::ShowColumns(Box::new(ShowColumns { table_name: tbl_name, schema_name: match tbl_schema.as_str() { @@ -1186,13 +1186,10 @@ impl<'a> DaskParser<'a> { /// Parse Dask-SQL ANALYZE TABLE
fn parse_analyze_table(&mut self) -> Result { - let table_factor = self.parser.parse_table_factor()?; - // parse_table_factor parses the following keyword as an alias, so we need to go back a token - // TODO: open an issue in sqlparser around this when possible - self.parser.prev_token(); + let obj_name = self.parser.parse_object_name()?; self.parser .expect_keywords(&[Keyword::COMPUTE, Keyword::STATISTICS, Keyword::FOR])?; - let (tbl_schema, tbl_name) = DaskParserUtils::elements_from_tablefactor(&table_factor)?; + let (tbl_schema, tbl_name) = DaskParserUtils::elements_from_object_name(&obj_name)?; let columns = match self .parser .parse_keywords(&[Keyword::ALL, Keyword::COLUMNS]) diff --git a/dask_planner/src/sql/parser_utils.rs b/dask_planner/src/sql/parser_utils.rs index f44d7c440..c61dfc6a4 100644 --- a/dask_planner/src/sql/parser_utils.rs +++ b/dask_planner/src/sql/parser_utils.rs @@ -1,11 +1,11 @@ -use datafusion_sql::sqlparser::ast::{ObjectName, TableFactor}; +use datafusion_sql::sqlparser::ast::ObjectName; use datafusion_sql::sqlparser::parser::ParserError; pub struct DaskParserUtils; impl DaskParserUtils { /// Retrieves the schema and object name from a `ObjectName` instance - pub fn elements_from_objectname( + pub fn elements_from_object_name( obj_name: &ObjectName, ) -> Result<(String, String), ParserError> { let identities: Vec = obj_name.0.iter().map(|f| f.value.clone()).collect(); @@ -18,35 +18,4 @@ impl DaskParserUtils { )), } } - - /// Retrieves the table_schema and table_name from a `TableFactor` instance - pub fn elements_from_tablefactor( - tbl_factor: &TableFactor, - ) -> Result<(String, String), ParserError> { - match tbl_factor { - TableFactor::Table { - name, - alias: _, - args: _, - with_hints: _, - } => { - let identities: Vec = name.0.iter().map(|f| f.value.clone()).collect(); - - match identities.len() { - 1 => Ok(("".to_string(), identities[0].clone())), - 2 => Ok((identities[0].clone(), identities[1].clone())), - _ => Err(ParserError::ParserError( - "TableFactor name only supports 1 or 2 elements".to_string(), - )), - } - } - TableFactor::Derived { alias, .. } - | TableFactor::NestedJoin { alias, .. } - | TableFactor::TableFunction { alias, .. } - | TableFactor::UNNEST { alias, .. } => match alias { - Some(e) => Ok(("".to_string(), e.name.value.clone())), - None => Ok(("".to_string(), "".to_string())), - }, - } - } }