Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue 4153] Adds a backoff strategy to OpenAI API calls #3

Open
wants to merge 4 commits into
base: 5.25
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ If present, they take precedence over the analogous APOC configs.
By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures.
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
| enableBackOffRetries | If set to true, enables the backoff retry strategy for handling failures. (default: false)
| backOffRetries | Sets the maximum number of retry attempts before the operation throws an exception. (default: 5)
| exponentialBackoff | If set to true, applies an exponential progression to the wait time between retries. If set to false, the wait time increases linearly. (default: false)
|===


Expand Down
19 changes: 17 additions & 2 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.ExtendedUtil;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
Expand Down Expand Up @@ -35,6 +37,9 @@ public class OpenAI {
public static final String JSON_PATH_CONF_KEY = "jsonPath";
public static final String PATH_CONF_KEY = "path";
public static final String GPT_4O_MODEL = "gpt-4o";
public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";

@Context
public ApocConfig apocConfig;
Expand All @@ -58,6 +63,9 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
boolean enableBackOffRetries = Util.toBoolean( configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY) );
Integer backOffRetries = Util.toInteger(configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
boolean exponentialBackoff = Util.toBoolean( configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY) );
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");

Expand All @@ -77,7 +85,7 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
path = (String) configuration.getOrDefault(PATH_CONF_KEY, path);
OpenAIRequestHandler apiType = type.get();

jsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
String sJsonPath = (String) configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Farei anch'io un'altra variabile, ma visto che l'implementazione viene dall'alto meglio rimettere jsonPath come prima

Copy link
Owner Author

@mpetrini-larus mpetrini-larus Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

si ma poi non lo posso passare alla lambda sotto 😕
se la lascio com'era non è "effettivamente-statica" e la lambda non la digerisce

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Java maledetto, allora va bene così

headers.put("Content-Type", "application/json");
apiType.addApiKey(headers, apiKey);

Expand All @@ -87,7 +95,14 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = apiType.getFullUrl(path, configuration, apocConfig);
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
return ExtendedUtil.withBackOffRetries(
() -> JsonUtil.loadJson(url, headers, payload, sJsonPath, true, List.of(), urlAccessChecker),
enableBackOffRetries, backOffRetries,
exeception -> {
if(!exeception.getMessage().contains("response code: 429"))
throw new RuntimeException(exeception);
}
);
}

private static void handleAPIProvider(OpenAIRequestHandler.Type type,
Expand Down
74 changes: 65 additions & 9 deletions extended/src/main/java/apoc/util/ExtendedUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import com.fasterxml.jackson.core.json.JsonWriteFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.collections.api.block.function.Function0;
import org.neo4j.exceptions.Neo4jException;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.ExecutionPlanDescription;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.QueryExecutionType;
import org.neo4j.graphdb.Result;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Mode;
import org.neo4j.values.storable.DateTimeValue;
import org.neo4j.values.storable.DateValue;
Expand All @@ -30,14 +32,11 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAccessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.time.temporal.TemporalUnit;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -353,5 +352,62 @@ public static float[] listOfNumbersToFloatArray(List<? extends Number> embedding
}
return floats;
}


public static <T> T withBackOffRetries(
Supplier<T> func, boolean retry, int backoffRetry,
Consumer<Exception> exceptionHandler
){
return withBackOffRetries(func, retry, backoffRetry, exceptionHandler, false);
}


public static <T> T withBackOffRetries(
Supplier<T> func, boolean retry, int backoffRetry,
Consumer<Exception> exceptionHandler, boolean exponential
) {
T result = null;
backoffRetry = backoffRetry < 1 ? 5 : backoffRetry;
int countDown = backoffRetry;
exceptionHandler = Objects.requireNonNullElse(exceptionHandler, exe -> {});
while (true) {
try {
result = func.get();
break;
} catch (Exception e) {
if(!retry || countDown < 1) throw e;
exceptionHandler.accept(e);
countDown--;
backoffSleep(
getDelay(backoffRetry, countDown, exponential)
);
}
}
return result;
}

private static void backoffSleep(long millis){
sleep(millis, "Operation interrupted during backoff");
}

public static void sleep(long millis, String interruptedMessage) {
try {
Thread.sleep(millis);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException(interruptedMessage, ie);
}
}

private static long getDelay(Integer backoffRetry, Integer countDown, boolean exponential){
long sleepMultiplier = exponential ?
(long) Math.pow(2, backoffRetry - countDown) : // Exponential retry progression
backoffRetry - countDown; // Linear retry progression
return Math.min(
Duration.ofSeconds(1)
.multipliedBy(sleepMultiplier)
.toMillis(),
Duration.ofSeconds(30).toMillis() // Max 30s
);
}

}
1 change: 1 addition & 0 deletions extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@ public void chatCompletionNullGpt35Turbo() {
Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL))
);
}

}
106 changes: 106 additions & 0 deletions extended/src/test/java/apoc/util/ExtendedUtilTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package apoc.util;

import org.junit.Test;

import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue;

public class ExtendedUtilTest {

private static int i = 0;

@Test
public void testWithLinearBackOffRetriesWithSuccess() {
i = 0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true, -1, // test backoffRetry default value -> 5
runEx -> {
if(!runEx.getMessage().contains("Expected"))
throw new RuntimeException("Some Bad News...");
}
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with a linear backoff strategy,
// sleeping for 1 second, 2 seconds, and 3 seconds between retries.
// This results in a total wait time of 6 seconds (1s + 2s + 3s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 6 seconds.
assertTrue(time > 5500);
assertTrue(time < 6500);
}

@Test
public void testWithExponentialBackOffRetriesWithSuccess() {
i=0;
long start = System.currentTimeMillis();
int result = ExtendedUtil.withBackOffRetries(
this::testFunction,
true, 0, // test backoffRetry default value -> 5
runEx -> {
if(!runEx.getMessage().contains("Expected"))
throw new RuntimeException("Some Bad News...");
},
Boolean.TRUE
);
long time = System.currentTimeMillis() - start;

assertEquals(4, result);

// The method will attempt to execute the operation with an exponential backoff strategy,
// sleeping for 2 second, 4 seconds, and 8 seconds between retries.
// This results in a total wait time of 14 seconds (2s + 4s + 8s) if the operation succeeds on the third attempt,
// leading to an approximate execution time of 14 seconds.
assertTrue(time > 13500);
assertTrue(time < 14500);
}

@Test
public void testBackOffRetriesWithError() {
i=0;
long start = System.currentTimeMillis();
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
true, 2,
runEx -> {}
)
);
long time = System.currentTimeMillis() - start;

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scrivi pure qua qualche riga di commento come sopra:

// The method will attempt to execute the operation with an exponential backoff strategy,
...etc

// The method is configured to retry the operation twice.
// So, it will make two extra-attempts, waiting for 1 second and 2 seconds before failing and throwing an exception.
// Resulting in an approximate execution time of 3 seconds.
assertTrue(time > 2500);
assertTrue(time < 3500);
}

@Test
public void testWithoutBackOffRetriesWithError() {
i=0;
assertThrows(
RuntimeException.class,
() -> ExtendedUtil.withBackOffRetries(
this::testFunction,
false, 30,
runEx -> {}
)
);

// Retry strategy is not active and the testFunction is executed only once by raising an exception.
assertEquals(1, i);
}

private int testFunction() {
i++;
if (i == 4) {
throw new RuntimeException("Expected i not equal to 4");
}
return i;
}

}
Loading