From b0a38c4020d9d47d753d2043025ce3477e859a01 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Thu, 9 Jun 2022 11:30:24 +0800 Subject: [PATCH] refactor(expr): introduce ExprError (#3081) * refactor(expr): introduce ExprError Signed-off-by: TennyZhuang * fix test Signed-off-by: TennyZhuang * fix clippy Signed-off-by: TennyZhuang * prefer to use try_collect Signed-off-by: TennyZhuang --- src/batch/src/executor/project.rs | 7 +- src/batch/src/executor/sort_agg.rs | 32 +++---- src/batch/src/executor/update.rs | 8 +- src/batch/src/executor/values.rs | 6 +- src/common/src/error.rs | 7 ++ src/expr/src/error.rs | 95 +++++++++++++++++++ src/expr/src/expr/build_expr_from_prost.rs | 29 ++---- src/expr/src/expr/expr_array.rs | 24 +++-- src/expr/src/expr/expr_case.rs | 17 ++-- src/expr/src/expr/expr_coalesce.rs | 22 +++-- src/expr/src/expr/expr_concat_ws.rs | 29 +++--- src/expr/src/expr/expr_field.rs | 34 ++++--- src/expr/src/expr/expr_in.rs | 16 ++-- src/expr/src/expr/expr_input_ref.rs | 21 ++-- src/expr/src/expr/expr_is_null.rs | 20 ++-- src/expr/src/expr/expr_literal.rs | 51 +++++----- src/expr/src/expr/expr_unary.rs | 20 ++-- src/expr/src/expr/mod.rs | 19 ++-- src/expr/src/expr/template.rs | 63 ++++++------ src/expr/src/lib.rs | 5 + src/expr/src/vector_op/agg/general_agg.rs | 14 +-- .../src/vector_op/agg/general_distinct_agg.rs | 12 +-- .../vector_op/agg/general_sorted_grouper.rs | 16 ++-- src/expr/src/vector_op/arithmetic_op.rs | 46 +++------ src/expr/src/vector_op/array_access.rs | 33 +++---- src/expr/src/vector_op/ascii.rs | 5 +- src/expr/src/vector_op/bitwise_op.rs | 19 ++-- src/expr/src/vector_op/cast.rs | 50 ++++------ src/expr/src/vector_op/cmp.rs | 25 ++--- src/expr/src/vector_op/conjunction.rs | 2 +- src/expr/src/vector_op/extract.rs | 14 +-- src/expr/src/vector_op/length.rs | 6 +- src/expr/src/vector_op/like.rs | 2 +- src/expr/src/vector_op/lower.rs | 11 ++- src/expr/src/vector_op/ltrim.rs | 9 +- src/expr/src/vector_op/md5.rs | 11 ++- src/expr/src/vector_op/position.rs | 11 +-- src/expr/src/vector_op/replace.rs | 15 +-- src/expr/src/vector_op/round.rs | 3 +- src/expr/src/vector_op/rtrim.rs | 9 +- src/expr/src/vector_op/split_part.rs | 73 ++++++-------- src/expr/src/vector_op/substr.rs | 19 ++-- src/expr/src/vector_op/tests.rs | 22 ++--- src/expr/src/vector_op/to_char.rs | 5 +- src/expr/src/vector_op/translate.rs | 9 +- src/expr/src/vector_op/trim.rs | 9 +- src/expr/src/vector_op/tumble.rs | 10 +- src/expr/src/vector_op/upper.rs | 11 ++- src/source/src/parser/common.rs | 10 +- src/stream/src/from_proto/project.rs | 4 +- 50 files changed, 536 insertions(+), 474 deletions(-) create mode 100644 src/expr/src/error.rs diff --git a/src/batch/src/executor/project.rs b/src/batch/src/executor/project.rs index cd02287639b8..3ee3fbfbd049 100644 --- a/src/batch/src/executor/project.rs +++ b/src/batch/src/executor/project.rs @@ -13,6 +13,7 @@ // limitations under the License. use futures_async_stream::try_stream; +use itertools::Itertools; use risingwave_common::array::column::Column; use risingwave_common::array::DataChunk; use risingwave_common::catalog::{Field, Schema}; @@ -57,7 +58,7 @@ impl ProjectExecutor { .expr .iter_mut() .map(|expr| expr.eval(&data_chunk).map(Column::new)) - .collect::>>()?; + .try_collect()?; let ret = DataChunk::new(arrays, data_chunk.cardinality()); yield ret } @@ -80,11 +81,11 @@ impl BoxedExecutorBuilder for ProjectExecutor { NodeBody::Project )?; - let project_exprs = project_node + let project_exprs: Vec<_> = project_node .get_select_list() .iter() .map(build_from_prost) - .collect::>>()?; + .try_collect()?; let fields = project_exprs .iter() diff --git a/src/batch/src/executor/sort_agg.rs b/src/batch/src/executor/sort_agg.rs index c01ee3547959..726217638562 100644 --- a/src/batch/src/executor/sort_agg.rs +++ b/src/batch/src/executor/sort_agg.rs @@ -64,22 +64,22 @@ impl BoxedExecutorBuilder for SortAggExecutor { NodeBody::SortAgg )?; - let agg_states = sort_agg_node + let agg_states: Vec<_> = sort_agg_node .get_agg_calls() .iter() .map(|x| AggStateFactory::new(x)?.create_agg_state()) - .collect::>>()?; + .try_collect()?; - let group_keys = sort_agg_node + let group_keys: Vec<_> = sort_agg_node .get_group_keys() .iter() .map(build_from_prost) - .collect::>>()?; + .try_collect()?; - let sorted_groupers = group_keys + let sorted_groupers: Vec<_> = group_keys .iter() .map(|e| create_sorted_grouper(e.return_type())) - .collect::>>()?; + .try_collect()?; let fields = group_keys .iter() @@ -124,11 +124,11 @@ impl SortAggExecutor { #[for_await] for child_chunk in self.child.execute() { let child_chunk = child_chunk?.compact()?; - let group_columns = self + let group_columns: Vec<_> = self .group_keys .iter_mut() .map(|expr| expr.eval(&child_chunk)) - .collect::>>()?; + .try_collect()?; let groups = self .sorted_groupers @@ -416,7 +416,7 @@ mod tests { }; let count_star = AggStateFactory::new(&prost)?.create_agg_state()?; - let group_exprs = (1..=2) + let group_exprs: Vec<_> = (1..=2) .map(|idx| { build_from_prost(&ExprNode { expr_type: InputRef as i32, @@ -427,7 +427,7 @@ mod tests { rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })), }) }) - .collect::>>()?; + .try_collect()?; let sorted_groupers = group_exprs .iter() @@ -628,7 +628,7 @@ mod tests { }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; - let group_exprs = (1..=2) + let group_exprs: Vec<_> = (1..=2) .map(|idx| { build_from_prost(&ExprNode { expr_type: InputRef as i32, @@ -639,12 +639,12 @@ mod tests { rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })), }) }) - .collect::>>()?; + .try_collect()?; - let sorted_groupers = group_exprs + let sorted_groupers: Vec<_> = group_exprs .iter() .map(|e| create_sorted_grouper(e.return_type())) - .collect::>>()?; + .try_collect()?; let agg_states = vec![sum_agg]; @@ -751,7 +751,7 @@ mod tests { }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; - let group_exprs = (1..=2) + let group_exprs: Vec<_> = (1..=2) .map(|idx| { build_from_prost(&ExprNode { expr_type: InputRef as i32, @@ -762,7 +762,7 @@ mod tests { rex_node: Some(RexNode::InputRef(InputRefExpr { column_idx: idx })), }) }) - .collect::>>()?; + .try_collect()?; let sorted_groupers = group_exprs .iter() diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index b10bccfc4d4e..13374e0872d8 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -99,11 +99,11 @@ impl UpdateExecutor { let len = data_chunk.cardinality(); let updated_data_chunk = { - let columns = self + let columns: Vec<_> = self .exprs .iter_mut() .map(|expr| expr.eval(&data_chunk).map(Column::new)) - .collect::>>()?; + .try_collect()?; DataChunk::new(columns, len) }; @@ -176,11 +176,11 @@ impl BoxedExecutorBuilder for UpdateExecutor { let table_id = TableId::from(&update_node.table_source_ref_id); - let exprs = update_node + let exprs: Vec<_> = update_node .get_exprs() .iter() .map(build_from_prost) - .collect::>>()?; + .try_collect()?; Ok(Box::new(Self::new( table_id, diff --git a/src/batch/src/executor/values.rs b/src/batch/src/executor/values.rs index cf6a265e4e82..f3411e4878f0 100644 --- a/src/batch/src/executor/values.rs +++ b/src/batch/src/executor/values.rs @@ -117,11 +117,7 @@ impl BoxedExecutorBuilder for ValuesExecutor { let mut rows: Vec> = Vec::with_capacity(value_node.get_tuples().len()); for row in value_node.get_tuples() { - let expr_row = row - .get_cells() - .iter() - .map(build_from_prost) - .collect::>>()?; + let expr_row: Vec<_> = row.get_cells().iter().map(build_from_prost).try_collect()?; rows.push(expr_row); } diff --git a/src/common/src/error.rs b/src/common/src/error.rs index b64ea07a272b..1a5cab92157e 100644 --- a/src/common/src/error.rs +++ b/src/common/src/error.rs @@ -96,6 +96,12 @@ pub enum ErrorCode { #[source] BoxedError, ), + #[error("Expr error: {0:?}")] + ExprError( + #[backtrace] + #[source] + BoxedError, + ), #[error("Stream error: {0:?}")] StreamError( #[backtrace] @@ -316,6 +322,7 @@ impl ErrorCode { ErrorCode::ConnectorError(_) => 25, ErrorCode::InvalidParameterValue(_) => 26, ErrorCode::UnrecognizedConfigurationParameter(_) => 27, + ErrorCode::ExprError(_) => 28, ErrorCode::UnknownError(_) => 101, } } diff --git a/src/expr/src/error.rs b/src/expr/src/error.rs new file mode 100644 index 000000000000..5b79078a14ea --- /dev/null +++ b/src/expr/src/error.rs @@ -0,0 +1,95 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub use anyhow::anyhow; +use risingwave_common::error::{ErrorCode, RwError}; +use risingwave_common::types::DataType; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ExprError { + #[error("Unsupported function: {0}")] + UnsupportedFunction(String), + + #[error("Can't cast {0} to {1}")] + Cast(&'static str, &'static str), + + // TODO: Unify Cast and Cast2. + #[error("Can't cast {0:?} to {1:?}")] + Cast2(DataType, DataType), + + #[error("Out of range")] + NumericOutOfRange, + + #[error("Parse error: {0}")] + Parse(&'static str), + + #[error("Invalid parameter {name}: {reason}")] + InvalidParam { name: &'static str, reason: String }, + + #[error("Array error: {0}")] + Array( + #[backtrace] + #[source] + RwError, + ), + + #[error(transparent)] + Internal(#[from] anyhow::Error), +} + +#[macro_export] +macro_rules! ensure { + ($cond:expr $(,)?) => { + if !$cond { + return Err($crate::ExprError::Internal($crate::error::anyhow!( + stringify!($cond) + ))); + } + }; + ($cond:expr, $msg:literal $(,)?) => { + if !$cond { + return Err($crate::ExprError::Internal($crate::error::anyhow!($msg))); + } + }; + ($cond:expr, $err:expr $(,)?) => { + if !$cond { + return Err($crate::ExprError::Internal($crate::error::anyhow!$err)); + } + }; + ($cond:expr, $fmt:expr, $($arg:tt)*) => { + if !$cond { + return Err($crate::ExprError::Internal($crate::error::anyhow!($fmt, $($arg)*))); + } + }; +} + +impl From for RwError { + fn from(s: ExprError) -> Self { + ErrorCode::ExprError(Box::new(s)).into() + } +} + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err($crate::ExprError::Internal($crate::error::anyhow!($msg))) + }; + ($err:expr $(,)?) => { + return Err($crate::ExprError::Internal($crate::error::anyhow!($err))) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err($crate::ExprError::Internal($crate::error::anyhow!($fmt, $($arg)*))) + }; +} diff --git a/src/expr/src/expr/build_expr_from_prost.rs b/src/expr/src/expr/build_expr_from_prost.rs index 3c216eadd21e..fec596d503ea 100644 --- a/src/expr/src/expr/build_expr_from_prost.rs +++ b/src/expr/src/expr/build_expr_from_prost.rs @@ -13,8 +13,6 @@ // limitations under the License. use risingwave_common::array::DataChunk; -use risingwave_common::ensure; -use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::{DataType, ToOwnedDatum}; use risingwave_pb::expr::expr_node::RexNode; use risingwave_pb::expr::ExprNode; @@ -31,15 +29,14 @@ use crate::expr::expr_unary::{ new_length_default, new_ltrim_expr, new_rtrim_expr, new_trim_expr, new_unary_expr, }; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression}; +use crate::{bail, ensure, Result}; fn get_children_and_return_type(prost: &ExprNode) -> Result<(Vec, DataType)> { - let ret_type = DataType::from(prost.get_return_type()?); - if let RexNode::FuncCall(func_call) = prost.get_rex_node()? { + let ret_type = DataType::from(prost.get_return_type().unwrap()); + if let RexNode::FuncCall(func_call) = prost.get_rex_node().unwrap() { Ok((func_call.get_children().to_vec(), ret_type)) } else { - Err(RwError::from(ErrorCode::InternalError( - "expects a function call".to_string(), - ))) + bail!("Expected RexNode::FuncCall"); } } @@ -47,7 +44,7 @@ pub fn build_unary_expr_prost(prost: &ExprNode) -> Result { let (children, ret_type) = get_children_and_return_type(prost)?; ensure!(children.len() == 1); let child_expr = expr_build_from_prost(&children[0])?; - new_unary_expr(prost.get_expr_type()?, ret_type, child_expr) + new_unary_expr(prost.get_expr_type().unwrap(), ret_type, child_expr) } pub fn build_binary_expr_prost(prost: &ExprNode) -> Result { @@ -56,7 +53,7 @@ pub fn build_binary_expr_prost(prost: &ExprNode) -> Result { let left_expr = expr_build_from_prost(&children[0])?; let right_expr = expr_build_from_prost(&children[1])?; Ok(new_binary_expr( - prost.get_expr_type()?, + prost.get_expr_type().unwrap(), ret_type, left_expr, right_expr, @@ -69,7 +66,7 @@ pub fn build_nullable_binary_expr_prost(prost: &ExprNode) -> Result Result { let else_clause = if len % 2 == 1 { let else_clause = expr_build_from_prost(&children[len - 1])?; if else_clause.return_type() != ret_type { - return Err(RwError::from(ErrorCode::ProtocolError( - "the return type of else and case not match".to_string(), - ))); + bail!("Type mismatched between else and case."); } Some(else_clause) } else { @@ -184,14 +179,10 @@ pub fn build_case_expr(prost: &ExprNode) -> Result { let when_expr = expr_build_from_prost(&children[when_index])?; let then_expr = expr_build_from_prost(&children[then_index])?; if when_expr.return_type() != DataType::Boolean { - return Err(RwError::from(ErrorCode::ProtocolError( - "the return type of when clause and condition not match".to_string(), - ))); + bail!("Type mismatched between when clause and condition"); } if then_expr.return_type() != ret_type { - return Err(RwError::from(ErrorCode::ProtocolError( - "the return type of then clause and case not match".to_string(), - ))); + bail!("Type mismatched between then clause and case"); } let when_clause = WhenClause::new(when_expr, then_expr); when_clauses.push(when_clause); diff --git a/src/expr/src/expr/expr_array.rs b/src/expr/src/expr/expr_array.rs index 8f818a370371..7f64483207d1 100644 --- a/src/expr/src/expr/expr_array.rs +++ b/src/expr/src/expr/expr_array.rs @@ -19,13 +19,12 @@ use risingwave_common::array::column::Column; use risingwave_common::array::{ ArrayBuilder, ArrayImpl, ArrayMeta, ArrayRef, DataChunk, ListArrayBuilder, ListValue, Row, }; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{DataType, Datum, Scalar}; -use risingwave_common::{ensure, try_match_expand}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; +use crate::{bail, ensure, ExprError, Result}; #[derive(Debug)] pub struct ArrayExpression { @@ -56,11 +55,16 @@ impl Expression for ArrayExpression { ArrayMeta::List { datatype: self.element_type.clone(), }, - )?; + ) + .map_err(ExprError::Array)?; chunk .rows() - .try_for_each(|row| builder.append_row_ref(row))?; - builder.finish().map(|a| Arc::new(ArrayImpl::List(a))) + .try_for_each(|row| builder.append_row_ref(row)) + .map_err(ExprError::Array)?; + builder + .finish() + .map(|a| Arc::new(ArrayImpl::List(a))) + .map_err(ExprError::Array) } fn eval_row(&self, input: &Row) -> Result { @@ -87,13 +91,15 @@ impl ArrayExpression { } impl<'a> TryFrom<&'a ExprNode> for ArrayExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type()? == Type::Array); + ensure!(prost.get_expr_type().unwrap() == Type::Array); - let ret_type = DataType::from(prost.get_return_type()?); - let func_call_node = try_match_expand!(prost.get_rex_node().unwrap(), RexNode::FuncCall)?; + let ret_type = DataType::from(prost.get_return_type().unwrap()); + let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { + bail!("Expected RexNode::FuncCall"); + }; let elements = func_call_node .children .iter() diff --git a/src/expr/src/expr/expr_case.rs b/src/expr/src/expr/expr_case.rs index 319d7b80ec04..fa53407b8919 100644 --- a/src/expr/src/expr/expr_case.rs +++ b/src/expr/src/expr/expr_case.rs @@ -14,10 +14,10 @@ use itertools::Itertools; use risingwave_common::array::{ArrayRef, DataChunk, Row}; -use risingwave_common::error::Result; use risingwave_common::types::{DataType, Datum, ScalarImpl, ScalarRefImpl, ToOwnedDatum}; use crate::expr::{BoxedExpression, Expression}; +use crate::{ExprError, Result}; #[derive(Debug)] pub struct WhenClause { @@ -72,7 +72,10 @@ impl Expression for CaseExpression { ) }) .collect_vec(); - let mut output_array = self.return_type().create_array_builder(input.capacity())?; + let mut output_array = self + .return_type() + .create_array_builder(input.capacity()) + .map_err(ExprError::Array)?; for idx in 0..input.capacity() { if let Some((_, t)) = when_thens .iter() @@ -83,15 +86,17 @@ impl Expression for CaseExpression { .as_bool() }) { - output_array.append_datum(&t.to_owned_datum())?; + output_array + .append_datum(&t.to_owned_datum()) + .map_err(ExprError::Array)?; } else if let Some(els) = els.as_mut() { let t = els.datum_at(idx); - output_array.append_datum(&t)?; + output_array.append_datum(&t).map_err(ExprError::Array)?; } else { - output_array.append_null()?; + output_array.append_null().map_err(ExprError::Array)?; }; } - let output_array = output_array.finish()?.into(); + let output_array = output_array.finish().map_err(ExprError::Array)?.into(); Ok(output_array) } diff --git a/src/expr/src/expr/expr_coalesce.rs b/src/expr/src/expr/expr_coalesce.rs index 7a044c58a78f..877b3ac8c900 100644 --- a/src/expr/src/expr/expr_coalesce.rs +++ b/src/expr/src/expr/expr_coalesce.rs @@ -16,13 +16,12 @@ use std::convert::TryFrom; use std::sync::Arc; use risingwave_common::array::{ArrayRef, DataChunk, Row}; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{DataType, Datum}; -use risingwave_common::{ensure, try_match_expand}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; +use crate::{bail, ensure, ExprError, Result}; #[derive(Debug)] pub struct CoalesceExpression { @@ -41,7 +40,10 @@ impl Expression for CoalesceExpression { .iter() .map(|c| c.eval(input)) .collect::>>()?; - let mut builder = self.return_type.create_array_builder(input.cardinality())?; + let mut builder = self + .return_type + .create_array_builder(input.cardinality()) + .map_err(ExprError::Array)?; let len = children_array[0].len(); for i in 0..len { @@ -53,9 +55,9 @@ impl Expression for CoalesceExpression { break; } } - builder.append_datum(&data)?; + builder.append_datum(&data).map_err(ExprError::Array)?; } - Ok(Arc::new(builder.finish()?)) + Ok(Arc::new(builder.finish().map_err(ExprError::Array)?)) } fn eval_row(&self, input: &Row) -> Result { @@ -84,13 +86,15 @@ impl CoalesceExpression { } impl<'a> TryFrom<&'a ExprNode> for CoalesceExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type()? == Type::Coalesce); + ensure!(prost.get_expr_type().unwrap() == Type::Coalesce,); - let ret_type = DataType::from(prost.get_return_type()?); - let func_call_node = try_match_expand!(prost.get_rex_node().unwrap(), RexNode::FuncCall)?; + let ret_type = DataType::from(prost.get_return_type().unwrap()); + let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { + bail!("Expected RexNode::FuncCall"); + }; let children = func_call_node .children diff --git a/src/expr/src/expr/expr_concat_ws.rs b/src/expr/src/expr/expr_concat_ws.rs index 819093088bed..704f267613c7 100644 --- a/src/expr/src/expr/expr_concat_ws.rs +++ b/src/expr/src/expr/expr_concat_ws.rs @@ -18,13 +18,12 @@ use std::sync::Arc; use risingwave_common::array::{ Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, Row, Utf8ArrayBuilder, }; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{DataType, Datum, Scalar}; -use risingwave_common::{ensure, try_match_expand}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; +use crate::{bail, ensure, ExprError, Result}; #[derive(Debug)] pub struct ConcatWsExpression { @@ -53,13 +52,13 @@ impl Expression for ConcatWsExpression { .collect::>(); let row_len = input.cardinality(); - let mut builder = Utf8ArrayBuilder::new(row_len)?; + let mut builder = Utf8ArrayBuilder::new(row_len).map_err(ExprError::Array)?; for row_idx in 0..row_len { let sep = match sep_column.value_at(row_idx) { Some(sep) => sep, None => { - builder.append(None)?; + builder.append(None).map_err(ExprError::Array)?; continue; } }; @@ -69,21 +68,23 @@ impl Expression for ConcatWsExpression { let mut string_columns = string_columns_ref.iter(); for string_column in string_columns.by_ref() { if let Some(string) = string_column.value_at(row_idx) { - writer.write_ref(string)?; + writer.write_ref(string).map_err(ExprError::Array)?; break; } } for string_column in string_columns { if let Some(string) = string_column.value_at(row_idx) { - writer.write_ref(sep)?; - writer.write_ref(string)?; + writer.write_ref(sep).map_err(ExprError::Array)?; + writer.write_ref(string).map_err(ExprError::Array)?; } } - builder = writer.finish()?.into_inner(); + builder = writer.finish().map_err(ExprError::Array)?.into_inner(); } - Ok(Arc::new(ArrayImpl::from(builder.finish()?))) + Ok(Arc::new(ArrayImpl::from( + builder.finish().map_err(ExprError::Array)?, + ))) } fn eval_row(&self, input: &Row) -> Result { @@ -129,13 +130,15 @@ impl ConcatWsExpression { } impl<'a> TryFrom<&'a ExprNode> for ConcatWsExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type()? == Type::ConcatWs); + ensure!(prost.get_expr_type().unwrap() == Type::ConcatWs); - let ret_type = DataType::from(prost.get_return_type()?); - let func_call_node = try_match_expand!(prost.get_rex_node().unwrap(), RexNode::FuncCall)?; + let ret_type = DataType::from(prost.get_return_type().unwrap()); + let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { + bail!("Expected RexNode::FuncCall"); + }; let children = &func_call_node.children; let sep_expr = expr_build_from_prost(&children[0])?; diff --git a/src/expr/src/expr/expr_field.rs b/src/expr/src/expr/expr_field.rs index 4965db83d8fe..428f748752c0 100644 --- a/src/expr/src/expr/expr_field.rs +++ b/src/expr/src/expr/expr_field.rs @@ -14,14 +14,14 @@ use std::convert::TryFrom; +use anyhow::anyhow; use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk, Row}; -use risingwave_common::error::{internal_error, ErrorCode, Result, RwError}; use risingwave_common::types::{DataType, Datum}; -use risingwave_common::{ensure, ensure_eq, try_match_expand}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; +use crate::{bail, ensure, ExprError, Result}; /// `FieldExpression` access a field from a struct. #[derive(Debug)] @@ -41,12 +41,12 @@ impl Expression for FieldExpression { if let ArrayImpl::Struct(struct_array) = array.as_ref() { Ok(struct_array.field_at(self.index)) } else { - Err(internal_error("expects a struct array ref")) + Err(anyhow!("expects a struct array ref").into()) } } fn eval_row(&self, _input: &Row) -> Result { - Err(internal_error("expects a struct array ref")) + Err(anyhow!("expects a struct array ref").into()) } } @@ -61,23 +61,31 @@ impl FieldExpression { } impl<'a> TryFrom<&'a ExprNode> for FieldExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type()? == Type::Field); + ensure!(prost.get_expr_type().unwrap() == Type::Field); - let ret_type = DataType::from(prost.get_return_type()?); - let func_call_node = try_match_expand!(prost.get_rex_node().unwrap(), RexNode::FuncCall)?; + let ret_type = DataType::from(prost.get_return_type().unwrap()); + let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { + bail!("Expected RexNode::FuncCall"); + }; let children = func_call_node.children.to_vec(); // Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or // `InputRef`, the second is i32 `Literal`. - ensure_eq!(children.len(), 2); + ensure!(children.len() == 2); let input = expr_build_from_prost(&children[0])?; - let value = try_match_expand!(children[1].get_rex_node().unwrap(), RexNode::Constant)?; - let index = i32::from_be_bytes(value.body.clone().try_into().map_err(|e| { - ErrorCode::InternalError(format!("Failed to deserialize i32, reason: {:?}", e)) - })?); + let RexNode::Constant(value) = children[1].get_rex_node().unwrap() else { + bail!("Expected Constant as 1st argument"); + }; + let index = i32::from_be_bytes( + value + .body + .clone() + .try_into() + .map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))?, + ); Ok(FieldExpression::new(ret_type, input, index as usize)) } } diff --git a/src/expr/src/expr/expr_in.rs b/src/expr/src/expr/expr_in.rs index d610a695528e..a8547004bdfd 100644 --- a/src/expr/src/expr/expr_in.rs +++ b/src/expr/src/expr/expr_in.rs @@ -21,6 +21,7 @@ use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChu use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; use crate::expr::{BoxedExpression, Expression}; +use crate::{ExprError, Result}; #[derive(Debug)] pub(crate) struct InExpression { @@ -56,10 +57,11 @@ impl Expression for InExpression { self.return_type.clone() } - fn eval(&self, input: &DataChunk) -> risingwave_common::error::Result { + fn eval(&self, input: &DataChunk) -> Result { let input_array = self.left.eval(input)?; let visibility = input.visibility(); - let mut output_array = BoolArrayBuilder::new(input.cardinality())?; + let mut output_array = + BoolArrayBuilder::new(input.cardinality()).map_err(ExprError::Array)?; match visibility { Some(bitmap) => { for (data, vis) in input_array.iter().zip_eq(bitmap.iter()) { @@ -67,20 +69,22 @@ impl Expression for InExpression { continue; } let ret = self.exists(&data.to_owned_datum()); - output_array.append(Some(ret))?; + output_array.append(Some(ret)).map_err(ExprError::Array)?; } } None => { for data in input_array.iter() { let ret = self.exists(&data.to_owned_datum()); - output_array.append(Some(ret))?; + output_array.append(Some(ret)).map_err(ExprError::Array)?; } } }; - Ok(Arc::new(output_array.finish()?.into())) + Ok(Arc::new( + output_array.finish().map_err(ExprError::Array)?.into(), + )) } - fn eval_row(&self, input: &Row) -> risingwave_common::error::Result { + fn eval_row(&self, input: &Row) -> Result { let data = self.left.eval_row(input)?; let ret = self.exists(&data); Ok(Some(ret.to_scalar_value())) diff --git a/src/expr/src/expr/expr_input_ref.rs b/src/expr/src/expr/expr_input_ref.rs index d7c671ffbede..36554de2ecad 100644 --- a/src/expr/src/expr/expr_input_ref.rs +++ b/src/expr/src/expr/expr_input_ref.rs @@ -17,13 +17,12 @@ use std::ops::Index; use std::sync::Arc; use risingwave_common::array::{ArrayRef, DataChunk, Row}; -use risingwave_common::ensure; -use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::{DataType, Datum}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::Expression; +use crate::{bail, ensure, ExprError, Result}; /// `InputRefExpression` references to a column in input relation #[derive(Debug)] @@ -46,7 +45,11 @@ impl Expression for InputRefExpression { match bitmap { Some(bitmap) => { if input.cardinality() != input.capacity() { - Ok(Arc::new(array.compact(bitmap, cardinality)?)) + Ok(Arc::new( + array + .compact(bitmap, cardinality) + .map_err(ExprError::Array)?, + )) } else { Ok(array) } @@ -72,21 +75,19 @@ impl InputRefExpression { } impl<'a> TryFrom<&'a ExprNode> for InputRefExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_expr_type()? == Type::InputRef); + ensure!(prost.get_expr_type().unwrap() == Type::InputRef); - let ret_type = DataType::from(prost.get_return_type()?); - if let RexNode::InputRef(input_ref_node) = prost.get_rex_node()? { + let ret_type = DataType::from(prost.get_return_type().unwrap()); + if let RexNode::InputRef(input_ref_node) = prost.get_rex_node().unwrap() { Ok(Self { return_type: ret_type, idx: input_ref_node.column_idx as usize, }) } else { - Err(RwError::from(ErrorCode::InternalError( - "expects a input ref node".to_string(), - ))) + bail!("Expect an input ref node") } } } diff --git a/src/expr/src/expr/expr_is_null.rs b/src/expr/src/expr/expr_is_null.rs index 3828688af747..4e2be25107a5 100644 --- a/src/expr/src/expr/expr_is_null.rs +++ b/src/expr/src/expr/expr_is_null.rs @@ -17,10 +17,10 @@ use std::sync::Arc; use risingwave_common::array::{ ArrayBuilder, ArrayImpl, ArrayRef, BoolArrayBuilder, DataChunk, Row, }; -use risingwave_common::error::Result; use risingwave_common::types::{DataType, Datum, Scalar}; use crate::expr::{BoxedExpression, Expression}; +use crate::{ExprError, Result}; #[derive(Debug)] pub struct IsNullExpression { @@ -58,14 +58,17 @@ impl Expression for IsNullExpression { } fn eval(&self, input: &DataChunk) -> Result { - let mut builder = BoolArrayBuilder::new(input.cardinality())?; + let mut builder = BoolArrayBuilder::new(input.cardinality()).map_err(ExprError::Array)?; self.child .eval(input)? .null_bitmap() .iter() - .try_for_each(|b| builder.append(Some(!b)))?; + .try_for_each(|b| builder.append(Some(!b))) + .map_err(ExprError::Array)?; - Ok(Arc::new(ArrayImpl::Bool(builder.finish()?))) + Ok(Arc::new(ArrayImpl::Bool( + builder.finish().map_err(ExprError::Array)?, + ))) } fn eval_row(&self, input: &Row) -> Result { @@ -81,14 +84,17 @@ impl Expression for IsNotNullExpression { } fn eval(&self, input: &DataChunk) -> Result { - let mut builder = BoolArrayBuilder::new(input.cardinality())?; + let mut builder = BoolArrayBuilder::new(input.cardinality()).map_err(ExprError::Array)?; self.child .eval(input)? .null_bitmap() .iter() - .try_for_each(|b| builder.append(Some(b)))?; + .try_for_each(|b| builder.append(Some(b))) + .map_err(ExprError::Array)?; - Ok(Arc::new(ArrayImpl::Bool(builder.finish()?))) + Ok(Arc::new(ArrayImpl::Bool( + builder.finish().map_err(ExprError::Array)?, + ))) } fn eval_row(&self, input: &Row) -> Result { diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index 95ab07430ed9..4b88c925c66e 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -15,15 +15,14 @@ use std::convert::TryFrom; use std::sync::Arc; -use prost::DecodeError; use risingwave_common::array::{Array, ArrayBuilder, ArrayBuilderImpl, ArrayRef, DataChunk, Row}; -use risingwave_common::error::{ErrorCode, Result, RwError}; +use risingwave_common::for_all_variants; use risingwave_common::types::{DataType, Datum, Scalar, ScalarImpl}; -use risingwave_common::{ensure, for_all_variants}; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; use crate::expr::Expression; +use crate::{bail, ensure, ExprError, Result}; macro_rules! array_impl_literal_append { ([$arr_builder: ident, $literal: ident, $cardinality: ident], $( { $variant_name:ident, $suffix_name:ident, $array:ty, $builder:ty } ),*) => { @@ -36,9 +35,9 @@ macro_rules! array_impl_literal_append { append_literal_to_arr(inner, None, $cardinality)?; } )* - (_, _) => return Err(ErrorCode::NotImplemented( - "Do not support values in insert values executor".to_string(), None.into(), - ).into()), + (_, _) => $crate::bail!( + "Do not support values in insert values executor".to_string() + ), } }; } @@ -55,12 +54,18 @@ impl Expression for LiteralExpression { } fn eval(&self, input: &DataChunk) -> Result { - let mut array_builder = self.return_type.create_array_builder(input.cardinality())?; + let mut array_builder = self + .return_type + .create_array_builder(input.cardinality()) + .map_err(ExprError::Array)?; let cardinality = input.cardinality(); let builder = &mut array_builder; let literal = &self.literal; for_all_variants! {array_impl_literal_append, builder, literal, cardinality} - array_builder.finish().map(Arc::new) + array_builder + .finish() + .map(Arc::new) + .map_err(ExprError::Array) } fn eval_row(&self, _input: &Row) -> Result { @@ -77,7 +82,7 @@ where A1: ArrayBuilder, { for _ in 0..cardinality { - a.append(v)? + a.append(v).map_err(ExprError::Array)? } Ok(()) } @@ -122,11 +127,11 @@ impl LiteralExpression { } impl<'a> TryFrom<&'a ExprNode> for LiteralExpression { - type Error = RwError; + type Error = ExprError; fn try_from(prost: &'a ExprNode) -> Result { ensure!(prost.expr_type == Type::ConstantValue as i32); - let ret_type = DataType::from(prost.get_return_type()?); + let ret_type = DataType::from(prost.get_return_type().unwrap()); if prost.rex_node.is_none() { return Ok(Self { return_type: ret_type, @@ -134,28 +139,26 @@ impl<'a> TryFrom<&'a ExprNode> for LiteralExpression { }); } - if let RexNode::Constant(prost_value) = prost.get_rex_node()? { + if let RexNode::Constant(prost_value) = prost.get_rex_node().unwrap() { // TODO: We need to unify these - let value = - ScalarImpl::bytes_to_scalar(prost_value.get_body(), prost.get_return_type()?)?; + let value = ScalarImpl::bytes_to_scalar( + prost_value.get_body(), + prost.get_return_type().unwrap(), + ) + .map_err(ExprError::Array)?; Ok(Self { return_type: ret_type, literal: Some(value), }) } else { - Err(RwError::from(ErrorCode::ProstError(DecodeError::new( - "Cannot parse the RexNode", - )))) + bail!("Cannot parse the RexNode"); } } } #[cfg(test)] mod tests { - use std::sync::Arc; - - use risingwave_common::array::column::Column; - use risingwave_common::array::{I32Array, PrimitiveArray, StructValue}; + use risingwave_common::array::{I32Array, StructValue}; use risingwave_common::array_nonnull; use risingwave_common::types::{Decimal, IntervalUnit, IntoOrdered}; use risingwave_pb::data::data_type::{IntervalType, TypeName}; @@ -302,12 +305,6 @@ mod tests { } } - #[allow(dead_code)] - fn create_column(vec: &[Option]) -> Result { - let array = PrimitiveArray::from_slice(vec).map(|x| Arc::new(x.into()))?; - Ok(Column::new(array)) - } - #[test] fn test_literal_eval_dummy_chunk() { let literal = LiteralExpression::new(DataType::Int32, Some(1.into())); diff --git a/src/expr/src/expr/expr_unary.rs b/src/expr/src/expr/expr_unary.rs index b1b332b58971..0f4ab5c36bc2 100644 --- a/src/expr/src/expr/expr_unary.rs +++ b/src/expr/src/expr/expr_unary.rs @@ -15,7 +15,6 @@ //! For expression that only accept one value as input (e.g. CAST) use risingwave_common::array::*; -use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; use risingwave_pb::expr::expr_node::Type as ProstType; @@ -37,6 +36,7 @@ use crate::vector_op::round::*; use crate::vector_op::rtrim::rtrim; use crate::vector_op::trim::trim; use crate::vector_op::upper::upper; +use crate::{ExprError, Result}; /// This macro helps to create cast expression. /// It receives all the combinations of `gen_cast` and generates corresponding match cases @@ -62,11 +62,7 @@ macro_rules! gen_cast_impl { ), )* _ => { - return Err(ErrorCode::NotImplemented(format!( - "CAST({:?} AS {:?}) not supported yet!", - $child.return_type(), $ret - ), 1632.into()) - .into()); + return Err(ExprError::Cast2($child.return_type(), $ret)); } } }; @@ -174,10 +170,7 @@ macro_rules! gen_unary_impl { ), )* _ => { - return Err(ErrorCode::NotImplemented(format!( - "{:?} is not supported on ({:?}, {:?})", $expr_name, $child.return_type(), $ret, - ), 112.into()) - .into()); + return Err(ExprError::UnsupportedFunction(format!("{}({:?}) -> {:?}", $expr_name, $child.return_type(), $ret))); } } }; @@ -327,11 +320,10 @@ pub fn new_unary_expr( gen_round_expr! {"Ceil", child_expr, return_type, round_f64, round_decimal} } (expr, ret, child) => { - return Err(ErrorCode::NotImplemented(format!( - "The expression {:?}({:?}) ->{:?} using vectorized expression framework is not supported yet.", + return Err(ExprError::UnsupportedFunction(format!( + "{:?}({:?}) -> {:?}", expr, child, ret - ), 112.into()) - .into()); + ))); } }; diff --git a/src/expr/src/expr/mod.rs b/src/expr/src/expr/mod.rs index 9e416a8ce58b..7a3870d2aa5c 100644 --- a/src/expr/src/expr/mod.rs +++ b/src/expr/src/expr/mod.rs @@ -39,16 +39,16 @@ pub use agg::AggKind; pub use expr_input_ref::InputRefExpression; pub use expr_literal::*; use risingwave_common::array::{ArrayRef, DataChunk, Row}; -use risingwave_common::error::ErrorCode::InternalError; -use risingwave_common::error::Result; use risingwave_common::types::{DataType, Datum}; use risingwave_pb::expr::ExprNode; +use super::Result; use crate::expr::build_expr_from_prost::*; use crate::expr::expr_array::ArrayExpression; use crate::expr::expr_coalesce::CoalesceExpression; use crate::expr::expr_concat_ws::ConcatWsExpression; use crate::expr::expr_field::FieldExpression; +use crate::ExprError; pub type ExpressionRef = Arc; @@ -63,6 +63,7 @@ pub trait Expression: std::fmt::Debug + Sync + Send { /// * `input` - input data of the Project Executor fn eval(&self, input: &DataChunk) -> Result; + /// Evaluate the expression in row-based execution. fn eval_row(&self, input: &Row) -> Result; fn boxed(self) -> BoxedExpression @@ -78,7 +79,7 @@ pub type BoxedExpression = Box; pub fn build_from_prost(prost: &ExprNode) -> Result { use risingwave_pb::expr::expr_node::Type::*; - match prost.get_expr_type()? { + match prost.get_expr_type().unwrap() { Cast | Upper | Lower | Md5 | Not | IsTrue | IsNotTrue | IsFalse | IsNotFalse | IsNull | IsNotNull | Neg | Ascii | Abs | Ceil | Floor | Round | BitwiseNot | CharLength => { build_unary_expr_prost(prost) @@ -107,16 +108,15 @@ pub fn build_from_prost(prost: &ExprNode) -> Result { In => build_in_expr(prost), Field => FieldExpression::try_from(prost).map(Expression::boxed), Array => ArrayExpression::try_from(prost).map(Expression::boxed), - _ => Err(InternalError(format!( - "Unsupported expression type: {:?}", + _ => Err(ExprError::UnsupportedFunction(format!( + "{:?}", prost.get_expr_type() - )) - .into()), + ))), } } -#[derive(Debug)] /// Simply wrap a row level expression as an array level expression +#[derive(Debug)] pub struct RowExpression { expr: BoxedExpression, } @@ -127,7 +127,8 @@ impl RowExpression { } pub fn eval(&mut self, row: &Row, data_types: &[DataType]) -> Result { - let input = DataChunk::from_rows(slice::from_ref(row), data_types)?; + let input = + DataChunk::from_rows(slice::from_ref(row), data_types).map_err(ExprError::Array)?; self.expr.eval(&input) } diff --git a/src/expr/src/expr/template.rs b/src/expr/src/expr/template.rs index 26570b575ec3..f048cb3180fb 100644 --- a/src/expr/src/expr/template.rs +++ b/src/expr/src/expr/template.rs @@ -23,7 +23,6 @@ use risingwave_common::array::{ Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, ArrayRef, BytesGuard, BytesWriter, DataChunk, Row, Utf8Array, }; -use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::for_all_variants; use risingwave_common::types::{option_as_scalar_ref, DataType, Datum, Scalar, ScalarImpl}; @@ -34,22 +33,22 @@ macro_rules! array_impl_add_datum { match ($arr_builder, $datum) { $( (ArrayBuilderImpl::$variant_name(inner), Some(ScalarImpl::$variant_name(v))) => { - inner.append(Some(v.as_scalar_ref()))?; + inner.append(Some(v.as_scalar_ref())).map_err($crate::ExprError::Array)?; } (ArrayBuilderImpl::$variant_name(inner), None) => { - inner.append(None)?; + inner.append(None).map_err($crate::ExprError::Array)?; } )* - (_, _) => return Err(ErrorCode::NotImplemented( - "Do not support values in insert values executor".to_string(), None.into(), - ).into()), + (_, _) => $crate::bail!( + "Do not support values in insert values executor".to_string(), + ), } }; } macro_rules! gen_eval { { $macro:ident, $ty_name:ident, $OA:ty, $($arg:ident,)* } => { - fn eval(&self, data_chunk: &DataChunk) -> Result { + fn eval(&self, data_chunk: &DataChunk) -> $crate::Result { paste! { $( let [] = self.[].eval(data_chunk)?; @@ -57,7 +56,7 @@ macro_rules! gen_eval { )* let bitmap = data_chunk.get_visibility_ref(); - let mut output_array = <$OA as Array>::Builder::new(data_chunk.capacity())?; + let mut output_array = <$OA as Array>::Builder::new(data_chunk.capacity()).map_err($crate::ExprError::Array)?; Ok(Arc::new(match bitmap { Some(bitmap) => { for (($([], )*), visible) in multizip(($([].iter(), )*)).zip_eq(bitmap.iter()) { @@ -66,13 +65,13 @@ macro_rules! gen_eval { } $macro!(self, output_array, $([],)*) } - output_array.finish()?.into() + output_array.finish().map_err($crate::ExprError::Array)?.into() } None => { for ($([], )*) in multizip(($([].iter(), )*)) { $macro!(self, output_array, $([],)*) } - output_array.finish()?.into() + output_array.finish().map_err($crate::ExprError::Array)?.into() } })) } @@ -82,25 +81,25 @@ macro_rules! gen_eval { /// resulting datums are placed in their own arrays. The arrays are then handled in the same /// way as in `eval()`. This could be optimized to work on the datums directly /// instead of placing them in arrays. - fn eval_row(&self, row: &Row) -> Result { + fn eval_row(&self, row: &Row) -> $crate::Result { paste! { $( let [] = self.[].eval_row(row)?; - let mut [] = self.[].return_type().create_array_builder(1)?; + let mut [] = self.[].return_type().create_array_builder(1).map_err($crate::ExprError::Array)?; let [] = &mut []; for_all_variants! {array_impl_add_datum, [], []} - let [] = [].finish().map(Arc::new)?; + let [] = [].finish().map(Arc::new).map_err($crate::ExprError::Array)?; let []: &$arg = [].as_ref().into(); )* - let mut output_array = <$OA as Array>::Builder::new(1)?; + let mut output_array = <$OA as Array>::Builder::new(1).map_err($crate::ExprError::Array)?; for ($([], )*) in multizip(($([].iter(), )*)) { $macro!(self, output_array, $([],)*) } - let output_arrayimpl: ArrayImpl = output_array.finish()?.into(); + let output_arrayimpl: ArrayImpl = output_array.finish().map_err($crate::ExprError::Array)?.into(); Ok(output_arrayimpl.to_datum()) } @@ -113,9 +112,9 @@ macro_rules! eval_normal { if let ($(Some($arg), )*) = ($($arg, )*) { let ret = ($self.func)($($arg, )*)?; let output = Some(ret.as_scalar_ref()); - $output_array.append(output)?; + $output_array.append(output).map_err($crate::ExprError::Array)?; } else { - $output_array.append(None)?; + $output_array.append(None).map_err($crate::ExprError::Array)?; } } } @@ -126,7 +125,7 @@ macro_rules! gen_expr_normal { pub struct $ty_name< $($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> Result, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> $crate::Result, > { $([]: BoxedExpression,)* return_type: DataType, @@ -136,7 +135,7 @@ macro_rules! gen_expr_normal { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> Result + Sized + Sync + Send, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> $crate::Result + Sized + Sync + Send, > fmt::Debug for $ty_name<$($arg, )* OA, F> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct(stringify!($ty_name)) @@ -149,8 +148,8 @@ macro_rules! gen_expr_normal { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> Result + Sized + Sync + Send, - > Expression for $ty_name<$($arg, )* OA, F> + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> $crate::Result + Sized + Sync + Send, + > Expression for $ty_name<$($arg, )* OA, F> where $(for<'a> &'a $arg: std::convert::From<&'a ArrayImpl>,)* for<'a> &'a OA: std::convert::From<&'a ArrayImpl>, @@ -164,7 +163,7 @@ macro_rules! gen_expr_normal { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> Result + Sized + Sync + Send, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )*) -> $crate::Result + Sized + Sync + Send, > $ty_name<$($arg, )* OA, F> { #[allow(dead_code)] pub fn new( @@ -191,7 +190,7 @@ macro_rules! eval_bytes { let guard = ($self.func)($($arg, )* writer)?; $output_array = guard.into_inner(); } else { - $output_array.append(None)?; + $output_array.append(None).map_err($crate::ExprError::Array)?; } } } @@ -201,7 +200,7 @@ macro_rules! gen_expr_bytes { paste! { pub struct $ty_name< $($arg: Array, )* - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> Result, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> $crate::Result, > { $([]: BoxedExpression,)* return_type: DataType, @@ -210,7 +209,7 @@ macro_rules! gen_expr_bytes { } impl<$($arg: Array, )* - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> Result + Sized + Sync + Send, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> $crate::Result + Sized + Sync + Send, > fmt::Debug for $ty_name<$($arg, )* F> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct(stringify!($ty_name)) @@ -222,7 +221,7 @@ macro_rules! gen_expr_bytes { } impl<$($arg: Array, )* - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> Result + Sized + Sync + Send, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> $crate::Result + Sized + Sync + Send, > Expression for $ty_name<$($arg, )* F> where $(for<'a> &'a $arg: std::convert::From<&'a ArrayImpl>,)* @@ -235,7 +234,7 @@ macro_rules! gen_expr_bytes { } impl<$($arg: Array, )* - F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> Result + Sized + Sync + Send, + F: for<$($lt),*> Fn($($arg::RefItem<$lt>, )* BytesWriter) -> $crate::Result + Sized + Sync + Send, > $ty_name<$($arg, )* F> { pub fn new( $([]: BoxedExpression, )* @@ -258,7 +257,7 @@ macro_rules! eval_nullable { ($self:ident, $output_array:ident, $($arg:ident,)*) => { { let ret = ($self.func)($($arg,)*)?; - $output_array.append(option_as_scalar_ref(&ret))?; + $output_array.append(option_as_scalar_ref(&ret)).map_err($crate::ExprError::Array)?; } } } @@ -269,7 +268,7 @@ macro_rules! gen_expr_nullable { pub struct $ty_name< $($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> Result>, + F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> $crate::Result>, > { $([]: BoxedExpression,)* return_type: DataType, @@ -279,7 +278,7 @@ macro_rules! gen_expr_nullable { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> Result> + Sized + Sync + Send, + F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> $crate::Result> + Sized + Sync + Send, > fmt::Debug for $ty_name<$($arg, )* OA, F> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct(stringify!($ty_name)) @@ -292,7 +291,7 @@ macro_rules! gen_expr_nullable { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> Result> + Sized + Sync + Send, + F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> $crate::Result> + Sized + Sync + Send, > Expression for $ty_name<$($arg, )* OA, F> where $(for<'a> &'a $arg: std::convert::From<&'a ArrayImpl>,)* @@ -307,7 +306,7 @@ macro_rules! gen_expr_nullable { impl<$($arg: Array, )* OA: Array, - F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> Result> + Sized + Sync + Send, + F: for<$($lt),*> Fn($(Option<$arg::RefItem<$lt>>, )*) -> $crate::Result> + Sized + Sync + Send, > $ty_name<$($arg, )* OA, F> { // Compile failed due to some GAT lifetime issues so make this field private. // Check issues #742. diff --git a/src/expr/src/lib.rs b/src/expr/src/lib.rs index 33993610ae40..31776b86bd52 100644 --- a/src/expr/src/lib.rs +++ b/src/expr/src/lib.rs @@ -31,6 +31,11 @@ #![feature(backtrace)] #![feature(fn_traits)] #![feature(assert_matches)] +#![feature(let_else)] +pub mod error; pub mod expr; pub mod vector_op; + +pub use error::ExprError; +pub type Result = std::result::Result; diff --git a/src/expr/src/vector_op/agg/general_agg.rs b/src/expr/src/vector_op/agg/general_agg.rs index 49acf861f543..5fa08d7828e3 100644 --- a/src/expr/src/vector_op/agg/general_agg.rs +++ b/src/expr/src/vector_op/agg/general_agg.rs @@ -256,7 +256,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Int32(I32ArrayBuilder::new(0)?), + ArrayBuilderImpl::Int32(I32ArrayBuilder::new(0).unwrap()), ); if !result.is_empty() { let actual = actual?; @@ -295,7 +295,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)?), + ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); @@ -314,7 +314,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - DecimalArrayBuilder::new(0)?.into(), + DecimalArrayBuilder::new(0).unwrap().into(), )?; let actual: &DecimalArray = (&actual).into(); let actual = actual.iter().collect::>>(); @@ -334,7 +334,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0)?), + ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_float32(); let actual = actual.iter().collect::>(); @@ -353,7 +353,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)?), + ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); @@ -372,7 +372,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)?), + ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); @@ -391,7 +391,7 @@ mod tests { Arc::new(input), &agg_type, return_type, - ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)?), + ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index e9b1463a931a..9eb2b1451198 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -271,7 +271,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)?), + ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); @@ -290,7 +290,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - DecimalArrayBuilder::new(0)?.into(), + DecimalArrayBuilder::new(0).unwrap().into(), )?; let actual: &DecimalArray = (&actual).into(); let actual = actual.iter().collect::>>(); @@ -310,7 +310,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0)?), + ArrayBuilderImpl::Float32(F32ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_float32(); let actual = actual.iter().collect::>(); @@ -329,7 +329,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)?), + ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); @@ -348,7 +348,7 @@ mod tests { Arc::new(input.into()), &agg_type, return_type, - ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)?), + ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_utf8(); let actual = actual.iter().collect::>(); @@ -367,7 +367,7 @@ mod tests { Arc::new(input), &agg_type, return_type, - ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)?), + ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0).unwrap()), )?; let actual = actual.as_int64(); let actual = actual.iter().collect::>(); diff --git a/src/expr/src/vector_op/agg/general_sorted_grouper.rs b/src/expr/src/vector_op/agg/general_sorted_grouper.rs index 24591f99b8bd..e54a624bc502 100644 --- a/src/expr/src/vector_op/agg/general_sorted_grouper.rs +++ b/src/expr/src/vector_op/agg/general_sorted_grouper.rs @@ -320,7 +320,7 @@ mod tests { ongoing: false, group_value: None, }; - let mut builder = I32ArrayBuilder::new(0)?; + let mut builder = I32ArrayBuilder::new(0).unwrap(); let input = I32Array::from_slice(&[Some(1), Some(1), Some(3)]).unwrap(); let eq = g.detect_groups_concrete(&input)?; @@ -334,7 +334,7 @@ mod tests { g.output_concrete(&mut builder)?; assert_eq!( - builder.finish()?.iter().collect::>(), + builder.finish().unwrap().iter().collect::>(), vec![Some(1), Some(3), Some(4)] ); Ok(()) @@ -352,11 +352,11 @@ mod tests { #[test] fn vec_agg_group() -> Result<()> { let mut g0 = GeneralSortedGrouper::::new(false, None); - let mut g0_builder = I32ArrayBuilder::new(0)?; + let mut g0_builder = I32ArrayBuilder::new(0).unwrap(); let mut g1 = GeneralSortedGrouper::::new(false, None); - let mut g1_builder = I32ArrayBuilder::new(0)?; + let mut g1_builder = I32ArrayBuilder::new(0).unwrap(); let mut a = GeneralAgg::::new(DataType::Int64, 0, sum, None); - let mut a_builder = I64ArrayBuilder::new(0)?; + let mut a_builder = I64ArrayBuilder::new(0).unwrap(); let g0_input = I32Array::from_slice(&[Some(1), Some(1), Some(3)]).unwrap(); let eq0 = g0.detect_groups_concrete(&g0_input)?; @@ -382,15 +382,15 @@ mod tests { g1.output_concrete(&mut g1_builder)?; a.output_concrete(&mut a_builder)?; assert_eq!( - g0_builder.finish()?.iter().collect::>(), + g0_builder.finish().unwrap().iter().collect::>(), vec![Some(1), Some(1), Some(3), Some(4)] ); assert_eq!( - g1_builder.finish()?.iter().collect::>(), + g1_builder.finish().unwrap().iter().collect::>(), vec![Some(7), Some(8), Some(8), Some(8)] ); assert_eq!( - a_builder.finish()?.iter().collect::>(), + a_builder.finish().unwrap().iter().collect::>(), vec![Some(1), Some(2), Some(4), Some(5)] ); Ok(()) diff --git a/src/expr/src/vector_op/arithmetic_op.rs b/src/expr/src/vector_op/arithmetic_op.rs index 99a89219ceb2..3ba9eef710cb 100644 --- a/src/expr/src/vector_op/arithmetic_op.rs +++ b/src/expr/src/vector_op/arithmetic_op.rs @@ -17,14 +17,13 @@ use std::convert::TryInto; use std::fmt::Debug; use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Signed}; -use risingwave_common::error::ErrorCode::{InternalError, NumericValueOutOfRange}; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{ CheckedAdd as NaiveDateTimeCheckedAdd, Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, }; use super::cast::date_to_timestamp; +use crate::{ExprError, Result}; #[inline(always)] pub fn general_add(l: T1, r: T2) -> Result @@ -34,8 +33,7 @@ where T3: CheckedAdd, { general_atm(l, r, |a, b| { - a.checked_add(&b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_add(&b).ok_or(ExprError::NumericOutOfRange) }) } @@ -47,8 +45,7 @@ where T3: CheckedSub, { general_atm(l, r, |a, b| { - a.checked_sub(&b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_sub(&b).ok_or(ExprError::NumericOutOfRange) }) } @@ -60,8 +57,7 @@ where T3: CheckedMul, { general_atm(l, r, |a, b| { - a.checked_mul(&b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_mul(&b).ok_or(ExprError::NumericOutOfRange) }) } @@ -73,8 +69,7 @@ where T3: CheckedDiv, { general_atm(l, r, |a, b| { - a.checked_div(&b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_div(&b).ok_or(ExprError::NumericOutOfRange) }) } @@ -86,15 +81,13 @@ where T3: CheckedRem, { general_atm(l, r, |a, b| { - a.checked_rem(&b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_rem(&b).ok_or(ExprError::NumericOutOfRange) }) } #[inline(always)] pub fn general_neg(expr: T1) -> Result { - expr.checked_neg() - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + expr.checked_neg().ok_or(ExprError::NumericOutOfRange) } #[inline(always)] @@ -118,20 +111,12 @@ where F: FnOnce(T3, T3) -> Result, { // TODO: We need to improve the error message - let l: T3 = l.try_into().map_err(|_| { - RwError::from(InternalError(format!( - "Can't convert {} to {}", - type_name::(), - type_name::() - ))) - })?; - let r: T3 = r.try_into().map_err(|_| { - RwError::from(InternalError(format!( - "Can't convert {} to {}", - type_name::(), - type_name::() - ))) - })?; + let l: T3 = l + .try_into() + .map_err(|_| ExprError::Cast(type_name::(), type_name::()))?; + let r: T3 = r + .try_into() + .map_err(|_| ExprError::Cast(type_name::(), type_name::()))?; atm(l, r) } @@ -154,7 +139,7 @@ pub fn interval_timestamp_add( l: IntervalUnit, r: NaiveDateTimeWrapper, ) -> Result { - r.checked_add(l) + r.checked_add(l).map_err(ExprError::Array) } #[inline(always)] @@ -202,8 +187,7 @@ pub fn interval_int_mul(l: IntervalUnit, r: T2) -> Result + Debug, { - l.checked_mul_int(r) - .ok_or_else(|| NumericValueOutOfRange.into()) + l.checked_mul_int(r).ok_or(ExprError::NumericOutOfRange) } #[inline(always)] diff --git a/src/expr/src/vector_op/array_access.rs b/src/expr/src/vector_op/array_access.rs index 5ef8f9d417c2..0e50644bb9bd 100644 --- a/src/expr/src/vector_op/array_access.rs +++ b/src/expr/src/vector_op/array_access.rs @@ -13,17 +13,18 @@ // limitations under the License. use risingwave_common::array::ListRef; -use risingwave_common::error::Result; use risingwave_common::types::{Scalar, ToOwnedDatum}; +use crate::{ExprError, Result}; + #[inline(always)] pub fn array_access(l: Option, r: Option) -> Result> { match (l, r) { // index must be greater than 0 following a one-based numbering convention for arrays (Some(list), Some(index)) if index > 0 => { - let datumref = list.value_at(index as usize)?; + let datumref = list.value_at(index as usize).map_err(ExprError::Array)?; if let Some(scalar) = datumref.to_owned_datum() { - Ok(Some(scalar.try_into()?)) + Ok(Some(scalar.try_into().map_err(ExprError::Array)?)) } else { Ok(None) } @@ -49,10 +50,10 @@ mod tests { ]); let l1 = ListRef::ValueRef { val: &v1 }; - assert_eq!(array_access::(Some(l1), Some(1)), Ok(Some(1))); - assert_eq!(array_access::(Some(l1), Some(-1)), Ok(None)); - assert_eq!(array_access::(Some(l1), Some(0)), Ok(None)); - assert_eq!(array_access::(Some(l1), Some(4)), Ok(None)); + assert_eq!(array_access::(Some(l1), Some(1)).unwrap(), Some(1)); + assert_eq!(array_access::(Some(l1), Some(-1)).unwrap(), None); + assert_eq!(array_access::(Some(l1), Some(0)).unwrap(), None); + assert_eq!(array_access::(Some(l1), Some(4)).unwrap(), None); } #[test] @@ -74,16 +75,16 @@ mod tests { let l3 = ListRef::ValueRef { val: &v3 }; assert_eq!( - array_access::(Some(l1), Some(1)), - Ok(Some("来自".into())) + array_access::(Some(l1), Some(1)).unwrap(), + Some("来自".into()) ); assert_eq!( - array_access::(Some(l2), Some(2)), - Ok(Some("荷兰".into())) + array_access::(Some(l2), Some(2)).unwrap(), + Some("荷兰".into()) ); assert_eq!( - array_access::(Some(l3), Some(3)), - Ok(Some("的爱".into())) + array_access::(Some(l3), Some(3)).unwrap(), + Some("的爱".into()) ); } @@ -101,11 +102,11 @@ mod tests { ]); let l = ListRef::ValueRef { val: &v }; assert_eq!( - array_access::(Some(l), Some(1)), - Ok(Some(ListValue::new(vec![ + array_access::(Some(l), Some(1)).unwrap(), + Some(ListValue::new(vec![ Some(ScalarImpl::Utf8("foo".into())), Some(ScalarImpl::Utf8("bar".into())), - ]))) + ])) ); } } diff --git a/src/expr/src/vector_op/ascii.rs b/src/expr/src/vector_op/ascii.rs index 5ccf80cc51a6..0156325399b5 100644 --- a/src/expr/src/vector_op/ascii.rs +++ b/src/expr/src/vector_op/ascii.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; +use crate::Result; #[inline(always)] pub fn ascii(s: &str) -> Result { @@ -22,11 +22,12 @@ pub fn ascii(s: &str) -> Result { #[cfg(test)] mod tests { use super::*; + #[test] fn test_ascii() { let cases = [("hello", 104), ("你好", 228), ("", 0)]; for (s, expected) in cases { - assert_eq!(ascii(s), Ok(expected)) + assert_eq!(ascii(s).unwrap(), expected) } } } diff --git a/src/expr/src/vector_op/bitwise_op.rs b/src/expr/src/vector_op/bitwise_op.rs index 053100039d47..6abebf1e34ab 100644 --- a/src/expr/src/vector_op/bitwise_op.rs +++ b/src/expr/src/vector_op/bitwise_op.rs @@ -17,10 +17,9 @@ use std::fmt::Debug; use std::ops::{BitAnd, BitOr, BitXor, Not}; use num_traits::{CheckedShl, CheckedShr}; -use risingwave_common::error::ErrorCode::{InternalError, NumericValueOutOfRange}; -use risingwave_common::error::{Result, RwError}; use crate::vector_op::arithmetic_op::general_atm; +use crate::{ExprError, Result}; // Conscious decision for shl and shr is made here to diverge from PostgreSQL. // If overflow happens, instead of truncated to zero, we return overflow error as this is @@ -34,8 +33,7 @@ where T2: TryInto + Debug, { general_shift(l, r, |a, b| { - a.checked_shl(b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_shl(b).ok_or(ExprError::NumericOutOfRange) }) } @@ -46,8 +44,7 @@ where T2: TryInto + Debug, { general_shift(l, r, |a, b| { - a.checked_shr(b) - .ok_or_else(|| RwError::from(NumericValueOutOfRange)) + a.checked_shr(b).ok_or(ExprError::NumericOutOfRange) }) } @@ -59,13 +56,9 @@ where F: FnOnce(T1, u32) -> Result, { // TODO: We need to improve the error message - let r: u32 = r.try_into().map_err(|_| { - RwError::from(InternalError(format!( - "Can't convert {} to {}", - type_name::(), - type_name::() - ))) - })?; + let r: u32 = r + .try_into() + .map_err(|_| ExprError::Cast(type_name::(), type_name::()))?; atm(l, r) } diff --git a/src/expr/src/vector_op/cast.rs b/src/expr/src/vector_op/cast.rs index 6320cac42079..f1644c5a4da5 100644 --- a/src/expr/src/vector_op/cast.rs +++ b/src/expr/src/vector_op/cast.rs @@ -12,18 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::convert::From; use std::any::type_name; use std::str::FromStr; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; use num_traits::ToPrimitive; -use risingwave_common::error::ErrorCode::{InternalError, InvalidInputSyntax}; -use risingwave_common::error::{parse_error, Result, RwError}; use risingwave_common::types::{ Decimal, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, OrderedF32, OrderedF64, }; +use crate::{ExprError, Result}; + /// String literals for bool type. /// /// See [`https://www.postgresql.org/docs/9.5/datatype-boolean.html`] @@ -61,7 +60,7 @@ pub fn str_to_str(n: &str) -> Result { pub fn str_to_date(elem: &str) -> Result { Ok(NaiveDateWrapper::new( NaiveDate::parse_from_str(elem, "%Y-%m-%d") - .map_err(|_| parse_error(PARSE_ERROR_STR_TO_DATE))?, + .map_err(|_| ExprError::Parse(PARSE_ERROR_STR_TO_DATE))?, )) } @@ -73,7 +72,7 @@ pub fn str_to_time(elem: &str) -> Result { if let Ok(time) = NaiveTime::parse_from_str(elem, "%H:%M") { return Ok(NaiveTimeWrapper::new(time)); } - Err(parse_error(PARSE_ERROR_STR_TO_TIME)) + Err(ExprError::Parse(PARSE_ERROR_STR_TO_TIME)) } #[inline(always)] @@ -87,18 +86,14 @@ pub fn str_to_timestamp(elem: &str) -> Result { if let Ok(date) = NaiveDate::parse_from_str(elem, "%Y-%m-%d") { return Ok(NaiveDateTimeWrapper::new(date.and_hms(0, 0, 0))); } - Err(parse_error(PARSE_ERROR_STR_TO_TIMESTAMP)) + Err(ExprError::Parse(PARSE_ERROR_STR_TO_TIMESTAMP)) } #[inline(always)] pub fn str_to_timestampz(elem: &str) -> Result { DateTime::parse_from_str(elem, "%Y-%m-%d %H:%M:%S %:z") .map(|ret| ret.timestamp_nanos() / 1000) - .map_err(|_| { - parse_error( - "Can't cast string to timestamp (expected format is YYYY-MM-DD HH:MM:SS[.MS])", - ) - }) + .map_err(|_| ExprError::Parse(PARSE_ERROR_STR_TO_TIMESTAMP)) } #[inline(always)] @@ -107,14 +102,8 @@ where T: FromStr, ::Err: std::fmt::Display, { - elem.parse().map_err(|e| { - RwError::from(InternalError(format!( - "Can't cast {:?} to {:?}: {}", - elem, - type_name::(), - e - ))) - }) + elem.parse() + .map_err(|_| ExprError::Cast(type_name::(), type_name::())) } #[inline(always)] @@ -144,11 +133,10 @@ macro_rules! define_cast_to_primitive { { elem.[]() .ok_or_else(|| { - RwError::from(InternalError(format!( - "Can't cast {:?} to {}", - elem, + ExprError::Cast( + std::any::type_name::(), std::any::type_name::<$ty>() - ))) + ) }) .map(Into::into) } @@ -185,14 +173,8 @@ where T1: TryInto + std::fmt::Debug + Copy, >::Error: std::fmt::Display, { - elem.try_into().map_err(|e| { - RwError::from(InternalError(format!( - "Can't cast {:?} to {:?}: {}", - &elem, - type_name::(), - e - ))) - }) + elem.try_into() + .map_err(|_| ExprError::Cast(std::any::type_name::(), std::any::type_name::())) } #[inline(always)] @@ -209,7 +191,7 @@ pub fn str_to_bool(input: &str) -> Result { { Ok(false) } else { - Err(InvalidInputSyntax(format!("'{}' is not a valid bool", input)).into()) + Err(ExprError::Parse("Invalid bool")) } } @@ -238,7 +220,7 @@ mod tests { str_to_timestamp("1999-01-08 04:05:06AA") .unwrap_err() .to_string(), - parse_error(PARSE_ERROR_STR_TO_TIMESTAMP).to_string() + ExprError::Parse(PARSE_ERROR_STR_TO_TIMESTAMP).to_string() ); assert_eq!( str_to_date("1999-01-08AA").unwrap_err().to_string(), @@ -246,7 +228,7 @@ mod tests { ); assert_eq!( str_to_time("AA04:05:06").unwrap_err().to_string(), - parse_error(PARSE_ERROR_STR_TO_TIME).to_string() + ExprError::Parse(PARSE_ERROR_STR_TO_TIME).to_string() ); } diff --git a/src/expr/src/vector_op/cmp.rs b/src/expr/src/vector_op/cmp.rs index a51082cf10ea..fe56bcc51102 100644 --- a/src/expr/src/vector_op/cmp.rs +++ b/src/expr/src/vector_op/cmp.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::convert::From; use std::any::type_name; use std::fmt::Debug; use risingwave_common::array::{ListRef, StructRef}; -use risingwave_common::error::ErrorCode::InternalError; -use risingwave_common::error::{Result, RwError}; + +use crate::{ExprError, Result}; fn general_cmp(l: T1, r: T2, cmp: F) -> Result where @@ -28,20 +27,12 @@ where F: FnOnce(T3, T3) -> bool, { // TODO: We need to improve the error message - let l: T3 = l.try_into().map_err(|_| { - RwError::from(InternalError(format!( - "Can't convert {} to {}", - type_name::(), - type_name::() - ))) - })?; - let r: T3 = r.try_into().map_err(|_| { - RwError::from(InternalError(format!( - "Can't convert {} to {}", - type_name::(), - type_name::() - ))) - })?; + let l: T3 = l + .try_into() + .map_err(|_| ExprError::Cast(type_name::(), type_name::()))?; + let r: T3 = r + .try_into() + .map_err(|_| ExprError::Cast(type_name::(), type_name::()))?; Ok(cmp(l, r)) } diff --git a/src/expr/src/vector_op/conjunction.rs b/src/expr/src/vector_op/conjunction.rs index 1df1070415b9..58cd8096016f 100644 --- a/src/expr/src/vector_op/conjunction.rs +++ b/src/expr/src/vector_op/conjunction.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; +use crate::Result; #[inline(always)] pub fn and(l: Option, r: Option) -> Result> { diff --git a/src/expr/src/vector_op/extract.rs b/src/expr/src/vector_op/extract.rs index 2ead0911e622..758b9a158987 100644 --- a/src/expr/src/vector_op/extract.rs +++ b/src/expr/src/vector_op/extract.rs @@ -13,10 +13,10 @@ // limitations under the License. use chrono::{Datelike, Timelike}; -use risingwave_common::error::ErrorCode::InternalError; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{Decimal, NaiveDateTimeWrapper, NaiveDateWrapper}; +use crate::{bail, Result}; + fn extract_time(time: T, time_unit: &str) -> Result where T: Timelike, @@ -25,10 +25,7 @@ where "HOUR" => Ok(time.hour().into()), "MINUTE" => Ok(time.minute().into()), "SECOND" => Ok(time.second().into()), - _ => Err(RwError::from(InternalError(format!( - "Unsupported time unit {} in extract function", - time_unit - )))), + _ => bail!("Unsupported time unit {} in extract function", time_unit), } } @@ -43,10 +40,7 @@ where // Sun = 0 and Sat = 6 "DOW" => Ok(date.weekday().num_days_from_sunday().into()), "DOY" => Ok(date.ordinal().into()), - _ => Err(RwError::from(InternalError(format!( - "Unsupported time unit {} in extract function", - time_unit - )))), + _ => bail!("Unsupported time unit {} in extract function", time_unit), } } diff --git a/src/expr/src/vector_op/length.rs b/src/expr/src/vector_op/length.rs index 0830daa561cc..1d600dbbcb90 100644 --- a/src/expr/src/vector_op/length.rs +++ b/src/expr/src/vector_op/length.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; +use crate::Result; #[inline(always)] pub fn length_default(s: &str) -> Result { @@ -26,10 +26,10 @@ mod tests { #[test] fn test_length() { - let cases = [("hello world", Ok(11)), ("hello rust", Ok(10))]; + let cases = [("hello world", 11), ("hello rust", 10)]; for (s, expected) in cases { - assert_eq!(length_default(s), expected) + assert_eq!(length_default(s).unwrap(), expected) } } } diff --git a/src/expr/src/vector_op/like.rs b/src/expr/src/vector_op/like.rs index af0667461db9..be3829e18ad4 100644 --- a/src/expr/src/vector_op/like.rs +++ b/src/expr/src/vector_op/like.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; +use crate::Result; #[inline(always)] pub fn like_default(s: &str, p: &str) -> Result { diff --git a/src/expr/src/vector_op/lower.rs b/src/expr/src/vector_op/lower.rs index 794b586e0d84..cfe28e79789d 100644 --- a/src/expr/src/vector_op/lower.rs +++ b/src/expr/src/vector_op/lower.rs @@ -13,11 +13,14 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn lower(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(&s.to_lowercase()) + writer + .write_ref(&s.to_lowercase()) + .map_err(ExprError::Array) } #[cfg(test)] @@ -35,10 +38,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = lower(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/ltrim.rs b/src/expr/src/vector_op/ltrim.rs index 29ac5e27c1fb..c7a5dacd58a5 100644 --- a/src/expr/src/vector_op/ltrim.rs +++ b/src/expr/src/vector_op/ltrim.rs @@ -13,14 +13,15 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; /// Note: the behavior of `ltrim` in `PostgreSQL` and `trim_start` (or `trim_left`) in Rust /// are actually different when the string is in right-to-left languages like Arabic or Hebrew. /// Since we would like to simplify the implementation, currently we omit this case. #[inline(always)] pub fn ltrim(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(s.trim_start()) + writer.write_ref(s.trim_start()).map_err(ExprError::Array) } #[cfg(test)] @@ -37,10 +38,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = ltrim(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/md5.rs b/src/expr/src/vector_op/md5.rs index 4876fc52b1d6..207cf030e45b 100644 --- a/src/expr/src/vector_op/md5.rs +++ b/src/expr/src/vector_op/md5.rs @@ -14,11 +14,14 @@ use md5 as lib_md5; use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn md5(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(&format!("{:x}", lib_md5::compute(s))) + writer + .write_ref(&format!("{:x}", lib_md5::compute(s))) + .map_err(ExprError::Array) } #[cfg(test)] @@ -39,10 +42,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = md5(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/position.rs b/src/expr/src/vector_op/position.rs index c261c6bff243..aa7c20f7d200 100644 --- a/src/expr/src/vector_op/position.rs +++ b/src/expr/src/vector_op/position.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; +use crate::Result; #[inline(always)] /// Location of specified substring @@ -33,14 +33,13 @@ mod tests { #[test] fn test_length() { let cases = [ - ("hello world", "world", Ok(7)), - ("床前明月光", "月光", Ok(4)), - ("床前明月光", "故乡", Ok(0)), + ("hello world", "world", 7), + ("床前明月光", "月光", 4), + ("床前明月光", "故乡", 0), ]; for (str, sub_str, expected) in cases { - println!("position is {}", position(str, sub_str).unwrap()); - assert_eq!(position(str, sub_str), expected) + assert_eq!(position(str, sub_str).unwrap(), expected) } } } diff --git a/src/expr/src/vector_op/replace.rs b/src/expr/src/vector_op/replace.rs index 717f54b09e3b..92be28a7b0de 100644 --- a/src/expr/src/vector_op/replace.rs +++ b/src/expr/src/vector_op/replace.rs @@ -13,23 +13,26 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn replace(s: &str, from_str: &str, to_str: &str, writer: BytesWriter) -> Result { if from_str.is_empty() { - return writer.write_ref(s); + return writer.write_ref(s).map_err(ExprError::Array); } let mut last = 0; let mut writer = writer.begin(); while let Some(mut start) = s[last..].find(from_str) { start += last; - writer.write_ref(&s[last..start])?; - writer.write_ref(to_str)?; + writer + .write_ref(&s[last..start]) + .map_err(ExprError::Array)?; + writer.write_ref(to_str).map_err(ExprError::Array)?; last = start + from_str.len(); } - writer.write_ref(&s[last..])?; - writer.finish() + writer.write_ref(&s[last..]).map_err(ExprError::Array)?; + writer.finish().map_err(ExprError::Array) } #[cfg(test)] diff --git a/src/expr/src/vector_op/round.rs b/src/expr/src/vector_op/round.rs index fdc44865ce4b..bb959af4efa5 100644 --- a/src/expr/src/vector_op/round.rs +++ b/src/expr/src/vector_op/round.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::Result; use risingwave_common::types::{Decimal, OrderedF64}; +use crate::Result; + #[inline(always)] pub fn round_digits>(input: Decimal, digits: D) -> Result { let digits = digits.into(); diff --git a/src/expr/src/vector_op/rtrim.rs b/src/expr/src/vector_op/rtrim.rs index 23936bcbc11a..9b746339092c 100644 --- a/src/expr/src/vector_op/rtrim.rs +++ b/src/expr/src/vector_op/rtrim.rs @@ -13,14 +13,15 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; /// Note: the behavior of `rtrim` in `PostgreSQL` and `trim_end` (or `trim_right`) in Rust /// are actually different when the string is in right-to-left languages like Arabic or Hebrew. /// Since we would like to simplify the implementation, currently we omit this case. #[inline(always)] pub fn rtrim(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(s.trim_end()) + writer.write_ref(s.trim_end()).map_err(ExprError::Array) } #[cfg(test)] @@ -37,10 +38,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = rtrim(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/split_part.rs b/src/expr/src/vector_op/split_part.rs index c88aecd0e6ca..cbb02488bafc 100644 --- a/src/expr/src/vector_op/split_part.rs +++ b/src/expr/src/vector_op/split_part.rs @@ -13,7 +13,8 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::{ErrorCode, Result, RwError}; + +use crate::{ExprError, Result}; #[inline(always)] pub fn split_part( @@ -23,9 +24,10 @@ pub fn split_part( writer: BytesWriter, ) -> Result { if nth_expr == 0 { - return Err(RwError::from(ErrorCode::InvalidParameterValue( - "field position must not be zero".into(), - ))); + return Err(ExprError::InvalidParam { + name: "data", + reason: "can't be zero".to_string(), + }); }; let mut split = string_expr.split(delimiter_expr); @@ -42,12 +44,7 @@ pub fn split_part( } } else { match nth_expr.cmp(&0) { - std::cmp::Ordering::Equal => { - return Err(RwError::from(ErrorCode::InternalError( - "Impossible happened, field position must not be zero already had been checked." - .into(), - ))); - } + std::cmp::Ordering::Equal => unreachable!(), // Since `nth_expr` can not be 0, so the `abs()` of it can not be smaller than 1 // (that's `abs(1)` or `abs(-1)`). Hence the result of sub 1 can not be less than 0. @@ -65,48 +62,40 @@ pub fn split_part( } }; - writer.write_ref(nth_val) + writer.write_ref(nth_val).map_err(ExprError::Array) } #[cfg(test)] mod tests { use risingwave_common::array::{Array, ArrayBuilder, Utf8ArrayBuilder}; - use risingwave_common::error::{ErrorCode, Result, RwError}; use super::split_part; #[test] fn test_split_part() { - let cases: Vec<(&str, &str, i32, Result<&str>)> = vec![ + let cases: Vec<(&str, &str, i32, Option<&str>)> = vec![ // postgres cases - ("", "@", 1, Ok("")), - ("", "@", -1, Ok("")), - ("joeuser@mydatabase", "", 1, Ok("joeuser@mydatabase")), - ("joeuser@mydatabase", "", 2, Ok("")), - ("joeuser@mydatabase", "", -1, Ok("joeuser@mydatabase")), - ("joeuser@mydatabase", "", -2, Ok("")), - ( - "joeuser@mydatabase", - "@", - 0, - Err(RwError::from(ErrorCode::InvalidParameterValue( - "field position must not be zero".into(), - ))), - ), - ("joeuser@mydatabase", "@@", 1, Ok("joeuser@mydatabase")), - ("joeuser@mydatabase", "@@", 2, Ok("")), - ("joeuser@mydatabase", "@", 1, Ok("joeuser")), - ("joeuser@mydatabase", "@", 2, Ok("mydatabase")), - ("joeuser@mydatabase", "@", 3, Ok("")), - ("@joeuser@mydatabase@", "@", 2, Ok("joeuser")), - ("joeuser@mydatabase", "@", -1, Ok("mydatabase")), - ("joeuser@mydatabase", "@", -2, Ok("joeuser")), - ("joeuser@mydatabase", "@", -3, Ok("")), - ("@joeuser@mydatabase@", "@", -2, Ok("mydatabase")), + ("", "@", 1, Some("")), + ("", "@", -1, Some("")), + ("joeuser@mydatabase", "", 1, Some("joeuser@mydatabase")), + ("joeuser@mydatabase", "", 2, Some("")), + ("joeuser@mydatabase", "", -1, Some("joeuser@mydatabase")), + ("joeuser@mydatabase", "", -2, Some("")), + ("joeuser@mydatabase", "@", 0, None), + ("joeuser@mydatabase", "@@", 1, Some("joeuser@mydatabase")), + ("joeuser@mydatabase", "@@", 2, Some("")), + ("joeuser@mydatabase", "@", 1, Some("joeuser")), + ("joeuser@mydatabase", "@", 2, Some("mydatabase")), + ("joeuser@mydatabase", "@", 3, Some("")), + ("@joeuser@mydatabase@", "@", 2, Some("joeuser")), + ("joeuser@mydatabase", "@", -1, Some("mydatabase")), + ("joeuser@mydatabase", "@", -2, Some("joeuser")), + ("joeuser@mydatabase", "@", -3, Some("")), + ("@joeuser@mydatabase@", "@", -2, Some("mydatabase")), // other cases // makes sure that `rsplit` is not used internally when `nth` is negative - ("@@@", "@@", -1, Ok("@")), + ("@@@", "@@", -1, Some("@")), ]; for (i, case @ (string_expr, delimiter_expr, nth_expr, expected)) in @@ -118,17 +107,15 @@ mod tests { match actual { Ok(guard) => { - let expected = expected.clone().unwrap(); + let expected = expected.unwrap(); let array = guard.into_inner().finish().unwrap(); let actual = array.value_at(0).unwrap(); assert_eq!(expected, actual, "\nat case {i}: {:?}\n", case) } - Err(err) => { - let expected = expected.clone().unwrap_err().to_string(); - let actual = err.to_string(); - assert_eq!(expected, actual, "\nat case {i}: {:?}\n", case) + Err(_err) => { + assert!(expected.is_none()); } }; } diff --git a/src/expr/src/vector_op/substr.rs b/src/expr/src/vector_op/substr.rs index d5831e4dc09d..3c712cecce8d 100644 --- a/src/expr/src/vector_op/substr.rs +++ b/src/expr/src/vector_op/substr.rs @@ -15,18 +15,19 @@ use std::cmp::{max, min}; use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::{ErrorCode, Result}; + +use crate::{bail, ExprError, Result}; #[inline(always)] pub fn substr_start(s: &str, start: i32, writer: BytesWriter) -> Result { let start = min(max(start - 1, 0) as usize, s.len()); - writer.write_ref(&s[start..]) + writer.write_ref(&s[start..]).map_err(ExprError::Array) } #[inline(always)] pub fn substr_for(s: &str, count: i32, writer: BytesWriter) -> Result { let end = min(count as usize, s.len()); - writer.write_ref(&s[..end]) + writer.write_ref(&s[..end]).map_err(ExprError::Array) } #[inline(always)] @@ -37,15 +38,11 @@ pub fn substr_start_for( writer: BytesWriter, ) -> Result { if count < 0 { - return Err(ErrorCode::InvalidInputSyntax(format!( - "length in substr should be non-negative: {}", - count - )) - .into()); + bail!("length in substr should be non-negative: {}", count); } let begin = max(start - 1, 0) as usize; let end = min(max(start - 1 + count, 0) as usize, s.len()); - writer.write_ref(&s[begin..end]) + writer.write_ref(&s[begin..end]).map_err(ExprError::Array) } #[cfg(test)] @@ -68,7 +65,7 @@ mod tests { ]; for (s, off, len, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = match (off, len) { (Some(off), Some(len)) => { @@ -84,7 +81,7 @@ mod tests { (None, Some(len)) => substr_for(&s, len, writer)?, _ => unreachable!(), }; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/tests.rs b/src/expr/src/vector_op/tests.rs index 165cb03c0b47..17566bda7989 100644 --- a/src/expr/src/vector_op/tests.rs +++ b/src/expr/src/vector_op/tests.rs @@ -16,7 +16,6 @@ use std::assert_matches::assert_matches; use std::str::FromStr; use chrono::{NaiveDate, NaiveDateTime}; -use risingwave_common::error::ErrorCode::NumericValueOutOfRange; use risingwave_common::types::{ Decimal, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, OrderedF32, OrderedF64, }; @@ -26,6 +25,7 @@ use crate::vector_op::bitwise_op::*; use crate::vector_op::cast::date_to_timestamp; use crate::vector_op::cmp::*; use crate::vector_op::conjunction::*; +use crate::ExprError; #[test] fn test_arithmetic() { @@ -146,28 +146,28 @@ fn test_bitwise() { assert_eq!(general_shl::(1i32, 0i32).unwrap(), 1i32); assert_eq!(general_shl::(1i64, 31i32).unwrap(), 2147483648i64); assert_matches!( - general_shl::(1i32, 32i32).unwrap_err().inner(), - NumericValueOutOfRange + general_shl::(1i32, 32i32).unwrap_err(), + ExprError::NumericOutOfRange, ); assert_eq!( general_shr::(-2147483648i64, 31i32).unwrap(), -1i64 ); - assert_eq!(general_shr::(1i64, 0i32), Ok(1i64)); + assert_eq!(general_shr::(1i64, 0i32).unwrap(), 1i64); // truth table assert_eq!( - general_bitand::(0b0011u32, 0b0101u32), - Ok(0b1u64) + general_bitand::(0b0011u32, 0b0101u32).unwrap(), + 0b1u64 ); assert_eq!( - general_bitor::(0b0011u32, 0b0101u32), - Ok(0b0111u64) + general_bitor::(0b0011u32, 0b0101u32).unwrap(), + 0b0111u64 ); assert_eq!( - general_bitxor::(0b0011u32, 0b0101u32), - Ok(0b0110u64) + general_bitxor::(0b0011u32, 0b0101u32).unwrap(), + 0b0110u64 ); - assert_eq!(general_bitnot::(0b01i32), Ok(-2i32)); + assert_eq!(general_bitnot::(0b01i32).unwrap(), -2i32); } #[test] diff --git a/src/expr/src/vector_op/to_char.rs b/src/expr/src/vector_op/to_char.rs index f590dc2ff8d9..1eaacae6f225 100644 --- a/src/expr/src/vector_op/to_char.rs +++ b/src/expr/src/vector_op/to_char.rs @@ -14,9 +14,10 @@ use aho_corasick::AhoCorasickBuilder; use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; use risingwave_common::types::NaiveDateTimeWrapper; +use crate::{ExprError, Result}; + /// Compile the pg pattern to chrono pattern. // TODO: Chrono can not fully support the pg format, so consider using other implementations later. fn compile_pattern_to_chrono(tmpl: &str) -> String { @@ -50,5 +51,5 @@ pub fn to_char_timestamp( ) -> Result { let chrono_tmpl = compile_pattern_to_chrono(tmpl); let res = data.0.format(&chrono_tmpl).to_string(); - dst.write_ref(&res) + dst.write_ref(&res).map_err(ExprError::Array) } diff --git a/src/expr/src/vector_op/translate.rs b/src/expr/src/vector_op/translate.rs index 5c06d84a57c0..b8fe35263402 100644 --- a/src/expr/src/vector_op/translate.rs +++ b/src/expr/src/vector_op/translate.rs @@ -15,7 +15,8 @@ use std::collections::HashMap; use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn translate( @@ -44,7 +45,7 @@ pub fn translate( None => Some(c), }); - writer.write_from_char_iter(iter) + writer.write_from_char_iter(iter).map_err(ExprError::Array) } #[cfg(test)] @@ -72,10 +73,10 @@ mod tests { ]; for (s, match_str, replace_str, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = translate(s, match_str, replace_str, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/trim.rs b/src/expr/src/vector_op/trim.rs index f06bd59a7c05..1510684b3861 100644 --- a/src/expr/src/vector_op/trim.rs +++ b/src/expr/src/vector_op/trim.rs @@ -13,11 +13,12 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn trim(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(s.trim()) + writer.write_ref(s.trim()).map_err(ExprError::Array) } #[cfg(test)] @@ -34,10 +35,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = trim(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/expr/src/vector_op/tumble.rs b/src/expr/src/vector_op/tumble.rs index 1dc6c94c7a6c..3ebe47377c80 100644 --- a/src/expr/src/vector_op/tumble.rs +++ b/src/expr/src/vector_op/tumble.rs @@ -13,11 +13,10 @@ // limitations under the License. use chrono::NaiveDateTime; -use risingwave_common::error::ErrorCode::InternalError; -use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper}; use super::cast::date_to_timestamp; +use crate::{ExprError, Result}; #[inline(always)] pub fn tumble_start_date( @@ -35,9 +34,10 @@ pub fn tumble_start_date_time( ) -> Result { let diff = time.0.timestamp(); if window.get_months() != 0 { - return Err(RwError::from(InternalError( - "unimplemented: tumble_start only support days or milliseconds".to_string(), - ))); + return Err(ExprError::InvalidParam { + name: "window", + reason: "unimplemented: tumble_start only support days or milliseconds".to_string(), + }); } let window = window.get_days() as i64 * 24 * 60 * 60 + window.get_ms() / 1000; let offset = diff / window; diff --git a/src/expr/src/vector_op/upper.rs b/src/expr/src/vector_op/upper.rs index e0efe4882e23..15aa388a1183 100644 --- a/src/expr/src/vector_op/upper.rs +++ b/src/expr/src/vector_op/upper.rs @@ -13,11 +13,14 @@ // limitations under the License. use risingwave_common::array::{BytesGuard, BytesWriter}; -use risingwave_common::error::Result; + +use crate::{ExprError, Result}; #[inline(always)] pub fn upper(s: &str, writer: BytesWriter) -> Result { - writer.write_ref(&s.to_uppercase()) + writer + .write_ref(&s.to_uppercase()) + .map_err(ExprError::Array) } #[cfg(test)] @@ -35,10 +38,10 @@ mod tests { ]; for (s, expected) in cases { - let builder = Utf8ArrayBuilder::new(1)?; + let builder = Utf8ArrayBuilder::new(1).unwrap(); let writer = builder.writer(); let guard = upper(s, writer)?; - let array = guard.into_inner().finish()?; + let array = guard.into_inner().finish().unwrap(); let v = array.value_at(0).unwrap(); assert_eq!(v, expected); } diff --git a/src/source/src/parser/common.rs b/src/source/src/parser/common.rs index 3679f5a55f92..41e835d18143 100644 --- a/src/source/src/parser/common.rs +++ b/src/source/src/parser/common.rs @@ -82,17 +82,11 @@ pub(crate) fn json_parse_value( } DataType::Date => match value.and_then(|v| v.as_str()) { None => Err(RwError::from(InternalError("parse error".to_string()))), - Some(date_str) => match str_to_date(date_str) { - Ok(date) => Ok(ScalarImpl::NaiveDate(date)), - Err(e) => Err(e), - }, + Some(date_str) => Ok(ScalarImpl::NaiveDate(str_to_date(date_str)?)), }, DataType::Timestamp => match value.and_then(|v| v.as_str()) { None => Err(RwError::from(InternalError("parse error".to_string()))), - Some(date_str) => match str_to_timestamp(date_str) { - Ok(timestamp) => Ok(ScalarImpl::NaiveDateTime(timestamp)), - Err(e) => Err(e), - }, + Some(date_str) => Ok(ScalarImpl::NaiveDateTime(str_to_timestamp(date_str)?)), }, _ => Err(ErrorCode::NotImplemented( "unsupported type for json_parse_value".to_string(), diff --git a/src/stream/src/from_proto/project.rs b/src/stream/src/from_proto/project.rs index bb783233ab56..ce5fda354505 100644 --- a/src/stream/src/from_proto/project.rs +++ b/src/stream/src/from_proto/project.rs @@ -27,11 +27,11 @@ impl ExecutorBuilder for ProjectExecutorBuilder { _stream: &mut LocalStreamManagerCore, ) -> Result { let node = try_match_expand!(node.get_node_body().unwrap(), NodeBody::Project)?; - let project_exprs = node + let project_exprs: Vec<_> = node .get_select_list() .iter() .map(build_from_prost) - .collect::>>()?; + .try_collect()?; Ok(ProjectExecutor::new( params.input.remove(0),