Skip to content

Commit

Permalink
feat(round): supports decimal place as second parameter (#3221)
Browse files Browse the repository at this point in the history
  • Loading branch information
aceforeverd authored Apr 21, 2023
1 parent 3ab08f0 commit 094087f
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cases/function/function/test_calculate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ cases:
floor(c5) as r5 from {0};
expect:
order: id
columns: ["id int", "r0 bigint", "r1 bigint", "r2 bigint", "r3 double", "r4 double","r5 double"]
columns: ["id int", "r0 bigint", "r1 bigint", "r2 bigint", "r3 float", "r4 double","r5 double"]
rows:
- [1, 1, 2, 2, 1.000000, 0.000000,1.0]
- [2, NULL, NULL, 2, NULL, NULL,0.0]
Expand Down
2 changes: 1 addition & 1 deletion cases/integration_test/function/test_calculate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ cases:
floor(c5) as r5 from {0};
expect:
order: id
columns: ["id int", "r0 bigint", "r1 bigint", "r2 bigint", "r3 double", "r4 double","r5 double"]
columns: ["id int", "r0 bigint", "r1 bigint", "r2 bigint", "r3 float", "r4 double","r5 double"]
rows:
- [1, 1, 2, 2, 1.000000, 0.000000,1.0]
- [2, NULL, NULL, 2, NULL, NULL,0.0]
Expand Down
77 changes: 71 additions & 6 deletions hybridse/src/codegen/udf_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,77 @@ TEST_F(UdfIRBuilderTest, PowerUdfTest) {
2147483648, 65536);
}

