Skip to content

Commit

Permalink
Fixes #4156: Improves handling of empty or blank input for openai pro…
Browse files Browse the repository at this point in the history
…cedures (#4228)

* Fixes #4156: Improves handling of empty or blank input for openai procedures

* fix tests

* changed boolean conditions
  • Loading branch information
vga91 authored Dec 6, 2024
1 parent 4e2f977 commit d6b31c3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 27 deletions.
1 change: 1 addition & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ If present, they take precedence over the analogous APOC configs.
By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures.
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
|===


Expand Down
2 changes: 1 addition & 1 deletion extended/src/main/java/apoc/ml/MLUtil.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package apoc.ml;

public class MLUtil {
public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input";
public static final String ERROR_NULL_INPUT = "Null, blank or empty input provided. Please specify a valid input";

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
Expand Down
56 changes: 45 additions & 11 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.net.MalformedURLException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -35,6 +35,7 @@ public class OpenAI {
public static final String JSON_PATH_CONF_KEY = "jsonPath";
public static final String PATH_CONF_KEY = "path";
public static final String GPT_4O_MODEL = "gpt-4o";
public static final String FAIL_ON_ERROR_CONF = "failOnError";

@Context
public ApocConfig apocConfig;
Expand Down Expand Up @@ -147,6 +148,10 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/
boolean failOnError = isFailOnError(configuration);
if (checkNullInput(texts, failOnError)) return Stream.empty();
texts = texts.stream().filter(StringUtils::isNotBlank).toList();
if (checkEmptyInput(texts, failOnError)) return Stream.empty();
return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
(map, text) -> {
Long index = (Long) map.get("index");
Expand All @@ -156,6 +161,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
);
}


static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker,
BiFunction<Map, String, T> embeddingMapping, Function<String, T> nullMapping) throws JsonProcessingException, MalformedURLException {
if (texts == null) {
Expand Down Expand Up @@ -194,19 +200,19 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("api_ke
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
}
*/
if (prompt == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
boolean failOnError = isFailOnError(configuration);
if(checkBlankInput(prompt, failOnError)) return Stream.empty();
return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
}

@Procedure("apoc.ml.openai.chat")
@Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (messages == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
boolean failOnError = isFailOnError(configuration);
if (checkNullInput(messages, failOnError)) return Stream.empty();
messages = messages.stream().filter(MapUtils::isNotEmpty).toList();
if (checkEmptyInput(messages, failOnError)) return Stream.empty();
configuration.putIfAbsent("model", GPT_4O_MODEL);
return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
Expand All @@ -220,4 +226,32 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
} ] }
*/
}

private static boolean isFailOnError(Map<String, Object> configuration) {
return Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
}

static boolean checkNullInput(Object input, boolean failOnError) {
return checkInput(failOnError, () -> Objects.isNull(input));
}

static boolean checkEmptyInput(Collection<?> input, boolean failOnError) {
return checkInput(failOnError, () -> input.isEmpty());
}

static boolean checkBlankInput(String input, boolean failOnError) {
return checkInput(failOnError, () -> StringUtils.isBlank(input));
}

private static boolean checkInput(
boolean failOnError,
Supplier<Boolean> checkFunction
){
if (checkFunction.get()) {
if(failOnError) throw new RuntimeException(ERROR_NULL_INPUT);
return true;
}
return false;
}

}
18 changes: 4 additions & 14 deletions extended/src/test/java/apoc/ml/MLTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
package apoc.ml;

import apoc.util.ExtendedTestUtil;
import org.neo4j.graphdb.GraphDatabaseService;

import java.util.Map;

import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static apoc.ml.MLUtil.*;

public class MLTestUtil {

public static void assertNullInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
try {
testCall(db, query, params,
(row) -> fail("Should fail due to null input")
);
} catch (RuntimeException e) {
String message = e.getMessage();
assertTrue("Current error message is: " + message,
message.contains(ERROR_NULL_INPUT)
);
}
ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT);
}
}
46 changes: 45 additions & 1 deletion extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static apoc.ml.MLTestUtil.assertNullInputFails;
import static apoc.ml.MLUtil.MODEL_CONF_KEY;
import static apoc.ml.OpenAI.GPT_4O_MODEL;
import static apoc.ml.OpenAI.FAIL_ON_ERROR_CONF;
import static apoc.ml.OpenAITestResultUtils.*;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand Down Expand Up @@ -140,13 +141,49 @@ public void embeddingsNull() {
);
}

@Test
public void chatNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void chatReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void embeddingsReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}


@Test
public void chatWithEmptyFails() {
assertNullInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void completionNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void chatCompletionNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Expand All @@ -160,4 +197,11 @@ public void chatCompletionNullGpt35Turbo() {
Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL))
);
}

@Test
public void completionReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}
}

0 comments on commit d6b31c3

Please sign in to comment.