Skip to content

Commit

Permalink
[NOID] Fixes #4156: Improves handling of empty or blank input for ope…
Browse files Browse the repository at this point in the history
…nai procedures (#4228)

* Fixes #4156: Improves handling of empty or blank input for openai procedures

* fix tests

* changed boolean conditions
  • Loading branch information
vga91 committed Dec 18, 2024
1 parent 7e0d2a5 commit a10ec9e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ If present, they take precedence over the analogous APOC configs.
| apiType | analogous to `apoc.ml.openai.type` APOC config
| endpoint | analogous to `apoc.ml.openai.url` APOC config
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
| failOnError | If true (default), the procedure fails in case of empty, blank or null input
|===


Expand Down
2 changes: 2 additions & 0 deletions full/src/main/java/apoc/ml/MLUtil.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package apoc.ml;

public class MLUtil {
public static final String ERROR_NULL_INPUT = "Null, blank or empty input provided. Please specify a valid input";

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
public static final String MODEL_CONF_KEY = "model";
Expand Down
44 changes: 43 additions & 1 deletion full/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,36 @@
import static apoc.ml.MLUtil.API_TYPE_CONF_KEY;
import static apoc.ml.MLUtil.API_VERSION_CONF_KEY;
import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
import static apoc.ml.MLUtil.MODEL_CONF_KEY;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class OpenAI {
public static final String FAIL_ON_ERROR_CONF = "failOnError";

@Context
public ApocConfig apocConfig;

Expand Down Expand Up @@ -106,7 +115,10 @@ public Stream<EmbeddingResult> getEmbedding(
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/

boolean failOnError = isFailOnError(configuration);
if (checkNullInput(texts, failOnError)) return Stream.empty();
texts = texts.stream().filter(StringUtils::isNotBlank).toList();
if (checkEmptyInput(texts, failOnError)) return Stream.empty();
return getEmbeddingResult(texts, apiKey, configuration, apocConfig, (map, text) -> {
Long index = (Long) map.get("index");
return new EmbeddingResult(index, text, (List<Double>) map.get("embedding"));
Expand Down Expand Up @@ -147,6 +159,8 @@ public Stream<MapResult> completion(
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
}
*/
boolean failOnError = isFailOnError(configuration);
if (checkBlankInput(prompt, failOnError)) return Stream.empty();
return executeRequest(
apiKey,
configuration,
Expand All @@ -167,6 +181,10 @@ public Stream<MapResult> chatCompletion(
@Name("api_key") String apiKey,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
throws Exception {
boolean failOnError = isFailOnError(configuration);
if (checkNullInput(messages, failOnError)) return Stream.empty();
messages = messages.stream().filter(MapUtils::isNotEmpty).toList();
if (checkEmptyInput(messages, failOnError)) return Stream.empty();
String model = (String) configuration.putIfAbsent("model", "gpt-4o");
return executeRequest(apiKey, configuration, "chat/completions", model, "messages", messages, "$", apocConfig)
.map(v -> (Map<String, Object>) v)
Expand All @@ -181,4 +199,28 @@ public Stream<MapResult> chatCompletion(
} ] }
*/
}

private static boolean isFailOnError(Map<String, Object> configuration) {
return Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
}

static boolean checkNullInput(Object input, boolean failOnError) {
return checkInput(failOnError, () -> Objects.isNull(input));
}

static boolean checkEmptyInput(Collection<?> input, boolean failOnError) {
return checkInput(failOnError, input::isEmpty);
}

static boolean checkBlankInput(String input, boolean failOnError) {
return checkInput(failOnError, () -> StringUtils.isBlank(input));
}

private static boolean checkInput(boolean failOnError, Supplier<Boolean> checkFunction) {
if (checkFunction.get()) {
if (failOnError) throw new RuntimeException(ERROR_NULL_INPUT);
return true;
}
return false;
}
}
59 changes: 59 additions & 0 deletions full/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package apoc.ml;

import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.util.TestUtil.testCall;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import apoc.util.ExtendedTestUtil;
import apoc.util.TestUtil;
import java.util.List;
import java.util.Map;
Expand All @@ -13,6 +15,7 @@
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.test.rule.DbmsRule;
import org.neo4j.test.rule.ImpermanentDbmsRule;

Expand Down Expand Up @@ -138,4 +141,60 @@ public void chatCompletion() {
}
*/
}

@Test
public void embeddingsNull() {
assertNullInputFails(
db,
"CALL apoc.ml.openai.embedding(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap()));
}

@Test
public void chatNull() {
assertNullInputFails(
db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap()));
}

@Test
public void chatReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(
db,
"CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)));
}

@Test
public void embeddingsReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(
db,
"CALL apoc.ml.openai.embedding(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)));
}

@Test
public void chatWithEmptyFails() {
assertNullInputFails(
db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)", Map.of("apiKey", openaiKey, "conf", emptyMap()));
}

@Test
public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(
db,
"CALL apoc.ml.openai.embedding([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)));
}

@Test
public void completionReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(
db,
"CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false)));
}

public static void assertNullInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT);
}
}
13 changes: 13 additions & 0 deletions full/src/test/java/apoc/util/ExtendedTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package apoc.util;

import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testCallAssertions;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.neo4j.test.assertion.Assert.assertEventually;

import java.util.Collections;
Expand Down Expand Up @@ -67,4 +70,14 @@ public static void testResultEventually(
timeout,
TimeUnit.SECONDS);
}

public static void assertFails(
GraphDatabaseService db, String query, Map<String, Object> params, String expectedErrMsg) {
try {
testCall(db, query, params, r -> fail("Should fail due to " + expectedErrMsg));
} catch (Exception e) {
String actualErrMsg = e.getMessage();
assertTrue("Actual err. message is: " + actualErrMsg, actualErrMsg.contains(expectedErrMsg));
}
}
}

0 comments on commit a10ec9e

Please sign in to comment.