diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java
new file mode 100644
index 000000000000..3e003d35424c
--- /dev/null
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kinesis;
+
+import com.amazonaws.AmazonClientException;
+
+/** Thrown when the Kinesis client was throttled due to rate limits. */
+class KinesisClientThrottledException extends TransientKinesisException {
+
+ public KinesisClientThrottledException(String s, AmazonClientException e) {
+ super(s, e);
+ }
+}
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java
index efa673e9ceb7..8506c12caa10 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java
@@ -40,6 +40,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingDeque;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
@@ -95,7 +96,6 @@
*
*
{@code
* public class MyCustomKinesisClientProvider implements AWSClientsProvider {
- * {@literal @}Override
* public AmazonKinesis getKinesisClient() {
* // set up your client here
* }
@@ -149,12 +149,10 @@
* this.customWatermarkPolicy = new WatermarkPolicyFactory.CustomWatermarkPolicy(WatermarkParameters.create());
* }
*
- * @Override
* public Instant getWatermark() {
* return customWatermarkPolicy.getWatermark();
* }
*
- * @Override
* public void update(KinesisRecord record) {
* customWatermarkPolicy.update(record);
* }
@@ -162,7 +160,6 @@
*
* // custom factory
* class MyCustomPolicyFactory implements WatermarkPolicyFactory {
- * @Override
* public WatermarkPolicy createWatermarkPolicy() {
* return new MyCustomPolicy();
* }
@@ -174,6 +171,69 @@
* .withCustomWatermarkPolicy(new MyCustomPolicyFactory())
* }
*
+ * By default Kinesis IO will poll the Kinesis getRecords() API as fast as possible which may
+ * lead to excessive read throttling. To limit the rate of getRecords() calls you can set a rate
+ * limit policy. For example, the default fixed delay policy will limit the rate to one API call per
+ * second per shard:
+ *
+ *
{@code
+ * p.apply(KinesisIO.read()
+ * .withStreamName("streamName")
+ * .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ * .withFixedDelayRateLimitPolicy())
+ * }
+ *
+ * You can also use a fixed delay policy with a specified delay interval, for example:
+ *
+ *
{@code
+ * p.apply(KinesisIO.read()
+ * .withStreamName("streamName")
+ * .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ * .withFixedDelayRateLimitPolicy(Duration.millis(500))
+ * }
+ *
+ * If you need to change the polling interval of a Kinesis pipeline at runtime, for example to
+ * compensate for adding and removing additional consumers to the stream, then you can supply the
+ * delay interval as a function so that you can obtain the current delay interval from some external
+ * source:
+ *
+ *
{@code
+ * p.apply(KinesisIO.read()
+ * .withStreamName("streamName")
+ * .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ * .withDynamicDelayRateLimitPolicy(() -> Duration.millis())
+ * }
+ *
+ * Finally, you can create a custom rate limit policy that responds to successful read calls
+ * and/or read throttling exceptions with your own rate-limiting logic:
+ *
+ *
{@code
+ * // custom policy
+ * public class MyCustomPolicy implements RateLimitPolicy {
+ *
+ * public void onSuccess(List records) throws InterruptedException {
+ * // handle successful getRecords() call
+ * }
+ *
+ * public void onThrottle(KinesisClientThrottledException e) throws InterruptedException {
+ * // handle Kinesis read throttling exception
+ * }
+ * }
+ *
+ * // custom factory
+ * class MyCustomPolicyFactory implements RateLimitPolicyFactory {
+ *
+ * public RateLimitPolicy getRateLimitPolicy() {
+ * return new MyCustomPolicy();
+ * }
+ * }
+ *
+ * p.apply(KinesisIO.read()
+ * .withStreamName("streamName")
+ * .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ * .withCustomRateLimitPolicy(new MyCustomPolicyFactory())
+ * }
+ *
* Writing to Kinesis
*
* Example usage:
@@ -240,6 +300,7 @@ public static Read read() {
.setMaxNumRecords(Long.MAX_VALUE)
.setUpToDateThreshold(Duration.ZERO)
.setWatermarkPolicyFactory(WatermarkPolicyFactory.withArrivalTimePolicy())
+ .setRateLimitPolicyFactory(RateLimitPolicyFactory.withoutLimiter())
.setMaxCapacityPerShard(ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD)
.build();
}
@@ -274,6 +335,8 @@ public abstract static class Read extends PTransform delay) {
+ checkArgument(delay != null, "delay cannot be null");
+ return toBuilder().setRateLimitPolicyFactory(RateLimitPolicyFactory.withDelay(delay)).build();
+ }
+
+ /**
+ * Specifies the {@code RateLimitPolicyFactory} for a custom rate limiter.
+ *
+ * @param rateLimitPolicyFactory Custom rate limit policy factory.
+ */
+ public Read withCustomRateLimitPolicy(RateLimitPolicyFactory rateLimitPolicyFactory) {
+ checkArgument(rateLimitPolicyFactory != null, "rateLimitPolicyFactory cannot be null");
+ return toBuilder().setRateLimitPolicyFactory(rateLimitPolicyFactory).build();
+ }
+
/** Specifies the maximum number of messages per one shard. */
public Read withMaxCapacityPerShard(Integer maxCapacity) {
checkArgument(maxCapacity > 0, "maxCapacity must be positive, but was: %s", maxCapacity);
@@ -442,6 +546,7 @@ public PCollection expand(PBegin input) {
getInitialPosition(),
getUpToDateThreshold(),
getWatermarkPolicyFactory(),
+ getRateLimitPolicyFactory(),
getRequestRecordsLimit(),
getMaxCapacityPerShard()));
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java
index 9e869f5c11fa..2c4222a112bb 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java
@@ -40,6 +40,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader {
private final KinesisSource source;
private final CheckpointGenerator initialCheckpointGenerator;
private final WatermarkPolicyFactory watermarkPolicyFactory;
+ private final RateLimitPolicyFactory rateLimitPolicyFactory;
private final Duration upToDateThreshold;
private final Duration backlogBytesCheckThreshold;
private CustomOptional currentRecord = CustomOptional.absent();
@@ -53,6 +54,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader {
CheckpointGenerator initialCheckpointGenerator,
KinesisSource source,
WatermarkPolicyFactory watermarkPolicyFactory,
+ RateLimitPolicyFactory rateLimitPolicyFactory,
Duration upToDateThreshold,
Integer maxCapacityPerShard) {
this(
@@ -60,6 +62,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader {
initialCheckpointGenerator,
source,
watermarkPolicyFactory,
+ rateLimitPolicyFactory,
upToDateThreshold,
Duration.standardSeconds(30),
maxCapacityPerShard);
@@ -70,6 +73,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader {
CheckpointGenerator initialCheckpointGenerator,
KinesisSource source,
WatermarkPolicyFactory watermarkPolicyFactory,
+ RateLimitPolicyFactory rateLimitPolicyFactory,
Duration upToDateThreshold,
Duration backlogBytesCheckThreshold,
Integer maxCapacityPerShard) {
@@ -77,6 +81,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader {
this.initialCheckpointGenerator =
checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator");
this.watermarkPolicyFactory = watermarkPolicyFactory;
+ this.rateLimitPolicyFactory = rateLimitPolicyFactory;
this.source = source;
this.upToDateThreshold = upToDateThreshold;
this.backlogBytesCheckThreshold = backlogBytesCheckThreshold;
@@ -185,6 +190,7 @@ ShardReadersPool createShardReadersPool() throws TransientKinesisException {
kinesis,
initialCheckpointGenerator.generate(kinesis),
watermarkPolicyFactory,
+ rateLimitPolicyFactory,
maxCapacityPerShard);
}
}
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
index a9d05f3deef2..98f7a88af005 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java
@@ -38,6 +38,7 @@ class KinesisSource extends UnboundedSource split(int desiredNumSplits, PipelineOptions options)
streamName,
upToDateThreshold,
watermarkPolicyFactory,
+ rateLimitPolicyFactory,
limit,
maxCapacityPerShard));
}
@@ -126,6 +132,7 @@ public UnboundedReader createReader(
checkpointGenerator,
this,
watermarkPolicyFactory,
+ rateLimitPolicyFactory,
upToDateThreshold,
maxCapacityPerShard);
}
@@ -139,6 +146,8 @@ public Coder getCheckpointMarkCoder() {
public void validate() {
checkNotNull(awsClientsProvider);
checkNotNull(initialCheckpointGenerator);
+ checkNotNull(watermarkPolicyFactory);
+ checkNotNull(rateLimitPolicyFactory);
}
@Override
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java
new file mode 100644
index 000000000000..8ee1e81558f7
--- /dev/null
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kinesis;
+
+import java.util.List;
+
+public interface RateLimitPolicy {
+
+ /**
+ * Called after Kinesis records are successfully retrieved.
+ *
+ * @param records The list of retrieved records.
+ */
+ default void onSuccess(List records) throws InterruptedException {}
+
+ /**
+ * Called after the Kinesis client is throttled.
+ *
+ * @param e The {@code KinesisClientThrottledException} thrown by the client.
+ */
+ default void onThrottle(KinesisClientThrottledException e) throws InterruptedException {}
+}
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java
new file mode 100644
index 000000000000..54bc78835c30
--- /dev/null
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kinesis;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.function.Supplier;
+import org.joda.time.Duration;
+
+/**
+ * Implement this interface to create a {@code RateLimitPolicy}. Used to create a rate limiter for
+ * each shard.
+ */
+public interface RateLimitPolicyFactory extends Serializable {
+
+ RateLimitPolicy getRateLimitPolicy();
+
+ static RateLimitPolicyFactory withoutLimiter() {
+ return () -> new RateLimitPolicy() {};
+ }
+
+ static RateLimitPolicyFactory withFixedDelay() {
+ return DelayIntervalRateLimiter::new;
+ }
+
+ static RateLimitPolicyFactory withFixedDelay(Duration delay) {
+ return () -> new DelayIntervalRateLimiter(() -> delay);
+ }
+
+ static RateLimitPolicyFactory withDelay(Supplier delay) {
+ return () -> new DelayIntervalRateLimiter(delay);
+ }
+
+ class DelayIntervalRateLimiter implements RateLimitPolicy {
+
+ private static final Supplier DEFAULT_DELAY = () -> Duration.standardSeconds(1);
+
+ private final Supplier delay;
+
+ public DelayIntervalRateLimiter() {
+ this(DEFAULT_DELAY);
+ }
+
+ public DelayIntervalRateLimiter(Supplier delay) {
+ this.delay = delay;
+ }
+
+ @Override
+ public void onSuccess(List records) throws InterruptedException {
+ Thread.sleep(delay.get().getMillis());
+ }
+ }
+}
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java
index 195101ce9c0d..7348c6eb3073 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java
@@ -75,6 +75,7 @@ class ShardReadersPool {
private final SimplifiedKinesisClient kinesis;
private final WatermarkPolicyFactory watermarkPolicyFactory;
+ private final RateLimitPolicyFactory rateLimitPolicyFactory;
private final KinesisReaderCheckpoint initialCheckpoint;
private final int queueCapacityPerShard;
private final AtomicBoolean poolOpened = new AtomicBoolean(true);
@@ -83,10 +84,12 @@ class ShardReadersPool {
SimplifiedKinesisClient kinesis,
KinesisReaderCheckpoint initialCheckpoint,
WatermarkPolicyFactory watermarkPolicyFactory,
+ RateLimitPolicyFactory rateLimitPolicyFactory,
int queueCapacityPerShard) {
this.kinesis = kinesis;
this.initialCheckpoint = initialCheckpoint;
this.watermarkPolicyFactory = watermarkPolicyFactory;
+ this.rateLimitPolicyFactory = rateLimitPolicyFactory;
this.queueCapacityPerShard = queueCapacityPerShard;
this.executorService = Executors.newCachedThreadPool();
this.numberOfRecordsInAQueueByShard = new ConcurrentHashMap<>();
@@ -115,11 +118,12 @@ void start() throws TransientKinesisException {
void startReadingShards(Iterable shardRecordsIterators) {
for (final ShardRecordsIterator recordsIterator : shardRecordsIterators) {
numberOfRecordsInAQueueByShard.put(recordsIterator.getShardId(), new AtomicInteger());
- executorService.submit(() -> readLoop(recordsIterator));
+ executorService.submit(
+ () -> readLoop(recordsIterator, rateLimitPolicyFactory.getRateLimitPolicy()));
}
}
- private void readLoop(ShardRecordsIterator shardRecordsIterator) {
+ private void readLoop(ShardRecordsIterator shardRecordsIterator, RateLimitPolicy rateLimiter) {
while (poolOpened.get()) {
try {
List kinesisRecords;
@@ -143,10 +147,20 @@ private void readLoop(ShardRecordsIterator shardRecordsIterator) {
recordsQueue.put(kinesisRecord);
numberOfRecordsInAQueueByShard.get(kinesisRecord.getShardId()).incrementAndGet();
}
+ rateLimiter.onSuccess(kinesisRecords);
+ } catch (KinesisClientThrottledException e) {
+ try {
+ rateLimiter.onThrottle(e);
+ } catch (InterruptedException ex) {
+ LOG.warn("Thread was interrupted, finishing the read loop", ex);
+ Thread.currentThread().interrupt();
+ break;
+ }
} catch (TransientKinesisException e) {
LOG.warn("Transient exception occurred.", e);
} catch (InterruptedException e) {
LOG.warn("Thread was interrupted, finishing the read loop", e);
+ Thread.currentThread().interrupt();
break;
} catch (Throwable e) {
LOG.error("Unexpected exception occurred", e);
diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java
index 7cf4cd5627a1..38a3f5a29e83 100644
--- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java
+++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java
@@ -211,7 +211,7 @@ private T wrapExceptions(Callable callable) throws TransientKinesisExcept
} catch (ExpiredIteratorException e) {
throw e;
} catch (LimitExceededException | ProvisionedThroughputExceededException e) {
- throw new TransientKinesisException(
+ throw new KinesisClientThrottledException(
"Too many requests to Kinesis. Wait some time and retry.", e);
} catch (AmazonServiceException e) {
if (e.getErrorType() == AmazonServiceException.ErrorType.Service) {
diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java
index 060af47e269e..812a598a154e 100644
--- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java
+++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java
@@ -69,6 +69,7 @@ private KinesisReader createReader(Duration backlogBytesCheckThreshold) {
generator,
kinesisSource,
WatermarkPolicyFactory.withArrivalTimePolicy(),
+ RateLimitPolicyFactory.withoutLimiter(),
Duration.ZERO,
backlogBytesCheckThreshold,
ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD) {
diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java
new file mode 100644
index 000000000000..dccfc8cb92cb
--- /dev/null
+++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kinesis;
+
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
+import static org.powermock.api.mockito.PowerMockito.verifyStatic;
+
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Duration;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(RateLimitPolicyFactory.class)
+public class RateLimitPolicyFactoryTest {
+
+ @Test
+ public void defaultPolicyShouldDoNothing() throws Exception {
+ PowerMockito.spy(Thread.class);
+ PowerMockito.doNothing().when(Thread.class);
+ Thread.sleep(anyLong());
+ RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy();
+ rateLimitPolicy.onSuccess(ImmutableList.of());
+ verifyStatic(Thread.class, never());
+ Thread.sleep(anyLong());
+ }
+
+ @Test
+ public void shouldDelayDefaultInterval() throws Exception {
+ PowerMockito.spy(Thread.class);
+ PowerMockito.doNothing().when(Thread.class);
+ Thread.sleep(anyLong());
+ RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withFixedDelay().getRateLimitPolicy();
+ rateLimitPolicy.onSuccess(ImmutableList.of());
+ verifyStatic(Thread.class);
+ Thread.sleep(eq(1000L));
+ }
+
+ @Test
+ public void shouldDelayFixedInterval() throws Exception {
+ PowerMockito.spy(Thread.class);
+ PowerMockito.doNothing().when(Thread.class);
+ Thread.sleep(anyLong());
+ RateLimitPolicy rateLimitPolicy =
+ RateLimitPolicyFactory.withFixedDelay(Duration.millis(500)).getRateLimitPolicy();
+ rateLimitPolicy.onSuccess(ImmutableList.of());
+ verifyStatic(Thread.class);
+ Thread.sleep(eq(500L));
+ }
+
+ @Test
+ public void shouldDelayDynamicInterval() throws Exception {
+ PowerMockito.spy(Thread.class);
+ PowerMockito.doNothing().when(Thread.class);
+ Thread.sleep(anyLong());
+ AtomicLong delay = new AtomicLong(0L);
+ RateLimitPolicy rateLimitPolicy =
+ RateLimitPolicyFactory.withDelay(() -> Duration.millis(delay.getAndUpdate(d -> d ^ 1)))
+ .getRateLimitPolicy();
+ rateLimitPolicy.onSuccess(ImmutableList.of());
+ verifyStatic(Thread.class);
+ Thread.sleep(eq(0L));
+ Thread.sleep(eq(1L));
+ Thread.sleep(eq(0L));
+ Thread.sleep(eq(1L));
+ }
+}
diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java
index 0d9e9a37ed92..59ab92ea2224 100644
--- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java
+++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java
@@ -21,6 +21,8 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
@@ -54,7 +56,9 @@ public class ShardReadersPoolTest {
@Mock private ShardCheckpoint firstCheckpoint, secondCheckpoint;
@Mock private SimplifiedKinesisClient kinesis;
@Mock private KinesisRecord a, b, c, d;
- @Mock private WatermarkPolicyFactory factory;
+ @Mock private WatermarkPolicyFactory watermarkPolicyFactory;
+ @Mock private RateLimitPolicyFactory rateLimitPolicyFactory;
+ @Mock private RateLimitPolicy customRateLimitPolicy;
private KinesisReaderCheckpoint checkpoint;
private ShardReadersPool shardReadersPool;
@@ -75,12 +79,18 @@ public void setUp() throws TransientKinesisException {
when(thirdIterator.getShardId()).thenReturn("shard3");
when(fourthIterator.getShardId()).thenReturn("shard4");
- WatermarkPolicy policy = WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy();
+ WatermarkPolicy watermarkPolicy =
+ WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy();
+ RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy();
checkpoint = new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint));
- shardReadersPool = Mockito.spy(new ShardReadersPool(kinesis, checkpoint, factory, 100));
+ shardReadersPool =
+ Mockito.spy(
+ new ShardReadersPool(
+ kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 100));
- when(factory.createWatermarkPolicy()).thenReturn(policy);
+ when(watermarkPolicyFactory.createWatermarkPolicy()).thenReturn(watermarkPolicy);
+ when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(rateLimitPolicy);
doReturn(firstIterator).when(shardReadersPool).createShardIterator(kinesis, firstCheckpoint);
doReturn(secondIterator).when(shardReadersPool).createShardIterator(kinesis, secondCheckpoint);
@@ -178,8 +188,10 @@ public void shouldInterruptPuttingRecordsToQueueAndStopShortly()
new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint));
WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy();
+ RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter();
ShardReadersPool shardReadersPool =
- new ShardReadersPool(kinesis, checkpoint, watermarkPolicyFactory, 2);
+ new ShardReadersPool(
+ kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 2);
shardReadersPool.start();
Stopwatch stopwatch = Stopwatch.createStarted();
@@ -237,12 +249,14 @@ public void shouldStopReadersPoolAlsoWhenExceptionsOccurDuringStopping() throws
public void shouldReturnAbsentOptionalWhenStartedWithNoIterators() throws Exception {
KinesisReaderCheckpoint checkpoint = new KinesisReaderCheckpoint(Collections.emptyList());
WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy();
+ RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter();
shardReadersPool =
Mockito.spy(
new ShardReadersPool(
kinesis,
checkpoint,
watermarkPolicyFactory,
+ rateLimitPolicyFactory,
ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD));
doReturn(firstIterator)
.when(shardReadersPool)
@@ -286,4 +300,34 @@ public void shouldReturnTheLeastWatermarkOfAllShards() throws TransientKinesisEx
verify(firstIterator, times(2)).getShardWatermark();
verify(secondIterator, times(2)).getShardWatermark();
}
+
+ @Test
+ public void shouldCallRateLimitPolicy()
+ throws TransientKinesisException, KinesisShardClosedException, InterruptedException {
+ KinesisClientThrottledException e = new KinesisClientThrottledException("", null);
+ when(firstIterator.readNextBatch())
+ .thenThrow(e)
+ .thenReturn(ImmutableList.of(a, b))
+ .thenReturn(Collections.emptyList());
+ when(secondIterator.readNextBatch())
+ .thenReturn(singletonList(c))
+ .thenReturn(singletonList(d))
+ .thenReturn(Collections.emptyList());
+ when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(customRateLimitPolicy);
+
+ shardReadersPool.start();
+ List fetchedRecords = new ArrayList<>();
+ while (fetchedRecords.size() < 4) {
+ CustomOptional nextRecord = shardReadersPool.nextRecord();
+ if (nextRecord.isPresent()) {
+ fetchedRecords.add(nextRecord.get());
+ }
+ }
+
+ verify(customRateLimitPolicy).onThrottle(same(e));
+ verify(customRateLimitPolicy).onSuccess(eq(ImmutableList.of(a, b)));
+ verify(customRateLimitPolicy).onSuccess(eq(singletonList(c)));
+ verify(customRateLimitPolicy).onSuccess(eq(singletonList(d)));
+ verify(customRateLimitPolicy, atLeastOnce()).onSuccess(eq(Collections.emptyList()));
+ }
}
diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java
index 7858de779aa6..1c1d9e027e10 100644
--- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java
+++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java
@@ -116,13 +116,13 @@ public void shouldHandleExpiredIterationExceptionForGetShardIterator() {
@Test
public void shouldHandleLimitExceededExceptionForGetShardIterator() {
shouldHandleGetShardIteratorError(
- new LimitExceededException(""), TransientKinesisException.class);
+ new LimitExceededException(""), KinesisClientThrottledException.class);
}
@Test
public void shouldHandleProvisionedThroughputExceededExceptionForGetShardIterator() {
shouldHandleGetShardIteratorError(
- new ProvisionedThroughputExceededException(""), TransientKinesisException.class);
+ new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class);
}
@Test
@@ -190,13 +190,14 @@ public void shouldHandleExpiredIterationExceptionForShardListing() {
@Test
public void shouldHandleLimitExceededExceptionForShardListing() {
- shouldHandleShardListingError(new LimitExceededException(""), TransientKinesisException.class);
+ shouldHandleShardListingError(
+ new LimitExceededException(""), KinesisClientThrottledException.class);
}
@Test
public void shouldHandleProvisionedThroughputExceededExceptionForShardListing() {
shouldHandleShardListingError(
- new ProvisionedThroughputExceededException(""), TransientKinesisException.class);
+ new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class);
}
@Test
@@ -281,13 +282,13 @@ public void shouldNotCallCloudWatchWhenSpecifiedPeriodTooShort() throws Exceptio
@Test
public void shouldHandleLimitExceededExceptionForGetBacklogBytes() {
shouldHandleGetBacklogBytesError(
- new LimitExceededException(""), TransientKinesisException.class);
+ new LimitExceededException(""), KinesisClientThrottledException.class);
}
@Test
public void shouldHandleProvisionedThroughputExceededExceptionForGetBacklogBytes() {
shouldHandleGetBacklogBytesError(
- new ProvisionedThroughputExceededException(""), TransientKinesisException.class);
+ new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class);
}
@Test