Skip to content

Commit

Permalink
Support building decimal columns with Table.TestBuilder (#6770)
Browse files Browse the repository at this point in the history
This PR is about to support building decimal columns with Table.TestBuilder, which is widely used in automatic tests of spark-rapids plugin.
  • Loading branch information
sperlingxx authored Nov 18, 2020
1 parent 01b8b5c commit 546b9c3
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
- PR #6761 Add Java/JNI bindings for round
- PR #6786 Add nested type support to ColumnVector#getDeviceMemorySize
- PR #6780 Move `cudf::cast` tests to separate test file
- PR #6770 Support building decimal columns with Table.TestBuilder

## Bug Fixes

Expand Down
75 changes: 75 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import ai.rapids.cudf.HostColumnVector.StructType;

import java.io.File;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -2510,6 +2512,58 @@ public TestBuilder timestampSecondsColumn(Long... values) {
return this;
}

public TestBuilder decimal32Column(int scale, Integer... unscaledValues) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale)));
typeErasedData.add(unscaledValues);
return this;
}

public TestBuilder decimal32Column(int scale, RoundingMode mode, Double... values) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale)));
BigDecimal[] data = Arrays.stream(values).map((x) -> {
if (x == null) return null;
return BigDecimal.valueOf(x).setScale(-scale, mode);
}).toArray(BigDecimal[]::new);
typeErasedData.add(data);
return this;
}

public TestBuilder decimal32Column(int scale, RoundingMode mode, String... values) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, scale)));
BigDecimal[] data = Arrays.stream(values).map((x) -> {
if (x == null) return null;
return new BigDecimal(x).setScale(-scale, mode);
}).toArray(BigDecimal[]::new);
typeErasedData.add(data);
return this;
}

public TestBuilder decimal64Column(int scale, Long... unscaledValues) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale)));
typeErasedData.add(unscaledValues);
return this;
}

public TestBuilder decimal64Column(int scale, RoundingMode mode, Double... values) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale)));
BigDecimal[] data = Arrays.stream(values).map((x) -> {
if (x == null) return null;
return BigDecimal.valueOf(x).setScale(-scale, mode);
}).toArray(BigDecimal[]::new);
typeErasedData.add(data);
return this;
}

public TestBuilder decimal64Column(int scale, RoundingMode mode, String... values) {
types.add(new BasicType(true, DType.create(DType.DTypeEnum.DECIMAL64, scale)));
BigDecimal[] data = Arrays.stream(values).map((x) -> {
if (x == null) return null;
return new BigDecimal(x).setScale(-scale, mode);
}).toArray(BigDecimal[]::new);
typeErasedData.add(data);
return this;
}

private static ColumnVector from(DType type, Object dataArray) {
ColumnVector ret = null;
switch (type.typeId) {
Expand Down Expand Up @@ -2552,6 +2606,27 @@ private static ColumnVector from(DType type, Object dataArray) {
case FLOAT64:
ret = ColumnVector.fromBoxedDoubles((Double[]) dataArray);
break;
case DECIMAL32:
case DECIMAL64:
int scale = type.getScale();
if (dataArray instanceof Integer[]) {
BigDecimal[] data = Arrays.stream(((Integer[]) dataArray))
.map((i) -> i == null ? null : BigDecimal.valueOf(i, -scale))
.toArray(BigDecimal[]::new);
ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data));
} else if (dataArray instanceof Long[]) {
BigDecimal[] data = Arrays.stream(((Long[]) dataArray))
.map((i) -> i == null ? null : BigDecimal.valueOf(i, -scale))
.toArray(BigDecimal[]::new);
ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data));
} else if (dataArray instanceof BigDecimal[]) {
BigDecimal[] data = (BigDecimal[]) dataArray;
ret = ColumnVector.build(type, data.length, (b) -> b.appendBoxed(data));
} else {
throw new IllegalArgumentException(
"Data array of invalid type(" + dataArray.getClass() + ") to build decimal column");
}
break;
default:
throw new IllegalArgumentException(type + " is not supported yet");
}
Expand Down
Loading

0 comments on commit 546b9c3

Please sign in to comment.