diff --git a/jflyte/pom.xml b/jflyte/pom.xml
index 0d86ce35a..25c8b6213 100644
--- a/jflyte/pom.xml
+++ b/jflyte/pom.xml
@@ -148,6 +148,11 @@
junit-vintage-engine
test
+
+ org.junit.jupiter
+ junit-jupiter-params
+ test
+
org.hamcrest
hamcrest
diff --git a/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java b/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java
index ace9aaa51..f958d26e6 100644
--- a/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java
+++ b/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java
@@ -19,8 +19,12 @@
import com.google.auto.value.AutoValue;
import com.google.errorprone.annotations.Var;
import io.grpc.Status;
+import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
+import java.util.Set;
import java.util.concurrent.Callable;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -29,6 +33,10 @@
abstract class GrpcRetries {
private static final Logger LOG = LoggerFactory.getLogger(GrpcRetries.class);
+ private static final Set RETRYABLE_CODES =
+ Stream.of(Code.UNAVAILABLE, Code.DEADLINE_EXCEEDED, Code.INTERNAL)
+ .collect(Collectors.toSet());
+
public abstract int maxRetries();
public abstract long maxDelayMilliseconds();
@@ -56,12 +64,11 @@ public T retry(Retryable retryable) {
} catch (StatusRuntimeException e) {
Status.Code code = e.getStatus().getCode();
- boolean isRetryable =
- code == Status.Code.UNAVAILABLE || code == Status.Code.DEADLINE_EXCEEDED;
+ boolean isRetryable = isRetryable(code);
if (attempt < maxRetries() && isRetryable) {
long delay =
- Math.min(maxDelayMilliseconds(), (1 << attempt) * initialDelayMilliseconds());
+ Math.min(maxDelayMilliseconds(), (1L << attempt) * initialDelayMilliseconds());
LOG.warn("Retrying in " + delay + " ms", e);
try {
@@ -78,6 +85,10 @@ public T retry(Retryable retryable) {
} while (true);
}
+ private static boolean isRetryable(Code code) {
+ return RETRYABLE_CODES.contains(code);
+ }
+
static GrpcRetries create() {
return create(
/* maxRetries= */ 10,
diff --git a/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java b/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java
index e70f22fbc..f28c099be 100644
--- a/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java
+++ b/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java
@@ -20,9 +20,12 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.grpc.Status;
+import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
public class GrpcRetriesTest {
@@ -50,8 +53,11 @@ void testMaxAttempts() {
assertEquals(Status.DEADLINE_EXCEEDED, e.getStatus());
}
- @Test
- void testSuccessfulRetry() {
+ @ParameterizedTest
+ @EnumSource(
+ value = Code.class,
+ names = {"DEADLINE_EXCEEDED", "UNAVAILABLE", "INTERNAL"})
+ void testSuccessfulRetry(Code code) {
AtomicLong attempts = new AtomicLong();
GrpcRetries retries =
GrpcRetries.create(
@@ -64,7 +70,7 @@ void testSuccessfulRetry() {
retries.retry(
() -> {
if (attempts.incrementAndGet() <= 5L) {
- throw new StatusRuntimeException(Status.DEADLINE_EXCEEDED);
+ throw new StatusRuntimeException(code.toStatus());
} else {
return 10;
}
@@ -93,10 +99,10 @@ void testNonRetryable() {
retries.retry(
() -> {
attempts.incrementAndGet();
- throw new StatusRuntimeException(Status.INTERNAL);
+ throw new StatusRuntimeException(Status.INVALID_ARGUMENT);
}));
- assertEquals(Status.INTERNAL, e.getStatus());
+ assertEquals(Status.INVALID_ARGUMENT, e.getStatus());
assertEquals(1, attempts.get());
}
diff --git a/pom.xml b/pom.xml
index ecc26bb00..889b4c975 100644
--- a/pom.xml
+++ b/pom.xml
@@ -264,6 +264,11 @@
junit-vintage-engine
${junit.version}
+
+ org.junit.jupiter
+ junit-jupiter-params
+ ${junit.version}
+
org.hamcrest
hamcrest