From 5a9313316ad2ddf1218060e50d6972b5fb719e52 Mon Sep 17 00:00:00 2001 From: akarnokd Date: Fri, 21 Aug 2015 14:05:44 +0200 Subject: [PATCH] Scan backpressure and first emission fix --- .../rx/internal/operators/OperatorScan.java | 340 ++++++++++++++---- .../internal/operators/OperatorScanTest.java | 69 +++- 2 files changed, 321 insertions(+), 88 deletions(-) diff --git a/src/main/java/rx/internal/operators/OperatorScan.java b/src/main/java/rx/internal/operators/OperatorScan.java index 788842100d..1cbdb53d54 100644 --- a/src/main/java/rx/internal/operators/OperatorScan.java +++ b/src/main/java/rx/internal/operators/OperatorScan.java @@ -15,15 +15,14 @@ */ package rx.internal.operators; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Queue; +import rx.*; import rx.Observable.Operator; -import rx.Producer; -import rx.Subscriber; -import rx.exceptions.Exceptions; -import rx.exceptions.OnErrorThrowable; -import rx.functions.Func0; -import rx.functions.Func2; +import rx.exceptions.*; +import rx.functions.*; +import rx.internal.util.atomic.SpscLinkedAtomicQueue; +import rx.internal.util.unsafe.*; /** * Returns an Observable that applies a function to the first item emitted by a source Observable, then feeds @@ -87,87 +86,290 @@ public OperatorScan(final Func2 accumulator) { @Override public Subscriber call(final Subscriber child) { - return new Subscriber(child) { - private final R initialValue = initialValueFactory.call(); + final R initialValue = initialValueFactory.call(); + + if (initialValue == NO_INITIAL_VALUE) { + return new Subscriber(child) { + boolean once; + R value; + @SuppressWarnings("unchecked") + @Override + public void onNext(T t) { + R v; + if (!once) { + once = true; + v = (R)t; + } else { + v = value; + try { + v = accumulator.call(v, t); + } catch (Throwable e) { + Exceptions.throwIfFatal(e); + child.onError(OnErrorThrowable.addValueAsLastCause(e, t)); + return; + } + } + value = v; + child.onNext(v); + } + @Override + public void onError(Throwable e) { + child.onError(e); + } + @Override + public void onCompleted() { + child.onCompleted(); + } + }; + } + + final InitialProducer ip = new InitialProducer(initialValue, child); + + Subscriber parent = new Subscriber() { private R value = initialValue; - boolean initialized = false; - @SuppressWarnings("unchecked") @Override public void onNext(T currentValue) { - emitInitialValueIfNeeded(child); - - if (this.value == NO_INITIAL_VALUE) { - // if there is NO_INITIAL_VALUE then we know it is type T for both so cast T to R - this.value = (R) currentValue; - } else { - try { - this.value = accumulator.call(this.value, currentValue); - } catch (Throwable e) { - Exceptions.throwIfFatal(e); - child.onError(OnErrorThrowable.addValueAsLastCause(e, currentValue)); - return; - } + R v = value; + try { + v = accumulator.call(v, currentValue); + } catch (Throwable e) { + Exceptions.throwIfFatal(e); + onError(OnErrorThrowable.addValueAsLastCause(e, currentValue)); + return; } - child.onNext(this.value); + value = v; + ip.onNext(v); } @Override public void onError(Throwable e) { - child.onError(e); + ip.onError(e); } @Override public void onCompleted() { - emitInitialValueIfNeeded(child); - child.onCompleted(); - } - - private void emitInitialValueIfNeeded(final Subscriber child) { - if (!initialized) { - initialized = true; - // we emit first time through if we have an initial value - if (initialValue != NO_INITIAL_VALUE) { - child.onNext(initialValue); - } - } + ip.onCompleted(); } - /** - * We want to adjust the requested value by subtracting 1 if we have an initial value - */ @Override public void setProducer(final Producer producer) { - child.setProducer(new Producer() { - - final AtomicBoolean once = new AtomicBoolean(); - - final AtomicBoolean excessive = new AtomicBoolean(); - - @Override - public void request(long n) { - if (once.compareAndSet(false, true)) { - if (initialValue == NO_INITIAL_VALUE || n == Long.MAX_VALUE) { - producer.request(n); - } else if (n == 1) { - excessive.set(true); - producer.request(1); // request at least 1 - } else { - // n != Long.MAX_VALUE && n != 1 - producer.request(n - 1); - } - } else { - // pass-thru after first time - if (n > 1 // avoid to request 0 - && excessive.compareAndSet(true, false) && n != Long.MAX_VALUE) { - producer.request(n - 1); - } else { - producer.request(n); - } + ip.setProducer(producer); + } + }; + + child.add(parent); + child.setProducer(ip); + return parent; + } + + static final class InitialProducer implements Producer, Observer { + final Subscriber child; + final Queue queue; + + boolean emitting; + /** Missed a terminal event. */ + boolean missed; + /** Missed a request. */ + long missedRequested; + /** Missed a producer. */ + Producer missedProducer; + /** The current requested amount. */ + long requested; + /** The current producer. */ + Producer producer; + + volatile boolean done; + Throwable error; + + public InitialProducer(R initialValue, Subscriber child) { + this.child = child; + Queue q; + // TODO switch to the linked-array based queue once available + if (UnsafeAccess.isUnsafeAvailable()) { + q = new SpscLinkedQueue(); // new SpscUnboundedArrayQueue(8); + } else { + q = new SpscLinkedAtomicQueue(); // new SpscUnboundedAtomicArrayQueue(8); + } + this.queue = q; + q.offer(initialValue); + } + + @Override + public void request(long n) { + if (n < 0L) { + throw new IllegalArgumentException("n >= required but it was " + n); + } else + if (n != 0L) { + synchronized (this) { + if (emitting) { + long mr = missedRequested; + long mu = mr + n; + if (mu < 0L) { + mu = Long.MAX_VALUE; } + missedRequested = mu; + return; } - }); + emitting = true; + } + + long r = requested; + long u = r + n; + if (u < 0L) { + u = Long.MAX_VALUE; + } + requested = u; + + Producer p = producer; + if (p != null) { + p.request(n); + } + + emitLoop(); } - }; + } + + @Override + public void onNext(R t) { + queue.offer(NotificationLite.instance().next(t)); + emit(); + } + + boolean checkTerminated(boolean d, boolean empty, Subscriber child) { + if (child.isUnsubscribed()) { + return true; + } + if (d) { + Throwable err = error; + if (err != null) { + child.onError(err); + return true; + } else + if (empty) { + child.onCompleted(); + return true; + } + } + return false; + } + + @Override + public void onError(Throwable e) { + error = e; + done = true; + emit(); + } + + @Override + public void onCompleted() { + done = true; + emit(); + } + + public void setProducer(Producer p) { + if (p == null) { + throw new NullPointerException(); + } + synchronized (this) { + if (emitting) { + missedProducer = p; + return; + } + emitting = true; + } + producer = p; + long r = requested; + if (r != 0L) { + p.request(r); + } + emitLoop(); + } + + void emit() { + synchronized (this) { + if (emitting) { + missed = true; + return; + } + emitting = true; + } + emitLoop(); + } + + void emitLoop() { + final Subscriber child = this.child; + final Queue queue = this.queue; + final NotificationLite nl = NotificationLite.instance(); + long r = requested; + for (;;) { + boolean max = r == Long.MAX_VALUE; + boolean d = done; + boolean empty = queue.isEmpty(); + if (checkTerminated(d, empty, child)) { + return; + } + while (r != 0L) { + d = done; + Object o = queue.poll(); + empty = o == null; + if (checkTerminated(d, empty, child)) { + return; + } + if (empty) { + break; + } + R v = nl.getValue(o); + try { + child.onNext(v); + } catch (Throwable e) { + Exceptions.throwIfFatal(e); + child.onError(OnErrorThrowable.addValueAsLastCause(e, v)); + return; + } + if (!max) { + r--; + } + } + if (!max) { + requested = r; + } + + Producer p; + long mr; + synchronized (this) { + p = missedProducer; + mr = missedRequested; + if (!missed && p == null && mr == 0L) { + emitting = false; + return; + } + missed = false; + missedProducer = null; + missedRequested = 0L; + } + + if (mr != 0L && !max) { + long u = r + mr; + if (u < 0L) { + u = Long.MAX_VALUE; + } + requested = u; + r = u; + } + + if (p != null) { + producer = p; + if (r != 0L) { + p.request(r); + } + } else { + p = producer; + if (p != null && mr != 0L) { + p.request(mr); + } + } + } + } } } diff --git a/src/test/java/rx/internal/operators/OperatorScanTest.java b/src/test/java/rx/internal/operators/OperatorScanTest.java index e05d4d9bb1..ac7772753f 100644 --- a/src/test/java/rx/internal/operators/OperatorScanTest.java +++ b/src/test/java/rx/internal/operators/OperatorScanTest.java @@ -15,33 +15,23 @@ */ package rx.internal.operators; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyString; +import static org.junit.Assert.*; +import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; +import java.util.*; +import java.util.concurrent.atomic.*; -import org.junit.Before; -import org.junit.Test; +import org.junit.*; import org.mockito.MockitoAnnotations; +import rx.*; import rx.Observable; import rx.Observer; -import rx.Producer; -import rx.Subscriber; -import rx.functions.Action2; -import rx.functions.Func0; -import rx.functions.Func1; -import rx.functions.Func2; +import rx.functions.*; +import rx.observables.AbstractOnSubscribe; import rx.observers.TestSubscriber; +import rx.subjects.PublishSubject; public class OperatorScanTest { @@ -360,4 +350,45 @@ public void onNext(Integer integer) { verify(producer.get(), never()).request(0); verify(producer.get(), times(2)).request(1); } + + @Test + public void testInitialValueEmittedNoProducer() { + PublishSubject source = PublishSubject.create(); + + TestSubscriber ts = TestSubscriber.create(); + + source.scan(0, new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }).subscribe(ts); + + ts.assertNoErrors(); + ts.assertNotCompleted(); + ts.assertValue(0); + } + + @Test + public void testInitialValueEmittedWithProducer() { + Observable source = new AbstractOnSubscribe() { + @Override + protected void next(rx.observables.AbstractOnSubscribe.SubscriptionState state) { + state.stop(); + } + }.toObservable(); + + TestSubscriber ts = TestSubscriber.create(); + + source.scan(0, new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }).subscribe(ts); + + ts.assertNoErrors(); + ts.assertNotCompleted(); + ts.assertValue(0); + } }