diff --git a/extended/src/main/java/apoc/ml/MLUtil.java b/extended/src/main/java/apoc/ml/MLUtil.java index f3054c52d5..aa8f32f815 100644 --- a/extended/src/main/java/apoc/ml/MLUtil.java +++ b/extended/src/main/java/apoc/ml/MLUtil.java @@ -2,6 +2,7 @@ public class MLUtil { public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input"; + public static final String ERROR_EMPTY_OR_BLANK_INPUT = "The input(s) provided is/are empty or blank. Please specify a valid input"; public static final String ENDPOINT_CONF_KEY = "endpoint"; public static final String API_VERSION_CONF_KEY = "apiVersion"; diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index ed9389a1cd..fd3cc8a32c 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -4,7 +4,11 @@ import apoc.Extended; import apoc.result.MapResult; import apoc.util.JsonUtil; +import apoc.util.Util; import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.lang3.StringUtils; +import org.glassfish.jersey.internal.util.Producer; import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -12,13 +16,10 @@ import org.neo4j.procedure.Procedure; import java.net.MalformedURLException; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -35,6 +36,7 @@ public class OpenAI { public static final String JSON_PATH_CONF_KEY = "jsonPath"; public static final String PATH_CONF_KEY = "path"; public static final String GPT_4O_MODEL = "gpt-4o"; + public static final String FAIL_ON_ERROR_CONF = "failOnError"; @Context public ApocConfig apocConfig; @@ -147,6 +149,10 @@ public Stream getEmbedding(@Name("texts") List texts, @ "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ + boolean shuldFail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true)); + if (!checkNullInput(texts, shuldFail)) return Stream.empty(); + texts = texts.stream().filter(StringUtils::isNotBlank).toList(); + if (!checkEmptyInput(texts, shuldFail)) return Stream.empty(); return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker, (map, text) -> { Long index = (Long) map.get("index"); @@ -156,6 +162,7 @@ public Stream getEmbedding(@Name("texts") List texts, @ ); } + static Stream getEmbeddingResult(List texts, String apiKey, Map configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker, BiFunction embeddingMapping, Function nullMapping) throws JsonProcessingException, MalformedURLException { if (texts == null) { @@ -194,9 +201,8 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } } */ - if (prompt == null) { - throw new RuntimeException(ERROR_NULL_INPUT); - } + boolean fail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true)); + if(!checkBlankInput(prompt, fail)) return Stream.empty(); return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); } @@ -204,9 +210,10 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke @Procedure("apoc.ml.openai.chat") @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - if (messages == null) { - throw new RuntimeException(ERROR_NULL_INPUT); - } + boolean shuldFail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true)); + if (!checkNullInput(messages, shuldFail)) return Stream.empty(); + messages = messages.stream().filter(MapUtils::isNotEmpty).toList(); + if (!checkEmptyInput(messages, shuldFail)) return Stream.empty(); configuration.putIfAbsent("model", GPT_4O_MODEL); return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker) .map(v -> (Map)v).map(MapResult::new); @@ -220,4 +227,29 @@ public Stream chatCompletion(@Name("messages") List Objects.isNull(input), ERROR_NULL_INPUT); + } + + static boolean checkEmptyInput(Collection input, boolean shuldFail){ + return checkInput(input, shuldFail, () -> input.isEmpty(), ERROR_EMPTY_OR_BLANK_INPUT); + } + + static boolean checkBlankInput(String input, boolean shuldFail){ + return checkInput(input, shuldFail, () -> StringUtils.isBlank(input), ERROR_EMPTY_OR_BLANK_INPUT); + } + + private static boolean checkInput( + Object input, boolean shuldFail, + Supplier checkFunction, + String exceptionMessage + ){ + if (checkFunction.get()) { + if(shuldFail) throw new RuntimeException(exceptionMessage); + return false; + } + return true; + } + } \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/MLTestUtil.java b/extended/src/test/java/apoc/ml/MLTestUtil.java index 27e6946b36..180d00b4ea 100644 --- a/extended/src/test/java/apoc/ml/MLTestUtil.java +++ b/extended/src/test/java/apoc/ml/MLTestUtil.java @@ -1,25 +1,19 @@ package apoc.ml; +import apoc.util.ExtendedTestUtil; import org.neo4j.graphdb.GraphDatabaseService; import java.util.Map; import static apoc.ml.MLUtil.ERROR_NULL_INPUT; -import static apoc.util.TestUtil.testCall; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static apoc.ml.MLUtil.ERROR_EMPTY_OR_BLANK_INPUT; public class MLTestUtil { public static void assertNullInputFails(GraphDatabaseService db, String query, Map params) { - try { - testCall(db, query, params, - (row) -> fail("Should fail due to null input") - ); - } catch (RuntimeException e) { - String message = e.getMessage(); - assertTrue("Current error message is: " + message, - message.contains(ERROR_NULL_INPUT) - ); - } + ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT); + } + + public static void assertEmptyInputFails(GraphDatabaseService db, String query, Map params) { + ExtendedTestUtil.assertFails(db, query, params, ERROR_EMPTY_OR_BLANK_INPUT); } } diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index e56bb2f13e..77d8b2c792 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -13,9 +13,11 @@ import java.util.Map; import java.util.Set; +import static apoc.ml.MLTestUtil.assertEmptyInputFails; import static apoc.ml.MLTestUtil.assertNullInputFails; import static apoc.ml.MLUtil.MODEL_CONF_KEY; import static apoc.ml.OpenAI.GPT_4O_MODEL; +import static apoc.ml.OpenAI.FAIL_ON_ERROR_CONF; import static apoc.ml.OpenAITestResultUtils.*; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; @@ -140,13 +142,34 @@ public void embeddingsNull() { ); } + @Test + public void chatNull() { + assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void chatReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + + @Test + public void embeddingsReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embeddings(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + @Test public void completionNull() { assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap()) ); } - + @Test public void chatCompletionNull() { assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", @@ -160,4 +183,40 @@ public void chatCompletionNullGpt35Turbo() { Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL)) ); } + + @Test + public void completionReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + + @Test + public void embeddingsWithEmptyFails() { + assertEmptyInputFails(db, "CALL apoc.ml.openai.embeddings([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void chatWithEmptyReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embeddings([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + + @Test + public void chatWithEmptyFails() { + assertEmptyInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", emptyMap()) + ); + } + + @Test + public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() { + TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)", + Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)) + ); + } + } \ No newline at end of file