Skip to content

Commit

Permalink
Merge pull request rapidsai#8 from liurenjie1024/ray/shuffle0822
Browse files Browse the repository at this point in the history
Add Kudo && Kudo2 serializer
  • Loading branch information
wjxiz1992 authored Aug 23, 2024
2 parents 3200f01 + 022963b commit 06efeaf
Show file tree
Hide file tree
Showing 36 changed files with 3,930 additions and 159 deletions.
12 changes: 9 additions & 3 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@
<version>3.2.4</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>3.26.3</version>
<scope>test</scope>
</dependency>
</dependencies>

<properties>
Expand Down Expand Up @@ -464,9 +470,9 @@
executable="cmake">
<arg value="${basedir}/src/main/native"/>
<arg line="${cmake.ccache.opts}"/>
<arg value="-DCUDA_STATIC_RUNTIME=${CUDA_STATIC_RUNTIME}" />
<arg value="-DCUDF_USE_PER_THREAD_DEFAULT_STREAM=${CUDF_USE_PER_THREAD_DEFAULT_STREAM}" />
<arg value="-DUSE_GDS=${USE_GDS}" />
<arg value="-DCUDA_STATIC_RUNTIME=${CUDA_STATIC_RUNTIME}"/>
<arg value="-DCUDF_USE_PER_THREAD_DEFAULT_STREAM=${CUDF_USE_PER_THREAD_DEFAULT_STREAM}"/>
<arg value="-DUSE_GDS=${USE_GDS}"/>
<arg value="-DCMAKE_CXX_FLAGS=${cxx.flags}"/>
<arg value="-DCMAKE_EXPORT_COMPILE_COMMANDS=${CMAKE_EXPORT_COMPILE_COMMANDS}"/>
<arg value="-DCUDF_CPP_BUILD_DIR=${CUDF_CPP_BUILD_DIR}"/>
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/BitVectorHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/**
* This class does bit manipulation using byte arithmetic
*/
final class BitVectorHelper {
public final class BitVectorHelper {

/**
* Shifts that to the left by the required bits then appends to this
Expand Down Expand Up @@ -74,7 +74,7 @@ private static void shiftSrcLeftAndWriteToDst(HostMemoryBuffer src, HostMemoryBu
* getValidityLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(14) => 2 bytes
*/
static long getValidityLengthInBytes(long rows) {
public static long getValidityLengthInBytes(long rows) {
return (rows + 7) / 8;
}

Expand Down
8 changes: 8 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,12 @@ public ColumnVector castTo(DType type) {
return super.castTo(type);
}

public static DeviceMemoryBuffer copyBitmask(long startAddress, long beginBit, long endBit) {
long[] ret = copyBitmaskNative(startAddress, beginBit, endBit);
return DeviceMemoryBuffer.fromRmm(ret[0], ret[1], ret[2]);
}


/////////////////////////////////////////////////////////////////////////////
// NATIVE METHODS
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -917,6 +923,8 @@ private static native long stringConcatenationSepCol(long[] columnViews,

static native long makeEmptyCudfColumn(int type, int scale);

private static native long[] copyBitmaskNative(long startAddresses, long beginBit, long endBit);

/////////////////////////////////////////////////////////////////////////////
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////
Expand Down
64 changes: 62 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -4838,7 +4838,7 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS
* creating the device side vector from host side nested vectors. Eventually this can go away or
* be refactored to hold less state like just the handles and the buffers to close.
*/
static class NestedColumnVector {
public static class NestedColumnVector {

private final DeviceMemoryBuffer data;
private final DeviceMemoryBuffer valid;
Expand All @@ -4848,7 +4848,7 @@ static class NestedColumnVector {
private final Optional<Long> nullCount;
List<NestedColumnVector> children;

private NestedColumnVector(DType type, long rows, Optional<Long> nullCount,
public NestedColumnVector(DType type, long rows, Optional<Long> nullCount,
DeviceMemoryBuffer data, DeviceMemoryBuffer valid,
DeviceMemoryBuffer offsets, List<NestedColumnVector> children) {
this.dataType = type;
Expand Down Expand Up @@ -4988,6 +4988,50 @@ List<DeviceMemoryBuffer> getBuffersToClose() {
return buffers;
}

public List<DeviceMemoryBuffer> getSelfBuffers() {
List<DeviceMemoryBuffer> buffers = new ArrayList<>();
if (data != null) {
buffers.add(data);
}
if (valid != null) {
buffers.add(valid);
}
if (offsets != null) {
buffers.add(offsets);
}
return buffers;
}

public List<NestedColumnVector> getChildren() {
return Collections.unmodifiableList(children);
}

public NestedColumnVector withNewChildren(List<NestedColumnVector> children) {
return new NestedColumnVector(dataType, rows, nullCount, data, valid, offsets, children);
}

public ColumnVector toColumnVector() {
List<DeviceMemoryBuffer> toClose = new ArrayList<>();
long[] childHandles = new long[children.size()];
try {
for (ColumnView.NestedColumnVector ncv : children) {
toClose.addAll(ncv.getBuffersToClose());
}
for (int i = 0; i < children.size(); i++) {
childHandles[i] = children.get(i).createViewHandle();
}
return new ColumnVector(dataType, rows, nullCount, data,
valid, offsets, toClose, childHandles);
} finally {
for (int i = 0; i < childHandles.length; i++) {
if (childHandles[i] != 0) {
ColumnView.deleteColumnView(childHandles[i]);
childHandles[i] = 0;
}
}
}
}

private static long getEndStringOffset(long totalRows, long index, HostMemoryBuffer offsets) {
assert index < totalRows;
return offsets.getInt((index + 1) * 4);
Expand Down Expand Up @@ -5306,4 +5350,20 @@ public ColumnVector toHex() {
assert getType().isIntegral() : "Only integers are supported";
return new ColumnVector(toHex(this.getNativeView()));
}

public void toSchema(String namePrefix, Schema.Builder builder) {
toSchemaInner(0, namePrefix, builder);
}

private int toSchemaInner(int idx, String namePrefix, Schema.Builder builder) {
String name = namePrefix + idx;

Schema.Builder thisBuilder = builder.addColumn(this.getType(), name);
int lastIdx = idx;
for (int i=0; i < this.getNumChildren(); i++) {
lastIdx = this.getChildColumnView(i).toSchemaInner(lastIdx + 1, namePrefix, thisBuilder);
}

return lastIdx;
}
}
8 changes: 8 additions & 0 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,13 @@ public boolean isNull() {
public Object getField(int index) {
return this.dataRecord.get(index);
}

@Override
public String toString() {
return "StructData{" +
"dataRecord=" + dataRecord +
'}';
}
}

public static class StructType extends HostColumnVector.DataType {
Expand Down Expand Up @@ -2214,4 +2221,5 @@ public int getNumChildren() {
return 0;
}
}

}
18 changes: 17 additions & 1 deletion java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public int getNumChildren() {
* @param rowIndex the row number
* @return an object that would need to be casted to appropriate type based on this vector's data type
*/
Object getElement(int rowIndex) {
public Object getElement(int rowIndex) {
if (type.equals(DType.LIST)) {
return getList(rowIndex);
} else if (type.equals(DType.STRUCT)) {
Expand Down Expand Up @@ -662,4 +662,20 @@ public String toString() {
return "(ID: " + id + ")";
}
}

public void toSchema(String namePrefix, Schema.Builder builder) {
toSchemaInner(0, namePrefix, builder);
}

private int toSchemaInner(int idx, String namePrefix, Schema.Builder builder) {
String name = namePrefix + idx;

Schema.Builder thisBuilder = builder.addColumn(this.getType(), name);
int lastIdx = idx;
for (int i=0; i < this.getNumChildren(); i++) {
lastIdx = this.getChildColumnView(i).toSchemaInner(lastIdx + 1, namePrefix, thisBuilder);
}

return lastIdx;
}
}
2 changes: 1 addition & 1 deletion java/src/main/java/ai/rapids/cudf/HostMemoryBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ public final void copyFromHostBuffer(long destOffset, HostMemoryBuffer srcData,
* @param in input stream to copy bytes from
* @param byteLength number of bytes to copy
*/
final void copyFromStream(long destOffset, InputStream in, long byteLength) throws IOException {
public final void copyFromStream(long destOffset, InputStream in, long byteLength) throws IOException {
addressOutOfBoundsCheck(address + destOffset, byteLength, "copy from stream");
byte[] arrayBuffer = new byte[(int) Math.min(1024 * 128, byteLength)];
long left = byteLength;
Expand Down
11 changes: 11 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -5000,4 +5000,15 @@ public Table build() {
}
}
}

public Schema toSchema() {
Schema.Builder builder = Schema.builder();

for (int i=0; i<getNumberOfColumns(); i++) {
ColumnVector cv = getColumn(i);
cv.toSchema("col_" + i + "_", builder);
}

return builder.build();
}
}
24 changes: 24 additions & 0 deletions java/src/main/java/ai/rapids/cudf/schema/SchemaVisitor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package ai.rapids.cudf.schema;

import ai.rapids.cudf.Schema;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* Interface for visiting a schema in post order.
*/
public interface SchemaVisitor<T, R> {
R visitTopSchema(Schema schema, List<T> children);

T visitStruct(Schema structType, List<T> children);

T preVisitList(Schema listType);

T visitList(Schema listType, T preVisitResult, T childResult);

T visit(Schema primitiveType);


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ai.rapids.cudf.schema;

import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.Schema;

import java.util.List;

public interface SchemaWithColumnsVisitor<T, R> {
R visitTopSchema(Schema schema, List<T> children);

T visitStruct(Schema structType, HostColumnVectorCore col, List<T> children);

T preVisitList(Schema listType, HostColumnVectorCore col);
T visitList(Schema listType, HostColumnVectorCore col, T preVisitResult, T childResult);

T visit(Schema primitiveType, HostColumnVectorCore col);
}
77 changes: 77 additions & 0 deletions java/src/main/java/ai/rapids/cudf/schema/Visitors.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package ai.rapids.cudf.schema;

import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.Schema;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class Visitors {
public static <T, R> R visitSchema(Schema schema, SchemaVisitor<T, R> visitor) {
Objects.requireNonNull(schema, "schema cannot be null");
Objects.requireNonNull(visitor, "visitor cannot be null");

List<T> childrenResult = IntStream.range(0, schema.getNumChildren())
.mapToObj(i -> visitSchemaInner(schema.getChild(i), visitor))
.collect(Collectors.toList());

return visitor.visitTopSchema(schema, childrenResult);
}

private static <T, R> T visitSchemaInner(Schema schema, SchemaVisitor<T, R> visitor) {
switch (schema.getType().getTypeId()) {
case STRUCT:
List<T> children = IntStream.range(0, schema.getNumChildren())
.mapToObj(childIdx -> visitSchemaInner(schema.getChild(childIdx), visitor))
.collect(Collectors.toList());
return visitor.visitStruct(schema, children);
case LIST:
T preVisitResult = visitor.preVisitList(schema);
T childResult = visitSchemaInner(schema.getChild(0), visitor);
return visitor.visitList(schema, preVisitResult, childResult);
default:
return visitor.visit(schema);
}
}


/**
* Entry point for visiting a schema with columns.
*/
public static <T, R> R visitSchemaWithColumns(Schema schema, List<HostColumnVector> cols,
SchemaWithColumnsVisitor<T, R> visitor) {
Objects.requireNonNull(schema, "schema cannot be null");
Objects.requireNonNull(cols, "cols cannot be null");
Objects.requireNonNull(visitor, "visitor cannot be null");

if (schema.getNumChildren() != cols.size()) {
throw new IllegalArgumentException("Schema children num: " + schema.getNumChildren() +
" is not same as columns num: " + cols.size());
}

List<T> childrenResult = IntStream.range(0, schema.getNumChildren())
.mapToObj(i -> visitSchema(schema.getChild(i), cols.get(i), visitor))
.collect(Collectors.toList());

return visitor.visitTopSchema(schema, childrenResult);
}

private static <T, R> T visitSchema(Schema schema, HostColumnVectorCore col, SchemaWithColumnsVisitor<T, R> visitor) {
switch (schema.getType().getTypeId()) {
case STRUCT:
List<T> children = IntStream.range(0, schema.getNumChildren())
.mapToObj(childIdx -> visitSchema(schema.getChild(childIdx), col.getChildColumnView(childIdx), visitor))
.collect(Collectors.toList());
return visitor.visitStruct(schema, col, children);
case LIST:
T preVisitResult = visitor.preVisitList(schema, col);
T childResult = visitSchema(schema.getChild(0), col.getChildColumnView(0), visitor);
return visitor.visitList(schema, col, preVisitResult, childResult);
default:
return visitor.visit(schema, col);
}
}
}
Loading

0 comments on commit 06efeaf

Please sign in to comment.