diff --git a/rxjava-core/src/main/java/rx/operators/OperationTake.java b/rxjava-core/src/main/java/rx/operators/OperationTake.java
index 881071e5b6c..061e930ff04 100644
--- a/rxjava-core/src/main/java/rx/operators/OperationTake.java
+++ b/rxjava-core/src/main/java/rx/operators/OperationTake.java
@@ -15,20 +15,30 @@
*/
package rx.operators;
-import static org.junit.Assert.*;
-import static org.mockito.Matchers.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
+import org.mockito.InOrder;
import rx.Observable;
import rx.Observable.OnSubscribeFunc;
import rx.Observer;
import rx.Subscription;
import rx.subscriptions.Subscriptions;
+import rx.util.functions.Func1;
/**
* Returns an Observable that emits the first num
items emitted by the source
@@ -114,6 +124,7 @@ private class ItemObserver implements Observer {
private final Observer super T> observer;
private final AtomicInteger counter = new AtomicInteger();
+ private volatile boolean hasEmitedError = false;
public ItemObserver(Observer super T> observer) {
this.observer = observer;
@@ -121,6 +132,9 @@ public ItemObserver(Observer super T> observer) {
@Override
public void onCompleted() {
+ if (hasEmitedError) {
+ return;
+ }
if (counter.getAndSet(num) < num) {
observer.onCompleted();
}
@@ -128,6 +142,9 @@ public void onCompleted() {
@Override
public void onError(Throwable e) {
+ if (hasEmitedError) {
+ return;
+ }
if (counter.getAndSet(num) < num) {
observer.onError(e);
}
@@ -135,9 +152,19 @@ public void onError(Throwable e) {
@Override
public void onNext(T args) {
+ if (hasEmitedError) {
+ return;
+ }
final int count = counter.incrementAndGet();
if (count <= num) {
- observer.onNext(args);
+ try {
+ observer.onNext(args);
+ } catch (Throwable ex) {
+ hasEmitedError = true;
+ observer.onError(ex);
+ subscription.unsubscribe();
+ return;
+ }
if (count == num) {
observer.onCompleted();
}
@@ -184,6 +211,47 @@ public void testTake2() {
verify(aObserver, times(1)).onCompleted();
}
+ @Test(expected = IllegalArgumentException.class)
+ public void testTakeWithError() {
+ Observable.from(1, 2, 3).take(1).map(new Func1() {
+ public Integer call(Integer t1) {
+ throw new IllegalArgumentException("some error");
+ }
+ }).toBlockingObservable().single();
+ }
+
+ @Test
+ public void testTakeWithErrorHappeningInOnNext() {
+ Observable w = Observable.from(1, 2, 3).take(2).map(new Func1() {
+ public Integer call(Integer t1) {
+ throw new IllegalArgumentException("some error");
+ }
+ });
+
+ @SuppressWarnings("unchecked")
+ Observer observer = mock(Observer.class);
+ w.subscribe(observer);
+ InOrder inOrder = inOrder(observer);
+ inOrder.verify(observer, times(1)).onError(any(IllegalArgumentException.class));
+ inOrder.verifyNoMoreInteractions();
+ }
+
+ @Test
+ public void testTakeWithErrorHappeningInTheLastOnNext() {
+ Observable w = Observable.from(1, 2, 3).take(1).map(new Func1() {
+ public Integer call(Integer t1) {
+ throw new IllegalArgumentException("some error");
+ }
+ });
+
+ @SuppressWarnings("unchecked")
+ Observer observer = mock(Observer.class);
+ w.subscribe(observer);
+ InOrder inOrder = inOrder(observer);
+ inOrder.verify(observer, times(1)).onError(any(IllegalArgumentException.class));
+ inOrder.verifyNoMoreInteractions();
+ }
+
@Test
public void testTakeDoesntLeakErrors() {
Observable source = Observable.create(new OnSubscribeFunc()