diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index 95c23fb4ac..a2fc7ad71d 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -36,6 +36,9 @@ If present, they take precedence over the analogous APOC configs. By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures. | jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. +| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false) +| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5) +| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false) |=== diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index ed9389a1cd..d2e8a7a8c1 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -3,7 +3,9 @@ import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; +import apoc.util.ExtendedUtil; import apoc.util.JsonUtil; +import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; @@ -35,6 +37,9 @@ public class OpenAI { public static final String JSON_PATH_CONF_KEY = "jsonPath"; public static final String PATH_CONF_KEY = "path"; public static final String GPT_4O_MODEL = "gpt-4o"; + public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries"; + public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff"; + public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries"; @Context public ApocConfig apocConfig; @@ -58,6 +63,9 @@ public EmbeddingResult(long index, String text, List embedding) { static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); + boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) ); + Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5)); + boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) ); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); @@ -77,7 +85,7 @@ static Stream executeRequest(String apiKey, Map configur path = (String) configuration.getOrDefault(PATH_CONF_KEY, path); OpenAIRequestHandler apiType = type.get(); - jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); + String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath); headers.put("Content-Type", "application/json"); apiType.addApiKey(headers, apiKey); @@ -87,7 +95,14 @@ static Stream executeRequest(String apiKey, Map configur // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model // therefore is better to join the not-empty path pieces var url = apiType.getFullUrl(path, configuration, apocConfig); - return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker); + return ExtendedUtil.withBackOffRetries( + () -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker), + enableBackOffRetries, backOffRetries, + exeception -> { + if(!exeception.getMessage().contains("response code: 429")) + throw new RuntimeException(exeception); + } + ); } private static void handleAPIProvider(OpenAIRequestHandler.Type type, diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index f7fa8a5c9d..e96ddb978f 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -5,12 +5,14 @@ import com.fasterxml.jackson.core.json.JsonWriteFeature; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.lang3.StringUtils; +import org.eclipse.collections.api.block.function.Function0; import org.neo4j.exceptions.Neo4jException; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.ExecutionPlanDescription; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.QueryExecutionType; import org.neo4j.graphdb.Result; +import org.neo4j.logging.Log; import org.neo4j.procedure.Mode; import org.neo4j.values.storable.DateTimeValue; import org.neo4j.values.storable.DateValue; @@ -30,14 +32,11 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.temporal.TemporalAccessor; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.time.temporal.TemporalUnit; +import java.util.*; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -353,5 +352,62 @@ public static float[] listOfNumbersToFloatArray(List embedding } return floats; } - + + public static T withBackOffRetries( + Supplier func, boolean retry, int backoffRetry, + Consumer exceptionHandler + ){ + return withBackOffRetries(func, retry, backoffRetry, exceptionHandler, false); + } + + + public static T withBackOffRetries( + Supplier func, boolean retry, int backoffRetry, + Consumer exceptionHandler, boolean exponential + ) { + T result = null; + backoffRetry = backoffRetry < 1 ? 5 : backoffRetry; + int countDown = backoffRetry; + exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {}); + while (true) { + try { + result = func.get(); + break; + } catch (Exception e) { + if(!retry || countDown < 1) throw e; + exceptionHandler.accept(e); + countDown--; + backoffSleep( + getDelay(backoffRetry, countDown, exponential) + ); + } + } + return result; + } + + private static void backoffSleep(long millis){ + sleep(millis, "Operation interrupted during backoff"); + } + + public static void sleep(long millis, String interruptedMessage) { + try { + Thread.sleep(millis); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException(interruptedMessage, ie); + } + } + + private static long getDelay(Integer backoffRetry, Integer countDown, boolean exponential){ + long sleepMultiplier = exponential ? + (long) Math.pow(2, backoffRetry - countDown) : // Exponential retry progression + backoffRetry - countDown; // Linear retry progression + return Math.min( + Duration.ofSeconds(1) + .multipliedBy(sleepMultiplier) + .toMillis(), + Duration.ofSeconds(30).toMillis() // Max 30s + ); + } + } diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index e56bb2f13e..d09faa5a10 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -160,4 +160,5 @@ public void chatCompletionNullGpt35Turbo() { Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL)) ); } + } \ No newline at end of file diff --git a/extended/src/test/java/apoc/util/ExtendedUtilTest.java b/extended/src/test/java/apoc/util/ExtendedUtilTest.java new file mode 100644 index 0000000000..b36603f793 --- /dev/null +++ b/extended/src/test/java/apoc/util/ExtendedUtilTest.java @@ -0,0 +1,106 @@ +package apoc.util; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertTrue; + +public class ExtendedUtilTest { + + private static int i = 0; + + @Test + public void testWithLinearBackOffRetriesWithSuccess() { + i = 0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, -1, // test backoffRetry default value -> 5 + runEx -> { + if(!runEx.getMessage().contains("Expected")) + throw new RuntimeException("Some Bad News..."); + } + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with a linear backoff strategy, + // sleeping for 1 second, 2 seconds, and 3 seconds between retries. + // This results in a total wait time of 6 seconds (1s + 2s + 3s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 6 seconds. + assertTrue(time > 5500); + assertTrue(time < 6500); + } + + @Test + public void testWithExponentialBackOffRetriesWithSuccess() { + i=0; + long start = System.currentTimeMillis(); + int result = ExtendedUtil.withBackOffRetries( + this::testFunction, + true, 0, // test backoffRetry default value -> 5 + runEx -> { + if(!runEx.getMessage().contains("Expected")) + throw new RuntimeException("Some Bad News..."); + }, + Boolean.TRUE + ); + long time = System.currentTimeMillis() - start; + + assertEquals(4, result); + + // The method will attempt to execute the operation with an exponential backoff strategy, + // sleeping for 2 second, 4 seconds, and 8 seconds between retries. + // This results in a total wait time of 14 seconds (2s + 4s + 8s) if the operation succeeds on the third attempt, + // leading to an approximate execution time of 14 seconds. + assertTrue(time > 13500); + assertTrue(time < 14500); + } + + @Test + public void testBackOffRetriesWithError() { + i=0; + long start = System.currentTimeMillis(); + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + true, 2, + runEx -> {} + ) + ); + long time = System.currentTimeMillis() - start; + + // The method is configured to retry the operation twice. + // So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception. + // Resulting in an approximate execution time of 3 seconds. + assertTrue(time > 2500); + assertTrue(time < 3500); + } + + @Test + public void testWithoutBackOffRetriesWithError() { + i=0; + assertThrows( + RuntimeException.class, + () -> ExtendedUtil.withBackOffRetries( + this::testFunction, + false, 30, + runEx -> {} + ) + ); + + // Retry strategy is not active and the testFunction is executed only once by raising an exception. + assertEquals(1, i); + } + + private int testFunction() { + i++; + if (i == 4) { + throw new RuntimeException("Expected i not equal to 4"); + } + return i; + } + +}