diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 4ba4f4fe45..406ee68f21 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -225,6 +225,10 @@ public static FunctionExpression sqrt(Expression... expressions) { return compile(BuiltinFunctionName.SQRT, expressions); } + public FunctionExpression cbrt(Expression... expressions) { + return compile(BuiltinFunctionName.CBRT, expressions); + } + public static FunctionExpression truncate(Expression... expressions) { return compile(BuiltinFunctionName.TRUNCATE, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 51d91eb372..cc3db47982 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -41,6 +41,7 @@ public enum BuiltinFunctionName { ROUND(FunctionName.of("round")), SIGN(FunctionName.of("sign")), SQRT(FunctionName.of("sqrt")), + CBRT(FunctionName.of("cbrt")), TRUNCATE(FunctionName.of("truncate")), ACOS(FunctionName.of("acos")), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 0ce48af48c..0e4df086fb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -52,6 +52,7 @@ public class MathematicalFunction { */ public static void register(BuiltinFunctionRepository repository) { repository.register(abs()); + repository.register(cbrt()); repository.register(ceil()); repository.register(ceiling()); repository.register(conv()); @@ -471,6 +472,20 @@ private static DefaultFunctionResolver sqrt() { DOUBLE, type)).collect(Collectors.toList())); } + /** + * Definition of cbrt(x) function. + * Calculate the cube root of a number x + * The supported signature is + * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + */ + private static DefaultFunctionResolver cbrt() { + return FunctionDSL.define(BuiltinFunctionName.CBRT.getName(), + ExprCoreType.numberTypes().stream() + .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( + v -> new ExprDoubleValue(Math.cbrt(v.doubleValue()))), + DOUBLE, type)).collect(Collectors.toList())); + } + /** * Definition of truncate(x, d) function. * Returns the number x, truncated to d decimal places diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 59e12a4155..ba44e7eacb 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -2327,4 +2327,79 @@ public void tan_missing_value() { assertEquals(DOUBLE, tan.type()); assertTrue(tan.valueOf(valueEnv()).isMissing()); } + + /** + * Test cbrt with int value. + */ + @ParameterizedTest(name = "cbrt({0})") + @ValueSource(ints = {1, 2}) + public void cbrt_int_value(Integer value) { + FunctionExpression cbrt = dsl.cbrt(DSL.literal(value)); + assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value)))); + assertEquals(String.format("cbrt(%s)", value), cbrt.toString()); + } + + /** + * Test cbrt with long value. + */ + @ParameterizedTest(name = "cbrt({0})") + @ValueSource(longs = {1L, 2L}) + public void cbrt_long_value(Long value) { + FunctionExpression cbrt = dsl.cbrt(DSL.literal(value)); + assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value)))); + assertEquals(String.format("cbrt(%s)", value), cbrt.toString()); + } + + /** + * Test cbrt with float value. + */ + @ParameterizedTest(name = "cbrt({0})") + @ValueSource(floats = {1F, 2F}) + public void cbrt_float_value(Float value) { + FunctionExpression cbrt = dsl.cbrt(DSL.literal(value)); + assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value)))); + assertEquals(String.format("cbrt(%s)", value), cbrt.toString()); + } + + /** + * Test cbrt with double value. + */ + @ParameterizedTest(name = "cbrt({0})") + @ValueSource(doubles = {1D, 2D, Double.MAX_VALUE, Double.MIN_VALUE}) + public void cbrt_double_value(Double value) { + FunctionExpression cbrt = dsl.cbrt(DSL.literal(value)); + assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value)))); + assertEquals(String.format("cbrt(%s)", value), cbrt.toString()); + } + + /** + * Test cbrt with negative value. + */ + @ParameterizedTest(name = "cbrt({0})") + @ValueSource(doubles = {-1D, -2D}) + public void cbrt_negative_value(Double value) { + FunctionExpression cbrt = dsl.cbrt(DSL.literal(value)); + assertThat(cbrt.valueOf(valueEnv()), allOf(hasType(DOUBLE), hasValue(Math.cbrt(value)))); + assertEquals(String.format("cbrt(%s)", value), cbrt.toString()); + } + + /** + * Test cbrt with null value. + */ + @Test + public void cbrt_null_value() { + FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_NULL_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, cbrt.type()); + assertTrue(cbrt.valueOf(valueEnv()).isNull()); + } + + /** + * Test cbrt with missing value. + */ + @Test + public void cbrt_missing_value() { + FunctionExpression cbrt = dsl.cbrt(DSL.ref(INT_TYPE_MISSING_VALUE_FIELD, INTEGER)); + assertEquals(DOUBLE, cbrt.type()); + assertTrue(cbrt.valueOf(valueEnv()).isMissing()); + } } diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 788cac0433..9c26427143 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -209,9 +209,24 @@ CBRT Description >>>>>>>>>>> -Specifications: +Usage: CBRT(number) calculates the cube root of a number + +Argument type: INTEGER/LONG/FLOAT/DOUBLE + +Return type: DOUBLE -1. CBRT(NUMBER T) -> T +(Non-negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE +(Negative) INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE + +Example:: + + opensearchsql> SELECT CBRT(8), CBRT(9.261), CBRT(-27); + fetched rows / total rows = 1/1 + +-----------+---------------+-------------+ + | CBRT(8) | CBRT(9.261) | CBRT(-27) | + |-----------+---------------+-------------| + | 2.0 | 2.1 | -3.0 | + +-----------+---------------+-------------+ CEIL diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index b5ec37acf1..efa16ba9d7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -162,4 +162,20 @@ protected JSONObject executeQuery(String query) throws IOException { Response response = client().performRequest(request); return new JSONObject(getResponseBody(response)); } + + + @Test + public void testCbrt() throws IOException { + JSONObject result = executeQuery("select cbrt(8)"); + verifySchema(result, schema("cbrt(8)", "double")); + verifyDataRows(result, rows(2.0)); + + result = executeQuery("select cbrt(9.261)"); + verifySchema(result, schema("cbrt(9.261)", "double")); + verifyDataRows(result, rows(2.1)); + + result = executeQuery("select cbrt(-27)"); + verifySchema(result, schema("cbrt(-27)", "double")); + verifyDataRows(result, rows(-3.0)); + } } diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index c803f2b5c3..b3fd29b342 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -383,7 +383,7 @@ aggregationFunctionName ; mathematicalFunctionName - : ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER + : ABS | CBRT | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER | RAND | ROUND | SIGN | SQRT | TRUNCATE | trigonometricFunctionName ;