Skip to content

Commit

Permalink
refactor(expr): introduce ExprError (#3081)
Browse files Browse the repository at this point in the history
* refactor(expr): introduce ExprError

Signed-off-by: TennyZhuang <[email protected]>

* fix test

Signed-off-by: TennyZhuang <[email protected]>

* fix clippy

Signed-off-by: TennyZhuang <[email protected]>

* prefer to use try_collect

Signed-off-by: TennyZhuang <[email protected]>
  • Loading branch information
TennyZhuang authored Jun 9, 2022
1 parent 4e9090d commit b0a38c4
Show file tree
Hide file tree
Showing 50 changed files with 536 additions and 474 deletions.
7 changes: 4 additions & 3 deletions src/batch/src/executor/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use futures_async_stream::try_stream;
use itertools::Itertools;
use risingwave_common::array::column::Column;
use risingwave_common::array::DataChunk;
use risingwave_common::catalog::{Field, Schema};
Expand Down Expand Up @@ -57,7 +58,7 @@ impl ProjectExecutor {
.expr
.iter_mut()
.map(|expr| expr.eval(&data_chunk).map(Column::new))
.collect::<Result<Vec<_>>>()?;
.try_collect()?;
let ret = DataChunk::new(arrays, data_chunk.cardinality());
yield ret
}
Expand All @@ -80,11 +81,11 @@ impl BoxedExecutorBuilder for ProjectExecutor {
NodeBody::Project
)?;

let project_exprs = project_node
let project_exprs: Vec<_> = project_node
.get_select_list()
.iter()
.map(build_from_prost)
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

let fields = project_exprs
.iter()
Expand Down
32 changes: 16 additions & 16 deletions src/batch/src/executor/sort_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,22 @@ impl BoxedExecutorBuilder for SortAggExecutor {
NodeBody::SortAgg
)?;

let agg_states = sort_agg_node
let agg_states: Vec<_> = sort_agg_node
.get_agg_calls()
.iter()
.map(|x| AggStateFactory::new(x)?.create_agg_state())
.collect::<Result<Vec<BoxedAggState>>>()?;
.try_collect()?;

let group_keys = sort_agg_node
let group_keys: Vec<_> = sort_agg_node
.get_group_keys()
.iter()
.map(build_from_prost)
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

let sorted_groupers = group_keys
let sorted_groupers: Vec<_> = group_keys
.iter()
.map(|e| create_sorted_grouper(e.return_type()))
.collect::<Result<Vec<BoxedSortedGrouper>>>()?;
.try_collect()?;

let fields = group_keys
.iter()
Expand Down Expand Up @@ -124,11 +124,11 @@ impl SortAggExecutor {
#[for_await]
for child_chunk in self.child.execute() {
let child_chunk = child_chunk?.compact()?;
let group_columns = self
let group_columns: Vec<_> = self
.group_keys
.iter_mut()
.map(|expr| expr.eval(&child_chunk))
.collect::<Result<Vec<_>>>()?;
.try_collect()?;

let groups = self
.sorted_groupers
Expand Down Expand Up @@ -416,7 +416,7 @@ mod tests {
};

let count_star = AggStateFactory::new(&prost)?.create_agg_state()?;
let group_exprs = (1..=2)
let group_exprs: Vec<_> = (1..=2)
.map(|idx| {
build_from_prost(&ExprNode {
expr_type: InputRef as i32,
Expand All @@ -427,7 +427,7 @@ mod tests {
rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })),
})
})
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

let sorted_groupers = group_exprs
.iter()
Expand Down Expand Up @@ -628,7 +628,7 @@ mod tests {
};

let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?;
let group_exprs = (1..=2)
let group_exprs: Vec<_> = (1..=2)
.map(|idx| {
build_from_prost(&ExprNode {
expr_type: InputRef as i32,
Expand All @@ -639,12 +639,12 @@ mod tests {
rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })),
})
})
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

let sorted_groupers = group_exprs
let sorted_groupers: Vec<_> = group_exprs
.iter()
.map(|e| create_sorted_grouper(e.return_type()))
.collect::<Result<Vec<BoxedSortedGrouper>>>()?;
.try_collect()?;

let agg_states = vec![sum_agg];

Expand Down Expand Up @@ -751,7 +751,7 @@ mod tests {
};

