From 3b760b324da5dbe6572e1c3db8288bbaadea1914 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Sat, 6 Mar 2021 19:47:11 -0800 Subject: [PATCH 1/4] Java changes for Decimal DIV --- .../java/ai/rapids/cudf/BinaryOperable.java | 50 ++++++++++--------- .../main/java/ai/rapids/cudf/ColumnView.java | 7 +++ java/src/main/native/src/ColumnViewJni.cpp | 10 ++++ 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java index e5e849a74c4..c598575c470 100644 --- a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java +++ b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java @@ -38,7 +38,7 @@ public interface BinaryOperable { * with scale=0 as scale is required. Dtype is discarded for binary operations for decimal * types in cudf as a new DType is created for output type with the new scale. */ - static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { + static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable rhs) { DType a = lhs.getType(); DType b = rhs.getType(); if (a.equals(DType.FLOAT64) || b.equals(DType.FLOAT64)) { @@ -86,13 +86,15 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { int scale = 0; if (a.typeId == DType.DTypeEnum.DECIMAL32) { if (b.typeId == DType.DTypeEnum.DECIMAL32) { - return DType.create(DType.DTypeEnum.DECIMAL32, scale); + return DType.create(DType.DTypeEnum.DECIMAL32, + ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } } else if (a.typeId == DType.DTypeEnum.DECIMAL64) { if (b.typeId == DType.DTypeEnum.DECIMAL64) { - return DType.create(DType.DTypeEnum.DECIMAL64, scale); + return DType.create(DType.DTypeEnum.DECIMAL64, + ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } @@ -128,7 +130,7 @@ default ColumnVector add(BinaryOperable rhs, DType outType) { * Add + operator. this + rhs */ default ColumnVector add(BinaryOperable rhs) { - return add(rhs, implicitConversion(this, rhs)); + return add(rhs, implicitConversion(BinaryOp.ADD, this, rhs)); } /** @@ -144,7 +146,7 @@ default ColumnVector sub(BinaryOperable rhs, DType outType) { * Subtract one vector from another. this - rhs */ default ColumnVector sub(BinaryOperable rhs) { - return sub(rhs, implicitConversion(this, rhs)); + return sub(rhs, implicitConversion(BinaryOp.SUB, this, rhs)); } /** @@ -160,7 +162,7 @@ default ColumnVector mul(BinaryOperable rhs, DType outType) { * Multiply two vectors together. this * rhs */ default ColumnVector mul(BinaryOperable rhs) { - return mul(rhs, implicitConversion(this, rhs)); + return mul(rhs, implicitConversion(BinaryOp.MUL, this, rhs)); } /** @@ -176,7 +178,7 @@ default ColumnVector div(BinaryOperable rhs, DType outType) { * Divide one vector by another. this / rhs */ default ColumnVector div(BinaryOperable rhs) { - return div(rhs, implicitConversion(this, rhs)); + return div(rhs, implicitConversion(BinaryOp.DIV, this, rhs)); } /** @@ -192,7 +194,7 @@ default ColumnVector trueDiv(BinaryOperable rhs, DType outType) { * (double)this / (double)rhs */ default ColumnVector trueDiv(BinaryOperable rhs) { - return trueDiv(rhs, implicitConversion(this, rhs)); + return trueDiv(rhs, implicitConversion(BinaryOp.TRUE_DIV, this, rhs)); } /** @@ -208,7 +210,7 @@ default ColumnVector floorDiv(BinaryOperable rhs, DType outType) { * Math.floor(this/rhs) */ default ColumnVector floorDiv(BinaryOperable rhs) { - return floorDiv(rhs, implicitConversion(this, rhs)); + return floorDiv(rhs, implicitConversion(BinaryOp.FLOOR_DIV, this, rhs)); } /** @@ -224,7 +226,7 @@ default ColumnVector mod(BinaryOperable rhs, DType outType) { * this % rhs */ default ColumnVector mod(BinaryOperable rhs) { - return mod(rhs, implicitConversion(this, rhs)); + return mod(rhs, implicitConversion(BinaryOp.MOD, this, rhs)); } /** @@ -240,7 +242,7 @@ default ColumnVector pow(BinaryOperable rhs, DType outType) { * Math.pow(this, rhs) */ default ColumnVector pow(BinaryOperable rhs) { - return pow(rhs, implicitConversion(this, rhs)); + return pow(rhs, implicitConversion(BinaryOp.POW, this, rhs)); } /** @@ -338,7 +340,7 @@ default ColumnVector bitAnd(BinaryOperable rhs, DType outType) { * Bit wise and (&). this & rhs */ default ColumnVector bitAnd(BinaryOperable rhs) { - return bitAnd(rhs, implicitConversion(this, rhs)); + return bitAnd(rhs, implicitConversion(BinaryOp.BITWISE_AND, this, rhs)); } /** @@ -352,7 +354,7 @@ default ColumnVector bitOr(BinaryOperable rhs, DType outType) { * Bit wise or (|). this | rhs */ default ColumnVector bitOr(BinaryOperable rhs) { - return bitOr(rhs, implicitConversion(this, rhs)); + return bitOr(rhs, implicitConversion(BinaryOp.BITWISE_OR, this, rhs)); } /** @@ -366,7 +368,7 @@ default ColumnVector bitXor(BinaryOperable rhs, DType outType) { * Bit wise xor (^). this ^ rhs */ default ColumnVector bitXor(BinaryOperable rhs) { - return bitXor(rhs, implicitConversion(this, rhs)); + return bitXor(rhs, implicitConversion(BinaryOp.BITWISE_XOR, this, rhs)); } /** @@ -380,7 +382,7 @@ default ColumnVector and(BinaryOperable rhs, DType outType) { * Logical and (&&). this && rhs */ default ColumnVector and(BinaryOperable rhs) { - return and(rhs, implicitConversion(this, rhs)); + return and(rhs, implicitConversion(BinaryOp.LOGICAL_AND, this, rhs)); } /** @@ -394,7 +396,7 @@ default ColumnVector or(BinaryOperable rhs, DType outType) { * Logical or (||). this || rhs */ default ColumnVector or(BinaryOperable rhs) { - return or(rhs, implicitConversion(this, rhs)); + return or(rhs, implicitConversion(BinaryOp.LOGICAL_OR, this, rhs)); } /** @@ -421,7 +423,7 @@ default ColumnVector shiftLeft(BinaryOperable shiftBy, DType outType) { * with this[i] << shiftBy */ default ColumnVector shiftLeft(BinaryOperable shiftBy) { - return shiftLeft(shiftBy, implicitConversion(this, shiftBy)); + return shiftLeft(shiftBy, implicitConversion(BinaryOp.SHIFT_LEFT, this, shiftBy)); } /** @@ -447,7 +449,7 @@ default ColumnVector shiftRight(BinaryOperable shiftBy, DType outType) { * with this[i] >> shiftBy */ default ColumnVector shiftRight(BinaryOperable shiftBy) { - return shiftRight(shiftBy, implicitConversion(this, shiftBy)); + return shiftRight(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT, this, shiftBy)); } /** @@ -475,7 +477,8 @@ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy, DType outType) { * with this[i] >>> shiftBy */ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy) { - return shiftRightUnsigned(shiftBy, implicitConversion(this, shiftBy)); + return shiftRightUnsigned(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT_UNSIGNED, this, + shiftBy)); } /** @@ -505,7 +508,7 @@ default ColumnVector arctan2(BinaryOperable xCoordinate, DType outType) { * in radians, between the positive x axis and the ray to the point (x, y) ≠ (0, 0). */ default ColumnVector arctan2(BinaryOperable xCoordinate) { - return arctan2(xCoordinate, implicitConversion(this, xCoordinate)); + return arctan2(xCoordinate, implicitConversion(BinaryOp.ATAN2, this, xCoordinate)); } /** @@ -529,7 +532,7 @@ default ColumnVector pmod(BinaryOperable rhs, DType outputType) { * */ default ColumnVector pmod(BinaryOperable rhs) { - return pmod(rhs, implicitConversion(this, rhs)); + return pmod(rhs, implicitConversion(BinaryOp.PMOD, this, rhs)); } /** @@ -557,7 +560,7 @@ default ColumnVector maxNullAware(BinaryOperable rhs, DType outType) { * Returns the max non null value. */ default ColumnVector maxNullAware(BinaryOperable rhs) { - return maxNullAware(rhs, implicitConversion(this, rhs)); + return maxNullAware(rhs, implicitConversion(BinaryOp.NULL_MAX, this, rhs)); } /** @@ -571,6 +574,7 @@ default ColumnVector minNullAware(BinaryOperable rhs, DType outType) { * Returns the min non null value. */ default ColumnVector minNullAware(BinaryOperable rhs) { - return minNullAware(rhs, implicitConversion(this, rhs)); + return minNullAware(rhs, implicitConversion(BinaryOp.NULL_MIN, this, rhs)); } + } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 331c5b08764..284eabc6335 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -71,6 +71,13 @@ public final long getNativeView() { return viewHandle; } + public static int getFixedPointOutpuScale(BinaryOp op, DType lhsType, DType rhsType) { + assert (lhsType.isDecimalType() && rhsType.isDecimalType()); + return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale()); + } + + private static native int fixedPointOutputScale(int op, int lhsScale, int rhsScale); + public final DType getType() { return type; } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index a0613f9b73f..a58474198e2 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -60,6 +60,7 @@ #include #include #include +#include "cudf/types.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" @@ -1026,6 +1027,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, j CATCH_STD(env, 0); } +JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass, jint int_op, + jint lhs_scale, jint rhs_scale) { + try { + // we just return the scale as the types will be the same as the lhs input + return cudf::binary_operation_fixed_point_scale(static_cast(int_op), lhs_scale, rhs_scale); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVS(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_ptr, jint int_op, jint out_dtype, From 79dcf6140aa1d92bc160eca5f975a0c789f64611 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 15 Mar 2021 11:39:03 -0700 Subject: [PATCH 2/4] Addressed review comments --- java/src/main/java/ai/rapids/cudf/BinaryOperable.java | 4 ++-- java/src/main/java/ai/rapids/cudf/ColumnView.java | 2 +- java/src/main/native/src/ColumnViewJni.cpp | 11 +++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java index c598575c470..68213c21956 100644 --- a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java +++ b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java @@ -87,14 +87,14 @@ static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable if (a.typeId == DType.DTypeEnum.DECIMAL32) { if (b.typeId == DType.DTypeEnum.DECIMAL32) { return DType.create(DType.DTypeEnum.DECIMAL32, - ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType())); + ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } } else if (a.typeId == DType.DTypeEnum.DECIMAL64) { if (b.typeId == DType.DTypeEnum.DECIMAL64) { return DType.create(DType.DTypeEnum.DECIMAL64, - ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType())); + ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 034540c5d0a..ccef6280aa3 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -129,7 +129,7 @@ public final long getNativeView() { return viewHandle; } - public static int getFixedPointOutpuScale(BinaryOp op, DType lhsType, DType rhsType) { + protected static int getFixedPointOutputScale(BinaryOp op, DType lhsType, DType rhsType) { assert (lhsType.isDecimalType() && rhsType.isDecimalType()); return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale()); } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index aa33bcea83e..0ce9d6303e4 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1027,12 +1027,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, j CATCH_STD(env, 0); } -JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass, jint int_op, - jint lhs_scale, jint rhs_scale) { +JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass, + jint int_op, + jint lhs_scale, + jint rhs_scale) { try { // we just return the scale as the types will be the same as the lhs input - return cudf::binary_operation_fixed_point_scale(static_cast(int_op), lhs_scale, rhs_scale); - } + return cudf::binary_operation_fixed_point_scale(static_cast(int_op), + lhs_scale, rhs_scale); + } CATCH_STD(env, 0); } From a734b4e8d31b806d6b6b7992dd7f4650180d1f09 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 15 Mar 2021 11:50:23 -0700 Subject: [PATCH 3/4] cmake change needed for building within compose --- java/src/main/native/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index c1239fe69ea..cadf584fe90 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -54,7 +54,9 @@ message(VERBOSE "CUDF_JNI: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTI message(VERBOSE "CUDF_JNI: Build with GPUDirect Storage support: ${USE_GDS}") set(CUDF_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../../../../cpp") -set(CUDF_CPP_BUILD_DIR "${CUDF_SOURCE_DIR}/build") +if (NOT CUDF_CPP_BUILD_DIR) + set(CUDF_CPP_BUILD_DIR "${CUDF_SOURCE_DIR}/build") +endif() set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/" From 704973a26c71605cacf184f3f82b60e037de9469 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 15 Mar 2021 14:52:13 -0700 Subject: [PATCH 4/4] addressed review comments --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 2 +- java/src/main/native/CMakeLists.txt | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index ccef6280aa3..2f3f2bf80cf 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -129,7 +129,7 @@ public final long getNativeView() { return viewHandle; } - protected static int getFixedPointOutputScale(BinaryOp op, DType lhsType, DType rhsType) { + static int getFixedPointOutputScale(BinaryOp op, DType lhsType, DType rhsType) { assert (lhsType.isDecimalType() && rhsType.isDecimalType()); return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale()); } diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index cadf584fe90..c1239fe69ea 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -54,9 +54,7 @@ message(VERBOSE "CUDF_JNI: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTI message(VERBOSE "CUDF_JNI: Build with GPUDirect Storage support: ${USE_GDS}") set(CUDF_SOURCE_DIR "${PROJECT_SOURCE_DIR}/../../../../cpp") -if (NOT CUDF_CPP_BUILD_DIR) - set(CUDF_CPP_BUILD_DIR "${CUDF_SOURCE_DIR}/build") -endif() +set(CUDF_CPP_BUILD_DIR "${CUDF_SOURCE_DIR}/build") set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/"