diff --git a/rxjava-core/src/main/java/rx/operators/OperationTake.java b/rxjava-core/src/main/java/rx/operators/OperationTake.java index 881071e5b6c..061e930ff04 100644 --- a/rxjava-core/src/main/java/rx/operators/OperationTake.java +++ b/rxjava-core/src/main/java/rx/operators/OperationTake.java @@ -15,20 +15,30 @@ */ package rx.operators; -import static org.junit.Assert.*; -import static org.mockito.Matchers.*; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; +import org.mockito.InOrder; import rx.Observable; import rx.Observable.OnSubscribeFunc; import rx.Observer; import rx.Subscription; import rx.subscriptions.Subscriptions; +import rx.util.functions.Func1; /** * Returns an Observable that emits the first num items emitted by the source @@ -114,6 +124,7 @@ private class ItemObserver implements Observer { private final Observer observer; private final AtomicInteger counter = new AtomicInteger(); + private volatile boolean hasEmitedError = false; public ItemObserver(Observer observer) { this.observer = observer; @@ -121,6 +132,9 @@ public ItemObserver(Observer observer) { @Override public void onCompleted() { + if (hasEmitedError) { + return; + } if (counter.getAndSet(num) < num) { observer.onCompleted(); } @@ -128,6 +142,9 @@ public void onCompleted() { @Override public void onError(Throwable e) { + if (hasEmitedError) { + return; + } if (counter.getAndSet(num) < num) { observer.onError(e); } @@ -135,9 +152,19 @@ public void onError(Throwable e) { @Override public void onNext(T args) { + if (hasEmitedError) { + return; + } final int count = counter.incrementAndGet(); if (count <= num) { - observer.onNext(args); + try { + observer.onNext(args); + } catch (Throwable ex) { + hasEmitedError = true; + observer.onError(ex); + subscription.unsubscribe(); + return; + } if (count == num) { observer.onCompleted(); } @@ -184,6 +211,47 @@ public void testTake2() { verify(aObserver, times(1)).onCompleted(); } + @Test(expected = IllegalArgumentException.class) + public void testTakeWithError() { + Observable.from(1, 2, 3).take(1).map(new Func1() { + public Integer call(Integer t1) { + throw new IllegalArgumentException("some error"); + } + }).toBlockingObservable().single(); + } + + @Test + public void testTakeWithErrorHappeningInOnNext() { + Observable w = Observable.from(1, 2, 3).take(2).map(new Func1() { + public Integer call(Integer t1) { + throw new IllegalArgumentException("some error"); + } + }); + + @SuppressWarnings("unchecked") + Observer observer = mock(Observer.class); + w.subscribe(observer); + InOrder inOrder = inOrder(observer); + inOrder.verify(observer, times(1)).onError(any(IllegalArgumentException.class)); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void testTakeWithErrorHappeningInTheLastOnNext() { + Observable w = Observable.from(1, 2, 3).take(1).map(new Func1() { + public Integer call(Integer t1) { + throw new IllegalArgumentException("some error"); + } + }); + + @SuppressWarnings("unchecked") + Observer observer = mock(Observer.class); + w.subscribe(observer); + InOrder inOrder = inOrder(observer); + inOrder.verify(observer, times(1)).onError(any(IllegalArgumentException.class)); + inOrder.verifyNoMoreInteractions(); + } + @Test public void testTakeDoesntLeakErrors() { Observable source = Observable.create(new OnSubscribeFunc()