Skip to content

Commit

Permalink
[Inference API] Make message content optional in unified API (#118998) (
Browse files Browse the repository at this point in the history
#119226)

* Allow for null/empty content field

* remove tests which checked for null content

* [CI] Auto commit changes from spotless

* Improvements from review

---------

Co-authored-by: elasticsearchmachine <[email protected]>
(cherry picked from commit 79a8226)

# Conflicts:
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java
  • Loading branch information
maxhniebergall authored Dec 23, 2024
1 parent 24caae7 commit 1a697b5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ public record Message(Content content, String role, @Nullable String name, @Null
);

static {
PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> parseContent(p),
new ParseField("content"),
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareString(constructorArg(), new ParseField("role"));
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
Expand All @@ -143,7 +148,7 @@ private static Content parseContent(XContentParser parser) throws IOException {

public Message(StreamInput in) throws IOException {
this(
in.readNamedWriteable(Content.class),
in.readOptionalNamedWriteable(Content.class),
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
Expand All @@ -153,7 +158,7 @@ public Message(StreamInput in) throws IOException {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(content);
out.writeOptionalNamedWriteable(content);
out.writeString(role);
out.writeOptionalString(name);
out.writeOptionalString(toolCallId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
builder.startObject();
{
if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
if (message.content() == null) {
// content is optional
} else if (message.content() instanceof UnifiedCompletionRequest.ContentString contentString) {
builder.field(CONTENT_FIELD, contentString.content());
} else if (message.content() instanceof UnifiedCompletionRequest.ContentObjects contentObjects) {
builder.startArray(CONTENT_FIELD);
Expand All @@ -77,10 +79,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
}
builder.endArray();
} else {
throw new IllegalArgumentException(
Strings.format("Unsupported message.content class received: %s", message.content().getClass().getSimpleName())
);
}

builder.field(ROLE_FIELD, message.role());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,122 +702,62 @@ public void testSerializationWithBooleanFields() throws IOException {
assertJsonEquals(expectedJsonFalse, jsonStringFalse);
}

// 9. Serialization with Missing Required Fields
// Test with missing required fields to ensure appropriate exceptions are thrown.
public void testSerializationWithMissingRequiredFields() {
// Create a message with missing content (required field)
// 9. a test without the content field to show that the content field is optional
public void testSerializationWithoutContentField() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
null, // missing content
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(message);
// Create the unified request
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
messageList,
null, // model
null, // maxCompletionTokens
null, // stop
null, // temperature
null, // toolChoice
null, // tools
null // topP
);

// Create the unified chat input
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);

OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);

// Create the entity
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

// Attempt to serialize to XContent and expect an exception
try {
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
fail("Expected an exception due to missing required fields");
} catch (NullPointerException | IOException e) {
// Expected exception
}
}

// 10. Serialization with Mixed Valid and Invalid Data
// Test with a mix of valid and invalid data to ensure the serializer handles it gracefully.
public void testSerializationWithMixedValidAndInvalidData() throws IOException {
// Create a valid message
UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Valid content"),
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
"validName",
"validToolCallId",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
"validId",
new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"),
"validType"
)
)
);

// Create an invalid message with null content
UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message(
null, // invalid content
OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD,
"invalidName",
"invalidToolCallId",
"assistant",
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
"invalidId",
new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"),
"invalidType"
"id\\with\\backslashes",
new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"),
"type"
)
)
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(validMessage);
messageList.add(invalidMessage);
// Create the unified request with both valid and invalid messages
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(
messageList,
"model-name",
100L, // maxCompletionTokens
Collections.singletonList("stop"),
0.9f, // temperature
new UnifiedCompletionRequest.ToolChoiceString("tool_choice"),
Collections.singletonList(
new UnifiedCompletionRequest.Tool(
"type",
new UnifiedCompletionRequest.Tool.FunctionField(
"Fetches the weather in the given location",
"get_weather",
createParameters(),
true
)
)
),
0.8f // topP
);
messageList.add(message);
UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null);

// Create the unified chat input
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null);

OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null);

// Create the entity
OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

// Serialize to XContent and verify
try {
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
fail("Expected an exception due to invalid data");
} catch (NullPointerException | IOException e) {
// Expected exception
}
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);

String jsonString = Strings.toString(builder);
String expectedJson = """
{
"messages": [
{
"role": "assistant",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
"id": "id\\\\with\\\\backslashes",
"function": {
"arguments": "arguments\\"with\\"quotes",
"name": "function_name/with/slashes"
},
"type": "type"
}
]
}
],
"model": "test-endpoint",
"n": 1,
"stream": true,
"stream_options": {
"include_usage": true
}
}
""";
assertJsonEquals(jsonString, expectedJson);
}

public static Map<String, Object> createParameters() {
Expand Down

0 comments on commit 1a697b5

Please sign in to comment.