From 4b4d96265110ff1c70e2ddff15c19e2b918e8c29 Mon Sep 17 00:00:00 2001 From: Niranjan Artal <50492963+nartal1@users.noreply.github.com> Date: Tue, 27 Oct 2020 07:49:37 -0700 Subject: [PATCH] Initial work for decimal type in Java/JNI [skip ci] (#6514) This partially addresses tasks from issue https://github.com/rapidsai/cudf/issues/6515. The goal of this PR is to add DType class and make the necessary changes in Java/JNI code to support scale in DType. --- CHANGELOG.md | 1 + .../java/ai/rapids/cudf/ColumnVector.java | 52 +- java/src/main/java/ai/rapids/cudf/DType.java | 468 ++++++++++++++---- .../java/ai/rapids/cudf/HostColumnVector.java | 80 +-- .../ai/rapids/cudf/HostColumnVectorCore.java | 38 +- .../ai/rapids/cudf/JCudfSerialization.java | 17 +- java/src/main/java/ai/rapids/cudf/Scalar.java | 54 +- java/src/main/java/ai/rapids/cudf/Schema.java | 2 +- java/src/main/java/ai/rapids/cudf/Table.java | 10 +- java/src/main/native/src/ColumnVectorJni.cpp | 83 ++-- java/src/main/native/src/ScalarJni.cpp | 38 +- java/src/main/native/src/dtype_utils.hpp | 66 +++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 42 +- .../ai/rapids/cudf/HostMemoryBufferTest.java | 14 +- .../java/ai/rapids/cudf/ReductionTest.java | 2 +- .../test/java/ai/rapids/cudf/ScalarTest.java | 28 +- .../test/java/ai/rapids/cudf/TableTest.java | 2 +- 17 files changed, 710 insertions(+), 287 deletions(-) create mode 100644 java/src/main/native/src/dtype_utils.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 69e668499f1..2c6a9cfa066 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - PR #6573 Create `cudf::detail::byte_cast` for `cudf::byte_cast` - PR #6597 Use thread-local to track CUDA device in JNI - PR #6599 Replace `size()==0` with `empty()`, `is_empty()` +- PR #6514 Initial work for decimal type in Java/JNI ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 84a231ff883..4c5739b5f3b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -93,7 +93,7 @@ public long getNullCount() { @Override public DType getDataType() { - return DType.fromNative(getNativeTypeId(viewHandle)); + return DType.fromNative(getNativeTypeId(viewHandle), getNativeTypeScale(viewHandle)); } @Override @@ -141,9 +141,8 @@ public ColumnVector(long nativePointer) { assert nativePointer != 0; offHeap = new OffHeapState(nativePointer); MemoryCleaner.register(this, offHeap); - this.type = offHeap.getNativeType(); this.rows = offHeap.getNativeRowCount(); - + this.type = offHeap.getNativeType(); this.refCount = 0; incRefCountInternal(true); } @@ -196,7 +195,7 @@ public ColumnVector(DType type, long rows, Optional nullCount, childHandles[i] = nestedColumnVectors.get(i).getViewHandle(); } offHeap = new OffHeapState(type, (int) rows, nullCount, dataBuffer, validityBuffer, offsetBuffer, - toClose, childHandles); + toClose, childHandles); MemoryCleaner.register(this, offHeap); this.rows = rows; this.nullCount = nullCount; @@ -221,6 +220,7 @@ private ColumnVector(long viewAddress, DeviceMemoryBuffer contiguousBuffer) { this.rows = offHeap.getNativeRowCount(); // TODO we may want to ask for the null count anyways... this.nullCount = Optional.empty(); + this.refCount = 0; incRefCountInternal(true); } @@ -1339,12 +1339,12 @@ public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) { static long binaryOp(ColumnVector lhs, ColumnVector rhs, BinaryOp op, DType outputType) { return binaryOpVV(lhs.getNativeView(), rhs.getNativeView(), - op.nativeId, outputType.nativeId); + op.nativeId, outputType.typeId.getNativeId(), outputType.getScale()); } static long binaryOp(ColumnVector lhs, Scalar rhs, BinaryOp op, DType outputType) { return binaryOpVS(lhs.getNativeView(), rhs.getScalarHandle(), - op.nativeId, outputType.nativeId); + op.nativeId, outputType.typeId.getNativeId(), outputType.getScale()); } ///////////////////////////////////////////////////////////////////////////// @@ -1565,7 +1565,7 @@ public Scalar reduce(Aggregation aggregation) { public Scalar reduce(Aggregation aggregation, DType outType) { long nativeId = aggregation.createNativeInstance(); try { - return new Scalar(outType, reduce(getNativeView(), nativeId, outType.nativeId)); + return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale())); } finally { Aggregation.close(nativeId); } @@ -1698,7 +1698,7 @@ public ColumnVector castTo(DType type) { // Optimization return incRefCount(); } - return new ColumnVector(castTo(getNativeView(), type.nativeId)); + return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale())); } /** @@ -1981,8 +1981,9 @@ public ColumnVector asTimestamp(DType timestampType, String format) { "is required for .to_timestamp() operation"; assert format != null : "Format string may not be NULL"; assert timestampType.isTimestamp() : "unsupported conversion to non-timestamp DType"; + // Only nativeID is passed in the below function as timestamp type does not have `scale`. return new ColumnVector(stringTimestampToTimestamp(getNativeView(), - timestampType.nativeId, format)); + timestampType.typeId.getNativeId(), format)); } /** @@ -1999,7 +2000,7 @@ public ColumnVector asTimestamp(DType timestampType, String format) { * @return A new vector allocated on the GPU. */ public ColumnVector asStrings() { - switch(type) { + switch(type.typeId) { case TIMESTAMP_SECONDS: return asStrings("%Y-%m-%d %H:%M:%S"); case TIMESTAMP_DAYS: @@ -2878,15 +2879,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long pad(long nativeHandle, int width, int side, String fillChar); - private static native long binaryOpVS(long lhs, long rhs, int op, int dtype); + private static native long binaryOpVS(long lhs, long rhs, int op, int dtype, int scale); - private static native long binaryOpVV(long lhs, long rhs, int op, int dtype); + private static native long binaryOpVV(long lhs, long rhs, int op, int dtype, int scale); private static native long byteCount(long viewHandle) throws CudfException; private static native long extractListElement(long nativeView, int index); - private static native long castTo(long nativeHandle, int type); + private static native long castTo(long nativeHandle, int type, int scale); private static native long byteListCast(long nativeHandle, boolean config); @@ -2944,7 +2945,7 @@ private static native long rollingWindow( private static native long ifElseSS(long predVec, long trueScalar, long falseScalar) throws CudfException; - private static native long reduce(long viewHandle, long aggregation, int dtype) throws CudfException; + private static native long reduce(long viewHandle, long aggregation, int dtype, int scale) throws CudfException; private static native long isNullNative(long viewHandle); @@ -3029,6 +3030,8 @@ private static native long bitwiseMergeAndSetValidity(long baseHandle, long[] vi private static native int getNativeTypeId(long viewHandle) throws CudfException; + private static native int getNativeTypeScale(long viewHandle) throws CudfException; + private static native int getNativeRowCount(long viewHandle) throws CudfException; private static native int getNativeNullCount(long viewHandle) throws CudfException; @@ -3044,9 +3047,10 @@ private static native long bitwiseMergeAndSetValidity(long baseHandle, long[] vi private static native long getNativeValidityAddress(long viewHandle) throws CudfException; private static native long getNativeValidityLength(long viewHandle) throws CudfException; - private static native long makeCudfColumnView(int type, long data, long dataSize, long offsets, + private static native long makeCudfColumnView(int type, int scale, long data, long dataSize, long offsets, long valid, int nullCount, int size, long[] childHandle); + private static native long getChildCvPointer(long viewHandle, int childIndex) throws CudfException; private static native int getNativeNumChildren(long viewHandle) throws CudfException; @@ -3076,7 +3080,7 @@ private static native long makeCudfColumnView(int type, long data, long dataSize */ private static native long getNativeColumnView(long cudfColumnHandle) throws CudfException; - private static native long makeEmptyCudfColumn(int type); + private static native long makeEmptyCudfColumn(int type, int scale); private static DeviceMemoryBufferView getDataBuffer(long viewHandle) { long address = getNativeDataAddress(viewHandle); @@ -3219,13 +3223,13 @@ public OffHeapState(DType type, int rows, Optional nullCount, toClose.addAll(buffers); } if (rows == 0) { - this.columnHandle = makeEmptyCudfColumn(type.nativeId); + this.columnHandle = makeEmptyCudfColumn(type.typeId.getNativeId(), type.getScale()); } else { long cd = data == null ? 0 : data.address; long cdSize = data == null ? 0 : data.length; long od = offsets == null ? 0 : offsets.address; long vd = valid == null ? 0 : valid.address; - this.viewHandle = makeCudfColumnView(type.nativeId, cd, cdSize, od, vd, nc, rows, childColumnViewHandles) ; + this.viewHandle = makeCudfColumnView(type.typeId.getNativeId(), type.getScale(), cd, cdSize, od, vd, nc, rows, childColumnViewHandles) ; } } @@ -3270,7 +3274,11 @@ private void setNativeNullCount(int nullCount) throws CudfException { } public DType getNativeType() { - return DType.fromNative(getNativeTypeId(getViewHandle())); + return DType.fromNative(getNativeTypeId(getViewHandle()), getNativeTypeScale(getViewHandle())); + } + + public int getNativeScale() { + return getNativeTypeScale(getViewHandle()); } public BaseDeviceMemoryBuffer getData() { @@ -3499,8 +3507,8 @@ private long getViewHandle() { long offsetAddr = offsets == null ? 0 : offsets.address; long validAddr = valid == null ? 0 : valid.address; int nc = nullCount.orElse(OffHeapState.UNKNOWN_NULL_COUNT).intValue(); - return makeCudfColumnView(dataType.nativeId, dataAddr, dataLen, offsetAddr, validAddr, nc, - (int)rows, childrenColViews); + return makeCudfColumnView(dataType.typeId.getNativeId(), dataType.getScale() , dataAddr, dataLen, + offsetAddr, validAddr, nc, (int)rows, childrenColViews); } private List getBuffersToClose() { @@ -3534,7 +3542,7 @@ private static NestedColumnVector createNestedColumnVector(DType type, long rows DeviceMemoryBuffer valid = null; DeviceMemoryBuffer offsets = null; if (dataBuffer != null) { - long dataLen = rows * type.sizeInBytes; + long dataLen = rows * type.getSizeInBytes(); if (type == DType.STRING) { // This needs a different type dataLen = getEndStringOffset(rows, rows - 1, offsetBuffer); diff --git a/java/src/main/java/ai/rapids/cudf/DType.java b/java/src/main/java/ai/rapids/cudf/DType.java index 776c835c7f4..9d32a7c40ec 100644 --- a/java/src/main/java/ai/rapids/cudf/DType.java +++ b/java/src/main/java/ai/rapids/cudf/DType.java @@ -15,78 +15,302 @@ */ package ai.rapids.cudf; +import java.math.BigDecimal; import java.util.EnumSet; +import java.util.Objects; + +public final class DType { + + public static final int DECIMAL32_MAX_PRECISION = 10; + public static final int DECIMAL64_MAX_PRECISION = 19; + + /* enum representing various types. Whenever a new non-decimal type is added please make sure + below sections are updated as well: + 1. Create a singleton object of the new type. + 2. Update SINGLETON_DTYPE_LOOKUP to reflect new type. The order should be maintained between + DTypeEnum and SINGLETON_DTYPE_LOOKUP */ + public enum DTypeEnum { + EMPTY(0, 0, "NOT SUPPORTED"), + INT8(1, 1, "byte"), + INT16(2, 2, "short"), + INT32(4, 3, "int"), + INT64(8, 4, "long"), + UINT8(1, 5, "uint8"), + UINT16(2, 6, "uint16"), + UINT32(4, 7, "uint32"), + UINT64(8, 8, "uint64"), + FLOAT32(4, 9, "float"), + FLOAT64(8, 10, "double"), + /** + * Byte wise true non-0/false 0. In general true will be 1. + */ + BOOL8(1, 11, "bool"), + /** + * Days since the UNIX epoch + */ + TIMESTAMP_DAYS(4, 12, "date32"), + /** + * s since the UNIX epoch + */ + TIMESTAMP_SECONDS(8, 13, "timestamp[s]"), + /** + * ms since the UNIX epoch + */ + TIMESTAMP_MILLISECONDS(8, 14, "timestamp[ms]"), + /** + * microseconds since the UNIX epoch + */ + TIMESTAMP_MICROSECONDS(8, 15, "timestamp[us]"), + /** + * ns since the UNIX epoch + */ + TIMESTAMP_NANOSECONDS(8, 16, "timestamp[ns]"), + + //We currently don't have mappings for duration type to I/O files, and these + //simpleNames might change in future when we do + DURATION_DAYS(4, 17, "int32"), + DURATION_SECONDS(8, 18, "int64"), + DURATION_MILLISECONDS(8, 19, "int64"), + DURATION_MICROSECONDS(8, 20, "int64"), + DURATION_NANOSECONDS(8, 21, "int64"), + //DICTIONARY32(4, 22, "NO IDEA"), + + STRING(0, 23, "str"), + LIST(0, 24, "list"), + DECIMAL32(4, 25, "decimal32"), + DECIMAL64(8, 26, "decimal64"), + STRUCT(0, 27, "struct"); + + final int sizeInBytes; + final int nativeId; + final String simpleName; + + DTypeEnum(int sizeInBytes, int nativeId, String simpleName) { + this.sizeInBytes = sizeInBytes; + this.nativeId = nativeId; + this.simpleName = simpleName; + } + + public int getNativeId() { return nativeId; } + } + + final DTypeEnum typeId; + private final int scale; + + private DType(DTypeEnum id) { + typeId = id; + scale = 0; + } -public enum DType { - EMPTY(0, 0, "NOT SUPPORTED"), - INT8(1, 1, "byte"), - INT16(2, 2, "short"), - INT32(4, 3, "int"), - INT64(8, 4, "long"), - UINT8(1, 5, "uint8"), - UINT16(2, 6, "uint16"), - UINT32(4, 7, "uint32"), - UINT64(8, 8, "uint64"), - FLOAT32(4, 9, "float"), - FLOAT64(8, 10, "double"), - /** - * Byte wise true non-0/false 0. In general true will be 1. - */ - BOOL8(1, 11, "bool"), /** - * Days since the UNIX epoch + * Constructor for Decimal Type + * @param id Enum representing data type. + * @param decimalScale Scale of fixed point decimal type */ - TIMESTAMP_DAYS(4, 12, "date32"), + private DType(DTypeEnum id, int decimalScale) { + typeId = id; + scale = decimalScale; + } + + public static final DType EMPTY = new DType(DTypeEnum.EMPTY); + public static final DType INT8 = new DType(DTypeEnum.INT8); + public static final DType INT16 = new DType(DTypeEnum.INT16); + public static final DType INT32 = new DType(DTypeEnum.INT32); + public static final DType INT64 = new DType(DTypeEnum.INT64); + public static final DType UINT8 = new DType(DTypeEnum.UINT8); + public static final DType UINT16 = new DType(DTypeEnum.UINT16); + public static final DType UINT32 = new DType(DTypeEnum.UINT32); + public static final DType UINT64 = new DType(DTypeEnum.UINT64); + public static final DType FLOAT32 = new DType(DTypeEnum.FLOAT32); + public static final DType FLOAT64 = new DType(DTypeEnum.FLOAT64); + public static final DType BOOL8 = new DType(DTypeEnum.BOOL8); + public static final DType TIMESTAMP_DAYS = new DType(DTypeEnum.TIMESTAMP_DAYS); + public static final DType TIMESTAMP_SECONDS = new DType(DTypeEnum.TIMESTAMP_SECONDS); + public static final DType TIMESTAMP_MILLISECONDS = new DType(DTypeEnum.TIMESTAMP_MILLISECONDS); + public static final DType TIMESTAMP_MICROSECONDS = new DType(DTypeEnum.TIMESTAMP_MICROSECONDS); + public static final DType TIMESTAMP_NANOSECONDS = new DType(DTypeEnum.TIMESTAMP_NANOSECONDS); + public static final DType DURATION_DAYS = new DType(DTypeEnum.DURATION_DAYS); + public static final DType DURATION_SECONDS = new DType(DTypeEnum.DURATION_SECONDS); + public static final DType DURATION_MILLISECONDS = new DType(DTypeEnum.DURATION_MILLISECONDS); + public static final DType DURATION_MICROSECONDS = new DType(DTypeEnum.DURATION_MICROSECONDS); + public static final DType DURATION_NANOSECONDS = new DType(DTypeEnum.DURATION_NANOSECONDS); + public static final DType STRING = new DType(DTypeEnum.STRING); + public static final DType LIST = new DType(DTypeEnum.LIST); + public static final DType STRUCT = new DType(DTypeEnum.STRUCT); + + /* This is used in fromNative method to return singleton object for non-decimal types. + Please make sure the order here is same as that of DTypeEnum. Whenever a new non-decimal + type is added in DTypeEnum, this array needs to be updated as well.*/ + private static final DType[] SINGLETON_DTYPE_LOOKUP = new DType[]{ + EMPTY, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + FLOAT32, + FLOAT64, + BOOL8, + TIMESTAMP_DAYS, + TIMESTAMP_SECONDS, + TIMESTAMP_MILLISECONDS, + TIMESTAMP_MICROSECONDS, + TIMESTAMP_NANOSECONDS, + DURATION_DAYS, + DURATION_SECONDS, + DURATION_MILLISECONDS, + DURATION_MICROSECONDS, + DURATION_NANOSECONDS, + null, // DICTIONARY32 + STRING, + LIST, + null, // DECIMAL32 + null, // DECIMAL64 + STRUCT + }; + /** - * s since the UNIX epoch + * This only works for fixed width types. Variable width types like strings the value is + * undefined and should be ignored. + * + * @return size of type in bytes. */ - TIMESTAMP_SECONDS(8, 13, "timestamp[s]"), + public int getSizeInBytes() { return typeId.sizeInBytes; } + /** - * ms since the UNIX epoch + * Returns scale for Decimal Type. + * @return scale base-10 exponent to multiply the unscaled value to produce the decimal value. + * Example: Consider unscaled value = 123456 + * if scale = -2, decimal value = 123456 * 10^-2 = 1234.56 + * if scale = 2, decimal value = 123456 * 10^2 = 12345600 */ - TIMESTAMP_MILLISECONDS(8, 14, "timestamp[ms]"), + public int getScale() { return scale; } + /** - * microseconds since the UNIX epoch + * Returns string name mapped to type. + * @return name corresponding to type */ - TIMESTAMP_MICROSECONDS(8, 15, "timestamp[us]"), + public String getSimpleName() { return typeId.simpleName; } + /** - * ns since the UNIX epoch + * Return enum for this DType + * @return DTypeEnum */ - TIMESTAMP_NANOSECONDS(8, 16, "timestamp[ns]"), + public DTypeEnum getTypeId() { + return typeId; + } - //We currently don't have mappings for duration type to I/O files, and these - //simpleNames might change in future when we do - DURATION_DAYS(4, 17, "int32"), - DURATION_SECONDS(8, 18, "int64"), - DURATION_MILLISECONDS(8, 19, "int64"), - DURATION_MICROSECONDS(8, 20, "int64"), - DURATION_NANOSECONDS(8, 21, "int64"), - //DICTIONARY32(4, 22, "NO IDEA"), + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DType type = (DType) o; + return scale == type.scale && typeId == type.typeId; + } - STRING(0, 23, "str"), - LIST(0, 24, "list"), - STRUCT(0, 27, "struct"); + @Override + public int hashCode() { + return Objects.hash(typeId, scale); + } - private static final DType[] D_TYPES = DType.values(); - final int sizeInBytes; - final int nativeId; - final String simpleName; + @Override + public String toString() { + if (isDecimalType()) { + return typeId + " scale:" + scale; + } else { + return String.valueOf(typeId); + } + } - DType(int sizeInBytes, int nativeId, String simpleName) { - this.sizeInBytes = sizeInBytes; - this.nativeId = nativeId; - this.simpleName = simpleName; + /** + * Factory method for non-decimal DType instances. + * @param dt enum corresponding to datatype. + * @return DType + */ + public static DType create(DTypeEnum dt) { + if (dt == DTypeEnum.DECIMAL32 || dt == DTypeEnum.DECIMAL64) { + throw new IllegalArgumentException("Could not create a Decimal DType without scale"); + } + return DType.fromNative(dt.nativeId, 0); } - public boolean isTimestamp() { - return TIMESTAMPS.contains(this); + /** + * Factory method specialized for decimal DType instances. + * @param dt enum corresponding to datatype. + * @param scale base-10 exponent to multiply the unscaled value to produce the decimal value. + * Example: Consider unscaled value = 123456 + * if scale = -2, decimal value = 123456 * 10^-2 = 1234.56 + * if scale = 2, decimal value = 123456 * 10^2 = 12345600 + * @return DType + */ + public static DType create(DTypeEnum dt, int scale) { + if (dt != DTypeEnum.DECIMAL32 && dt != DTypeEnum.DECIMAL64) { + throw new IllegalArgumentException("Could not create a non-Decimal DType with scale"); + } + return DType.fromNative(dt.nativeId, scale); + } + + /** + * Factory method for DType instances + * @param nativeId nativeId of DataTypeEnun + * @param scale base-10 exponent to multiply the unscaled value to produce the decimal value + * Example: Consider unscaled value = 123456 + * if scale = -2, decimal value = 123456 * 10^-2 = 1234.56 + * if scale = 2, decimal value = 123456 * 10^2 = 12345600 + * @return DType + */ + public static DType fromNative(int nativeId, int scale) { + if (nativeId >=0 && nativeId < SINGLETON_DTYPE_LOOKUP.length) { + DType ret = SINGLETON_DTYPE_LOOKUP[nativeId]; + if (ret != null) { + assert ret.typeId.nativeId == nativeId : "Something went wrong and it looks like " + + "SINGLETON_DTYPE_LOOKUP is out of sync"; + return ret; + } + if (nativeId == DTypeEnum.DECIMAL32.nativeId) { + if (-scale > DECIMAL32_MAX_PRECISION) { + throw new IllegalArgumentException( + "Scale " + (-scale) + " exceeds DECIMAL32_MAX_PRECISION " + DECIMAL32_MAX_PRECISION); + } + return new DType(DTypeEnum.DECIMAL32, scale); + } + if (nativeId == DTypeEnum.DECIMAL64.nativeId) { + if (-scale > DECIMAL64_MAX_PRECISION) { + throw new IllegalArgumentException( + "Scale " + (-scale) + " exceeds DECIMAL64_MAX_PRECISION " + DECIMAL64_MAX_PRECISION); + } + return new DType(DTypeEnum.DECIMAL64, scale); + } + } + throw new IllegalArgumentException("Could not translate " + nativeId + " into a DType"); + } + + /** + * Create decimal-like DType using precision and scale of Java BigDecimal. + * + * @param dec BigDecimal + * @return DType + */ + public static DType fromJavaBigDecimal(BigDecimal dec) { + // Notice: Compared to scale of Java BigDecimal, scale of libcudf works in opposite. + // So, we negate the scale value before passing it into constructor. + if (dec.precision() <= DECIMAL32_MAX_PRECISION) { + return new DType(DTypeEnum.DECIMAL32, -dec.scale()); + } else if (dec.precision() <= DECIMAL64_MAX_PRECISION) { + return new DType(DTypeEnum.DECIMAL64, -dec.scale()); + } + throw new IllegalArgumentException("Precision " + dec.precision() + + " exceeds max precision cuDF can support " + DECIMAL64_MAX_PRECISION); } /** * Returns true for timestamps with time level resolution, as opposed to day level resolution */ public boolean hasTimeResolution() { - return TIME_RESOLUTION.contains(this); + return TIME_RESOLUTION.contains(this.typeId); } /** @@ -98,7 +322,7 @@ public boolean hasTimeResolution() { * DType.TIMESTAMP_DAYS */ public boolean isBackedByInt() { - return INTS.contains(this); + return INTS.contains(this.typeId); } /** @@ -116,88 +340,116 @@ public boolean isBackedByInt() { * DType.TIMESTAMP_NANOSECONDS */ public boolean isBackedByLong() { - return LONGS.contains(this); + return LONGS.contains(this.typeId); } + /** + * Returns true if this type is backed by short type + * Namely this method will return true for the following types + * DType.INT16, + * DType.UINT16 + */ + public boolean isBackedByShort() { return SHORTS.contains(this.typeId); } + + /** + * Returns true if this type is backed by byte type + * Namely this method will return true for the following types + * DType.INT8, + * DType.UINT8, + * DType.BOOL8 + */ + public boolean isBackedByByte() { return BYTES.contains(this.typeId); } + + /** + * Returns true if this type is of decimal type + * Namely this method will return true for the following types + * DType.DECIMAL32, + * DType.DECIMAL64 + */ + public boolean isDecimalType() { return DECIMALS.contains(this.typeId); } + /** * Returns true for duration types */ public boolean isDurationType() { - return DURATION_TYPE.contains(this); + return DURATION_TYPE.contains(this.typeId); } /** * Returns true for nested types */ public boolean isNestedType() { - return NESTED_TYPE.contains(this); + return NESTED_TYPE.contains(this.typeId); } - public int getNativeId() { - return nativeId; + @Deprecated + public boolean isTimestamp() { + return TIMESTAMPS.contains(this.typeId); } - /** - * This only works for fixed width types. Variable width types like strings the value is - * undefined and should be ignored. - * @return - */ - public int getSizeInBytes() { - return sizeInBytes; + public boolean isTimestampType() { + return TIMESTAMPS.contains(this.typeId); } - public static DType fromNative(int nativeId) { - for (DType type : D_TYPES) { - if (type.nativeId == nativeId) { - return type; - } - } - throw new IllegalArgumentException("Could not translate " + nativeId + " into a DType"); - } + private static final EnumSet TIMESTAMPS = EnumSet.of( + DTypeEnum.TIMESTAMP_DAYS, + DTypeEnum.TIMESTAMP_SECONDS, + DTypeEnum.TIMESTAMP_MILLISECONDS, + DTypeEnum.TIMESTAMP_MICROSECONDS, + DTypeEnum.TIMESTAMP_NANOSECONDS); + + private static final EnumSet TIME_RESOLUTION = EnumSet.of( + DTypeEnum.TIMESTAMP_SECONDS, + DTypeEnum.TIMESTAMP_MILLISECONDS, + DTypeEnum.TIMESTAMP_MICROSECONDS, + DTypeEnum.TIMESTAMP_NANOSECONDS); + + private static final EnumSet DURATION_TYPE = EnumSet.of( + DTypeEnum.DURATION_DAYS, + DTypeEnum.DURATION_MICROSECONDS, + DTypeEnum.DURATION_MILLISECONDS, + DTypeEnum.DURATION_NANOSECONDS, + DTypeEnum.DURATION_SECONDS + ); + + private static final EnumSet LONGS = EnumSet.of( + DTypeEnum.INT64, + DTypeEnum.UINT64, + DTypeEnum.DURATION_SECONDS, + DTypeEnum.DURATION_MILLISECONDS, + DTypeEnum.DURATION_MICROSECONDS, + DTypeEnum.DURATION_NANOSECONDS, + DTypeEnum.TIMESTAMP_SECONDS, + DTypeEnum.TIMESTAMP_MILLISECONDS, + DTypeEnum.TIMESTAMP_MICROSECONDS, + DTypeEnum.TIMESTAMP_NANOSECONDS + ); + + private static final EnumSet INTS = EnumSet.of( + DTypeEnum.INT32, + DTypeEnum.UINT32, + DTypeEnum.DURATION_DAYS, + DTypeEnum.TIMESTAMP_DAYS + ); - private static final EnumSet TIMESTAMPS = EnumSet.of( - DType.TIMESTAMP_DAYS, - DType.TIMESTAMP_SECONDS, - DType.TIMESTAMP_MILLISECONDS, - DType.TIMESTAMP_MICROSECONDS, - DType.TIMESTAMP_NANOSECONDS); - - private static final EnumSet TIME_RESOLUTION = EnumSet.of( - DType.TIMESTAMP_SECONDS, - DType.TIMESTAMP_MILLISECONDS, - DType.TIMESTAMP_MICROSECONDS, - DType.TIMESTAMP_NANOSECONDS); - - private static final EnumSet DURATION_TYPE = EnumSet.of( - DType.DURATION_DAYS, - DType.DURATION_MICROSECONDS, - DType.DURATION_MILLISECONDS, - DType.DURATION_NANOSECONDS, - DType.DURATION_SECONDS + private static final EnumSet SHORTS = EnumSet.of( + DTypeEnum.INT16, + DTypeEnum.UINT16 ); - private static final EnumSet LONGS = EnumSet.of( - DType.INT64, - DType.UINT64, - DType.DURATION_SECONDS, - DType.DURATION_MILLISECONDS, - DType.DURATION_MICROSECONDS, - DType.DURATION_NANOSECONDS, - DType.TIMESTAMP_SECONDS, - DType.TIMESTAMP_MILLISECONDS, - DType.TIMESTAMP_MICROSECONDS, - DType.TIMESTAMP_NANOSECONDS + private static final EnumSet BYTES = EnumSet.of( + DTypeEnum.INT8, + DTypeEnum.UINT8, + DTypeEnum.BOOL8 ); - private static final EnumSet INTS = EnumSet.of( - DType.INT32, - DType.UINT32, - DType.DURATION_DAYS, - DType.TIMESTAMP_DAYS + private static final EnumSet DECIMALS = EnumSet.of( + DTypeEnum.DECIMAL32, + DTypeEnum.DECIMAL64 ); - private static final EnumSet NESTED_TYPE = EnumSet.of( - DType.LIST, - DType.STRUCT + private static final EnumSet NESTED_TYPE = EnumSet.of( + DTypeEnum.LIST, + DTypeEnum.STRUCT ); -} +} \ No newline at end of file diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index bf5c97f3c00..c9e331ced73 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -35,7 +35,7 @@ public final class HostColumnVector extends HostColumnVectorCore { /** * The size in bytes of an offset entry */ - static final int OFFSET_SIZE = DType.INT32.sizeInBytes; + static final int OFFSET_SIZE = DType.INT32.getSizeInBytes(); static { NativeDepsLoader.loadNativeDeps(); } @@ -175,7 +175,7 @@ public ColumnVector copyToDevice() { if (!type.isNestedType()) { HostMemoryBuffer hdata = this.offHeap.data; if (hdata != null) { - long dataLen = rows * type.sizeInBytes; + long dataLen = rows * type.getSizeInBytes(); if (type == DType.STRING) { // This needs a different type dataLen = getEndStringOffset(rows - 1); @@ -835,7 +835,7 @@ public final ColumnBuilder appendNull() { growBuffersAndRows(true, 0); setNullAt(currentIndex); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); if (type == DType.STRING || type.isNestedType()) { offsets.setInt(currentIndex * OFFSET_SIZE, currentByteIndex); } @@ -928,9 +928,9 @@ public final ColumnBuilder append(byte value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type == DType.INT8 || type == DType.UINT8 || type == DType.BOOL8; assert currentIndex < rows; - data.setByte(currentIndex * type.sizeInBytes, value); + data.setByte(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -938,9 +938,9 @@ public final ColumnBuilder append(short value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type == DType.INT16 || type == DType.UINT16; assert currentIndex < rows; - data.setShort(currentIndex * type.sizeInBytes, value); + data.setShort(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -948,9 +948,9 @@ public final ColumnBuilder append(int value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type.isBackedByInt(); assert currentIndex < rows; - data.setInt(currentIndex * type.sizeInBytes, value); + data.setInt(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -958,9 +958,9 @@ public final ColumnBuilder append(long value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type.isBackedByLong(); assert currentIndex < rows; - data.setLong(currentIndex * type.sizeInBytes, value); + data.setLong(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -968,9 +968,9 @@ public final ColumnBuilder append(float value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type == DType.FLOAT32; assert currentIndex < rows; - data.setFloat(currentIndex * type.sizeInBytes, value); + data.setFloat(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -978,9 +978,9 @@ public final ColumnBuilder append(double value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type == DType.FLOAT64; assert currentIndex < rows; - data.setDouble(currentIndex * type.sizeInBytes, value); + data.setDouble(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -988,9 +988,9 @@ public final ColumnBuilder append(boolean value) { growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); assert type == DType.BOOL8; assert currentIndex < rows; - data.setBoolean(currentIndex * type.sizeInBytes, value); + data.setBoolean(currentIndex * type.getSizeInBytes(), value); currentIndex++; - currentByteIndex += type.sizeInBytes; + currentByteIndex += type.getSizeInBytes(); return this; } @@ -1104,7 +1104,7 @@ public static final class Builder implements AutoCloseable { // The first offset is always 0 this.offsets.setInt(0, 0); } else { - this.data = HostMemoryBuffer.allocate(rows * type.sizeInBytes); + this.data = HostMemoryBuffer.allocate(rows * type.getSizeInBytes()); } } @@ -1128,31 +1128,31 @@ public static final class Builder implements AutoCloseable { public final Builder append(boolean value) { assert type == DType.BOOL8; assert currentIndex < rows; - data.setByte(currentIndex * type.sizeInBytes, value ? (byte)1 : (byte)0); + data.setByte(currentIndex * type.getSizeInBytes(), value ? (byte)1 : (byte)0); currentIndex++; return this; } public final Builder append(byte value) { - assert type == DType.INT8 || type == DType.UINT8 || type == DType.BOOL8; + assert type.isBackedByByte(); assert currentIndex < rows; - data.setByte(currentIndex * type.sizeInBytes, value); + data.setByte(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } public final Builder append(byte value, long count) { assert (count + currentIndex) <= rows; - assert type == DType.INT8 || type == DType.UINT8 || type == DType.BOOL8; - data.setMemory(currentIndex * type.sizeInBytes, count, value); + assert type.isBackedByByte(); + data.setMemory(currentIndex * type.getSizeInBytes(), count, value); currentIndex += count; return this; } public final Builder append(short value) { - assert type == DType.INT16 || type == DType.UINT16; + assert type.isBackedByShort(); assert currentIndex < rows; - data.setShort(currentIndex * type.sizeInBytes, value); + data.setShort(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } @@ -1160,7 +1160,7 @@ public final Builder append(short value) { public final Builder append(int value) { assert type.isBackedByInt(); assert currentIndex < rows; - data.setInt(currentIndex * type.sizeInBytes, value); + data.setInt(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } @@ -1168,7 +1168,7 @@ public final Builder append(int value) { public final Builder append(long value) { assert type.isBackedByLong(); assert currentIndex < rows; - data.setLong(currentIndex * type.sizeInBytes, value); + data.setLong(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } @@ -1176,7 +1176,7 @@ public final Builder append(long value) { public final Builder append(float value) { assert type == DType.FLOAT32; assert currentIndex < rows; - data.setFloat(currentIndex * type.sizeInBytes, value); + data.setFloat(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } @@ -1184,7 +1184,7 @@ public final Builder append(float value) { public final Builder append(double value) { assert type == DType.FLOAT64; assert currentIndex < rows; - data.setDouble(currentIndex * type.sizeInBytes, value); + data.setDouble(currentIndex * type.getSizeInBytes(), value); currentIndex++; return this; } @@ -1239,16 +1239,16 @@ public Builder appendUTF8String(byte[] value, int offset, int length) { public Builder appendArray(byte... values) { assert (values.length + currentIndex) <= rows; - assert type == DType.INT8 || type == DType.UINT8 || type == DType.BOOL8; - data.setBytes(currentIndex * type.sizeInBytes, values, 0, values.length); + assert type.isBackedByByte(); + data.setBytes(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } public Builder appendArray(short... values) { - assert type == DType.INT16 || type == DType.UINT16; + assert type.isBackedByShort(); assert (values.length + currentIndex) <= rows; - data.setShorts(currentIndex * type.sizeInBytes, values, 0, values.length); + data.setShorts(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } @@ -1256,7 +1256,7 @@ public Builder appendArray(short... values) { public Builder appendArray(int... values) { assert type.isBackedByInt(); assert (values.length + currentIndex) <= rows; - data.setInts(currentIndex * type.sizeInBytes, values, 0, values.length); + data.setInts(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } @@ -1264,7 +1264,7 @@ public Builder appendArray(int... values) { public Builder appendArray(long... values) { assert type.isBackedByLong(); assert (values.length + currentIndex) <= rows; - data.setLongs(currentIndex * type.sizeInBytes, values, 0, values.length); + data.setLongs(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } @@ -1272,7 +1272,7 @@ public Builder appendArray(long... values) { public Builder appendArray(float... values) { assert type == DType.FLOAT32; assert (values.length + currentIndex) <= rows; - data.setFloats(currentIndex * type.sizeInBytes, values, 0, values.length); + data.setFloats(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } @@ -1280,7 +1280,7 @@ public Builder appendArray(float... values) { public Builder appendArray(double... values) { assert type == DType.FLOAT64; assert (values.length + currentIndex) <= rows; - data.setDoubles(currentIndex * type.sizeInBytes, values, 0, values.length); + data.setDoubles(currentIndex * type.getSizeInBytes(), values, 0, values.length); currentIndex += values.length; return this; } @@ -1429,15 +1429,15 @@ public final Builder appendBoxed(String... values) throws IndexOutOfBoundsExcept */ public final Builder append(HostColumnVector columnVector) { assert columnVector.rows <= (rows - currentIndex); - assert columnVector.type == type; + assert columnVector.type.equals(type); if (type == DType.STRING) { throw new UnsupportedOperationException( "Appending a string column vector client side is not currently supported"); } else { - data.copyFromHostBuffer(currentIndex * type.sizeInBytes, columnVector.offHeap.data, + data.copyFromHostBuffer(currentIndex * type.getSizeInBytes(), columnVector.offHeap.data, 0L, - columnVector.getRowCount() * type.sizeInBytes); + columnVector.getRowCount() * type.getSizeInBytes()); } //As this is doing the append on the host assume that a null count is available diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java index a0e727583f7..66aee37a2bc 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java @@ -42,6 +42,7 @@ public class HostColumnVectorCore implements ColumnViewAccess protected Optional nullCount; protected List children; + public HostColumnVectorCore(DType type, long rows, Optional nullCount, HostMemoryBuffer data, HostMemoryBuffer validity, HostMemoryBuffer offsets, List nestedChildren) { @@ -273,19 +274,18 @@ private void assertsForGet(long index) { * Get the value at index. */ public byte getByte(long index) { - assert type == DType.INT8 || type == DType.UINT8 || type == DType.BOOL8 : type + - " is not stored as a byte."; + assert type.isBackedByByte() : type + " is not stored as a byte."; assertsForGet(index); - return offHeap.data.getByte(index * type.sizeInBytes); + return offHeap.data.getByte(index * type.getSizeInBytes()); } /** * Get the value at index. */ public final short getShort(long index) { - assert type == DType.INT16 || type == DType.UINT16 : type + " is not stored as a short."; + assert type.isBackedByShort() : type + " is not stored as a short."; assertsForGet(index); - return offHeap.data.getShort(index * type.sizeInBytes); + return offHeap.data.getShort(index * type.getSizeInBytes()); } /** @@ -294,7 +294,7 @@ public final short getShort(long index) { public final int getInt(long index) { assert type.isBackedByInt() : type + " is not stored as a int."; assertsForGet(index); - return offHeap.data.getInt(index * type.sizeInBytes); + return offHeap.data.getInt(index * type.getSizeInBytes()); } /** @@ -309,7 +309,7 @@ long getStartStringOffset(long index) { * Get the starting element offset for the list or string at index */ long getStartListOffset(long index) { - assert type == DType.STRING || type == DType.LIST: type + + assert type.equals(DType.STRING) || type.equals(DType.LIST): type + " is not a supported string or list type."; assert (index >= 0 && index < rows) : "index is out of range 0 <= " + index + " < " + rows; return offHeap.offsets.getInt(index * 4); @@ -327,7 +327,7 @@ long getEndStringOffset(long index) { * Get the ending element offset for the list or string at index. */ long getEndListOffset(long index) { - assert type == DType.STRING || type == DType.LIST: type + + assert type.equals(DType.STRING) || type.equals(DType.LIST): type + " is not a supported string or list type."; assert (index >= 0 && index < rows) : "index is out of range 0 <= " + index + " < " + rows; // The offsets has one more entry than there are rows. @@ -341,34 +341,34 @@ public final long getLong(long index) { // Timestamps with time values are stored as longs assert type.isBackedByLong(): type + " is not stored as a long."; assertsForGet(index); - return offHeap.data.getLong(index * type.sizeInBytes); + return offHeap.data.getLong(index * type.getSizeInBytes()); } /** * Get the value at index. */ public final float getFloat(long index) { - assert type == DType.FLOAT32 : type + " is not a supported float type."; + assert type.equals(DType.FLOAT32) : type + " is not a supported float type."; assertsForGet(index); - return offHeap.data.getFloat(index * type.sizeInBytes); + return offHeap.data.getFloat(index * type.getSizeInBytes()); } /** * Get the value at index. */ public final double getDouble(long index) { - assert type == DType.FLOAT64 : type + " is not a supported double type."; + assert type.equals(DType.FLOAT64) : type + " is not a supported double type."; assertsForGet(index); - return offHeap.data.getDouble(index * type.sizeInBytes); + return offHeap.data.getDouble(index * type.getSizeInBytes()); } /** * Get the boolean value at index */ public final boolean getBoolean(long index) { - assert type == DType.BOOL8 : type + " is not a supported boolean type."; + assert type.equals(DType.BOOL8) : type + " is not a supported boolean type."; assertsForGet(index); - return offHeap.data.getBoolean(index * type.sizeInBytes); + return offHeap.data.getBoolean(index * type.getSizeInBytes()); } /** @@ -376,7 +376,7 @@ public final boolean getBoolean(long index) { * ideal because it is copying the data onto the heap. */ public byte[] getUTF8(long index) { - assert type == DType.STRING : type + " is not a supported string type."; + assert type.equals(DType.STRING) : type + " is not a supported string type."; assertsForGet(index); int start = (int)getStartListOffset(index); int size = (int)getEndListOffset(index) - start; @@ -403,9 +403,9 @@ public String getJavaString(long index) { * of lists and may not have nulls. */ public byte[] getBytesFromList(long rowIndex) { - assert type == DType.LIST : type + " is not a supported list of bytes type."; + assert type.equals(DType.LIST) : type + " is not a supported list of bytes type."; HostColumnVectorCore listData = children.get(0); - assert listData.type == DType.INT8 || listData.type == DType.UINT8 : type + + assert listData.type.equals(DType.INT8) || listData.type.equals(DType.UINT8) : type + " is not a supported list of bytes type."; assert !listData.hasNulls() : "byte list column with nulls are not supported"; assertsForGet(rowIndex); @@ -495,7 +495,7 @@ public boolean hasNulls() { private Object readValue(int rowIndex) { assert rowIndex < rows; int rowOffset = rowIndex * type.getSizeInBytes(); - switch (type) { + switch (type.typeId) { case INT32: // fall through case UINT32: // fall through case TIMESTAMP_DAYS: diff --git a/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java b/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java index 38fa0c53d23..dbc9ec9c032 100644 --- a/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java +++ b/java/src/main/java/ai/rapids/cudf/JCudfSerialization.java @@ -183,7 +183,7 @@ private void readFrom(DataInputStream din) throws IOException { types = new DType[numColumns]; nullCounts = new long[numColumns]; for (int i = 0; i < numColumns; i++) { - types[i] = DType.fromNative(din.readInt()); + types[i] = DType.fromNative(din.readInt(), din.readInt()); nullCounts[i] = din.readInt(); } @@ -200,7 +200,8 @@ public void writeTo(DataWriter dout) throws IOException { // Header for each column... for (int i = 0; i < numColumns; i++) { - dout.writeInt(types[i].nativeId); + dout.writeInt(types[i].typeId.getNativeId()); + dout.writeInt(types[i].getScale()); dout.writeInt((int) nullCounts[i]); } dout.writeLong(dataLen); @@ -564,7 +565,7 @@ private static long getSlicedSerializedDataSizeInBytes(ColumnBufferProvider[] co totalDataSize += padFor64byteAlignment(getRawStringDataLength(column, rowOffset, numRows)); } } else { - totalDataSize += padFor64byteAlignment(column.getType().sizeInBytes * numRows); + totalDataSize += padFor64byteAlignment(column.getType().getSizeInBytes() * numRows); } } return totalDataSize; @@ -593,7 +594,7 @@ private static long getConcatedSerializedDataSizeInBytes(int numColumns, long[] } totalDataSize += padFor64byteAlignment(stringDataLen); } else { - totalDataSize += padFor64byteAlignment(types[col].sizeInBytes * numRows); + totalDataSize += padFor64byteAlignment(types[col].getSizeInBytes() * numRows); } } return totalDataSize; @@ -657,7 +658,7 @@ static ColumnOffsets[] buildIndex(SerializedTableHeader header, bufferOffset += padFor64byteAlignment(dataLen); } } else { - dataLen = type.sizeInBytes * numRows; + dataLen = type.getSizeInBytes() * numRows; data = bufferOffset; bufferOffset += padFor64byteAlignment(dataLen); } @@ -1170,8 +1171,8 @@ private static long sliceBasicData(DataWriter out, long rowOffset, long numRows) throws IOException { DType type = column.getType(); - long bytesToCopy = numRows * type.sizeInBytes; - long srcOffset = rowOffset * type.sizeInBytes; + long bytesToCopy = numRows * type.getSizeInBytes(); + long srcOffset = rowOffset * type.getSizeInBytes(); return copySlicedAndPad(out, column, BufferType.DATA, srcOffset, bytesToCopy); } @@ -1186,7 +1187,7 @@ private static void concatBasicData(DataWriter out, long currentOffset = provider.getBufferStartOffset(BufferType.DATA); int numRowsForBatch = (int) provider.getRowCount(); - int dataLeft = numRowsForBatch * type.sizeInBytes; + int dataLeft = numRowsForBatch * type.getSizeInBytes(); out.copyDataFrom(dataBuffer, currentOffset, dataLeft); totalCopied += dataLeft; } diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 337a12de39d..6c9ca6a3282 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -21,6 +21,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Objects; @@ -40,7 +42,7 @@ public final class Scalar implements AutoCloseable, BinaryOperable { private final OffHeapState offHeap; public static Scalar fromNull(DType type) { - switch (type) { + switch (type.typeId) { case EMPTY: case BOOL8: return new Scalar(type, makeBool8Scalar(false, false)); @@ -70,7 +72,7 @@ public static Scalar fromNull(DType type) { case TIMESTAMP_MILLISECONDS: case TIMESTAMP_MICROSECONDS: case TIMESTAMP_NANOSECONDS: - return new Scalar(type, makeTimestampTimeScalar(type.nativeId, 0, false)); + return new Scalar(type, makeTimestampTimeScalar(type.typeId.getNativeId(), 0, false)); case STRING: return new Scalar(type, makeStringScalar(null, false)); case DURATION_DAYS: @@ -79,7 +81,11 @@ public static Scalar fromNull(DType type) { case DURATION_MILLISECONDS: case DURATION_NANOSECONDS: case DURATION_SECONDS: - return new Scalar(type, makeDurationTimeScalar(type.nativeId, 0, false)); + return new Scalar(type, makeDurationTimeScalar(type.typeId.getNativeId(), 0, false)); + case DECIMAL32: + return new Scalar(type, makeDecimal32Scalar(0, type.getScale(), false)); + case DECIMAL64: + return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false)); default: throw new IllegalArgumentException("Unexpected type: " + type); } @@ -227,6 +233,20 @@ public static Scalar fromDouble(Double value) { return Scalar.fromDouble(value.doubleValue()); } + public static Scalar fromBigDecimal(BigDecimal value) { + if (value == null) { + return Scalar.fromNull(DType.create(DType.DTypeEnum.DECIMAL64, 0)); + } + DType dt = DType.fromJavaBigDecimal(value); + long handle; + if (dt.typeId == DType.DTypeEnum.DECIMAL32) { + handle = makeDecimal32Scalar(value.unscaledValue().intValueExact(), -value.scale(), true); + } else { + handle = makeDecimal64Scalar(value.unscaledValue().longValueExact(), -value.scale(), true); + } + return new Scalar(dt, handle); + } + public static Scalar timestampDaysFromInt(int value) { return new Scalar(DType.TIMESTAMP_DAYS, makeTimestampDaysScalar(value, true)); } @@ -253,7 +273,7 @@ public static Scalar durationFromLong(DType type, long value) { } return durationDaysFromInt(intValue); } else { - return new Scalar(type, makeDurationTimeScalar(type.nativeId, value, true)); + return new Scalar(type, makeDurationTimeScalar(type.typeId.getNativeId(), value, true)); } } else { throw new IllegalArgumentException("type is not a timestamp: " + type); @@ -282,7 +302,7 @@ public static Scalar timestampFromLong(DType type, long value) { } return timestampDaysFromInt(intValue); } else { - return new Scalar(type, makeTimestampTimeScalar(type.nativeId, value, true)); + return new Scalar(type, makeTimestampTimeScalar(type.typeId.getNativeId(), value, true)); } } else { throw new IllegalArgumentException("type is not a timestamp: " + type); @@ -328,6 +348,8 @@ public static Scalar fromString(String value) { private static native long makeDurationTimeScalar(int dtype, long value, boolean isValid); private static native long makeTimestampDaysScalar(int value, boolean isValid); private static native long makeTimestampTimeScalar(int dtypeNativeId, long value, boolean isValid); + private static native long makeDecimal32Scalar(int value, int scale, boolean isValid); + private static native long makeDecimal64Scalar(long value, int scale, boolean isValid); Scalar(DType type, long scalarHandle) { @@ -426,6 +448,18 @@ public double getDouble() { return getDouble(getScalarHandle()); } + /** + * Returns the scalar value as a BigDecimal. + */ + public BigDecimal getBigDecimal() { + if (this.type.typeId == DType.DTypeEnum.DECIMAL32) { + return BigDecimal.valueOf(getInt(), -type.getScale()); + } else if (this.type.typeId == DType.DTypeEnum.DECIMAL64) { + return BigDecimal.valueOf(getLong(), -type.getScale()); + } + throw new IllegalArgumentException("Couldn't getBigDecimal from nonDecimal scalar"); + } + /** * Returns the scalar value as a Java string. */ @@ -453,10 +487,10 @@ public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) { static long binaryOp(Scalar lhs, ColumnVector rhs, BinaryOp op, DType outputType) { return binaryOpSV(lhs.getScalarHandle(), rhs.getNativeView(), - op.nativeId, outputType.nativeId); + op.nativeId, outputType.typeId.getNativeId(), outputType.getScale()); } - private static native long binaryOpSV(long lhs, long rhs, int op, int dtype); + private static native long binaryOpSV(long lhs, long rhs, int op, int dtype, int scale); @Override public boolean equals(Object o) { @@ -467,7 +501,7 @@ public boolean equals(Object o) { boolean valid = isValid(); if (valid != other.isValid()) return false; if (!valid) return true; - switch (type) { + switch (type.typeId) { case EMPTY: return true; case BOOL8: @@ -504,7 +538,7 @@ public boolean equals(Object o) { public int hashCode() { int valueHash = 0; if (isValid()) { - switch (type) { + switch (type.typeId) { case EMPTY: valueHash = 0; break; @@ -554,7 +588,7 @@ public String toString() { sb.append(type); if (getScalarHandle() != 0) { sb.append(" value="); - switch (type) { + switch (type.typeId) { case BOOL8: sb.append(getBoolean()); break; diff --git a/java/src/main/java/ai/rapids/cudf/Schema.java b/java/src/main/java/ai/rapids/cudf/Schema.java index 4391fad91f8..f0bc3d930d9 100644 --- a/java/src/main/java/ai/rapids/cudf/Schema.java +++ b/java/src/main/java/ai/rapids/cudf/Schema.java @@ -65,7 +65,7 @@ public static class Builder { private final List typeNames = new ArrayList<>(); public Builder column(DType type, String name) { - typeNames.add(type.simpleName); + typeNames.add(type.getSimpleName()); names.add(name); return this; } diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index fdad5b20631..76df58cab9f 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -603,7 +603,7 @@ public static Table readParquet(File path) { */ public static Table readParquet(ParquetOptions opts, File path) { return new Table(readParquet(opts.getIncludeColumnNames(), - path.getAbsolutePath(), 0, 0, opts.timeUnit().nativeId)); + path.getAbsolutePath(), 0, 0, opts.timeUnit().typeId.getNativeId())); } /** @@ -663,7 +663,7 @@ public static Table readParquet(ParquetOptions opts, HostMemoryBuffer buffer, assert len <= buffer.getLength() - offset; assert offset >= 0 && offset < buffer.length; return new Table(readParquet(opts.getIncludeColumnNames(), - null, buffer.getAddress() + offset, len, opts.timeUnit().nativeId)); + null, buffer.getAddress() + offset, len, opts.timeUnit().typeId.getNativeId())); } /** @@ -683,7 +683,7 @@ public static Table readORC(File path) { */ public static Table readORC(ORCOptions opts, File path) { return new Table(readORC(opts.getIncludeColumnNames(), - path.getAbsolutePath(), 0, 0, opts.usingNumPyTypes(), opts.timeUnit().nativeId)); + path.getAbsolutePath(), 0, 0, opts.usingNumPyTypes(), opts.timeUnit().typeId.getNativeId())); } /** @@ -744,7 +744,7 @@ public static Table readORC(ORCOptions opts, HostMemoryBuffer buffer, assert offset >= 0 && offset < buffer.length; return new Table(readORC(opts.getIncludeColumnNames(), null, buffer.getAddress() + offset, len, opts.usingNumPyTypes(), - opts.timeUnit().nativeId)); + opts.timeUnit().typeId.getNativeId())); } private static class ParquetTableWriter implements TableWriter { @@ -2333,7 +2333,7 @@ public TestBuilder timestampSecondsColumn(Long... values) { private static ColumnVector from(DType type, Object dataArray) { ColumnVector ret = null; - switch (type) { + switch (type.typeId) { case STRING: ret = ColumnVector.fromStrings((String[]) dataArray); break; diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 9b54137a30a..721640cc71f 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -56,35 +56,8 @@ #include #include "cudf_jni_apis.hpp" +#include "dtype_utils.hpp" -namespace { - -// convert a timestamp type to the corresponding duration type -cudf::data_type timestamp_to_duration(cudf::data_type dt) { - cudf::type_id duration_type_id; - switch (dt.id()) { - case cudf::type_id::TIMESTAMP_DAYS: - duration_type_id = cudf::type_id::DURATION_DAYS; - break; - case cudf::type_id::TIMESTAMP_SECONDS: - duration_type_id = cudf::type_id::DURATION_SECONDS; - break; - case cudf::type_id::TIMESTAMP_MILLISECONDS: - duration_type_id = cudf::type_id::DURATION_MILLISECONDS; - break; - case cudf::type_id::TIMESTAMP_MICROSECONDS: - duration_type_id = cudf::type_id::DURATION_MICROSECONDS; - break; - case cudf::type_id::TIMESTAMP_NANOSECONDS: - duration_type_id = cudf::type_id::DURATION_NANOSECONDS; - break; - default: - throw std::logic_error("Unexpected type in timestamp_to_duration"); - } - return cudf::data_type(duration_type_id); -} - -} // anonymous namespace extern "C" { @@ -286,14 +259,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_ifElseSS(JNIEnv *env, j JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_reduce(JNIEnv *env, jclass, jlong j_col_view, jlong j_agg, - jint j_dtype) { + jint j_dtype, jint scale) { JNI_NULL_CHECK(env, j_col_view, "column view is null", 0); JNI_NULL_CHECK(env, j_agg, "aggregation is null", 0); try { cudf::jni::auto_set_device(env); auto col = reinterpret_cast(j_col_view); auto agg = reinterpret_cast(j_agg); - cudf::data_type out_dtype{static_cast(j_dtype)}; + cudf::data_type out_dtype = cudf::jni::make_data_type(j_dtype, scale); + std::unique_ptr result = cudf::reduce(*col, agg->clone(), out_dtype); return reinterpret_cast(result.release()); } @@ -693,13 +667,14 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_dayOfYear(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castTo(JNIEnv *env, jobject j_object, - jlong handle, jint type) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castTo(JNIEnv *env, jclass, + jlong handle, jint type, + jint scale) { JNI_NULL_CHECK(env, handle, "native handle is null", 0); try { cudf::jni::auto_set_device(env); cudf::column_view *column = reinterpret_cast(handle); - cudf::data_type n_data_type(static_cast(type)); + cudf::data_type n_data_type = cudf::jni::make_data_type(type, scale); std::unique_ptr result; if (n_data_type.id() == cudf::type_id::STRING) { switch (column->type().id()) { @@ -756,7 +731,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castTo(JNIEnv *env, job JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Numeric cast to non-day timestamp requires INT64", 0); } } - cudf::data_type duration_type = timestamp_to_duration(n_data_type); + cudf::data_type duration_type = cudf::jni::timestamp_to_duration(n_data_type); cudf::column_view duration_view = cudf::column_view(duration_type, column->size(), column->head(), @@ -767,7 +742,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castTo(JNIEnv *env, job // This is a temporary workaround to allow Java to cast from timestamp types to integral types // without forcing an intermediate duration column to be manifested. Ultimately this style of // "reinterpret" casting will be supported via https://github.com/rapidsai/cudf/pull/5358 - cudf::data_type duration_type = timestamp_to_duration(column->type()); + cudf::data_type duration_type = cudf::jni::timestamp_to_duration(column->type()); cudf::column_view duration_view = cudf::column_view(duration_type, column->size(), column->head(), @@ -996,7 +971,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_stringConcatenation( JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_binaryOpVV(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_view, - jint int_op, jint out_dtype) { + jint int_op, jint out_dtype, + jint scale) { JNI_NULL_CHECK(env, lhs_view, "lhs is null", 0); JNI_NULL_CHECK(env, rhs_view, "rhs is null", 0); try { @@ -1004,9 +980,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_binaryOpVV(JNIEnv *env, auto lhs = reinterpret_cast(lhs_view); auto rhs = reinterpret_cast(rhs_view); + cudf::data_type n_data_type = cudf::jni::make_data_type(out_dtype, scale); cudf::binary_operator op = static_cast(int_op); std::unique_ptr result = cudf::binary_operation( - *lhs, *rhs, op, cudf::data_type(static_cast(out_dtype))); + *lhs, *rhs, op, n_data_type); return reinterpret_cast(result.release()); } CATCH_STD(env, 0); @@ -1014,17 +991,19 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_binaryOpVV(JNIEnv *env, JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_binaryOpVS(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_ptr, - jint int_op, jint out_dtype) { + jint int_op, jint out_dtype, + jint scale) { JNI_NULL_CHECK(env, lhs_view, "lhs is null", 0); JNI_NULL_CHECK(env, rhs_ptr, "rhs is null", 0); try { cudf::jni::auto_set_device(env); auto lhs = reinterpret_cast(lhs_view); cudf::scalar *rhs = reinterpret_cast(rhs_ptr); + cudf::data_type n_data_type = cudf::jni::make_data_type(out_dtype, scale); cudf::binary_operator op = static_cast(int_op); std::unique_ptr result = cudf::binary_operation( - *lhs, *rhs, op, cudf::data_type(static_cast(out_dtype))); + *lhs, *rhs, op, n_data_type); return reinterpret_cast(result.release()); } CATCH_STD(env, 0); @@ -1297,7 +1276,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_hash(JNIEnv *env, //////// JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeCudfColumnView( - JNIEnv *env, jobject j_object, jint j_type, jlong j_data, jlong j_data_size, jlong j_offset, + JNIEnv *env, jclass, jint j_type, jint scale, jlong j_data, jlong j_data_size, jlong j_offset, jlong j_valid, jint j_null_count, jint size, jlongArray j_children) { JNI_ARG_CHECK(env, (size != 0), "size is 0", 0); @@ -1305,7 +1284,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeCudfColumnView( using cudf::column_view; cudf::jni::auto_set_device(env); cudf::type_id n_type = static_cast(j_type); - cudf::data_type n_data_type(n_type); + cudf::data_type n_data_type = cudf::jni::make_data_type(j_type, scale); std::unique_ptr ret; void *data = reinterpret_cast(j_data); @@ -1364,8 +1343,20 @@ JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeTypeId(JNIEnv * CATCH_STD(env, 0); } +JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeTypeScale(JNIEnv *env, + jclass, + jlong handle) { + JNI_NULL_CHECK(env, handle, "native handle is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *column = reinterpret_cast(handle); + return column->type().scale(); + } + CATCH_STD(env, 0); +} + JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeRowCount(JNIEnv *env, - jobject j_object, + jclass, jlong handle) { JNI_NULL_CHECK(env, handle, "native handle is null", 0); try { @@ -1619,13 +1610,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeColumnView(JNI } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNIEnv *env, - jobject j_object, - jint j_type) { + jclass, + jint j_type, + jint scale) { try { cudf::jni::auto_set_device(env); cudf::type_id n_type = static_cast(j_type); - cudf::data_type n_data_type(n_type); + cudf::data_type n_data_type = cudf::jni::make_data_type(j_type, scale); + std::unique_ptr column(cudf::make_empty_column(n_data_type)); return reinterpret_cast(column.release()); } diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp index f2242713acc..4e74cab9328 100644 --- a/java/src/main/native/src/ScalarJni.cpp +++ b/java/src/main/native/src/ScalarJni.cpp @@ -15,9 +15,11 @@ */ #include +#include #include #include "cudf_jni_apis.hpp" +#include "dtype_utils.hpp" extern "C" { @@ -393,19 +395,51 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeTimestampTimeScalar(JNIEn CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeDecimal32Scalar(JNIEnv *env, jclass, + jint value, + jint scale, + jboolean is_valid) { + try { + cudf::jni::auto_set_device(env); + auto const value_ = static_cast(value); + auto const scale_ = numeric::scale_type{static_cast(scale)}; + std::unique_ptr s = cudf::make_fixed_point_scalar(value_, scale_); + s->set_valid(is_valid); + return reinterpret_cast(s.release()); + } + CATCH_STD(env, 0); +} + + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeDecimal64Scalar(JNIEnv *env, jclass, + jlong value, + jint scale, + jboolean is_valid) { + try { + cudf::jni::auto_set_device(env); + auto const value_ = static_cast(value); + auto const scale_ = numeric::scale_type{static_cast(scale)}; + std::unique_ptr s = cudf::make_fixed_point_scalar(value_, scale_); + s->set_valid(is_valid); + return reinterpret_cast(s.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclass, jlong lhs_ptr, jlong rhs_view, jint int_op, - jint out_dtype) { + jint out_dtype, jint scale) { JNI_NULL_CHECK(env, lhs_ptr, "lhs is null", 0); JNI_NULL_CHECK(env, rhs_view, "rhs is null", 0); try { cudf::jni::auto_set_device(env); cudf::scalar *lhs = reinterpret_cast(lhs_ptr); auto rhs = reinterpret_cast(rhs_view); + cudf::data_type n_data_type = cudf::jni::make_data_type(out_dtype, scale); cudf::binary_operator op = static_cast(int_op); std::unique_ptr result = cudf::binary_operation( - *lhs, *rhs, op, cudf::data_type(static_cast(out_dtype))); + *lhs, *rhs, op, n_data_type); return reinterpret_cast(result.release()); } CATCH_STD(env, 0); diff --git a/java/src/main/native/src/dtype_utils.hpp b/java/src/main/native/src/dtype_utils.hpp new file mode 100644 index 00000000000..bde7bd2894e --- /dev/null +++ b/java/src/main/native/src/dtype_utils.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace jni { + +// convert a timestamp type to the corresponding duration type +inline cudf::data_type timestamp_to_duration(cudf::data_type dt) { + cudf::type_id duration_type_id; + switch (dt.id()) { + case cudf::type_id::TIMESTAMP_DAYS: + duration_type_id = cudf::type_id::DURATION_DAYS; + break; + case cudf::type_id::TIMESTAMP_SECONDS: + duration_type_id = cudf::type_id::DURATION_SECONDS; + break; + case cudf::type_id::TIMESTAMP_MILLISECONDS: + duration_type_id = cudf::type_id::DURATION_MILLISECONDS; + break; + case cudf::type_id::TIMESTAMP_MICROSECONDS: + duration_type_id = cudf::type_id::DURATION_MICROSECONDS; + break; + case cudf::type_id::TIMESTAMP_NANOSECONDS: + duration_type_id = cudf::type_id::DURATION_NANOSECONDS; + break; + default: + throw std::logic_error("Unexpected type in timestamp_to_duration"); + } + return cudf::data_type(duration_type_id); +} + +inline bool is_decimal_type(cudf::type_id n_type) { + return n_type == cudf::type_id::DECIMAL32 || n_type == cudf::type_id::DECIMAL64 ; +} + +// create data_type including scale for decimal type +inline cudf::data_type make_data_type(jint out_dtype, jint scale) { + cudf::type_id n_type = static_cast(out_dtype); + cudf::data_type n_data_type; + if (is_decimal_type(n_type)) { + n_data_type = cudf::data_type(n_type, scale); + } else { + n_data_type = cudf::data_type(n_type); + } + return n_data_type; +} + +} // namespace jni +} // namespace cudf \ No newline at end of file diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 82f6d48fa6b..09df2279f8f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -703,7 +703,11 @@ void testSequenceOtherTypes() { @Test void testFromScalarZeroRows() { - for (DType type : DType.values()) { + for (DType.DTypeEnum type : DType.DTypeEnum.values()) { + // Decimal type not supported yet. Update this once it is supported. + if (type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { + continue; + } Scalar s = null; try { switch (type) { @@ -747,7 +751,7 @@ void testFromScalarZeroRows() { case TIMESTAMP_MILLISECONDS: case TIMESTAMP_MICROSECONDS: case TIMESTAMP_NANOSECONDS: - s = Scalar.timestampFromLong(type, 1234567890123456789L); + s = Scalar.timestampFromLong(DType.create(type), 1234567890123456789L); break; case STRING: s = Scalar.fromString("hello, world!"); @@ -759,7 +763,7 @@ void testFromScalarZeroRows() { case DURATION_MILLISECONDS: case DURATION_MICROSECONDS: case DURATION_NANOSECONDS: - s = Scalar.durationFromLong(type, 21313); + s = Scalar.durationFromLong(DType.create(type), 21313); break; case EMPTY: case LIST: @@ -770,7 +774,7 @@ void testFromScalarZeroRows() { } try (ColumnVector c = ColumnVector.fromScalar(s, 0)) { - assertEquals(type, c.getType()); + assertEquals(DType.create(type), c.getType()); assertEquals(0, c.getRowCount()); assertEquals(0, c.getNullCount()); } @@ -793,7 +797,10 @@ void testGetNativeView() { @Test void testFromScalar() { final int rowCount = 4; - for (DType type : DType.values()) { + for (DType.DTypeEnum type : DType.DTypeEnum.values()) { + if(type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { + continue; + } Scalar s = null; ColumnVector expected = null; ColumnVector result = null; @@ -871,25 +878,25 @@ void testFromScalar() { } case TIMESTAMP_SECONDS: { long v = 1234567890123456789L; - s = Scalar.timestampFromLong(type, v); + s = Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, v); expected = ColumnVector.timestampSecondsFromLongs(v, v, v, v); break; } case TIMESTAMP_MILLISECONDS: { long v = 1234567890123456789L; - s = Scalar.timestampFromLong(type, v); + s = Scalar.timestampFromLong(DType.TIMESTAMP_MILLISECONDS, v); expected = ColumnVector.timestampMilliSecondsFromLongs(v, v, v, v); break; } case TIMESTAMP_MICROSECONDS: { long v = 1234567890123456789L; - s = Scalar.timestampFromLong(type, v); + s = Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, v); expected = ColumnVector.timestampMicroSecondsFromLongs(v, v, v, v); break; } case TIMESTAMP_NANOSECONDS: { long v = 1234567890123456789L; - s = Scalar.timestampFromLong(type, v); + s = Scalar.timestampFromLong(DType.TIMESTAMP_NANOSECONDS, v); expected = ColumnVector.timestampNanoSecondsFromLongs(v, v, v, v); break; } @@ -907,25 +914,25 @@ void testFromScalar() { } case DURATION_MICROSECONDS: { long v = 1123123123L; - s = Scalar.durationFromLong(type, v); + s = Scalar.durationFromLong(DType.DURATION_MICROSECONDS, v); expected = ColumnVector.durationMicroSecondsFromLongs(v, v, v, v); break; } case DURATION_MILLISECONDS: { long v = 11212432423L; - s = Scalar.durationFromLong(type, v); + s = Scalar.durationFromLong(DType.DURATION_MILLISECONDS, v); expected = ColumnVector.durationMilliSecondsFromLongs(v, v, v, v); break; } case DURATION_NANOSECONDS: { long v = 12353245343L; - s = Scalar.durationFromLong(type, v); + s = Scalar.durationFromLong(DType.DURATION_NANOSECONDS, v); expected = ColumnVector.durationNanoSecondsFromLongs(v, v, v, v); break; } case DURATION_SECONDS: { long v = 132342321123L; - s = Scalar.durationFromLong(type, v); + s = Scalar.durationFromLong(DType.DURATION_SECONDS, v); expected = ColumnVector.durationSecondsFromLongs(v, v, v, v); break; } @@ -956,14 +963,15 @@ void testFromScalar() { @Test void testFromScalarNull() { final int rowCount = 4; - for (DType type : DType.values()) { - if (type == DType.EMPTY || type == DType.LIST || type == DType.STRUCT) { + for (DType.DTypeEnum type : DType.DTypeEnum.values()) { + if (type == DType.DTypeEnum.EMPTY || type == DType.DTypeEnum.LIST || type == DType.DTypeEnum.STRUCT + || type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { continue; } - try (Scalar s = Scalar.fromNull(type); + try (Scalar s = Scalar.fromNull(DType.create(type)); ColumnVector c = ColumnVector.fromScalar(s, rowCount); HostColumnVector hc = c.copyToHost()) { - assertEquals(type, c.getType()); + assertEquals(type, c.getType().typeId); assertEquals(rowCount, c.getRowCount()); assertEquals(rowCount, c.getNullCount()); for (int i = 0; i < rowCount; ++i) { diff --git a/java/src/test/java/ai/rapids/cudf/HostMemoryBufferTest.java b/java/src/test/java/ai/rapids/cudf/HostMemoryBufferTest.java index d12ab448b77..e848d4a89bf 100644 --- a/java/src/test/java/ai/rapids/cudf/HostMemoryBufferTest.java +++ b/java/src/test/java/ai/rapids/cudf/HostMemoryBufferTest.java @@ -82,8 +82,8 @@ void testDoubleFree() { public void testGetInt() { try (HostMemoryBuffer hostMemoryBuffer = HostMemoryBuffer.allocate(16)) { long offset = 1; - hostMemoryBuffer.setInt(offset * DType.INT32.sizeInBytes, 2); - assertEquals(2, hostMemoryBuffer.getInt(offset * DType.INT32.sizeInBytes)); + hostMemoryBuffer.setInt(offset * DType.INT32.getSizeInBytes(), 2); + assertEquals(2, hostMemoryBuffer.getInt(offset * DType.INT32.getSizeInBytes())); } } @@ -91,8 +91,8 @@ public void testGetInt() { public void testGetByte() { try (HostMemoryBuffer hostMemoryBuffer = HostMemoryBuffer.allocate(16)) { long offset = 1; - hostMemoryBuffer.setByte(offset * DType.INT8.sizeInBytes, (byte) 2); - assertEquals((byte) 2, hostMemoryBuffer.getByte(offset * DType.INT8.sizeInBytes)); + hostMemoryBuffer.setByte(offset * DType.INT8.getSizeInBytes(), (byte) 2); + assertEquals((byte) 2, hostMemoryBuffer.getByte(offset * DType.INT8.getSizeInBytes())); } } @@ -100,8 +100,8 @@ public void testGetByte() { public void testGetLong() { try (HostMemoryBuffer hostMemoryBuffer = HostMemoryBuffer.allocate(16)) { long offset = 1; - hostMemoryBuffer.setLong(offset * DType.INT64.sizeInBytes, 3); - assertEquals(3, hostMemoryBuffer.getLong(offset * DType.INT64.sizeInBytes)); + hostMemoryBuffer.setLong(offset * DType.INT64.getSizeInBytes(), 3); + assertEquals(3, hostMemoryBuffer.getLong(offset * DType.INT64.getSizeInBytes())); } } @@ -109,7 +109,7 @@ public void testGetLong() { public void testGetLongs() { try (HostMemoryBuffer hostMemoryBuffer = HostMemoryBuffer.allocate(16)) { hostMemoryBuffer.setLong(0, 3); - hostMemoryBuffer.setLong(DType.INT64.sizeInBytes, 10); + hostMemoryBuffer.setLong(DType.INT64.getSizeInBytes(), 10); long[] results = new long[2]; hostMemoryBuffer.getLongs(results, 0, 0, 2); assertEquals(3, results[0]); diff --git a/java/src/test/java/ai/rapids/cudf/ReductionTest.java b/java/src/test/java/ai/rapids/cudf/ReductionTest.java index 2641591f5c0..a71f8b3e8af 100644 --- a/java/src/test/java/ai/rapids/cudf/ReductionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ReductionTest.java @@ -55,7 +55,7 @@ private static Scalar buildExpectedScalar(Aggregation op, DType baseType, Object if (BOOL_REDUCTIONS.contains(op.kind)) { return Scalar.fromBool((Boolean) expectedObject); } - switch (baseType) { + switch (baseType.typeId) { case BOOL8: return Scalar.fromBool((Boolean) expectedObject); case INT8: diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index d512168b6c0..47cad78ce5c 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -20,6 +20,9 @@ import org.junit.jupiter.api.Test; +import java.math.BigDecimal; +import java.math.MathContext; + import static org.junit.jupiter.api.Assertions.*; public class ScalarTest extends CudfTestBase { @@ -45,7 +48,13 @@ public void testIncRef() { @Test public void testNull() { - for (DType type : DType.values()) { + for (DType.DTypeEnum dataType : DType.DTypeEnum.values()) { + DType type; + if (dataType == DType.DTypeEnum.DECIMAL32 || dataType == DType.DTypeEnum.DECIMAL64) { + type = DType.create(dataType, -3); + } else { + type = DType.create(dataType); + } if (!type.isNestedType()) { try (Scalar s = Scalar.fromNull(type)) { assertEquals(type, s.getType()); @@ -118,6 +127,23 @@ public void testDouble() { } } + @Test + public void testDecimal() { + BigDecimal[] bigDecimals = new BigDecimal[]{ + BigDecimal.valueOf(1234, 0), + BigDecimal.valueOf(12345678, 2), + BigDecimal.valueOf(1234567890123L, 6), + }; + for (BigDecimal bigDec: bigDecimals) { + try (Scalar s = Scalar.fromBigDecimal(bigDec)) { + assertEquals(DType.fromJavaBigDecimal(bigDec), s.getType()); + assertTrue(s.isValid()); + assertEquals(bigDec.unscaledValue().longValueExact(), s.getLong()); + assertEquals(bigDec, s.getBigDecimal()); + } + } + } + @Test public void testTimestampDays() { try (Scalar s = Scalar.timestampDaysFromInt(7)) { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index b098d359d96..194d313f100 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -173,7 +173,7 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l assertEquals(expected.isNull(expectedRow), cv.isNull(tableRow), "NULL for Column " + colName + " Row " + tableRow); if (!expected.isNull(expectedRow)) { - switch (type) { + switch (type.typeId) { case BOOL8: // fall through case INT8: // fall through case UINT8: