Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow different types of query variables (@@var) rather than just string #1943

Merged
merged 6 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datafusion-expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub enum Expr {
/// A named reference to a qualified filed in a schema.
Column(Column),
/// A named reference to a variable in a registry.
ScalarVariable(Vec<String>),
ScalarVariable(DataType, Vec<String>),
/// A constant value.
Literal(ScalarValue),
/// A binary expression such as "age > 21"
Expand Down Expand Up @@ -399,7 +399,7 @@ impl fmt::Debug for Expr {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
Expr::Column(c) => write!(f, "{}", c),
Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be better to print the data type, such as
Expr::ScalarVariable(data_type, var_names) => write!(f, "{}, data type: {}", var_names.join("."), data_type),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I avoided this because it looks like the Debug implementation for Expr tries to print SQL-compatible output, but, for example, "@v0, data type: DataType::Utf8" wouldn't be SQL-like

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I avoided this because it looks like the Debug implementation for Expr tries to print SQL-compatible output, but, for example, "@v0, data type: DataType::Utf8" wouldn't be SQL-like

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like the Debug implementation for Expr tries to print SQL-compatible output

It's a good point, makes sense to me.

Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Case {
expr,
Expand Down Expand Up @@ -562,7 +562,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
match e {
Expr::Alias(_, name) => Ok(name.clone()),
Expr::Column(c) => Ok(c.flat_name()),
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
Expr::Literal(value) => Ok(format!("{:?}", value)),
Expr::BinaryExpr { left, op, right } => {
let left = create_name(left, input_schema)?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion-proto/src/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
}
}
Expr::ScalarVariable(_) => unimplemented!(),
Expr::ScalarVariable(_, _) => unimplemented!(),
Expr::ScalarFunction { ref fun, ref args } => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
let args: Vec<Self> = args
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
}
Expr::Literal(_)
| Expr::Alias(_, _)
| Expr::ScalarVariable(_)
| Expr::ScalarVariable(_, _)
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
Expand Down
32 changes: 25 additions & 7 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use std::path::PathBuf;
use std::string::String;
use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, SchemaRef};

use crate::catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
Expand Down Expand Up @@ -1190,6 +1190,23 @@ impl ContextProvider for ExecutionContextState {
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.aggregate_functions.get(name).cloned()
}

fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
}

let provider_type = if &variable_names[0][0..2] == "@@" {
VarType::System
} else {
VarType::UserDefined
};

self.execution_props
.var_providers
.as_ref()
.and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
}
}

