Skip to content

Commit

Permalink
Initial work for decimal type in Java/JNI [skip ci] (#6514)
Browse files Browse the repository at this point in the history
This partially addresses tasks from issue #6515.
The goal of this PR is to add DType class and make the necessary changes in Java/JNI code to support scale in DType.
  • Loading branch information
nartal1 authored Oct 27, 2020
1 parent 5a594e6 commit 4b4d962
Show file tree
Hide file tree
Showing 17 changed files with 710 additions and 287 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
- PR #6573 Create `cudf::detail::byte_cast` for `cudf::byte_cast`
- PR #6597 Use thread-local to track CUDA device in JNI
- PR #6599 Replace `size()==0` with `empty()`, `is_empty()`
- PR #6514 Initial work for decimal type in Java/JNI

## Bug Fixes

Expand Down
52 changes: 30 additions & 22 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public long getNullCount() {

@Override
public DType getDataType() {
return DType.fromNative(getNativeTypeId(viewHandle));
return DType.fromNative(getNativeTypeId(viewHandle), getNativeTypeScale(viewHandle));
}

@Override
Expand Down Expand Up @@ -141,9 +141,8 @@ public ColumnVector(long nativePointer) {
assert nativePointer != 0;
offHeap = new OffHeapState(nativePointer);
MemoryCleaner.register(this, offHeap);
this.type = offHeap.getNativeType();
this.rows = offHeap.getNativeRowCount();

this.type = offHeap.getNativeType();
this.refCount = 0;
incRefCountInternal(true);
}
Expand Down Expand Up @@ -196,7 +195,7 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
childHandles[i] = nestedColumnVectors.get(i).getViewHandle();
}
offHeap = new OffHeapState(type, (int) rows, nullCount, dataBuffer, validityBuffer, offsetBuffer,
toClose, childHandles);
toClose, childHandles);
MemoryCleaner.register(this, offHeap);
this.rows = rows;
this.nullCount = nullCount;
Expand All @@ -221,6 +220,7 @@ private ColumnVector(long viewAddress, DeviceMemoryBuffer contiguousBuffer) {
this.rows = offHeap.getNativeRowCount();
// TODO we may want to ask for the null count anyways...
this.nullCount = Optional.empty();

this.refCount = 0;
incRefCountInternal(true);
}
Expand Down Expand Up @@ -1339,12 +1339,12 @@ public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) {

static long binaryOp(ColumnVector lhs, ColumnVector rhs, BinaryOp op, DType outputType) {
return binaryOpVV(lhs.getNativeView(), rhs.getNativeView(),
op.nativeId, outputType.nativeId);
op.nativeId, outputType.typeId.getNativeId(), outputType.getScale());
}

static long binaryOp(ColumnVector lhs, Scalar rhs, BinaryOp op, DType outputType) {
return binaryOpVS(lhs.getNativeView(), rhs.getScalarHandle(),
op.nativeId, outputType.nativeId);
op.nativeId, outputType.typeId.getNativeId(), outputType.getScale());
}

/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1565,7 +1565,7 @@ public Scalar reduce(Aggregation aggregation) {
public Scalar reduce(Aggregation aggregation, DType outType) {
long nativeId = aggregation.createNativeInstance();
try {
return new Scalar(outType, reduce(getNativeView(), nativeId, outType.nativeId));
return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale()));
} finally {
Aggregation.close(nativeId);
}
Expand Down Expand Up @@ -1698,7 +1698,7 @@ public ColumnVector castTo(DType type) {
// Optimization
return incRefCount();
}
return new ColumnVector(castTo(getNativeView(), type.nativeId));
return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale()));
}

