diff --git a/rxjava-core/src/main/java/rx/subjects/PublishSubject.java b/rxjava-core/src/main/java/rx/subjects/PublishSubject.java index 6c375017c0..10e27a10d7 100644 --- a/rxjava-core/src/main/java/rx/subjects/PublishSubject.java +++ b/rxjava-core/src/main/java/rx/subjects/PublishSubject.java @@ -15,18 +15,22 @@ */ package rx.subjects; +import static org.junit.Assert.*; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import junit.framework.Assert; import org.junit.Test; +import org.mockito.InOrder; import org.mockito.Mockito; import rx.Notification; @@ -34,6 +38,7 @@ import rx.Observer; import rx.Subscription; import rx.operators.AtomicObservableSubscription; +import rx.subscriptions.Subscriptions; import rx.util.functions.Action1; import rx.util.functions.Func0; import rx.util.functions.Func1; @@ -62,10 +67,15 @@ public class PublishSubject extends Subject { public static PublishSubject create() { final ConcurrentHashMap> observers = new ConcurrentHashMap>(); - + final AtomicReference> terminalState = new AtomicReference>(); + Func1, Subscription> onSubscribe = new Func1, Subscription>() { @Override public Subscription call(Observer observer) { + // first check if terminal state exist + Subscription s = checkTerminalState(observer); + if(s != null) return s; + final AtomicObservableSubscription subscription = new AtomicObservableSubscription(); subscription.wrap(new Subscription() { @@ -78,41 +88,96 @@ public void unsubscribe() { // on subscribe add it to the map of outbound observers to notify observers.put(subscription, observer); + + // check terminal state again + s = checkTerminalState(observer); + if(s != null) return s; + + /** + * NOTE: There is a race condition here. + * + * 1) terminal state gets set in onError or onCompleted + * 2) observers.put adds a new observer + * 3) checkTerminalState emits onError/onCompleted + * 4) onError or onCompleted also emits onError/onCompleted since it was adds to observers + * + * Thus the terminal state could end up being sent twice. + * + * I'm going to leave this for now as AtomicObserver will protect against this + * and I'd rather not add blocking synchronization in here unless the above race condition + * truly is an issue. + */ + return subscription; } + + private Subscription checkTerminalState(Observer observer) { + Notification n = terminalState.get(); + if (n != null) { + // we are terminated to immediately emit and don't continue with subscription + if (n.isOnCompleted()) { + observer.onCompleted(); + } else { + observer.onError(n.getException()); + } + return Subscriptions.empty(); + } else { + return null; + } + } }; - return new PublishSubject(onSubscribe, observers); + return new PublishSubject(onSubscribe, observers, terminalState); } private final ConcurrentHashMap> observers; + private final AtomicReference> terminalState; - protected PublishSubject(Func1, Subscription> onSubscribe, ConcurrentHashMap> observers) { + protected PublishSubject(Func1, Subscription> onSubscribe, ConcurrentHashMap> observers, AtomicReference> terminalState) { super(onSubscribe); this.observers = observers; + this.terminalState = terminalState; } @Override public void onCompleted() { - for (Observer observer : observers.values()) { + terminalState.set(new Notification()); + for (Observer observer : snapshotOfValues()) { observer.onCompleted(); } + observers.clear(); } @Override public void onError(Exception e) { - for (Observer observer : observers.values()) { + terminalState.set(new Notification(e)); + for (Observer observer : snapshotOfValues()) { observer.onError(e); } + observers.clear(); } @Override public void onNext(T args) { - for (Observer observer : observers.values()) { + for (Observer observer : snapshotOfValues()) { observer.onNext(args); } } + /** + * Current snapshot of 'values()' so that concurrent modifications aren't included. + * + * This makes it behave deterministically in a single-threaded execution when nesting subscribes. + * + * In multi-threaded execution it will cause new subscriptions to wait until the following onNext instead + * of possibly being included in the current onNext iteration. + * + * @return List> + */ + private Collection> snapshotOfValues() { + return new ArrayList>(observers.values()); + } + public static class UnitTest { @Test public void test() { @@ -307,6 +372,75 @@ private void assertObservedUntilTwo(Observer aObserver) verify(aObserver, Mockito.never()).onCompleted(); } + /** + * Test that subscribing after onError/onCompleted immediately terminates instead of causing it to hang. + * + * Nothing is mentioned in Rx Guidelines for what to do in this case so I'm doing what seems to make sense + * which is: + * + * - cache terminal state (onError/onCompleted) + * - any subsequent subscriptions will immediately receive the terminal state rather than start a new subscription + * + */ + @Test + public void testUnsubscribeAfterOnCompleted() { + PublishSubject subject = PublishSubject.create(); + + @SuppressWarnings("unchecked") + Observer anObserver = mock(Observer.class); + subject.subscribe(anObserver); + + subject.onNext("one"); + subject.onNext("two"); + subject.onCompleted(); + + InOrder inOrder = inOrder(anObserver); + inOrder.verify(anObserver, times(1)).onNext("one"); + inOrder.verify(anObserver, times(1)).onNext("two"); + inOrder.verify(anObserver, times(1)).onCompleted(); + inOrder.verify(anObserver, Mockito.never()).onError(any(Exception.class)); + + @SuppressWarnings("unchecked") + Observer anotherObserver = mock(Observer.class); + subject.subscribe(anotherObserver); + + inOrder = inOrder(anotherObserver); + inOrder.verify(anotherObserver, Mockito.never()).onNext("one"); + inOrder.verify(anotherObserver, Mockito.never()).onNext("two"); + inOrder.verify(anotherObserver, times(1)).onCompleted(); + inOrder.verify(anotherObserver, Mockito.never()).onError(any(Exception.class)); + } + + @Test + public void testUnsubscribeAfterOnError() { + PublishSubject subject = PublishSubject.create(); + RuntimeException exception = new RuntimeException("failure"); + + @SuppressWarnings("unchecked") + Observer anObserver = mock(Observer.class); + subject.subscribe(anObserver); + + subject.onNext("one"); + subject.onNext("two"); + subject.onError(exception); + + InOrder inOrder = inOrder(anObserver); + inOrder.verify(anObserver, times(1)).onNext("one"); + inOrder.verify(anObserver, times(1)).onNext("two"); + inOrder.verify(anObserver, times(1)).onError(exception); + inOrder.verify(anObserver, Mockito.never()).onCompleted(); + + @SuppressWarnings("unchecked") + Observer anotherObserver = mock(Observer.class); + subject.subscribe(anotherObserver); + + inOrder = inOrder(anotherObserver); + inOrder.verify(anotherObserver, Mockito.never()).onNext("one"); + inOrder.verify(anotherObserver, Mockito.never()).onNext("two"); + inOrder.verify(anotherObserver, times(1)).onError(exception); + inOrder.verify(anotherObserver, Mockito.never()).onCompleted(); + } + @Test public void testUnsubscribe() { @@ -340,5 +474,58 @@ public void call(PublishSubject DefaultSubject) } }); } + + @Test + public void testNestedSubscribe() { + final PublishSubject s = PublishSubject.create(); + + final AtomicInteger countParent = new AtomicInteger(); + final AtomicInteger countChildren = new AtomicInteger(); + final AtomicInteger countTotal = new AtomicInteger(); + + final ArrayList list = new ArrayList(); + + s.mapMany(new Func1>() { + + @Override + public Observable call(final Integer v) { + countParent.incrementAndGet(); + + // then subscribe to subject again (it will not receive the previous value) + return s.map(new Func1() { + + @Override + public String call(Integer v2) { + countChildren.incrementAndGet(); + return "Parent: " + v + " Child: " + v2; + } + + }); + } + + }).subscribe(new Action1() { + + @Override + public void call(String v) { + countTotal.incrementAndGet(); + list.add(v); + } + + }); + + + for(int i=0; i<10; i++) { + s.onNext(i); + } + s.onCompleted(); + + // System.out.println("countParent: " + countParent.get()); + // System.out.println("countChildren: " + countChildren.get()); + // System.out.println("countTotal: " + countTotal.get()); + + // 9+8+7+6+5+4+3+2+1+0 == 45 + assertEquals(45, list.size()); + } + } }