Skip to content

Commit

Permalink
[#24971] Add a retry policy for JmsIO #24971 (#24973)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amraneze authored Feb 11, 2023
1 parent 9b77bf9 commit 198b93e
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added in JmsIO a retry policy for failed publications (Java) ([#24971](https://github.com/apache/beam/issues/24971)).

## New Features / Improvements

Expand Down
257 changes: 215 additions & 42 deletions sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.jms;

import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.auto.value.AutoValue;
Expand All @@ -34,6 +35,7 @@
import javax.jms.Connection;
import javax.jms.ConnectionFactory;
import javax.jms.Destination;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageConsumer;
import javax.jms.MessageProducer;
Expand All @@ -47,6 +49,8 @@
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.ExecutorOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
Expand All @@ -55,11 +59,16 @@
import org.apache.beam.sdk.transforms.SerializableBiFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand Down Expand Up @@ -706,6 +715,8 @@ public abstract static class Write<EventT>

abstract @Nullable SerializableFunction<EventT, String> getTopicNameMapper();

abstract @Nullable RetryConfiguration getRetryConfiguration();

abstract Builder<EventT> builder();

@AutoValue.Builder
Expand All @@ -726,6 +737,8 @@ abstract Builder<EventT> setValueMapper(
abstract Builder<EventT> setTopicNameMapper(
SerializableFunction<EventT, String> topicNameMapper);

abstract Builder<EventT> setRetryConfiguration(RetryConfiguration retryConfiguration);

abstract Write<EventT> build();
}

Expand Down Expand Up @@ -866,6 +879,48 @@ public Write<EventT> withValueMapper(
return builder().setValueMapper(valueMapper).build();
}

/**
* Specify the JMS retry configuration. The {@link JmsIO.Write} acts as a publisher on the
* topic.
*
* <p>Allows a retry for failed published messages, the user should specify the maximum number
* of retries, a duration for retrying and a maximum cumulative retries. By default, the
* duration for retrying used is 15s and the maximum cumulative is 1000 days {@link
* RetryConfiguration}
*
* <p>For example:
*
* <pre>{@code
* RetryConfiguration retryConfiguration = RetryConfiguration.create(5);
* }</pre>
*
* or
*
* <pre>{@code
* RetryConfiguration retryConfiguration =
* RetryConfiguration.create(5, Duration.standardSeconds(30), null);
* }</pre>
*
* or
*
* <pre>{@code
* RetryConfiguration retryConfiguration =
* RetryConfiguration.create(5, Duration.standardSeconds(30), Duration.standardDays(15));
* }</pre>
*
* <pre>{@code
* .apply(JmsIO.write().withPublicationRetryPolicy(publicationRetryPolicy)
* }</pre>
*
* @param retryConfiguration The retry configuration that should be used in case of failed
* publications.
* @return The corresponding {@link JmsIO.Write}.
*/
public Write<EventT> withRetryConfiguration(RetryConfiguration retryConfiguration) {
checkArgument(retryConfiguration != null, "retryConfiguration can not be null");
return builder().setRetryConfiguration(retryConfiguration).build();
}

@Override
public WriteJmsResult<EventT> expand(PCollection<EventT> input) {
checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
Expand All @@ -878,15 +933,7 @@ public WriteJmsResult<EventT> expand(PCollection<EventT> input) {
"Only one of withQueue(queue), withTopic(topic), or withTopicNameMapper(function) must be set.");
checkArgument(getValueMapper() != null, "withValueMapper() is required");

final TupleTag<EventT> failedMessagesTag = new TupleTag<>();
final TupleTag<EventT> messagesTag = new TupleTag<>();
PCollectionTuple res =
input.apply(
ParDo.of(new WriterFn<>(this, failedMessagesTag))
.withOutputTags(messagesTag, TupleTagList.of(failedMessagesTag)));
PCollection<EventT> failedMessages = res.get(failedMessagesTag).setCoder(input.getCoder());
res.get(messagesTag).setCoder(input.getCoder());
return WriteJmsResult.in(input.getPipeline(), failedMessagesTag, failedMessages);
return input.apply(new Writer<>(this));
}

private boolean isExclusiveTopicQueue() {
Expand All @@ -897,32 +944,73 @@ private boolean isExclusiveTopicQueue() {
== 1;
return exclusiveTopicQueue;
}
}

static class Writer<T> extends PTransform<PCollection<T>, WriteJmsResult<T>> {

public static final String CONNECTION_ERRORS_METRIC_NAME = "connectionErrors";
public static final String PUBLICATION_RETRIES_METRIC_NAME = "publicationRetries";
public static final String JMS_IO_PRODUCER_METRIC_NAME = Writer.class.getCanonicalName();

private static final Logger LOG = LoggerFactory.getLogger(Writer.class);
private static final String PUBLISH_TO_JMS_STEP_NAME = "Publish to JMS";

private final JmsIO.Write<T> spec;
private final TupleTag<T> messagesTag;
private final TupleTag<T> failedMessagesTag;

Writer(JmsIO.Write<T> spec) {
this.spec = spec;
this.messagesTag = new TupleTag<>();
this.failedMessagesTag = new TupleTag<>();
}

@Override
public WriteJmsResult<T> expand(PCollection<T> input) {
PCollectionTuple failedPublishedMessagesTuple =
input.apply(
PUBLISH_TO_JMS_STEP_NAME,
ParDo.of(new JmsIOProducerFn<>(spec, failedMessagesTag))
.withOutputTags(messagesTag, TupleTagList.of(failedMessagesTag)));
PCollection<T> failedPublishedMessages =
failedPublishedMessagesTuple.get(failedMessagesTag).setCoder(input.getCoder());
failedPublishedMessagesTuple.get(messagesTag).setCoder(input.getCoder());

return WriteJmsResult.in(input.getPipeline(), failedMessagesTag, failedPublishedMessages);
}

private static class WriterFn<EventT> extends DoFn<EventT, EventT> {
private static class JmsConnection<T> implements Serializable {

private Write<EventT> spec;
private static final long serialVersionUID = 1L;

private Connection connection;
private Session session;
private MessageProducer producer;
private Destination destination;
private final TupleTag<EventT> failedMessageTag;
private transient @Initialized Session session;
private transient @Initialized Connection connection;
private transient @Initialized Destination destination;
private transient @Initialized MessageProducer producer;

public WriterFn(Write<EventT> spec, TupleTag<EventT> failedMessageTag) {
private boolean isProducerNeedsToBeCreated = true;
private final JmsIO.Write<T> spec;
private final Counter connectionErrors =
Metrics.counter(JMS_IO_PRODUCER_METRIC_NAME, CONNECTION_ERRORS_METRIC_NAME);

public JmsConnection(Write<T> spec) {
this.spec = spec;
this.failedMessageTag = failedMessageTag;
}

@Setup
public void setup() throws Exception {
if (producer == null) {
public void start() throws JMSException {
if (isProducerNeedsToBeCreated) {
ConnectionFactory connectionFactory = spec.getConnectionFactory();
if (spec.getUsername() != null) {
this.connection =
spec.getConnectionFactory()
.createConnection(spec.getUsername(), spec.getPassword());
connectionFactory.createConnection(spec.getUsername(), spec.getPassword());
} else {
this.connection = spec.getConnectionFactory().createConnection();
this.connection = connectionFactory.createConnection();
}
this.connection.setExceptionListener(
exception -> {
this.isProducerNeedsToBeCreated = true;
this.connectionErrors.inc();
});
this.connection.start();
// false means we don't use JMS transaction.
this.session = this.connection.createSession(false, Session.AUTO_ACKNOWLEDGE);
Expand All @@ -932,36 +1020,121 @@ public void setup() throws Exception {
} else if (spec.getTopic() != null) {
this.destination = session.createTopic(spec.getTopic());
}

this.producer = this.session.createProducer(null);
this.producer = this.session.createProducer(this.destination);
this.isProducerNeedsToBeCreated = false;
}
}

@ProcessElement
public void processElement(ProcessContext ctx) {
public void publishMessage(T input) throws JMSException, JmsIOException {
Destination destinationToSendTo = destination;
try {
Message message = spec.getValueMapper().apply(ctx.element(), session);
Message message = spec.getValueMapper().apply(input, session);
if (spec.getTopicNameMapper() != null) {
destinationToSendTo =
session.createTopic(spec.getTopicNameMapper().apply(ctx.element()));
destinationToSendTo = session.createTopic(spec.getTopicNameMapper().apply(input));
}
producer.send(destinationToSendTo, message);
} catch (Exception ex) {
LOG.error("Error sending message on topic {}", destinationToSendTo);
ctx.output(failedMessageTag, ctx.element());
} catch (JMSException | JmsIOException | NullPointerException exception) {
// Handle NPE in case of getValueMapper or getTopicNameMapper returns NPE
if (exception instanceof NullPointerException) {
throw new JmsIOException("An error occurred", exception);
}
throw exception;
}
}

public void close() throws JMSException {
isProducerNeedsToBeCreated = true;
if (producer != null) {
producer.close();
producer = null;
}
if (session != null) {
session.close();
session = null;
}
if (connection != null) {
try {
// If the connection failed, stopping the connection will throw a JMSException
connection.stop();
} catch (JMSException exception) {
LOG.warn("The connection couldn't be closed", exception);
}
connection.close();
connection = null;
}
}
}

static class JmsIOProducerFn<T> extends DoFn<T, T> {

private transient @Initialized FluentBackoff retryBackOff;

private final JmsIO.Write<T> spec;
private final TupleTag<T> failedMessagesTags;
private final @Initialized JmsConnection<T> jmsConnection;
private final Counter publicationRetries =
Metrics.counter(JMS_IO_PRODUCER_METRIC_NAME, PUBLICATION_RETRIES_METRIC_NAME);

JmsIOProducerFn(JmsIO.Write<T> spec, TupleTag<T> failedMessagesTags) {
this.spec = spec;
this.failedMessagesTags = failedMessagesTags;
this.jmsConnection = new JmsConnection<>(spec);
}

@Setup
public void setup() {
RetryConfiguration retryConfiguration = checkStateNotNull(spec.getRetryConfiguration());
retryBackOff =
FluentBackoff.DEFAULT
.withInitialBackoff(checkStateNotNull(retryConfiguration.getInitialDuration()))
.withMaxCumulativeBackoff(checkStateNotNull(retryConfiguration.getMaxDuration()))
.withMaxRetries(retryConfiguration.getMaxAttempts());
}

@StartBundle
public void startBundle() throws JMSException {
this.jmsConnection.start();
}

@ProcessElement
public void processElement(@Element T input, ProcessContext context) {
try {
publishMessage(input);
} catch (JMSException | JmsIOException | IOException | InterruptedException exception) {
LOG.error("Error while publishing the message", exception);
context.output(this.failedMessagesTags, input);
if (exception instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
}
}

private void publishMessage(T input)
throws JMSException, JmsIOException, IOException, InterruptedException {
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = checkStateNotNull(retryBackOff).backoff();
while (true) {
try {
this.jmsConnection.publishMessage(input);
break;
} catch (JMSException | JmsIOException exception) {
if (!BackOffUtils.next(sleeper, backoff)) {
throw exception;
} else {
publicationRetries.inc();
}
}
}
}

@FinishBundle
public void finishBundle() throws JMSException {
this.jmsConnection.close();
}

@Teardown
public void teardown() throws Exception {
producer.close();
producer = null;
session.close();
session = null;
connection.stop();
connection.close();
connection = null;
public void tearDown() throws JMSException {
this.jmsConnection.close();
}
}
}
Expand Down
Loading

0 comments on commit 198b93e

Please sign in to comment.