Skip to content

Commit

Permalink
use unified function names
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Mar 7, 2022
1 parent 2f66ace commit a260134
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 42 deletions.
28 changes: 14 additions & 14 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
for (const auto& groupingExpr : groupingExprs) {
// Velox's groupings are limited to be Field, so groupingExpr is
// expected to be FieldReference.
auto fieldExpr =
exprConverter_->toVeloxExpr(groupingExpr.selection(), inputPlanNodeId);
auto fieldExpr = exprConverter_->toVeloxExpr(
groupingExpr.selection(), inputPlanNodeId);
veloxGroupingExprs.emplace_back(fieldExpr);
outIdx += 1;
}
Expand Down Expand Up @@ -100,7 +100,7 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
auto aggExpr = std::make_shared<const core::CallTypedExpr>(
toVeloxType(aggOutType->type), std::move(aggParams), funcName);
aggExprs.emplace_back(aggExpr);

// Initialize the Aggregate Step.
if (!phaseInited) {
auto phase = aggFunction.phase();
Expand All @@ -121,7 +121,7 @@ std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
}
outIdx += 1;
}

// Construct the Aggregate Node.
bool ignoreNullKeys = false;
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> aggregateMasks(
Expand Down Expand Up @@ -351,7 +351,6 @@ std::string SubstraitVeloxPlanConverter::nextPlanNodeId() {
// TODO: Support different types here.
class FilterInfo {
public:

// Used to set the left bound.
void setLeft(double left, bool isExclusive) {
left_ = left;
Expand Down Expand Up @@ -413,8 +412,9 @@ connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toVeloxFilter(
flattenConditions(sFilter, scalarFunctions);
// Construct the FilterInfo for the related column.
for (const auto& scalarFunction : scalarFunctions) {
auto filterName = subParser_->findSubstraitFunction(
auto filterNameSpec = subParser_->findSubstraitFuncSpec(
functionMap_, scalarFunction.function_reference());
auto filterName = subParser_->getSubFunctionName(filterNameSpec);
int32_t colIdx;
// TODO: Add different types' support here.
double val;
Expand All @@ -439,15 +439,15 @@ connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toVeloxFilter(
"Substrait conversion not supported for arg type '{}'", typeCase);
}
}
if (filterName == "IS_NOT_NULL") {
if (filterName == "is_not_null") {
colInfoMap[colIdx]->forbidsNull();
} else if (filterName == "GREATER_THAN_OR_EQUAL") {
} else if (filterName == "gte") {
colInfoMap[colIdx]->setLeft(val, false);
} else if (filterName == "GREATER_THAN") {
} else if (filterName == "gt") {
colInfoMap[colIdx]->setLeft(val, true);
} else if (filterName == "LESS_THAN_OR_EQUAL") {
} else if (filterName == "lte") {
colInfoMap[colIdx]->setRight(val, false);
} else if (filterName == "LESS_THAN") {
} else if (filterName == "lt") {
colInfoMap[colIdx]->setRight(val, true);
} else {
VELOX_NYI(
Expand Down Expand Up @@ -498,10 +498,10 @@ void SubstraitVeloxPlanConverter::flattenConditions(
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kScalarFunction: {
auto sFunc = sFilter.scalar_function();
auto filterName = subParser_->findSubstraitFunction(
auto filterNameSpec = subParser_->findSubstraitFuncSpec(
functionMap_, sFunc.function_reference());
// TODO: Only AND relation is supported here.
if (filterName == "AND") {
// TODO: Only and relation is supported here.
if (subParser_->getSubFunctionName(filterNameSpec) == "and") {
for (const auto& sCondition : sFunc.args()) {
flattenConditions(sCondition, scalarFunctions);
}
Expand Down
23 changes: 17 additions & 6 deletions velox/substrait/SubstraitUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ std::string SubstraitParser::makeNodeName(int node_id, int col_idx) {
return fmt::format("n{}_{}", node_id, col_idx);
}

std::string SubstraitParser::findSubstraitFunction(
std::string SubstraitParser::findSubstraitFuncSpec(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) const {
if (functionMap.find(id) == functionMap.end()) {
Expand All @@ -131,20 +131,31 @@ std::string SubstraitParser::findSubstraitFunction(
return map[id];
}

std::string SubstraitParser::getSubFunctionName(
const std::string& subFuncSpec) const {
// Get the position of ":" in the function name.
std::size_t pos = subFuncSpec.find(":");
if (pos == std::string::npos) {
return subFuncSpec;
}
return subFuncSpec.substr(0, pos);
}

std::string SubstraitParser::findVeloxFunction(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) const {
std::string subFunc = findSubstraitFunction(functionMap, id);
std::string veloxFunc = mapToVeloxFunction(subFunc);
return veloxFunc;
std::string subFuncSpec = findSubstraitFuncSpec(functionMap, id);
std::string subFuncName = getSubFunctionName(subFuncSpec);
return mapToVeloxFunction(subFuncName);
}

std::string SubstraitParser::mapToVeloxFunction(
const std::string& subFunc) const {
if (substraitVeloxFunctionMap.find(subFunc) ==
substraitVeloxFunctionMap.end()) {
VELOX_FAIL(
"Could not find Substrait function {} in function map.", subFunc);
// If not finding the mapping from Substrait function name to Velox function
// name, the original Substrait function name will be used.
return subFunc;
}
std::unordered_map<std::string, std::string>& map =
const_cast<std::unordered_map<std::string, std::string>&>(
Expand Down
18 changes: 11 additions & 7 deletions velox/substrait/SubstraitUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

namespace facebook::velox::substrait {

/// This class contains some common funcitons used to parse Substrait components,
/// and convert them into recognizable representations.
/// This class contains some common functions used to parse Substrait
/// components, and convert them into recognizable representations.
class SubstraitParser {
public:
/// Used to store the type name and nullability.
Expand Down Expand Up @@ -59,12 +59,16 @@ class SubstraitParser {
/// a simple name or a compound name. The compound name format is:
/// <function name>:<short_arg_type0>_<short_arg_type1>_..._<short_arg_typeN>.
/// Currently, the input types in the function specification are not used. But
/// in the future, they should be used for the validation according the specifications
/// in Substrait yaml files.
/// in the future, they should be used for the validation according the
/// specifications in Substrait yaml files.
std::string findSubstraitFuncSpec(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) const;

/// This function is used to get the function name from the compound name.
/// When the input is a simple name, it will be returned.
std::string getSubFunctionName(const std::string& subFuncSpec) const;

/// Used to find the Velox function name according to the function id
/// from a pre-constructed function map.
std::string findVeloxFunction(
Expand All @@ -77,9 +81,9 @@ class SubstraitParser {
private:
/// Used for mapping Substrait function key words into Velox functions' key
/// words. Key: the Substrait function key word, Value: the Velox function key
/// word.
const std::unordered_map<std::string, std::string> substraitVeloxFunctionMap =
{{"MULTIPLY", "multiply"}, {"SUM", "sum"}};
/// word. For those functions with different names in Substrait and Velox,
/// a mapping relation should be added here.
const std::unordered_map<std::string, std::string> substraitVeloxFunctionMap;
};

} // namespace facebook::velox::substrait
2 changes: 1 addition & 1 deletion velox/substrait/tests/PlanConversionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class PlanConversionTest : public virtual HiveConnectorTestBase,
std::string absolutePath = "file://" + currentPath + path;
absolutePaths.emplace_back(absolutePath);
}

std::vector<u_int64_t> starts = planConverter->getStarts();
std::vector<u_int64_t> lengths = planConverter->getLengths();
// Construct the result iterator.
Expand Down
28 changes: 14 additions & 14 deletions velox/substrait/tests/sub.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,50 @@
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
"name": "IS_NOT_NULL"
"function_anchor": 4,
"name": "lte:fp64_fp64"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 2,
"name": "GREATER_THAN_OR_EQUAL"
"function_anchor": 5,
"name": "sum:opt_fp64"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 1,
"name": "AND"
"function_anchor": 3,
"name": "lt:fp64_fp64"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 4,
"name": "LESS_THAN_OR_EQUAL"
"function_anchor": 0,
"name": "is_not_null:fp64"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 5,
"name": "SUM"
"function_anchor": 1,
"name": "and:bool_bool"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 6,
"name": "MULTIPLY"
"function_anchor": 2,
"name": "gte:fp64_fp64"
}
},
{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 3,
"name": "LESS_THAN"
"function_anchor": 6,
"name": "multiply:opt_fp64_fp64"
}
}
],
Expand Down

0 comments on commit a260134

Please sign in to comment.