/**
Expand Down Expand Up @@ -1981,8 +1981,9 @@ public ColumnVector asTimestamp(DType timestampType, String format) {
"is required for .to_timestamp() operation";
assert format != null : "Format string may not be NULL";
assert timestampType.isTimestamp() : "unsupported conversion to non-timestamp DType";
// Only nativeID is passed in the below function as timestamp type does not have `scale`.
return new ColumnVector(stringTimestampToTimestamp(getNativeView(),
timestampType.nativeId, format));
timestampType.typeId.getNativeId(), format));
}

/**
Expand All @@ -1999,7 +2000,7 @@ public ColumnVector asTimestamp(DType timestampType, String format) {
* @return A new vector allocated on the GPU.
*/
public ColumnVector asStrings() {
switch(type) {
switch(type.typeId) {
case TIMESTAMP_SECONDS:
return asStrings("%Y-%m-%d %H:%M:%S");
case TIMESTAMP_DAYS:
Expand Down Expand Up @@ -2878,15 +2879,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long pad(long nativeHandle, int width, int side, String fillChar);

private static native long binaryOpVS(long lhs, long rhs, int op, int dtype);
private static native long binaryOpVS(long lhs, long rhs, int op, int dtype, int scale);

private static native long binaryOpVV(long lhs, long rhs, int op, int dtype);
private static native long binaryOpVV(long lhs, long rhs, int op, int dtype, int scale);

private static native long byteCount(long viewHandle) throws CudfException;

private static native long extractListElement(long nativeView, int index);

private static native long castTo(long nativeHandle, int type);
private static native long castTo(long nativeHandle, int type, int scale);

private static native long byteListCast(long nativeHandle, boolean config);

Expand Down Expand Up @@ -2944,7 +2945,7 @@ private static native long rollingWindow(

private static native long ifElseSS(long predVec, long trueScalar, long falseScalar) throws CudfException;

private static native long reduce(long viewHandle, long aggregation, int dtype) throws CudfException;
private static native long reduce(long viewHandle, long aggregation, int dtype, int scale) throws CudfException;

private static native long isNullNative(long viewHandle);

Expand Down Expand Up @@ -3029,6 +3030,8 @@ private static native long bitwiseMergeAndSetValidity(long baseHandle, long[] vi

private static native int getNativeTypeId(long viewHandle) throws CudfException;

private static native int getNativeTypeScale(long viewHandle) throws CudfException;

private static native int getNativeRowCount(long viewHandle) throws CudfException;

private static native int getNativeNullCount(long viewHandle) throws CudfException;
Expand All @@ -3044,9 +3047,10 @@ private static native long bitwiseMergeAndSetValidity(long baseHandle, long[] vi
private static native long getNativeValidityAddress(long viewHandle) throws CudfException;
private static native long getNativeValidityLength(long viewHandle) throws CudfException;

private static native long makeCudfColumnView(int type, long data, long dataSize, long offsets,
private static native long makeCudfColumnView(int type, int scale, long data, long dataSize, long offsets,
long valid, int nullCount, int size, long[] childHandle);


private static native long getChildCvPointer(long viewHandle, int childIndex) throws CudfException;

private static native int getNativeNumChildren(long viewHandle) throws CudfException;
Expand Down Expand Up @@ -3076,7 +3080,7 @@ private static native long makeCudfColumnView(int type, long data, long dataSize
*/
private static native long getNativeColumnView(long cudfColumnHandle) throws CudfException;

private static native long makeEmptyCudfColumn(int type);
private static native long makeEmptyCudfColumn(int type, int scale);

private static DeviceMemoryBufferView getDataBuffer(long viewHandle) {
long address = getNativeDataAddress(viewHandle);
Expand Down Expand Up @@ -3219,13 +3223,13 @@ public OffHeapState(DType type, int rows, Optional<Long> nullCount,
toClose.addAll(buffers);
}
if (rows == 0) {
this.columnHandle = makeEmptyCudfColumn(type.nativeId);
this.columnHandle = makeEmptyCudfColumn(type.typeId.getNativeId(), type.getScale());
} else {
long cd = data == null ? 0 : data.address;
long cdSize = data == null ? 0 : data.length;
long od = offsets == null ? 0 : offsets.address;
long vd = valid == null ? 0 : valid.address;
this.viewHandle = makeCudfColumnView(type.nativeId, cd, cdSize, od, vd, nc, rows, childColumnViewHandles) ;
this.viewHandle = makeCudfColumnView(type.typeId.getNativeId(), type.getScale(), cd, cdSize, od, vd, nc, rows, childColumnViewHandles) ;
}
}

Expand Down Expand Up @@ -3270,7 +3274,11 @@ private void setNativeNullCount(int nullCount) throws CudfException {
}

public DType getNativeType() {
return DType.fromNative(getNativeTypeId(getViewHandle()));
return DType.fromNative(getNativeTypeId(getViewHandle()), getNativeTypeScale(getViewHandle()));
}

public int getNativeScale() {
return getNativeTypeScale(getViewHandle());
}

public BaseDeviceMemoryBuffer getData() {
Expand Down Expand Up @@ -3499,8 +3507,8 @@ private long getViewHandle() {
long offsetAddr = offsets == null ? 0 : offsets.address;
long validAddr = valid == null ? 0 : valid.address;
int nc = nullCount.orElse(OffHeapState.UNKNOWN_NULL_COUNT).intValue();
return makeCudfColumnView(dataType.nativeId, dataAddr, dataLen, offsetAddr, validAddr, nc,
(int)rows, childrenColViews);
return makeCudfColumnView(dataType.typeId.getNativeId(), dataType.getScale() , dataAddr, dataLen,
offsetAddr, validAddr, nc, (int)rows, childrenColViews);
}

private List<DeviceMemoryBuffer> getBuffersToClose() {
Expand Down Expand Up @@ -3534,7 +3542,7 @@ private static NestedColumnVector createNestedColumnVector(DType type, long rows
DeviceMemoryBuffer valid = null;
DeviceMemoryBuffer offsets = null;
if (dataBuffer != null) {
long dataLen = rows * type.sizeInBytes;
long dataLen = rows * type.getSizeInBytes();
if (type == DType.STRING) {
// This needs a different type
dataLen = getEndStringOffset(rows, rows - 1, offsetBuffer);
Expand Down
Loading

0 comments on commit 4b4d962

Please sign in to comment.