diff --git a/src/main/java/org/apache/commons/collections4/iterators/IteratorChain.java b/src/main/java/org/apache/commons/collections4/iterators/IteratorChain.java index 01bae508b1..c4f3b34ed8 100644 --- a/src/main/java/org/apache/commons/collections4/iterators/IteratorChain.java +++ b/src/main/java/org/apache/commons/collections4/iterators/IteratorChain.java @@ -16,10 +16,13 @@ */ package org.apache.commons.collections4.iterators; +import org.apache.commons.collections4.IteratorUtils; + import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.Objects; +import java.util.Optional; import java.util.Queue; /** @@ -49,7 +52,7 @@ * @since 2.1 */ public class IteratorChain implements Iterator { - + public static long hasNextCalledCount = 0; /** The chain of iterators */ private final Queue> iteratorChain = new LinkedList<>(); @@ -151,7 +154,21 @@ public IteratorChain(final Collection> iteratorChain) { */ public void addIterator(final Iterator iterator) { checkLocked(); - iteratorChain.add(Objects.requireNonNull(iterator, "iterator")); + Objects.requireNonNull(iterator, "iterator"); + if (iterator instanceof UnmodifiableIterator) { + Optional> nestedIteratorChain = ((UnmodifiableIterator)iterator).getIteratorChain(); + if (nestedIteratorChain.isPresent()) { + for (Iterator nestedIterator : nestedIteratorChain.get().iteratorChain) { + iteratorChain.add(IteratorUtils.unmodifiableIterator(nestedIterator)); + } + } else { + iteratorChain.add(iterator); + } + } else if (iterator instanceof IteratorChain) { + iteratorChain.addAll(((IteratorChain)iterator).iteratorChain); + } else { + iteratorChain.add(iterator); + } } /** @@ -222,6 +239,7 @@ protected void updateCurrentIterator() { */ @Override public boolean hasNext() { + hasNextCalledCount++; lockChain(); updateCurrentIterator(); lastUsedIterator = currentIterator; diff --git a/src/main/java/org/apache/commons/collections4/iterators/UnmodifiableIterator.java b/src/main/java/org/apache/commons/collections4/iterators/UnmodifiableIterator.java index 5d86207d5d..de548efd5a 100644 --- a/src/main/java/org/apache/commons/collections4/iterators/UnmodifiableIterator.java +++ b/src/main/java/org/apache/commons/collections4/iterators/UnmodifiableIterator.java @@ -18,6 +18,7 @@ import java.util.Iterator; import java.util.Objects; +import java.util.Optional; import org.apache.commons.collections4.Unmodifiable; @@ -35,6 +36,16 @@ public final class UnmodifiableIterator implements Iterator, Unmodifiable /** The iterator being decorated */ private final Iterator iterator; + /** + * Obtain an Optional holding the nested IteratorChain, if it exists. + * @return Optional holding the iterator if it is an IteratorChain. + */ + Optional> getIteratorChain() { + return Optional.ofNullable(iterator instanceof IteratorChain + ? (IteratorChain)iterator + : null); + } + /** * Decorates the specified iterator such that it cannot be modified. *

diff --git a/src/test/java/org/apache/commons/collections4/iterators/IteratorChainTest.java b/src/test/java/org/apache/commons/collections4/iterators/IteratorChainTest.java index 64e290f116..1d7f49d4ff 100644 --- a/src/test/java/org/apache/commons/collections4/iterators/IteratorChainTest.java +++ b/src/test/java/org/apache/commons/collections4/iterators/IteratorChainTest.java @@ -35,12 +35,16 @@ public class IteratorChainTest extends AbstractIteratorTest { protected String[] testArray = { - "One", "Two", "Three", "Four", "Five", "Six" + "One", "Two", "Three", "Four", "Five", "Six" + }; + protected String[] testArray1234 = { + "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight" }; protected List list1 = null; protected List list2 = null; protected List list3 = null; + protected List list4 = null; public IteratorChainTest() { super(IteratorChainTest.class.getSimpleName()); @@ -57,6 +61,9 @@ public void setUp() { list3 = new ArrayList<>(); list3.add("Five"); list3.add("Six"); + list4 = new ArrayList<>(); + list4.add("Seven"); + list4.add("Eight"); } @Override @@ -170,4 +177,60 @@ public void testEmptyChain() { ); } + @Test + public void testChainOfChains() { + final Iterator iteratorChain1 = new IteratorChain<>(list1.iterator(), list2.iterator()); + final Iterator iteratorChain2 = new IteratorChain<>(list3.iterator(), list4.iterator()); + final Iterator iteratorChainOfChains = new IteratorChain<>(iteratorChain1, iteratorChain2); + + for (final String testValue : testArray1234) { + final Object iterValue = iteratorChainOfChains.next(); + + assertEquals( "Iteration value is correct", testValue, iterValue ); + } + + assertFalse("Iterator should now be empty", iteratorChainOfChains.hasNext()); + + try { + iteratorChainOfChains.next(); + } catch (final Exception e) { + assertEquals("NoSuchElementException must be thrown", e.getClass(), NoSuchElementException.class); + } + } + + @Test + public void testChainOfUnmodifiableChains() { + final Iterator iteratorChain1 = new IteratorChain<>(list1.iterator(), list2.iterator()); + final Iterator unmodifiableChain1 = IteratorUtils.unmodifiableIterator(iteratorChain1); + final Iterator iteratorChain2 = new IteratorChain<>(list3.iterator(), list4.iterator()); + final Iterator unmodifiableChain2 = IteratorUtils.unmodifiableIterator(iteratorChain2); + final Iterator iteratorChainOfChains = new IteratorChain<>(unmodifiableChain1, unmodifiableChain2); + + for (final String testValue : testArray1234) { + final Object iterValue = iteratorChainOfChains.next(); + + assertEquals( "Iteration value is correct", testValue, iterValue ); + } + + assertFalse("Iterator should now be empty", iteratorChainOfChains.hasNext()); + + try { + iteratorChainOfChains.next(); + } catch (final Exception e) { + assertEquals("NoSuchElementException must be thrown", e.getClass(), NoSuchElementException.class); + } + + } + + @Test + public void testChainOfUnmodifiableChainsRetainsUnmodifiableBehaviourOfNestedIterator() { + final Iterator iteratorChain1 = new IteratorChain<>(list1.iterator(), list2.iterator()); + final Iterator unmodifiableChain1 = IteratorUtils.unmodifiableIterator(iteratorChain1); + final Iterator iteratorChain2 = new IteratorChain<>(list3.iterator(), list4.iterator()); + final Iterator unmodifiableChain2 = IteratorUtils.unmodifiableIterator(iteratorChain2); + final Iterator iteratorChainOfChains = new IteratorChain<>(unmodifiableChain1, unmodifiableChain2); + + iteratorChainOfChains.next(); + assertThrows(UnsupportedOperationException.class, iteratorChainOfChains::remove, + "Calling remove must fail when nested iterator is unmodifiable"); } }