diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index e0cc96263b3..f36896a3c96 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1394,9 +1394,13 @@ public ColumnView replaceChildrenWithViews(int[] indices, List newChildren = new ArrayList<>(getNumChildren()); IntStream.range(0, getNumChildren()).forEach(i -> { ColumnView view = map.remove(i); + ColumnView child = getChildColumnView(i); if (view == null) { - newChildren.add(getChildColumnView(i)); + newChildren.add(child); } else { + if (child.getRowCount() != view.getRowCount()) { + throw new IllegalArgumentException("Child row count doesn't match the old child"); + } newChildren.add(view); } }); @@ -1431,7 +1435,7 @@ public ColumnView replaceChildrenWithViews(int[] indices, */ public ColumnView replaceListChild(ColumnView child) { assert(type == DType.LIST); - return replaceChildrenWithViews(new int[]{1}, new ColumnView[]{child}); + return replaceChildrenWithViews(new int[]{0}, new ColumnView[]{child}); } /** diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0675ece4863..d224543e574 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4005,7 +4005,7 @@ void testReplaceLeafNodeInList() { @Test void testReplaceLeafNodeInListWithIllegal() { - assertThrows(IllegalArgumentException.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); @@ -4023,6 +4023,7 @@ void testReplaceLeafNodeInListWithIllegal() { ColumnView replacedView = created.replaceListChild(newChild)) { } }); + assertTrue(e.getMessage().contains("Child row count doesn't match the old child")); } @Test @@ -4049,7 +4050,7 @@ void testReplaceColumnInStruct() { @Test void testReplaceIllegalIndexColumnInStruct() { - assertThrows(IllegalArgumentException.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.fromInts(1, 4); ColumnVector child2 = ColumnVector.fromInts(2, 5); ColumnVector child3 = ColumnVector.fromInts(3, 6); @@ -4059,11 +4060,12 @@ void testReplaceIllegalIndexColumnInStruct() { new ColumnVector[]{replaceWith})) { } }); + assertTrue(e.getMessage().contains("One or more invalid child indices passed to be replaced")); } @Test void testReplaceSameIndexColumnInStruct() { - assertThrows(IllegalArgumentException.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.fromInts(1, 4); ColumnVector child2 = ColumnVector.fromInts(2, 5); ColumnVector child3 = ColumnVector.fromInts(3, 6); @@ -4073,5 +4075,6 @@ void testReplaceSameIndexColumnInStruct() { new ColumnVector[]{replaceWith, replaceWith})) { } }); + assertTrue(e.getMessage().contains("Duplicate mapping found for replacing child index")); } }