impl FunctionRegistry for ExecutionContextState {
Expand Down Expand Up @@ -1300,14 +1317,15 @@ mod tests {
ctx.register_table("dual", provider)?;

let results =
plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?;
plan_and_collect(&mut ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

.await?;

let expected = vec![
"+----------------------+------------------------+",
"| @@version | @name |",
"+----------------------+------------------------+",
"| system-var-@@version | user-defined-var-@name |",
"+----------------------+------------------------+",
"+----------------------+------------------------+------------------------+",
"| @@version | @name | @integer Plus Int64(1) |",
"+----------------------+------------------------+------------------------+",
"| system-var-@@version | user-defined-var-@name | 42 |",
"+----------------------+------------------------+------------------------+",
];
assert_batches_eq!(expected, &results);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl ExprRewritable for Expr {
let expr = match self {
Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
Expr::Column(_) => self.clone(),
Expr::ScalarVariable(names) => Expr::ScalarVariable(names),
Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names),
Expr::Literal(value) => Expr::Literal(value),
Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
left: rewrite_boxed(left, rewriter)?,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/logical_plan/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl ExprSchemable for Expr {
expr.get_type(schema)
}
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
Expand Down Expand Up @@ -162,7 +162,7 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::ScalarVariable(_)
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl ExprVisitable for Expr {
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::Column(_)
| Expr::ScalarVariable(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
| Expr::Wildcard => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ impl ExprIdentifierVisitor<'_> {
desc.push_str("Column-");
desc.push_str(&column.flat_name());
}
Expr::ScalarVariable(var_names) => {
Expr::ScalarVariable(_, var_names) => {
desc.push_str("ScalarVariable-");
desc.push_str(&var_names.join("."));
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ impl<'a> ConstEvaluator<'a> {
Expr::Alias(..)
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. }
| Expr::ScalarVariable(_)
| Expr::ScalarVariable(_, _)
| Expr::Column(_)
| Expr::WindowFunction { .. }
| Expr::Sort { .. }
Expand Down
6 changes: 3 additions & 3 deletions datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
Expr::Column(qc) => {
self.accum.insert(qc.clone());
}
Expr::ScalarVariable(var_names) => {
Expr::ScalarVariable(_, var_names) => {
self.accum.insert(Column::from_name(var_names.join(".")));
}
Expr::Alias(_, _)
Expand Down Expand Up @@ -331,7 +331,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
}
Ok(expr_list)
}
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]),
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_, _) => Ok(vec![]),
Expr::Between {
expr, low, high, ..
} => Ok(vec![
Expand Down Expand Up @@ -476,7 +476,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
Expr::Column(_)
| Expr::Literal(_)
| Expr::InList { .. }
| Expr::ScalarVariable(_) => Ok(expr.clone()),
| Expr::ScalarVariable(_, _) => Ok(expr.clone()),
Expr::Sort {
asc, nulls_first, ..
} => Ok(Expr::Sort {
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
}
Expr::Alias(_, name) => Ok(name.clone()),
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
Expr::Literal(value) => Ok(format!("{:?}", value)),
Expr::BinaryExpr { left, op, right } => {
let left = create_physical_name(left, false)?;
Expand Down Expand Up @@ -883,7 +883,7 @@ pub fn create_physical_expr(
Ok(Arc::new(Column::new(&c.name, idx)))
}
Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
Expr::ScalarVariable(variable_names) => {
Expr::ScalarVariable(_, variable_names) => {
if &variable_names[0][0..2] == "@@" {
match execution_props.get_var_provider(VarType::System) {
Some(provider) => {
Expand Down
28 changes: 26 additions & 2 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ pub trait ContextProvider {
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>>;
/// Getter for system/user-defined variable type
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType>;
}

/// SQL query planner
Expand Down Expand Up @@ -1412,7 +1414,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if id.value.starts_with('@') {
// TODO: figure out if ScalarVariables should be insensitive.
let var_names = vec![id.value.clone()];
Ok(Expr::ScalarVariable(var_names))
let ty = self
.schema_provider
.get_variable_type(&var_names)
.ok_or_else(|| {
DataFusionError::Execution(format!(
"variable {:?} has no type information",
var_names
))
})?;
Ok(Expr::ScalarVariable(ty, var_names))
} else {
// Don't use `col()` here because it will try to
// interpret names with '.' as if they were
Expand Down Expand Up @@ -1440,7 +1451,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut var_names: Vec<_> = ids.iter().map(normalize_ident).collect();

if &var_names[0][0..1] == "@" {
Ok(Expr::ScalarVariable(var_names))
let ty = self
.schema_provider
.get_variable_type(&var_names)
.ok_or_else(|| {
DataFusionError::Execution(format!(
"variable {:?} has no type information",
var_names
))
})?;
Ok(Expr::ScalarVariable(ty, var_names))
} else {
match (var_names.pop(), var_names.pop()) {
(Some(name), Some(relation)) if var_names.is_empty() => {
Expand Down Expand Up @@ -3938,6 +3958,10 @@ mod tests {
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
unimplemented!()
}

fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
unimplemented!()
}
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ where
asc: *asc,
nulls_first: *nulls_first,
}),
Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_) => {
Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_, _) => {
Ok(expr.clone())
}
Expr::Wildcard => Ok(Expr::Wildcard),
Expand Down
21 changes: 19 additions & 2 deletions datafusion/src/test/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use crate::error::Result;
use crate::scalar::ScalarValue;
use crate::variable::VarProvider;
use arrow::datatypes::DataType;

/// System variable
#[derive(Default)]
Expand All @@ -38,6 +39,10 @@ impl VarProvider for SystemVar {
let s = format!("{}-{}", "system-var", var_names.concat());
Ok(ScalarValue::Utf8(Some(s)))
}

fn get_type(&self, _: &[String]) -> Option<DataType> {
Some(DataType::Utf8)
}
}

/// user defined variable
Expand All @@ -54,7 +59,19 @@ impl UserDefinedVar {
impl VarProvider for UserDefinedVar {
/// Get user defined variable value
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue> {
let s = format!("{}-{}", "user-defined-var", var_names.concat());
Ok(ScalarValue::Utf8(Some(s)))
if var_names[0] != "@integer" {
let s = format!("{}-{}", "user-defined-var", var_names.concat());
Ok(ScalarValue::Utf8(Some(s)))
} else {
Ok(ScalarValue::Int32(Some(41)))
}
}

fn get_type(&self, var_names: &[String]) -> Option<DataType> {
if var_names[0] != "@integer" {
Some(DataType::Utf8)
} else {
Some(DataType::Int32)
}
}
}
4 changes: 4 additions & 0 deletions datafusion/src/variable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use crate::error::Result;
use crate::scalar::ScalarValue;
use arrow::datatypes::DataType;

/// Variable type, system/user defined
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -33,4 +34,7 @@ pub enum VarType {
pub trait VarProvider {
/// Get variable value
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue>;

/// Return the type of the given variable
fn get_type(&self, var_names: &[String]) -> Option<DataType>;
}