Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Structured Output Interface #644

Merged
merged 12 commits into from
Oct 16, 2024
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ResponseFormat(BaseModel):
model_config = ConfigDict(use_enum_values=True)

type: ResponseFormatType
schema_spec: Union[Dict[str, Any], OrderedDict] = Field(alias="schema")
schema_spec: Optional[Union[Dict[str, Any], OrderedDict]] = Field(None, alias="schema")


class Parameters(BaseModel):
Expand Down
133 changes: 124 additions & 9 deletions docs/guides/structured_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ valid next tokens using this FSM and sets the likelihood of invalid tokens to `-
This example follows the [JSON-structured generation example](https://outlines-dev.github.io/outlines/quickstart/#json-structured-generation) in the Outlines quickstart.

We assume that you have already deployed LoRAX using a suitable base model and installed the [LoRAX Python Client](../reference/python_client.md).
Alternatively, see [below](structured_output.md#openai-compatible-api) for an example of structured generation using an
Alternatively, see [below](structured_output.md#example-openai-compatible-api) for an example of structured generation using an
OpenAI client.

```python
Expand All @@ -60,14 +60,36 @@ class Character(BaseModel):

client = Client("http://127.0.0.1:8080")

prompt = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. "
response = client.generate(prompt, response_format={
# Example 1: Using a schema
prompt_with_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength."
response_with_schema = client.generate(prompt_with_schema, response_format={
"type": "json_object",
"schema": Character.model_json_schema(),
})

my_character = json.loads(response.generated_text)
print(my_character)
my_character_with_schema = json.loads(response_with_schema.generated_text)\
print(my_character_with_schema)
# {
# "name": "Thorin",
# "age": 45,
# "armor": "plate",
# "strength": 90
# }

# Example 2: Without a schema (arbitrary JSON)
prompt_without_schema = "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength."
response_without_schema = client.generate(prompt_without_schema, response_format={
"type": "json_object", # No schema provided
})

my_character_without_schema = json.loads(response_without_schema.generated_text)
print(my_character_without_schema)
# {
# "characterName": "Aragon",
# "age": 38,
# "armorType": "chainmail",
# "power": 78
# }
```

You can also specify the JSON schema directly rather than using Pydantic:
Expand Down Expand Up @@ -99,7 +121,88 @@ Structured generation of JSON following a schema is supported via the `response_

!!! note

Currently a schema is **required**. This differs from the existing OpenAI JSON mode, in which no schema is supported.
Currently, `response_format` in OpenAI interface differs slightly from the LoRAX request interface.
When calling the OpenAI-compatible API, you should format the request exactly as specified in the official documentation.
For more details, refer to the OpenAI documentation here: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format.

#### Type 1: `text` (default)

- This is the standard mode where the model generates plain text output.
- In this example, the model simply returns plain text output.

```python
from openai import OpenAI

client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8080/v1",
)

resp = client.chat.completions.create(
model="", # optional: specify an adapter ID here
messages=[
{
"role": "user",
"content": "Describe a medieval fantasy character.",
},
],
max_tokens=100,
response_format={
"type": "text", # Default response type, plain text output
},
)

print(resp.choices[0].message.content)

'''
Sir Alaric is a noble knight of the realm. At the age of 35, he dons a suit of shining plate armor, protecting his strong, muscular frame. His strength is unparalleled in the kingdom, allowing him to wield his massive greatsword with ease.
'''
```

#### Type 2: `json_object`

- This mode outputs arbitrary JSON objects, making it ideal for generating data in a flexible JSON format without enforcing any schema. It's similar to OpenAI’s JSON mode.
- In this example, the model returns an arbitrary JSON object without enforcing a predefined schema.

```python
from openai import OpenAI

client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8080/v1",
)

resp = client.chat.completions.create(
model="", # optional: specify an adapter ID here
messages=[
{
"role": "user",
"content": "Generate a new character for my game: name, age, armor type, and strength.",
},
],
max_tokens=100,
response_format={
"type": "json_object", # Generate arbitrary JSON without a schema
},
)

my_character = json.loads(resp.choices[0].message.content)
print(my_character)

'''
{
"name": "Eldrin",
"age": 27,
"armor": "Dragonscale Armor",
"strength": "Fire Resistance"
}
'''
```

#### Type 3: `json_schema`

- The model returns a structured JSON object that adheres to the predefined schema. This ensures that the JSON follows the format of the `Character` model provided earlier.
- In this example, the model generates structured JSON output that adheres to a predefined schema.

```python
import json
Expand Down Expand Up @@ -131,18 +234,30 @@ resp = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Generate a new character for my awesome game: name, age (between 1 and 99), armor and strength. ",
"content": "Generate a new character for my game: name, age (between 1 and 99), armor, and strength.",
},
],
max_tokens=100,
response_format={
"type": "json_object",
"schema": Character.model_json_schema(),
"type": "json_schema", # Generate structured JSON output based on a schema
"json_schema": {
"name": "Character", # Name of the schema
"schema": Character.model_json_schema(), # The JSON schema generated by Pydantic
},
},
)

my_character = json.loads(resp.choices[0].message.content)
print(my_character)

'''
{
"name": "Thorin",
"age": 45,
"armor": "plate",
"strength": 90
}
'''
```


40 changes: 38 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,51 @@ struct UsageInfo {

#[derive(Clone, Debug, Deserialize, ToSchema)]
enum ResponseFormatType {
#[serde(alias = "text")]
Text,
#[serde(alias = "json_object")]
JsonObject,
#[serde(alias = "json_schema")]
JsonSchema,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ResponseFormat {
#[allow(dead_code)] // For now allow this field even though it is unused
r#type: ResponseFormatType,
schema: serde_json::Value, // TODO: make this optional once arbitrary JSON object is supported in Outlines

#[serde(default = "default_json_schema")]
schema: Option<serde_json::Value>,
}

// Default schema to be used when no value is provided
fn default_json_schema() -> Option<serde_json::Value> {
Some(serde_json::json!({
"additionalProperties": {
"type": ["object", "string", "integer", "number", "boolean", "null"]
},
"title": "ArbitraryJsonModel",
"type": "object"
}))
}

#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
struct JsonSchema {
#[allow(dead_code)] // For now allow this field even though it is unused
description: Option<String>,
#[allow(dead_code)] // For now allow this field even though it is unused
name: String,
schema: Option<serde_json::Value>,
#[allow(dead_code)] // For now allow this field even though it is unused
strict: Option<bool>,
}

// TODO check if json_schema field is required if type is json_schema
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct OpenAiResponseFormat {
#[serde(rename(deserialize = "type"))]
response_format_type: ResponseFormatType,
json_schema: Option<JsonSchema>,
}

#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
Expand Down Expand Up @@ -582,9 +618,9 @@ struct ChatCompletionRequest {
#[allow(dead_code)] // For now allow this field even though it is unused
user: Option<String>,
seed: Option<u64>,
response_format: Option<OpenAiResponseFormat>,
// Additional parameters
// TODO(travis): add other LoRAX params here
response_format: Option<ResponseFormat>,
repetition_penalty: Option<f32>,
top_k: Option<i32>,
ignore_eos_token: Option<bool>,
Expand Down
51 changes: 45 additions & 6 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use crate::config::Config;
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence,
default_json_schema, AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest,
CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse,
CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details,
EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, PrefillToken,
ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs,
OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken,
StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
};
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
Expand Down Expand Up @@ -263,6 +263,43 @@ async fn chat_completions_v1(
adapter_id = None;
}

// Modify input values to ResponseFormat to be OpenAI API compatible
let response_format: Option<ResponseFormat> = match req.response_format {
None => None,
Some(openai_format) => {
let response_format_type = openai_format.response_format_type.clone();
match response_format_type {
// Ignore when type is text
ResponseFormatType::Text => None,

// For json_object, use the fixed schema
ResponseFormatType::JsonObject => Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
}),

// For json_schema, use schema_value if available, otherwise fallback to the fixed schema
ResponseFormatType::JsonSchema => openai_format
.json_schema
.and_then(|schema| schema.schema)
.map_or_else(
|| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
})
},
|schema_value: serde_json::Value| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: Some(schema_value),
})
},
),
}
}
};

let mut gen_req = CompatGenerateRequest {
inputs: inputs.to_string(),
parameters: GenerateParameters {
Expand All @@ -288,7 +325,7 @@ async fn chat_completions_v1(
return_k_alternatives: None,
apply_chat_template: false,
seed: req.seed,
response_format: req.response_format,
response_format: response_format,
},
stream: req.stream.unwrap_or(false),
};
Expand Down Expand Up @@ -1115,6 +1152,8 @@ pub async fn run(
UsageInfo,
ResponseFormat,
ResponseFormatType,
OpenAiResponseFormat,
JsonSchema,
CompatGenerateRequest,
GenerateRequest,
GenerateParameters,
Expand Down
7 changes: 5 additions & 2 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,11 @@ impl Validation {

let mut schema: Option<String> = None;
if response_format.is_some() {
let response_format_val = response_format.unwrap();
schema = Some(response_format_val.schema.to_string())
if let Some(response_format_val) = response_format {
if let Some(schema_value) = response_format_val.schema {
schema = Some(schema_value.to_string());
}
}
}

let parameters = NextTokenChooserParameters {
Expand Down
Loading