diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java new file mode 100644 index 000000000000..3e003d35424c --- /dev/null +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisClientThrottledException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import com.amazonaws.AmazonClientException; + +/** Thrown when the Kinesis client was throttled due to rate limits. */ +class KinesisClientThrottledException extends TransientKinesisException { + + public KinesisClientThrottledException(String s, AmazonClientException e) { + super(s, e); + } +} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java index efa673e9ceb7..8506c12caa10 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisIO.java @@ -40,6 +40,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.Supplier; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; @@ -95,7 +96,6 @@ * *
{@code
  * public class MyCustomKinesisClientProvider implements AWSClientsProvider {
- *   {@literal @}Override
  *   public AmazonKinesis getKinesisClient() {
  *     // set up your client here
  *   }
@@ -149,12 +149,10 @@
  *       this.customWatermarkPolicy = new WatermarkPolicyFactory.CustomWatermarkPolicy(WatermarkParameters.create());
  *     }
  *
- *     @Override
  *     public Instant getWatermark() {
  *       return customWatermarkPolicy.getWatermark();
  *     }
  *
- *     @Override
  *     public void update(KinesisRecord record) {
  *       customWatermarkPolicy.update(record);
  *     }
@@ -162,7 +160,6 @@
  *
  * // custom factory
  * class MyCustomPolicyFactory implements WatermarkPolicyFactory {
- *     @Override
  *     public WatermarkPolicy createWatermarkPolicy() {
  *       return new MyCustomPolicy();
  *     }
@@ -174,6 +171,69 @@
  *    .withCustomWatermarkPolicy(new MyCustomPolicyFactory())
  * }
* + *

By default Kinesis IO will poll the Kinesis getRecords() API as fast as possible which may + * lead to excessive read throttling. To limit the rate of getRecords() calls you can set a rate + * limit policy. For example, the default fixed delay policy will limit the rate to one API call per + * second per shard: + * + *

{@code
+ * p.apply(KinesisIO.read()
+ *    .withStreamName("streamName")
+ *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ *    .withFixedDelayRateLimitPolicy())
+ * }
+ * + *

You can also use a fixed delay policy with a specified delay interval, for example: + * + *

{@code
+ * p.apply(KinesisIO.read()
+ *    .withStreamName("streamName")
+ *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ *    .withFixedDelayRateLimitPolicy(Duration.millis(500))
+ * }
+ * + *

If you need to change the polling interval of a Kinesis pipeline at runtime, for example to + * compensate for adding and removing additional consumers to the stream, then you can supply the + * delay interval as a function so that you can obtain the current delay interval from some external + * source: + * + *

{@code
+ * p.apply(KinesisIO.read()
+ *    .withStreamName("streamName")
+ *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ *    .withDynamicDelayRateLimitPolicy(() -> Duration.millis())
+ * }
+ * + *

Finally, you can create a custom rate limit policy that responds to successful read calls + * and/or read throttling exceptions with your own rate-limiting logic: + * + *

{@code
+ * // custom policy
+ * public class MyCustomPolicy implements RateLimitPolicy {
+ *
+ *   public void onSuccess(List records) throws InterruptedException {
+ *     // handle successful getRecords() call
+ *   }
+ *
+ *   public void onThrottle(KinesisClientThrottledException e) throws InterruptedException {
+ *     // handle Kinesis read throttling exception
+ *   }
+ * }
+ *
+ * // custom factory
+ * class MyCustomPolicyFactory implements RateLimitPolicyFactory {
+ *
+ *   public RateLimitPolicy getRateLimitPolicy() {
+ *     return new MyCustomPolicy();
+ *   }
+ * }
+ *
+ * p.apply(KinesisIO.read()
+ *    .withStreamName("streamName")
+ *    .withInitialPositionInStream(InitialPositionInStream.LATEST)
+ *    .withCustomRateLimitPolicy(new MyCustomPolicyFactory())
+ * }
+ * *

Writing to Kinesis

* *

Example usage: @@ -240,6 +300,7 @@ public static Read read() { .setMaxNumRecords(Long.MAX_VALUE) .setUpToDateThreshold(Duration.ZERO) .setWatermarkPolicyFactory(WatermarkPolicyFactory.withArrivalTimePolicy()) + .setRateLimitPolicyFactory(RateLimitPolicyFactory.withoutLimiter()) .setMaxCapacityPerShard(ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD) .build(); } @@ -274,6 +335,8 @@ public abstract static class Read extends PTransform delay) { + checkArgument(delay != null, "delay cannot be null"); + return toBuilder().setRateLimitPolicyFactory(RateLimitPolicyFactory.withDelay(delay)).build(); + } + + /** + * Specifies the {@code RateLimitPolicyFactory} for a custom rate limiter. + * + * @param rateLimitPolicyFactory Custom rate limit policy factory. + */ + public Read withCustomRateLimitPolicy(RateLimitPolicyFactory rateLimitPolicyFactory) { + checkArgument(rateLimitPolicyFactory != null, "rateLimitPolicyFactory cannot be null"); + return toBuilder().setRateLimitPolicyFactory(rateLimitPolicyFactory).build(); + } + /** Specifies the maximum number of messages per one shard. */ public Read withMaxCapacityPerShard(Integer maxCapacity) { checkArgument(maxCapacity > 0, "maxCapacity must be positive, but was: %s", maxCapacity); @@ -442,6 +546,7 @@ public PCollection expand(PBegin input) { getInitialPosition(), getUpToDateThreshold(), getWatermarkPolicyFactory(), + getRateLimitPolicyFactory(), getRequestRecordsLimit(), getMaxCapacityPerShard())); diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java index 9e869f5c11fa..2c4222a112bb 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisReader.java @@ -40,6 +40,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader { private final KinesisSource source; private final CheckpointGenerator initialCheckpointGenerator; private final WatermarkPolicyFactory watermarkPolicyFactory; + private final RateLimitPolicyFactory rateLimitPolicyFactory; private final Duration upToDateThreshold; private final Duration backlogBytesCheckThreshold; private CustomOptional currentRecord = CustomOptional.absent(); @@ -53,6 +54,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader { CheckpointGenerator initialCheckpointGenerator, KinesisSource source, WatermarkPolicyFactory watermarkPolicyFactory, + RateLimitPolicyFactory rateLimitPolicyFactory, Duration upToDateThreshold, Integer maxCapacityPerShard) { this( @@ -60,6 +62,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader { initialCheckpointGenerator, source, watermarkPolicyFactory, + rateLimitPolicyFactory, upToDateThreshold, Duration.standardSeconds(30), maxCapacityPerShard); @@ -70,6 +73,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader { CheckpointGenerator initialCheckpointGenerator, KinesisSource source, WatermarkPolicyFactory watermarkPolicyFactory, + RateLimitPolicyFactory rateLimitPolicyFactory, Duration upToDateThreshold, Duration backlogBytesCheckThreshold, Integer maxCapacityPerShard) { @@ -77,6 +81,7 @@ class KinesisReader extends UnboundedSource.UnboundedReader { this.initialCheckpointGenerator = checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator"); this.watermarkPolicyFactory = watermarkPolicyFactory; + this.rateLimitPolicyFactory = rateLimitPolicyFactory; this.source = source; this.upToDateThreshold = upToDateThreshold; this.backlogBytesCheckThreshold = backlogBytesCheckThreshold; @@ -185,6 +190,7 @@ ShardReadersPool createShardReadersPool() throws TransientKinesisException { kinesis, initialCheckpointGenerator.generate(kinesis), watermarkPolicyFactory, + rateLimitPolicyFactory, maxCapacityPerShard); } } diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java index a9d05f3deef2..98f7a88af005 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/KinesisSource.java @@ -38,6 +38,7 @@ class KinesisSource extends UnboundedSource split(int desiredNumSplits, PipelineOptions options) streamName, upToDateThreshold, watermarkPolicyFactory, + rateLimitPolicyFactory, limit, maxCapacityPerShard)); } @@ -126,6 +132,7 @@ public UnboundedReader createReader( checkpointGenerator, this, watermarkPolicyFactory, + rateLimitPolicyFactory, upToDateThreshold, maxCapacityPerShard); } @@ -139,6 +146,8 @@ public Coder getCheckpointMarkCoder() { public void validate() { checkNotNull(awsClientsProvider); checkNotNull(initialCheckpointGenerator); + checkNotNull(watermarkPolicyFactory); + checkNotNull(rateLimitPolicyFactory); } @Override diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java new file mode 100644 index 000000000000..8ee1e81558f7 --- /dev/null +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicy.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import java.util.List; + +public interface RateLimitPolicy { + + /** + * Called after Kinesis records are successfully retrieved. + * + * @param records The list of retrieved records. + */ + default void onSuccess(List records) throws InterruptedException {} + + /** + * Called after the Kinesis client is throttled. + * + * @param e The {@code KinesisClientThrottledException} thrown by the client. + */ + default void onThrottle(KinesisClientThrottledException e) throws InterruptedException {} +} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java new file mode 100644 index 000000000000..54bc78835c30 --- /dev/null +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import java.io.Serializable; +import java.util.List; +import java.util.function.Supplier; +import org.joda.time.Duration; + +/** + * Implement this interface to create a {@code RateLimitPolicy}. Used to create a rate limiter for + * each shard. + */ +public interface RateLimitPolicyFactory extends Serializable { + + RateLimitPolicy getRateLimitPolicy(); + + static RateLimitPolicyFactory withoutLimiter() { + return () -> new RateLimitPolicy() {}; + } + + static RateLimitPolicyFactory withFixedDelay() { + return DelayIntervalRateLimiter::new; + } + + static RateLimitPolicyFactory withFixedDelay(Duration delay) { + return () -> new DelayIntervalRateLimiter(() -> delay); + } + + static RateLimitPolicyFactory withDelay(Supplier delay) { + return () -> new DelayIntervalRateLimiter(delay); + } + + class DelayIntervalRateLimiter implements RateLimitPolicy { + + private static final Supplier DEFAULT_DELAY = () -> Duration.standardSeconds(1); + + private final Supplier delay; + + public DelayIntervalRateLimiter() { + this(DEFAULT_DELAY); + } + + public DelayIntervalRateLimiter(Supplier delay) { + this.delay = delay; + } + + @Override + public void onSuccess(List records) throws InterruptedException { + Thread.sleep(delay.get().getMillis()); + } + } +} diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java index 195101ce9c0d..7348c6eb3073 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/ShardReadersPool.java @@ -75,6 +75,7 @@ class ShardReadersPool { private final SimplifiedKinesisClient kinesis; private final WatermarkPolicyFactory watermarkPolicyFactory; + private final RateLimitPolicyFactory rateLimitPolicyFactory; private final KinesisReaderCheckpoint initialCheckpoint; private final int queueCapacityPerShard; private final AtomicBoolean poolOpened = new AtomicBoolean(true); @@ -83,10 +84,12 @@ class ShardReadersPool { SimplifiedKinesisClient kinesis, KinesisReaderCheckpoint initialCheckpoint, WatermarkPolicyFactory watermarkPolicyFactory, + RateLimitPolicyFactory rateLimitPolicyFactory, int queueCapacityPerShard) { this.kinesis = kinesis; this.initialCheckpoint = initialCheckpoint; this.watermarkPolicyFactory = watermarkPolicyFactory; + this.rateLimitPolicyFactory = rateLimitPolicyFactory; this.queueCapacityPerShard = queueCapacityPerShard; this.executorService = Executors.newCachedThreadPool(); this.numberOfRecordsInAQueueByShard = new ConcurrentHashMap<>(); @@ -115,11 +118,12 @@ void start() throws TransientKinesisException { void startReadingShards(Iterable shardRecordsIterators) { for (final ShardRecordsIterator recordsIterator : shardRecordsIterators) { numberOfRecordsInAQueueByShard.put(recordsIterator.getShardId(), new AtomicInteger()); - executorService.submit(() -> readLoop(recordsIterator)); + executorService.submit( + () -> readLoop(recordsIterator, rateLimitPolicyFactory.getRateLimitPolicy())); } } - private void readLoop(ShardRecordsIterator shardRecordsIterator) { + private void readLoop(ShardRecordsIterator shardRecordsIterator, RateLimitPolicy rateLimiter) { while (poolOpened.get()) { try { List kinesisRecords; @@ -143,10 +147,20 @@ private void readLoop(ShardRecordsIterator shardRecordsIterator) { recordsQueue.put(kinesisRecord); numberOfRecordsInAQueueByShard.get(kinesisRecord.getShardId()).incrementAndGet(); } + rateLimiter.onSuccess(kinesisRecords); + } catch (KinesisClientThrottledException e) { + try { + rateLimiter.onThrottle(e); + } catch (InterruptedException ex) { + LOG.warn("Thread was interrupted, finishing the read loop", ex); + Thread.currentThread().interrupt(); + break; + } } catch (TransientKinesisException e) { LOG.warn("Transient exception occurred.", e); } catch (InterruptedException e) { LOG.warn("Thread was interrupted, finishing the read loop", e); + Thread.currentThread().interrupt(); break; } catch (Throwable e) { LOG.error("Unexpected exception occurred", e); diff --git a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java index 7cf4cd5627a1..38a3f5a29e83 100644 --- a/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java +++ b/sdks/java/io/kinesis/src/main/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClient.java @@ -211,7 +211,7 @@ private T wrapExceptions(Callable callable) throws TransientKinesisExcept } catch (ExpiredIteratorException e) { throw e; } catch (LimitExceededException | ProvisionedThroughputExceededException e) { - throw new TransientKinesisException( + throw new KinesisClientThrottledException( "Too many requests to Kinesis. Wait some time and retry.", e); } catch (AmazonServiceException e) { if (e.getErrorType() == AmazonServiceException.ErrorType.Service) { diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java index 060af47e269e..812a598a154e 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/KinesisReaderTest.java @@ -69,6 +69,7 @@ private KinesisReader createReader(Duration backlogBytesCheckThreshold) { generator, kinesisSource, WatermarkPolicyFactory.withArrivalTimePolicy(), + RateLimitPolicyFactory.withoutLimiter(), Duration.ZERO, backlogBytesCheckThreshold, ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD) { diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java new file mode 100644 index 000000000000..dccfc8cb92cb --- /dev/null +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/RateLimitPolicyFactoryTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kinesis; + +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.powermock.api.mockito.PowerMockito.verifyStatic; + +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(RateLimitPolicyFactory.class) +public class RateLimitPolicyFactoryTest { + + @Test + public void defaultPolicyShouldDoNothing() throws Exception { + PowerMockito.spy(Thread.class); + PowerMockito.doNothing().when(Thread.class); + Thread.sleep(anyLong()); + RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy(); + rateLimitPolicy.onSuccess(ImmutableList.of()); + verifyStatic(Thread.class, never()); + Thread.sleep(anyLong()); + } + + @Test + public void shouldDelayDefaultInterval() throws Exception { + PowerMockito.spy(Thread.class); + PowerMockito.doNothing().when(Thread.class); + Thread.sleep(anyLong()); + RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withFixedDelay().getRateLimitPolicy(); + rateLimitPolicy.onSuccess(ImmutableList.of()); + verifyStatic(Thread.class); + Thread.sleep(eq(1000L)); + } + + @Test + public void shouldDelayFixedInterval() throws Exception { + PowerMockito.spy(Thread.class); + PowerMockito.doNothing().when(Thread.class); + Thread.sleep(anyLong()); + RateLimitPolicy rateLimitPolicy = + RateLimitPolicyFactory.withFixedDelay(Duration.millis(500)).getRateLimitPolicy(); + rateLimitPolicy.onSuccess(ImmutableList.of()); + verifyStatic(Thread.class); + Thread.sleep(eq(500L)); + } + + @Test + public void shouldDelayDynamicInterval() throws Exception { + PowerMockito.spy(Thread.class); + PowerMockito.doNothing().when(Thread.class); + Thread.sleep(anyLong()); + AtomicLong delay = new AtomicLong(0L); + RateLimitPolicy rateLimitPolicy = + RateLimitPolicyFactory.withDelay(() -> Duration.millis(delay.getAndUpdate(d -> d ^ 1))) + .getRateLimitPolicy(); + rateLimitPolicy.onSuccess(ImmutableList.of()); + verifyStatic(Thread.class); + Thread.sleep(eq(0L)); + Thread.sleep(eq(1L)); + Thread.sleep(eq(0L)); + Thread.sleep(eq(1L)); + } +} diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java index 0d9e9a37ed92..59ab92ea2224 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/ShardReadersPoolTest.java @@ -21,6 +21,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; @@ -54,7 +56,9 @@ public class ShardReadersPoolTest { @Mock private ShardCheckpoint firstCheckpoint, secondCheckpoint; @Mock private SimplifiedKinesisClient kinesis; @Mock private KinesisRecord a, b, c, d; - @Mock private WatermarkPolicyFactory factory; + @Mock private WatermarkPolicyFactory watermarkPolicyFactory; + @Mock private RateLimitPolicyFactory rateLimitPolicyFactory; + @Mock private RateLimitPolicy customRateLimitPolicy; private KinesisReaderCheckpoint checkpoint; private ShardReadersPool shardReadersPool; @@ -75,12 +79,18 @@ public void setUp() throws TransientKinesisException { when(thirdIterator.getShardId()).thenReturn("shard3"); when(fourthIterator.getShardId()).thenReturn("shard4"); - WatermarkPolicy policy = WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); + WatermarkPolicy watermarkPolicy = + WatermarkPolicyFactory.withArrivalTimePolicy().createWatermarkPolicy(); + RateLimitPolicy rateLimitPolicy = RateLimitPolicyFactory.withoutLimiter().getRateLimitPolicy(); checkpoint = new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint)); - shardReadersPool = Mockito.spy(new ShardReadersPool(kinesis, checkpoint, factory, 100)); + shardReadersPool = + Mockito.spy( + new ShardReadersPool( + kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 100)); - when(factory.createWatermarkPolicy()).thenReturn(policy); + when(watermarkPolicyFactory.createWatermarkPolicy()).thenReturn(watermarkPolicy); + when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(rateLimitPolicy); doReturn(firstIterator).when(shardReadersPool).createShardIterator(kinesis, firstCheckpoint); doReturn(secondIterator).when(shardReadersPool).createShardIterator(kinesis, secondCheckpoint); @@ -178,8 +188,10 @@ public void shouldInterruptPuttingRecordsToQueueAndStopShortly() new KinesisReaderCheckpoint(ImmutableList.of(firstCheckpoint, secondCheckpoint)); WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy(); + RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter(); ShardReadersPool shardReadersPool = - new ShardReadersPool(kinesis, checkpoint, watermarkPolicyFactory, 2); + new ShardReadersPool( + kinesis, checkpoint, watermarkPolicyFactory, rateLimitPolicyFactory, 2); shardReadersPool.start(); Stopwatch stopwatch = Stopwatch.createStarted(); @@ -237,12 +249,14 @@ public void shouldStopReadersPoolAlsoWhenExceptionsOccurDuringStopping() throws public void shouldReturnAbsentOptionalWhenStartedWithNoIterators() throws Exception { KinesisReaderCheckpoint checkpoint = new KinesisReaderCheckpoint(Collections.emptyList()); WatermarkPolicyFactory watermarkPolicyFactory = WatermarkPolicyFactory.withArrivalTimePolicy(); + RateLimitPolicyFactory rateLimitPolicyFactory = RateLimitPolicyFactory.withoutLimiter(); shardReadersPool = Mockito.spy( new ShardReadersPool( kinesis, checkpoint, watermarkPolicyFactory, + rateLimitPolicyFactory, ShardReadersPool.DEFAULT_CAPACITY_PER_SHARD)); doReturn(firstIterator) .when(shardReadersPool) @@ -286,4 +300,34 @@ public void shouldReturnTheLeastWatermarkOfAllShards() throws TransientKinesisEx verify(firstIterator, times(2)).getShardWatermark(); verify(secondIterator, times(2)).getShardWatermark(); } + + @Test + public void shouldCallRateLimitPolicy() + throws TransientKinesisException, KinesisShardClosedException, InterruptedException { + KinesisClientThrottledException e = new KinesisClientThrottledException("", null); + when(firstIterator.readNextBatch()) + .thenThrow(e) + .thenReturn(ImmutableList.of(a, b)) + .thenReturn(Collections.emptyList()); + when(secondIterator.readNextBatch()) + .thenReturn(singletonList(c)) + .thenReturn(singletonList(d)) + .thenReturn(Collections.emptyList()); + when(rateLimitPolicyFactory.getRateLimitPolicy()).thenReturn(customRateLimitPolicy); + + shardReadersPool.start(); + List fetchedRecords = new ArrayList<>(); + while (fetchedRecords.size() < 4) { + CustomOptional nextRecord = shardReadersPool.nextRecord(); + if (nextRecord.isPresent()) { + fetchedRecords.add(nextRecord.get()); + } + } + + verify(customRateLimitPolicy).onThrottle(same(e)); + verify(customRateLimitPolicy).onSuccess(eq(ImmutableList.of(a, b))); + verify(customRateLimitPolicy).onSuccess(eq(singletonList(c))); + verify(customRateLimitPolicy).onSuccess(eq(singletonList(d))); + verify(customRateLimitPolicy, atLeastOnce()).onSuccess(eq(Collections.emptyList())); + } } diff --git a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java index 7858de779aa6..1c1d9e027e10 100644 --- a/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java +++ b/sdks/java/io/kinesis/src/test/java/org/apache/beam/sdk/io/kinesis/SimplifiedKinesisClientTest.java @@ -116,13 +116,13 @@ public void shouldHandleExpiredIterationExceptionForGetShardIterator() { @Test public void shouldHandleLimitExceededExceptionForGetShardIterator() { shouldHandleGetShardIteratorError( - new LimitExceededException(""), TransientKinesisException.class); + new LimitExceededException(""), KinesisClientThrottledException.class); } @Test public void shouldHandleProvisionedThroughputExceededExceptionForGetShardIterator() { shouldHandleGetShardIteratorError( - new ProvisionedThroughputExceededException(""), TransientKinesisException.class); + new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); } @Test @@ -190,13 +190,14 @@ public void shouldHandleExpiredIterationExceptionForShardListing() { @Test public void shouldHandleLimitExceededExceptionForShardListing() { - shouldHandleShardListingError(new LimitExceededException(""), TransientKinesisException.class); + shouldHandleShardListingError( + new LimitExceededException(""), KinesisClientThrottledException.class); } @Test public void shouldHandleProvisionedThroughputExceededExceptionForShardListing() { shouldHandleShardListingError( - new ProvisionedThroughputExceededException(""), TransientKinesisException.class); + new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); } @Test @@ -281,13 +282,13 @@ public void shouldNotCallCloudWatchWhenSpecifiedPeriodTooShort() throws Exceptio @Test public void shouldHandleLimitExceededExceptionForGetBacklogBytes() { shouldHandleGetBacklogBytesError( - new LimitExceededException(""), TransientKinesisException.class); + new LimitExceededException(""), KinesisClientThrottledException.class); } @Test public void shouldHandleProvisionedThroughputExceededExceptionForGetBacklogBytes() { shouldHandleGetBacklogBytesError( - new ProvisionedThroughputExceededException(""), TransientKinesisException.class); + new ProvisionedThroughputExceededException(""), KinesisClientThrottledException.class); } @Test