TEST_F(UdfIRBuilderTest, RoundUdfTest) {
CheckUdf<int32_t, int16_t>("round", round(5), 5);
CheckUdf<int32_t, int32_t>("round", round(65536), 65536);
CheckUdf<int64_t, int64_t>("round", round(2147483648), 2147483648);
CheckUdf<double, float>("round", roundf(0.5f), 0.5f);
CheckUdf<double, double>("round", round(0.5), 0.5);
TEST_F(UdfIRBuilderTest, RoundWithPositiveD) {
// We use string as expet result in case the inaccuracy of flaot points
std::initializer_list<std::pair<double, std::string>> cases = {
// before decimal position = 0
{0.5, "0.50"}, {0.12, "0.12"}, {0.123, "0.12"}, {0.1478, "0.15"},
{0, "0.00"}, {0.012, "0.01"}, {0.0012, "0.00"}, {0.0078, "0.01"},
// before decimal position > 0
{1.1, "1.10" }, {1.14, "1.14"}, {1.177, "1.18"}, {1.171, "1.17"},
{21.1, "21.10" }, {21.14, "21.14"}, {21.177, "21.18"}, {21.171, "21.17"},
{1889, "1889.00"},
};

for (auto& val : cases) {
// non-negative value
CheckUdf<double, double, int32_t>("round", std::stod(val.second), val.first, 2);
// negative value
std::string expect = "-" + val.second;
CheckUdf<double, double, int32_t>("round", std::stod(expect), -val.first, 2);
}

for (auto c : {1, 2, 3, 4, 5, 6}) {
CheckUdf<int64_t, int64_t, int32_t>("round", c, c, 2);
}
}

TEST_F(UdfIRBuilderTest, RoundWithNegD) {
// We use string as expet result in case the inaccuracy of flaot points
std::initializer_list<std::pair<double, std::string>> cases = {
{0.0, "0.0"}, {1.23, "0"}, {100.12, "100"}, {3712.55, "3700"}, {4488, "4500"},
{88, "100"}, {175.4, "200"}
};

for (auto& val : cases) {
// non-negative value
CheckUdf<double, double, int32_t>("round", std::stod(val.second), val.first, -2);
// negative value
std::string expect = "-" + val.second;
CheckUdf<double, double, int32_t>("round", std::stod(expect), -val.first, -2);
}

std::initializer_list<std::pair<int32_t, int32_t>> icases = {{0, 0}, {1, 0}, {55, 100}, {100, 100},
{145, 100}, {199, 200}, {2312, 2300}};
for (auto c : icases) {
CheckUdf<int32_t, int32_t, int32_t>("round", c.second, c.first, -2);
CheckUdf<int32_t, int32_t, int32_t>("round", -c.second, -c.first, -2);
}
}

TEST_F(UdfIRBuilderTest, RoundWithZeroD) {
std::initializer_list<std::pair<double, std::string>> cases = {
{1.12, "1"}, {1.5, "2"}, {1.77, "2"}, {0.0, "0"}, {88, "88"}
};
for (auto& val : cases) {
// non-negative value
CheckUdf<double, double, int32_t>("round", std::stod(val.second), val.first, 0);
CheckUdf<double, double>("round", std::stod(val.second), val.first);
// negative value
std::string expect = "-" + val.second;
CheckUdf<double, double, int32_t>("round", std::stod(expect), -val.first, 0);
CheckUdf<double, double>("round", std::stod(expect), -val.first);
}

std::initializer_list<int32_t> icases = {1, 2, 3, 4, 5, 100, 88};
for (auto& val : icases) {
// non-negative value
CheckUdf<int32_t, int32_t, int32_t>("round", val, val, 0);
CheckUdf<int32_t, int32_t>("round", val, val);
// negative value
CheckUdf<int32_t, int32_t, int32_t>("round", -val, -val, 0);
CheckUdf<int32_t, int32_t>("round", -val, -val);
}
}

TEST_F(UdfIRBuilderTest, SinUdfTest) {
Expand Down
42 changes: 28 additions & 14 deletions hybridse/src/udf/default_udf_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1625,33 +1625,47 @@ void DefaultUdfLibrary::InitMathUdf() {

RegisterExternalTemplate<v1::Round>("round")
.doc(R"(
@brief Return the nearest integer value to expr (in floating-point format),
rounding halfway cases away from zero, regardless of the current rounding mode.
@brief Returns expr rounded to d decimal places using HALF_UP rounding mode.
@param numeric_expr Expression evaluated to numeric
@param d Integer decimal place, omitted, default to 0
When `d` is a positive, `numeric_expr` is rounded to the number of decimal positions specified by `d`. When `d` is a negative , `numeric_expr` is rounded on the left side of the decimal point.
Return type is the same as the type first parameter.
Example:
@code{.sql}
SELECT round(1.23);
-- 1 (double type)
SELECT ROUND(1.23);
-- output 1
SELECT round(1.23, 1)
-- 1.2 (double type)
SELECT round(123, -1)
-- 120 (int32 type)
@endcode
@param expr
@since 0.1.0)")
.args_in<int64_t, double>();
RegisterExternalTemplate<v1::Round32>("round").args_in<int16_t, int32_t>();
RegisterExprUdf("round").args<AnyArg>(
[](UdfResolveContext* ctx, ExprNode* x) -> ExprNode* {
.args_in<int16_t, int32_t, int64_t, float, double>();

RegisterExprUdf("round").variadic_args<AnyArg>(
[](UdfResolveContext* ctx, ExprNode* x, const std::vector<ExprNode*>& other) -> ExprNode* {
if (!x->GetOutputType()->IsArithmetic()) {
ctx->SetError("round do not support type " +
x->GetOutputType()->GetName());
ctx->SetError("round do not support type " + x->GetOutputType()->GetName());
return nullptr;
}
auto nm = ctx->node_manager();
auto cast = nm->MakeCastNode(node::kDouble, x);
return nm->MakeFuncNode("round", {cast}, nullptr);
if (other.size() > 1) {
ctx->SetError("can't round with more than 2 parameters");
return nullptr;
}

node::ExprNode* decimal_place = nm->MakeConstNode(0);
if (!other.empty()) {
decimal_place = nm->MakeCastNode(node::kInt32, other.front());
}
return nm->MakeFuncNode("round", {x, decimal_place}, nullptr);
});

RegisterExternalTemplate<v1::Sqrt>("sqrt")
Expand Down
25 changes: 15 additions & 10 deletions hybridse/src/udf/udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,21 @@ struct Pow {

template <class V>
struct Round {
using Args = std::tuple<V>;

V operator()(V r) { return static_cast<V>(round(r)); }
};

template <class V>
struct Round32 {
using Args = std::tuple<V>;

int32_t operator()(V r) { return static_cast<int32_t>(round(r)); }
using Args = std::tuple<V, int32_t>;

V operator()(V val, int32_t decimal_number) {
if constexpr (std::is_integral_v<V>) {
if (decimal_number >= 0) {
return val;
} else {
double factor = std::pow(10, -decimal_number);
return static_cast<V>(std::round(val / factor) * factor);
}
} else {
// floats
return static_cast<V>(std::round(val * std::pow(10, decimal_number)) / std::pow(10, decimal_number));
}
}
};

template <class V>
Expand Down

0 comments on commit 094087f

Please sign in to comment.