Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JSON schema conversion: ⚡️ faster repetitions, min/maxLength for strings, cap number length #6555

Merged
merged 17 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 138 additions & 95 deletions common/json-schema-to-grammar.cpp

Large diffs are not rendered by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I'm on board with the rename to use underscores -- while there are a few other files with underscores (such as pydantic_models_to_grammar.py), most seem to use hyphens (pydantic-models-to-grammar-examples.py, etc), and it seems like the old filename is possibly better?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I wanted all filenames in the repo to use hyphens. But later I found out that Python does not work well when there are hyphens in the filenames (e.g. I think you cannot include a Python file that has hyphens). So I think it's better to eventually rename all Python files to use underscores in their filenames

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I did all this as a prerequisite for #6389, in which i need to import the converter from Python. I also found out llama-cpp-python inlines that file in their codebase, since it's hard / not trivial to import (short of using importlib, which feels dirty).

Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,94 @@
import sys
from typing import Any, Dict, List, Set, Tuple, Union

def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
if not separator_rule:
if min_items == 0 and max_items == 1:
return f'{item_rule}?'
elif min_items == 1 and max_items is None:
return f'{item_rule}+'

result = ''

if min_items > 0:
if item_rule_is_literal and separator_rule is None:
result = '"' + (item_rule[1:-1] * min_items) + '"'
else:
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)

def opt_repetitions(up_to_n, prefix_with_sep=False):
'''
- n=4, no sep: '(a (a (a (a)?)?)?)?'
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
'''

content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
if up_to_n == 0:
return ''
elif up_to_n == 1:
return f'({content})?'
elif separator_rule and not prefix_with_sep:
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
else:
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)

if min_items > 0 and max_items != min_items:
result += ' '

if max_items is not None:
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
else:
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'

if min_items == 0 and separator_rule:
result = f'({item_rule} {item_operator}*)?'
else:
result += f'{item_operator}*'

return result


class BuiltinRule:
def __init__(self, content: str, deps: list = None):
self.content = content
self.deps = deps or []

_up_to_15_digits = _build_repetition('[0-9]', 0, 15)

# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'

PRIMITIVE_RULES = {
'boolean': '("true" | "false") space',
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
'value' : 'object | array | string | number | boolean',
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
'array' : '"[" space ( value ("," space value)* )? "]" space',
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
'string': r''' "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space''',
'null': '"null" space',
'boolean' : BuiltinRule('("true" | "false") space', []),
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
}
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']

# TODO: support "uri", "email" string formats
DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}

RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[^\\x0A\\x0D]'

RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])

INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
Expand All @@ -46,16 +103,16 @@
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')

DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits

class SchemaConverter:
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {'space': SPACE_RULE}
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()

Expand All @@ -65,6 +122,29 @@ def _format_literal(self, literal):
)
return f'"{escaped}"'

def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
'''
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
'''
assert len(literal) > 0, 'Empty literal not supported'
def recurse(i: int):
c = literal[i]
if maybe_escaped_underscores and c == '_':
yield f'[^{c}\\\\]'
yield ' | '
yield f'"\\\\"? "{c}"'
else:
yield f'[^{c}]'
if i < len(literal) - 1:
yield ' | '
yield self._format_literal(c)
yield ' ('
yield from recurse(i + 1)
yield ')?'

return ''.join(('(', *recurse(0), ')'))

def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
Expand Down Expand Up @@ -169,10 +249,10 @@ def transform() -> Tuple[str, bool]:

def get_dot():
if self._dotall:
rule = '[\\U00000000-\\U0010FFFF]'
rule = DOTALL
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
rule = DOT
return self._add_rule(f'dot', rule)

def join_seq():
Expand Down Expand Up @@ -246,26 +326,14 @@ def join_seq():

(sub, sub_is_literal) = seq[-1]

if min_times == 0 and max_times is None:
seq[-1] = (f'{sub}*', False)
elif min_times == 0 and max_times == 1:
seq[-1] = (f'{sub}?', False)
elif min_times == 1 and max_times is None:
seq[-1] = (f'{sub}+', False)
else:
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id

seq[-1] = (
' '.join(
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
False
)
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id

seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
else:
literal = ''
while i < length:
Expand Down Expand Up @@ -373,49 +441,47 @@ def add_component(comp_schema, is_required):
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
list_item_operator = f'( "," space {item_rule_name} )'
successive_items = ""
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
if min_items > 0:
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
if min_items == 0:
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
else:
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')

elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)

elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule(
return self._add_primitive(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
)

elif schema_type in (None, 'string') and schema_format in DATE_RULES:
for t, r in DATE_RULES.items():
self._add_rule(t, r)
return schema_format + '-string'
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
prim_name = f'{schema_format}-string'
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))

elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')

return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')

elif (schema_type == 'object') or (len(schema) == 0):
for n in OBJECT_RULE_NAMES:
self._add_rule(n, PRIMITIVE_RULES[n])
return self._add_rule(rule_name, 'object')
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))

else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_rule(
'root' if rule_name == 'root' else schema_type,
PRIMITIVE_RULES[schema_type]
)
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])

def _add_primitive(self, name: str, rule: BuiltinRule):
n = self._add_rule(name, rule.content)

for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f'Rule {dep} not known'
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n

def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
Expand All @@ -437,7 +503,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")

Expand Down
2 changes: 1 addition & 1 deletion examples/regex-to-grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"python",
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"json-schema-to-grammar.py"),
"json_schema_to_grammar.py"),
*rest,
"-",
"--raw-pattern",
Expand Down
5 changes: 5 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
* Continuous batching
* Multimodal (wip)
* Monitoring endpoints
* Schema-constrained JSON response format

The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggerganov/llama.cpp/issues/4216).

Expand Down Expand Up @@ -250,6 +251,8 @@ node index.js

`grammar`: Set grammar for grammar-based sampling. Default: no grammar

`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.

`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.

`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
Expand Down Expand Up @@ -365,6 +368,8 @@ Notice that each `probs` is an array of length `n_probs`.

See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported.

The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}`), similar to other OpenAI-inspired API providers.

*Examples:*

You can use either Python `openai` library with appropriate checkpoints:
Expand Down
Loading
Loading