let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?;
let group_exprs = (1..=2)
let group_exprs: Vec<_> = (1..=2)
.map(|idx| {
build_from_prost(&ExprNode {
expr_type: InputRef as i32,
Expand All @@ -762,7 +762,7 @@ mod tests {
rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })),
})
})
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

let sorted_groupers = group_exprs
.iter()
Expand Down
8 changes: 4 additions & 4 deletions src/batch/src/executor/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ impl UpdateExecutor {
let len = data_chunk.cardinality();

let updated_data_chunk = {
let columns = self
let columns: Vec<_> = self
.exprs
.iter_mut()
.map(|expr| expr.eval(&data_chunk).map(Column::new))
.collect::<Result<Vec<_>>>()?;
.try_collect()?;

DataChunk::new(columns, len)
};
Expand Down Expand Up @@ -176,11 +176,11 @@ impl BoxedExecutorBuilder for UpdateExecutor {

let table_id = TableId::from(&update_node.table_source_ref_id);

let exprs = update_node
let exprs: Vec<_> = update_node
.get_exprs()
.iter()
.map(build_from_prost)
.collect::<Result<Vec<BoxedExpression>>>()?;
.try_collect()?;

Ok(Box::new(Self::new(
table_id,
Expand Down
6 changes: 1 addition & 5 deletions src/batch/src/executor/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,7 @@ impl BoxedExecutorBuilder for ValuesExecutor {

let mut rows: Vec<Vec<BoxedExpression>> = Vec::with_capacity(value_node.get_tuples().len());
for row in value_node.get_tuples() {
let expr_row = row
.get_cells()
.iter()
.map(build_from_prost)
.collect::<Result<Vec<BoxedExpression>>>()?;
let expr_row: Vec<_> = row.get_cells().iter().map(build_from_prost).try_collect()?;
rows.push(expr_row);
}

Expand Down
7 changes: 7 additions & 0 deletions src/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ pub enum ErrorCode {
#[source]
BoxedError,
),
#[error("Expr error: {0:?}")]
ExprError(
#[backtrace]
#[source]
BoxedError,
),
#[error("Stream error: {0:?}")]
StreamError(
#[backtrace]
Expand Down Expand Up @@ -316,6 +322,7 @@ impl ErrorCode {
ErrorCode::ConnectorError(_) => 25,
ErrorCode::InvalidParameterValue(_) => 26,
ErrorCode::UnrecognizedConfigurationParameter(_) => 27,
ErrorCode::ExprError(_) => 28,
ErrorCode::UnknownError(_) => 101,
}
}
Expand Down
95 changes: 95 additions & 0 deletions src/expr/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2022 Singularity Data
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub use anyhow::anyhow;
use risingwave_common::error::{ErrorCode, RwError};
use risingwave_common::types::DataType;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum ExprError {
#[error("Unsupported function: {0}")]
UnsupportedFunction(String),

#[error("Can't cast {0} to {1}")]
Cast(&'static str, &'static str),

// TODO: Unify Cast and Cast2.
#[error("Can't cast {0:?} to {1:?}")]
Cast2(DataType, DataType),

#[error("Out of range")]
NumericOutOfRange,

#[error("Parse error: {0}")]
Parse(&'static str),

#[error("Invalid parameter {name}: {reason}")]
InvalidParam { name: &'static str, reason: String },

#[error("Array error: {0}")]
Array(
#[backtrace]
#[source]
RwError,
),

#[error(transparent)]
Internal(#[from] anyhow::Error),
}

#[macro_export]
macro_rules! ensure {
($cond:expr $(,)?) => {
if !$cond {
return Err($crate::ExprError::Internal($crate::error::anyhow!(
stringify!($cond)
)));
}
};
($cond:expr, $msg:literal $(,)?) => {
if !$cond {
return Err($crate::ExprError::Internal($crate::error::anyhow!($msg)));
}
};
($cond:expr, $err:expr $(,)?) => {
if !$cond {
return Err($crate::ExprError::Internal($crate::error::anyhow!$err));
}
};
($cond:expr, $fmt:expr, $($arg:tt)*) => {
if !$cond {
return Err($crate::ExprError::Internal($crate::error::anyhow!($fmt, $($arg)*)));
}
};
}

impl From<ExprError> for RwError {
fn from(s: ExprError) -> Self {
ErrorCode::ExprError(Box::new(s)).into()
}
}

#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err($crate::ExprError::Internal($crate::error::anyhow!($msg)))
};
($err:expr $(,)?) => {
return Err($crate::ExprError::Internal($crate::error::anyhow!($err)))
};
($fmt:expr, $($arg:tt)*) => {
return Err($crate::ExprError::Internal($crate::error::anyhow!($fmt, $($arg)*)))
};
}
29 changes: 10 additions & 19 deletions src/expr/src/expr/build_expr_from_prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
// limitations under the License.

use risingwave_common::array::DataChunk;
use risingwave_common::ensure;
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::{DataType, ToOwnedDatum};
use risingwave_pb::expr::expr_node::RexNode;
use risingwave_pb::expr::ExprNode;
Expand All @@ -31,23 +29,22 @@ use crate::expr::expr_unary::{
new_length_default, new_ltrim_expr, new_rtrim_expr, new_trim_expr, new_unary_expr,
};
use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression};
use crate::{bail, ensure, Result};

fn get_children_and_return_type(prost: &ExprNode) -> Result<(Vec<ExprNode>, DataType)> {
let ret_type = DataType::from(prost.get_return_type()?);
if let RexNode::FuncCall(func_call) = prost.get_rex_node()? {
let ret_type = DataType::from(prost.get_return_type().unwrap());
if let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() {
Ok((func_call.get_children().to_vec(), ret_type))
} else {
Err(RwError::from(ErrorCode::InternalError(
"expects a function call".to_string(),
)))
bail!("Expected RexNode::FuncCall");
}
}

pub fn build_unary_expr_prost(prost: &ExprNode) -> Result<BoxedExpression> {
let (children, ret_type) = get_children_and_return_type(prost)?;
ensure!(children.len() == 1);
let child_expr = expr_build_from_prost(&children[0])?;
new_unary_expr(prost.get_expr_type()?, ret_type, child_expr)
new_unary_expr(prost.get_expr_type().unwrap(), ret_type, child_expr)
}

pub fn build_binary_expr_prost(prost: &ExprNode) -> Result<BoxedExpression> {
Expand All @@ -56,7 +53,7 @@ pub fn build_binary_expr_prost(prost: &ExprNode) -> Result<BoxedExpression> {
let left_expr = expr_build_from_prost(&children[0])?;
let right_expr = expr_build_from_prost(&children[1])?;
Ok(new_binary_expr(
prost.get_expr_type()?,
prost.get_expr_type().unwrap(),
ret_type,
left_expr,
right_expr,
Expand All @@ -69,7 +66,7 @@ pub fn build_nullable_binary_expr_prost(prost: &ExprNode) -> Result<BoxedExpress
let left_expr = expr_build_from_prost(&children[0])?;
let right_expr = expr_build_from_prost(&children[1])?;
Ok(new_nullable_binary_expr(
prost.get_expr_type()?,
prost.get_expr_type().unwrap(),
ret_type,
left_expr,
right_expr,
Expand Down Expand Up @@ -169,9 +166,7 @@ pub fn build_case_expr(prost: &ExprNode) -> Result<BoxedExpression> {
let else_clause = if len % 2 == 1 {
let else_clause = expr_build_from_prost(&children[len - 1])?;
if else_clause.return_type() != ret_type {
return Err(RwError::from(ErrorCode::ProtocolError(
"the return type of else and case not match".to_string(),
)));
bail!("Type mismatched between else and case.");
}
Some(else_clause)
} else {
Expand All @@ -184,14 +179,10 @@ pub fn build_case_expr(prost: &ExprNode) -> Result<BoxedExpression> {
let when_expr = expr_build_from_prost(&children[when_index])?;
let then_expr = expr_build_from_prost(&children[then_index])?;
if when_expr.return_type() != DataType::Boolean {
return Err(RwError::from(ErrorCode::ProtocolError(
"the return type of when clause and condition not match".to_string(),
)));
bail!("Type mismatched between when clause and condition");
}
if then_expr.return_type() != ret_type {
return Err(RwError::from(ErrorCode::ProtocolError(
"the return type of then clause and case not match".to_string(),
)));
bail!("Type mismatched between then clause and case");
}
let when_clause = WhenClause::new(when_expr, then_expr);
when_clauses.push(when_clause);
Expand Down
Loading

0 comments on commit b0a38c4

Please sign in to comment.