Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[issue_4156] Improves handling of empty or blank input for openai procedures #2

Open
wants to merge 4 commits into
base: 5.25
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extended/src/main/java/apoc/ml/MLUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

public class MLUtil {
public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input";
public static final String ERROR_EMPTY_OR_BLANK_INPUT = "The input(s) provided is/are empty or blank. Please specify a valid input";

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
Expand Down
54 changes: 43 additions & 11 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.internal.util.Producer;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.net.MalformedURLException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -35,6 +36,7 @@ public class OpenAI {
public static final String JSON_PATH_CONF_KEY = "jsonPath";
public static final String PATH_CONF_KEY = "path";
public static final String GPT_4O_MODEL = "gpt-4o";
public static final String FAIL_ON_ERROR_CONF = "failOnError";

@Context
public ApocConfig apocConfig;
Expand Down Expand Up @@ -147,6 +149,10 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/
boolean shuldFail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
if (!checkNullInput(texts, shuldFail)) return Stream.empty();
texts = texts.stream().filter(StringUtils::isNotBlank).toList();
if (!checkEmptyInput(texts, shuldFail)) return Stream.empty();
return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
(map, text) -> {
Long index = (Long) map.get("index");
Expand All @@ -156,6 +162,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
);
}


static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker,
BiFunction<Map, String, T> embeddingMapping, Function<String, T> nullMapping) throws JsonProcessingException, MalformedURLException {
if (texts == null) {
Expand Down Expand Up @@ -194,19 +201,19 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("api_ke
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }
}
*/
if (prompt == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
boolean fail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
if(!checkBlankInput(prompt, fail)) return Stream.empty();
return executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
}

@Procedure("apoc.ml.openai.chat")
@Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (messages == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
boolean shuldFail = Util.toBoolean(configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
if (!checkNullInput(messages, shuldFail)) return Stream.empty();
messages = messages.stream().filter(MapUtils::isNotEmpty).toList();
if (!checkEmptyInput(messages, shuldFail)) return Stream.empty();
configuration.putIfAbsent("model", GPT_4O_MODEL);
return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
Expand All @@ -220,4 +227,29 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
} ] }
*/
}

static boolean checkNullInput(Object input, boolean shuldFail){
return checkInput(input, shuldFail, () -> Objects.isNull(input), ERROR_NULL_INPUT);
}

static boolean checkEmptyInput(Collection<?> input, boolean shuldFail){
return checkInput(input, shuldFail, () -> input.isEmpty(), ERROR_EMPTY_OR_BLANK_INPUT);
}

static boolean checkBlankInput(String input, boolean shuldFail){
return checkInput(input, shuldFail, () -> StringUtils.isBlank(input), ERROR_EMPTY_OR_BLANK_INPUT);
}

private static boolean checkInput(
Object input, boolean shuldFail,
Supplier<Boolean> checkFunction,
String exceptionMessage
){
if (checkFunction.get()) {
if(shuldFail) throw new RuntimeException(exceptionMessage);
return false;
}
return true;
}

}
20 changes: 7 additions & 13 deletions extended/src/test/java/apoc/ml/MLTestUtil.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
package apoc.ml;

import apoc.util.ExtendedTestUtil;
import org.neo4j.graphdb.GraphDatabaseService;

import java.util.Map;

import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static apoc.ml.MLUtil.ERROR_EMPTY_OR_BLANK_INPUT;

public class MLTestUtil {
public static void assertNullInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
try {
testCall(db, query, params,
(row) -> fail("Should fail due to null input")
);
} catch (RuntimeException e) {
String message = e.getMessage();
assertTrue("Current error message is: " + message,
message.contains(ERROR_NULL_INPUT)
);
}
ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT);
}

public static void assertEmptyInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
ExtendedTestUtil.assertFails(db, query, params, ERROR_EMPTY_OR_BLANK_INPUT);
}
}
61 changes: 60 additions & 1 deletion extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import java.util.Map;
import java.util.Set;

import static apoc.ml.MLTestUtil.assertEmptyInputFails;
import static apoc.ml.MLTestUtil.assertNullInputFails;
import static apoc.ml.MLUtil.MODEL_CONF_KEY;
import static apoc.ml.OpenAI.GPT_4O_MODEL;
import static apoc.ml.OpenAI.FAIL_ON_ERROR_CONF;
import static apoc.ml.OpenAITestResultUtils.*;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand Down Expand Up @@ -140,13 +142,34 @@ public void embeddingsNull() {
);
}

@Test
public void chatNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void chatReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void embeddingsReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embeddings(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void completionNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void chatCompletionNull() {
assertNullInputFails(db, "CALL apoc.ml.openai.chat(null, $apiKey, $conf)",
Expand All @@ -160,4 +183,40 @@ public void chatCompletionNullGpt35Turbo() {
Map.of("apiKey", openaiKey, "conf", Map.of(MODEL_CONF_KEY, GPT_35_MODEL))
);
}

@Test
public void completionReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.completion(null, $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void embeddingsWithEmptyFails() {
assertEmptyInputFails(db, "CALL apoc.ml.openai.embeddings([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void chatWithEmptyReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embeddings([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

@Test
public void chatWithEmptyFails() {
assertEmptyInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", emptyMap())
);
}

@Test
public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() {
TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)",
Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
);
}

}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aggiungi pure dei test con lista vuota tipo:

    @Test
    public void embeddingsWithEmptyFails() {
        assertEmptyInputFails(db, "CALL apoc.ml.openai.chat([], $apiKey, $conf)",
                Map.of("apiKey", openaiKey, "conf", emptyMap())
        );
    }
    
    @Test
    public void embeddingsWithEmptyReturnsEmptyIfFailOnErrorFalse() {
        TestUtil.testCallEmpty(db, "CALL apoc.ml.openai.embeddings([], $apiKey, $conf)",
                Map.of("apiKey", openaiKey, "conf", Map.of(FAIL_ON_ERROR_CONF, false))
        );
    }

e la classe MLTestUtil.java la cambi in questo modo, visto che mo abbiamo a disposizione questo stupendo ExtendedTestUtil.assertFails:

public class MLTestUtil {
    
    public static void assertNullInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
        ExtendedTestUtil.assertFails(db, query, params, ERROR_NULL_INPUT);
    }
    
    public static void assertEmptyInputFails(GraphDatabaseService db, String query, Map<String, Object> params) {
        ExtendedTestUtil.assertFails(db, query, params, ERROR_EMPTY_OR_BLANK_INPUT);
    }
}

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dovrei aver sistemato come suggerito!
mi facci sapere se c'è altro, o mi sono dimenticato qualcosa ^^

Loading