diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index 95c23fb4ac..b02fd7d4d7 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -36,6 +36,7 @@ If present, they take precedence over the analogous APOC configs. By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures. | jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. +| failOnError | If true (default), the procedure fails in case of empty, blank or null input |=== diff --git a/extended/src/main/java/apoc/ml/MLUtil.java b/extended/src/main/java/apoc/ml/MLUtil.java index f3054c52d5..0e162a625e 100644 --- a/extended/src/main/java/apoc/ml/MLUtil.java +++ b/extended/src/main/java/apoc/ml/MLUtil.java @@ -1,7 +1,7 @@ package apoc.ml; public class MLUtil { - public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input"; + public static final String ERROR_NULL_INPUT = "Null, blank or empty input provided. Please specify a valid input"; public static final String ENDPOINT_CONF_KEY = "endpoint"; public static final String API_VERSION_CONF_KEY = "apiVersion"; diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index ed9389a1cd..29a54a7f7d 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -4,7 +4,10 @@ import apoc.Extended; import apoc.result.MapResult; import apoc.util.JsonUtil; +import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.lang3.StringUtils; import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -12,13 +15,10 @@ import org.neo4j.procedure.Procedure; import java.net.MalformedURLException; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -35,6 +35,7 @@ public class OpenAI { public static final String JSON_PATH_CONF_KEY = "jsonPath"; public static final String PATH_CONF_KEY = "path"; public static final String GPT_4O_MODEL = "gpt-4o"; + public static final String FAIL_ON_ERROR_CONF = "failOnError"; @Context public ApocConfig apocConfig; @@ -147,6 +148,10 @@ public Stream getEmbedding(@Name("texts") List texts, @ "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ + boolean failOnError = isFailOnError(configuration); + if (checkNullInput(texts, failOnError)) return Stream.empty(); + texts = texts.stream().filter(StringUtils::isNotBlank).toList(); + if (checkEmptyInput(texts, failOnError)) return Stream.empty(); return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker, (map, text) -> { Long index = (Long) map.get("index"); @@ -156,6 +161,7 @@ public Stream getEmbedding(@Name("texts") List texts, @ ); } + static Stream getEmbeddingResult(List texts, String apiKey, Map configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker, BiFunction embeddingMapping, Function nullMapping) throws JsonProcessingException, MalformedURLException { if (texts == null) { @@ -194,9 +200,8 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } } */ - if (prompt == null) { - throw new RuntimeException(ERROR_NULL_INPUT); - } + boolean failOnError = isFailOnError(configuration); + if(checkBlankInput(prompt, failOnError)) return Stream.empty(); return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); } @@ -204,9 +209,10 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke @Procedure("apoc.ml.openai.chat") @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - if (messages == null) { - throw new RuntimeException(ERROR_NULL_INPUT); - } + boolean failOnError = isFailOnError(configuration); + if (checkNullInput(messages, failOnError)) return Stream.empty(); + messages = messages.stream().filter(MapUtils::isNotEmpty).toList(); + if (checkEmptyInput(messages, failOnError)) return Stream.empty(); configuration.putIfAbsent("model", GPT_4O_MODEL); return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); @@ -220,4 +226,32 @@ public Stream chatCompletion(@Name("messages") List configuration) { + return Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true)); + } + + static boolean checkNullInput(Object input, boolean failOnError) { + return checkInput(failOnError, () -> Objects.isNull(input)); + } + + static boolean checkEmptyInput(Collection input, boolean failOnError) { + return checkInput(failOnError, () -> input.isEmpty()); + } + + static boolean checkBlankInput(String input, boolean failOnError) { + return checkInput(failOnError, () -> StringUtils.isBlank(input)); + } + + private static boolean checkInput( + boolean failOnError, + Supplier checkFunction + ){ + if (checkFunction.get()) { + if(failOnError) throw new RuntimeException(ERROR_NULL_INPUT); + return true; + } + return false; + } + } \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/MLTestUtil.java b/extended/src/test/java/apoc/ml/MLTestUtil.java index 27e6946b36..58da66813b 100644 --- a/extended/src/test/java/apoc/ml/MLTestUtil.java +++ b/extended/src/test/java/apoc/ml/MLTestUtil.java @@ -1,25 +1,15 @@ package apoc.ml; +import apoc.util.ExtendedTestUtil; import org.neo4j.graphdb.GraphDatabaseService; import java.util.Map; -import static apoc.ml.MLUtil.ERROR_NULL_INPUT; -import static apoc.util.TestUtil.testCall; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static apoc.ml.MLUtil.*; public class MLTestUtil { + public static void assertNullInputFails(GraphDatabaseService db, String query, Map params) { - try { - testCall(db, query, params, - (row) -> fail("Should fail due to null input") - ); - } catch (RuntimeException e) { - String message = e.getMessage(); - assertTrue("Current error message is: " + message, - message.contains(ERROR_NULL_INPUT) - ); - } + ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT); } } diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index e56bb2f13e..ef7f7212b4 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -16,6 +16,7 @@ import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.ml.MLUtil.MODEL_CONF_KEY; import static apoc.ml.OpenAI.GPT_4O_MODEL; +import static apoc.ml.OpenAI.FAIL_ON_ERROR_CONF; import static apoc.ml.OpenAITestResultUtils.*; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -140,13 +141,49 @@ public void embeddingsNull() { ); } + @Test + public void chatNull() { + assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void chatReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + + @Test + public void embeddingsReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + + + @Test + public void chatWithEmptyFails() { + assertNullInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embedding([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + @Test public void completionNull() { assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap()) ); } - + @Test public void chatCompletionNull() { assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", @@ -160,4 +197,11 @@ public void chatCompletionNullGpt35Turbo() { Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL)) ); } + + @Test + public void completionReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } } \ No newline at end of file