Skip to content

Commit

Permalink
Handle zero denominator in divide and modulus for byte data type (#272)…
Browse files Browse the repository at this point in the history
… (opensearch-project#1716) (opensearch-project#1734)

* Fixed bug of byte/short values not handling divide/modulus arithmetic equations

Signed-off-by: Matthew Wells <[email protected]>
(cherry picked from commit 2c80631)

Co-authored-by: Matthew Wells <[email protected]>
  • Loading branch information
1 parent f133840 commit 38fc833
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ private static DefaultFunctionResolver addFunction() {
private static DefaultFunctionResolver divideBase(FunctionName functionName) {
return define(functionName,
impl(nullMissingHandling(
(v1, v2) -> new ExprByteValue(v1.byteValue() / v2.byteValue())),
(v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() :
new ExprByteValue(v1.byteValue() / v2.byteValue())),
BYTE, BYTE, BYTE),
impl(nullMissingHandling(
(v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() :
Expand Down Expand Up @@ -140,7 +141,7 @@ private static DefaultFunctionResolver divideFunction() {
}

/**
* Definition of modulo(x, y) function.
* Definition of modulus(x, y) function.
* Returns the number x modulo by number y
* The supported signature of modulo function is
* (x: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE, y: BYTE/SHORT/INTEGER/LONG/FLOAT/DOUBLE)
Expand All @@ -149,7 +150,8 @@ private static DefaultFunctionResolver divideFunction() {
private static DefaultFunctionResolver modulusBase(FunctionName functionName) {
return define(functionName,
impl(nullMissingHandling(
(v1, v2) -> new ExprByteValue(v1.byteValue() % v2.byteValue())),
(v1, v2) -> v2.byteValue() == 0 ? ExprNullValue.of() :
new ExprByteValue(v1.byteValue() % v2.byteValue())),
BYTE, BYTE, BYTE),
impl(nullMissingHandling(
(v1, v2) -> v2.shortValue() == 0 ? ExprNullValue.of() :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void mod(ExprValue op1, ExprValue op2) {
assertEquals(String.format("mod(%s, %s)", op1.toString(), op2.toString()),
expression.toString());

expression = DSL.mod(literal(op1), literal(new ExprShortValue(0)));
expression = DSL.mod(literal(op1), literal(new ExprByteValue(0)));
assertTrue(expression.valueOf(valueEnv()).isNull());
assertEquals(String.format("mod(%s, 0)", op1.toString()), expression.toString());
}
Expand All @@ -128,7 +128,7 @@ public void modulus(ExprValue op1, ExprValue op2) {
assertEquals(String.format("%%(%s, %s)", op1.toString(), op2.toString()),
expression.toString());

expression = DSL.modulus(literal(op1), literal(new ExprShortValue(0)));
expression = DSL.modulus(literal(op1), literal(new ExprByteValue(0)));
assertTrue(expression.valueOf(valueEnv()).isNull());
assertEquals(String.format("%%(%s, 0)", op1.toString()), expression.toString());
}
Expand All @@ -144,7 +144,7 @@ public void modulusFunction(ExprValue op1, ExprValue op2) {
assertEquals(String.format("modulus(%s, %s)", op1.toString(), op2.toString()),
expression.toString());

expression = DSL.modulusFunction(literal(op1), literal(new ExprShortValue(0)));
expression = DSL.modulusFunction(literal(op1), literal(new ExprByteValue(0)));
assertTrue(expression.valueOf(valueEnv()).isNull());
assertEquals(String.format("modulus(%s, 0)", op1.toString()), expression.toString());
}
Expand Down Expand Up @@ -183,7 +183,7 @@ public void divide(ExprValue op1, ExprValue op2) {
assertEquals(String.format("/(%s, %s)", op1.toString(), op2.toString()),
expression.toString());

expression = DSL.divide(literal(op1), literal(new ExprShortValue(0)));
expression = DSL.divide(literal(op1), literal(new ExprByteValue(0)));
assertTrue(expression.valueOf(valueEnv()).isNull());
assertEquals(String.format("/(%s, 0)", op1.toString()), expression.toString());
}
Expand All @@ -199,7 +199,7 @@ public void divideFunction(ExprValue op1, ExprValue op2) {
assertEquals(String.format("divide(%s, %s)", op1.toString(), op2.toString()),
expression.toString());

expression = DSL.divideFunction(literal(op1), literal(new ExprShortValue(0)));
expression = DSL.divideFunction(literal(op1), literal(new ExprByteValue(0)));
assertTrue(expression.valueOf(valueEnv()).isNull());
assertEquals(String.format("divide(%s, 0)", op1.toString()), expression.toString());
}
Expand Down

0 comments on commit 38fc833

Please sign in to comment.