diff --git a/jflyte/pom.xml b/jflyte/pom.xml index 0d86ce35a..25c8b6213 100644 --- a/jflyte/pom.xml +++ b/jflyte/pom.xml @@ -148,6 +148,11 @@ junit-vintage-engine test + + org.junit.jupiter + junit-jupiter-params + test + org.hamcrest hamcrest diff --git a/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java b/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java index ace9aaa51..f958d26e6 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java +++ b/jflyte/src/main/java/org/flyte/jflyte/GrpcRetries.java @@ -19,8 +19,12 @@ import com.google.auto.value.AutoValue; import com.google.errorprone.annotations.Var; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; +import java.util.Set; import java.util.concurrent.Callable; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,6 +33,10 @@ abstract class GrpcRetries { private static final Logger LOG = LoggerFactory.getLogger(GrpcRetries.class); + private static final Set RETRYABLE_CODES = + Stream.of(Code.UNAVAILABLE, Code.DEADLINE_EXCEEDED, Code.INTERNAL) + .collect(Collectors.toSet()); + public abstract int maxRetries(); public abstract long maxDelayMilliseconds(); @@ -56,12 +64,11 @@ public T retry(Retryable retryable) { } catch (StatusRuntimeException e) { Status.Code code = e.getStatus().getCode(); - boolean isRetryable = - code == Status.Code.UNAVAILABLE || code == Status.Code.DEADLINE_EXCEEDED; + boolean isRetryable = isRetryable(code); if (attempt < maxRetries() && isRetryable) { long delay = - Math.min(maxDelayMilliseconds(), (1 << attempt) * initialDelayMilliseconds()); + Math.min(maxDelayMilliseconds(), (1L << attempt) * initialDelayMilliseconds()); LOG.warn("Retrying in " + delay + " ms", e); try { @@ -78,6 +85,10 @@ public T retry(Retryable retryable) { } while (true); } + private static boolean isRetryable(Code code) { + return RETRYABLE_CODES.contains(code); + } + static GrpcRetries create() { return create( /* maxRetries= */ 10, diff --git a/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java b/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java index e70f22fbc..f28c099be 100644 --- a/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java +++ b/jflyte/src/test/java/org/flyte/jflyte/GrpcRetriesTest.java @@ -20,9 +20,12 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; import java.util.concurrent.atomic.AtomicLong; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class GrpcRetriesTest { @@ -50,8 +53,11 @@ void testMaxAttempts() { assertEquals(Status.DEADLINE_EXCEEDED, e.getStatus()); } - @Test - void testSuccessfulRetry() { + @ParameterizedTest + @EnumSource( + value = Code.class, + names = {"DEADLINE_EXCEEDED", "UNAVAILABLE", "INTERNAL"}) + void testSuccessfulRetry(Code code) { AtomicLong attempts = new AtomicLong(); GrpcRetries retries = GrpcRetries.create( @@ -64,7 +70,7 @@ void testSuccessfulRetry() { retries.retry( () -> { if (attempts.incrementAndGet() <= 5L) { - throw new StatusRuntimeException(Status.DEADLINE_EXCEEDED); + throw new StatusRuntimeException(code.toStatus()); } else { return 10; } @@ -93,10 +99,10 @@ void testNonRetryable() { retries.retry( () -> { attempts.incrementAndGet(); - throw new StatusRuntimeException(Status.INTERNAL); + throw new StatusRuntimeException(Status.INVALID_ARGUMENT); })); - assertEquals(Status.INTERNAL, e.getStatus()); + assertEquals(Status.INVALID_ARGUMENT, e.getStatus()); assertEquals(1, attempts.get()); } diff --git a/pom.xml b/pom.xml index ecc26bb00..889b4c975 100644 --- a/pom.xml +++ b/pom.xml @@ -264,6 +264,11 @@ junit-vintage-engine ${junit.version} + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + org.hamcrest hamcrest