From 1236f07d7d802758f9bcc7dcf87a2df166d7e32c Mon Sep 17 00:00:00 2001 From: Dave Moten Date: Fri, 15 May 2015 10:56:49 +1000 Subject: [PATCH] add request overflow checks and prevent Long.MAX_VALUE requests being decremented in OperatorGroupBy, added unit test that failed with previous code --- .../internal/operators/OperatorGroupBy.java | 16 +++++-- .../operators/OperatorGroupByTest.java | 47 +++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/main/java/rx/internal/operators/OperatorGroupBy.java b/src/main/java/rx/internal/operators/OperatorGroupBy.java index ef548066f3..93631569df 100644 --- a/src/main/java/rx/internal/operators/OperatorGroupBy.java +++ b/src/main/java/rx/internal/operators/OperatorGroupBy.java @@ -194,7 +194,7 @@ public void onError(Throwable e) { // If we already have items queued when a request comes in we vend those and decrement the outstanding request count void requestFromGroupedObservable(long n, GroupState group) { - group.requested.getAndAdd(n); + BackpressureUtils.getAndAddRequest(group.requested, n); if (group.count.getAndIncrement() == 0) { pollQueue(group); } @@ -330,13 +330,19 @@ private void cleanupGroup(Object key) { private void emitItem(GroupState groupState, Object item) { Queue q = groupState.buffer; AtomicLong keyRequested = groupState.requested; + //don't need to check for requested being Long.MAX_VALUE because this + //field is capped at MAX_QUEUE_SIZE REQUESTED.decrementAndGet(this); // short circuit buffering if (keyRequested != null && keyRequested.get() > 0 && (q == null || q.isEmpty())) { @SuppressWarnings("unchecked") Observer obs = (Observer)groupState.getObserver(); nl.accept(obs, item); - keyRequested.decrementAndGet(); + if (keyRequested.get() != Long.MAX_VALUE) { + // best endeavours check (no CAS loop here) because we mainly care about + // the initial request being Long.MAX_VALUE and that value being conserved. + keyRequested.decrementAndGet(); + } } else { q.add(item); BUFFERED_COUNT.incrementAndGet(this); @@ -381,7 +387,11 @@ private void drainIfPossible(GroupState groupState) { @SuppressWarnings("unchecked") Observer obs = (Observer)groupState.getObserver(); nl.accept(obs, t); - groupState.requested.decrementAndGet(); + if (groupState.requested.get()!=Long.MAX_VALUE) { + // best endeavours check (no CAS loop here) because we mainly care about + // the initial request being Long.MAX_VALUE and that value being conserved. + groupState.requested.decrementAndGet(); + } BUFFERED_COUNT.decrementAndGet(this); // if we have used up all the events we requested from upstream then figure out what to ask for this time based on the empty space in the buffer diff --git a/src/test/java/rx/internal/operators/OperatorGroupByTest.java b/src/test/java/rx/internal/operators/OperatorGroupByTest.java index 42023508d3..b14b7ad373 100644 --- a/src/test/java/rx/internal/operators/OperatorGroupByTest.java +++ b/src/test/java/rx/internal/operators/OperatorGroupByTest.java @@ -34,6 +34,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -1454,4 +1455,50 @@ public Integer call(Integer i) { assertEquals(Arrays.asList(e), inner1.getOnErrorEvents()); assertEquals(Arrays.asList(e), inner2.getOnErrorEvents()); } + + @Test + public void testRequestOverflow() { + final AtomicBoolean completed = new AtomicBoolean(false); + Observable + .just(1, 2, 3) + // group into one group + .groupBy(new Func1() { + @Override + public Integer call(Integer t) { + return 1; + } + }) + // flatten + .concatMap(new Func1, Observable>() { + @Override + public Observable call(GroupedObservable g) { + return g; + } + }) + .subscribe(new Subscriber() { + + @Override + public void onStart() { + request(2); + } + + @Override + public void onCompleted() { + completed.set(true); + + } + + @Override + public void onError(Throwable e) { + + } + + @Override + public void onNext(Integer t) { + System.out.println(t); + //provoke possible request overflow + request(Long.MAX_VALUE-1); + }}); + assertTrue(completed.get()); + } }