diff --git a/Core/src/main/java/org/tribuo/impl/ArrayExample.java b/Core/src/main/java/org/tribuo/impl/ArrayExample.java index 376940aa3..e5fdbfdb8 100644 --- a/Core/src/main/java/org/tribuo/impl/ArrayExample.java +++ b/Core/src/main/java/org/tribuo/impl/ArrayExample.java @@ -34,6 +34,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.PriorityQueue; import java.util.Set; @@ -570,6 +571,9 @@ public boolean hasNext() { @Override public Feature next() { + if (!hasNext()) { + throw new NoSuchElementException("Iterator exhausted at position " + pos); + } Feature f = new Feature(featureNames[pos],featureValues[pos]); pos++; return f; diff --git a/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java b/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java index db8b9e0d5..eca52463e 100644 --- a/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java +++ b/Core/src/main/java/org/tribuo/impl/BinaryFeaturesExample.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.PriorityQueue; import java.util.Set; @@ -502,6 +503,9 @@ public boolean hasNext() { @Override public Feature next() { + if (!hasNext()) { + throw new NoSuchElementException("Iterator exhausted at position " + pos); + } Feature f = new Feature(featureNames[pos], 1.0); pos++; return f; diff --git a/Core/src/test/java/org/tribuo/ExampleTest.java b/Core/src/test/java/org/tribuo/ExampleTest.java index 81fa9b487..27cd00f89 100644 --- a/Core/src/test/java/org/tribuo/ExampleTest.java +++ b/Core/src/test/java/org/tribuo/ExampleTest.java @@ -30,6 +30,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -261,6 +262,57 @@ public void invalidListExampleTest() { // This example should be valid assertTrue(test.validateExample()); } + + @Test + public void exampleIterators() { + MockOutput output = new MockOutput("UNK"); + List features = new ArrayList<>(); + features.add(new Feature("A",1.0)); + features.add(new Feature("C",1.0)); + features.add(new Feature("B",1.0)); + + ArrayExample array = new ArrayExample<>(output,features); + assertEquals(3,array.size()); + + Iterator arrayItr = array.iterator(); + assertTrue(arrayItr.hasNext()); + assertEquals(features.get(0),arrayItr.next()); + assertTrue(arrayItr.hasNext()); + // Features are lexicographically sorted inside examples + assertEquals(features.get(2),arrayItr.next()); + assertTrue(arrayItr.hasNext()); + assertEquals(features.get(1),arrayItr.next()); + assertFalse(arrayItr.hasNext()); + assertThrows(NoSuchElementException.class, arrayItr::next); + + ListExample list = new ListExample<>(output,features); + assertEquals(3,list.size()); + + Iterator listItr = list.iterator(); + assertTrue(listItr.hasNext()); + assertEquals(features.get(0),listItr.next()); + assertTrue(listItr.hasNext()); + // Features are lexicographically sorted inside examples + assertEquals(features.get(2),listItr.next()); + assertTrue(listItr.hasNext()); + assertEquals(features.get(1),listItr.next()); + assertFalse(listItr.hasNext()); + assertThrows(NoSuchElementException.class, listItr::next); + + BinaryFeaturesExample binary = new BinaryFeaturesExample<>(output,features); + assertEquals(3,binary.size()); + + Iterator binaryItr = binary.iterator(); + assertTrue(binaryItr.hasNext()); + assertEquals(features.get(0),binaryItr.next()); + assertTrue(binaryItr.hasNext()); + // Features are lexicographically sorted inside examples + assertEquals(features.get(2),binaryItr.next()); + assertTrue(binaryItr.hasNext()); + assertEquals(features.get(1),binaryItr.next()); + assertFalse(binaryItr.hasNext()); + assertThrows(NoSuchElementException.class, binaryItr::next); + } @Test public void testBinaryFeaturesExample() {