Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-4652][VL] Fix min_by/max_by result mismatch when RDD partition num > 1 #5711

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,12 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}

test("min_by/max_by") {
withTempPath {
path =>
Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: Integer, 5: Integer))
.toDF("a", "b")
.write
.parquet(path.getCanonicalPath)
spark.read
.parquet(path.getCanonicalPath)
.createOrReplaceTempView("test")
runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
withSQLConf(("spark.sql.leafNodeDefaultParallelism", "2")) {
runQueryAndCompare(
"select min_by(a, b), max_by(a, b) from " +
"values (5, 6), (null, 11), (null, 5) test(a, b)") {
checkGlutenOperatorMatch[HashAggregateExecTransformer]
}
}
}

Expand Down
23 changes: 13 additions & 10 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
* limitations under the License.
*/
#include "operators/functions/RegistrationAllFunctions.h"

#include "operators/functions/Arithmetic.h"
#include "operators/functions/RowConstructorWithAllNull.h"
#include "operators/functions/RowConstructorWithNull.h"
#include "operators/functions/RowFunctionWithNull.h"

#include "velox/expression/SpecialFormRegistry.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/RegistrationHelpers.h"
Expand All @@ -45,29 +44,32 @@ void registerFunctionOverwrite() {
velox::registerFunction<RoundFunction, double, double, int32_t>({"round"});
velox::registerFunction<RoundFunction, float, float, int32_t>({"round"});

auto kRowConstructorWithNull = RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull;
velox::exec::registerVectorFunction(
"row_constructor_with_null",
kRowConstructorWithNull,
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull</*allNull=*/false>>(),
RowFunctionWithNull</*allNull=*/false>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>());
kRowConstructorWithNull, std::make_unique<RowConstructorWithNullCallToSpecialForm>(kRowConstructorWithNull));

auto kRowConstructorWithAllNull = RowConstructorWithNullCallToSpecialForm::kRowConstructorWithAllNull;
velox::exec::registerVectorFunction(
"row_constructor_with_all_null",
kRowConstructorWithAllNull,
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
std::make_unique<RowFunctionWithNull</*allNull=*/true>>(),
RowFunctionWithNull</*allNull=*/true>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull,
std::make_unique<RowConstructorWithAllNullCallToSpecialForm>());
kRowConstructorWithAllNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>(kRowConstructorWithAllNull));
velox::functions::sparksql::registerBitwiseFunctions("spark_");
}
} // namespace

void registerAllFunctions() {
// The registration order matters. Spark sql functions are registered after
// presto sql functions to overwrite the registration for same named functions.
// presto sql functions to overwrite the registration for same named
// functions.
velox::functions::prestosql::registerAllScalarFunctions();
velox::functions::sparksql::registerFunctions("");
velox::aggregate::prestosql::registerAllAggregateFunctions(
Expand All @@ -76,7 +78,8 @@ void registerAllFunctions() {
"", true /*registerCompanionFunctions*/, true /*overwrite*/);
velox::window::prestosql::registerAllWindowFunctions();
velox::functions::window::sparksql::registerWindowFunctions("");
// Using function overwrite to handle function names mismatch between Spark and Velox.
// Using function overwrite to handle function names mismatch between Spark
// and Velox.
registerFunctionOverwrite();
}

Expand Down
37 changes: 0 additions & 37 deletions cpp/velox/operators/functions/RowConstructorWithAllNull.h

This file was deleted.

10 changes: 1 addition & 9 deletions cpp/velox/operators/functions/RowConstructorWithNull.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ facebook::velox::TypePtr RowConstructorWithNullCallToSpecialForm::resolveType(
}

facebook::velox::exec::ExprPtr RowConstructorWithNullCallToSpecialForm::constructSpecialForm(
const std::string& name,
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) {
auto name = this->rowFunctionName;
auto [function, metadata] = facebook::velox::exec::vectorFunctionFactories().withRLock(
[&config, &name](auto& functionMap) -> std::pair<
std::shared_ptr<facebook::velox::exec::VectorFunction>,
Expand All @@ -52,12 +52,4 @@ facebook::velox::exec::ExprPtr RowConstructorWithNullCallToSpecialForm::construc
return std::make_shared<facebook::velox::exec::Expr>(
type, std::move(compiledChildren), function, metadata, name, trackCpuUsage);
}

facebook::velox::exec::ExprPtr RowConstructorWithNullCallToSpecialForm::constructSpecialForm(
const facebook::velox::TypePtr& type,
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config) {
return constructSpecialForm(kRowConstructorWithNull, type, std::move(compiledChildren), trackCpuUsage, config);
}
} // namespace gluten
8 changes: 8 additions & 0 deletions cpp/velox/operators/functions/RowConstructorWithNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
namespace gluten {
class RowConstructorWithNullCallToSpecialForm : public facebook::velox::exec::FunctionCallToSpecialForm {
public:
RowConstructorWithNullCallToSpecialForm(const std::string& rowFunctionName) {
this->rowFunctionName = rowFunctionName;
}

facebook::velox::TypePtr resolveType(const std::vector<facebook::velox::TypePtr>& argTypes) override;

facebook::velox::exec::ExprPtr constructSpecialForm(
Expand All @@ -32,6 +36,7 @@ class RowConstructorWithNullCallToSpecialForm : public facebook::velox::exec::Fu
const facebook::velox::core::QueryConfig& config) override;

static constexpr const char* kRowConstructorWithNull = "row_constructor_with_null";
static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null";

protected:
facebook::velox::exec::ExprPtr constructSpecialForm(
Expand All @@ -40,5 +45,8 @@ class RowConstructorWithNullCallToSpecialForm : public facebook::velox::exec::Fu
std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const facebook::velox::core::QueryConfig& config);

private:
std::string rowFunctionName;
};
} // namespace gluten
Loading