Skip to content

Commit

Permalink
Structure (#59)
Browse files Browse the repository at this point in the history
* Add output_exclude_none parameter to StructureGenerateModel constructor
* Refactor structure method in ChatCompletionModel to remove instruction parameter

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Apr 12, 2024
1 parent 1a88c82 commit 11f2bf1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 1 addition & 2 deletions generate/chat_completion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatComplet
return sync_aiter(self.async_stream_generate(prompt, **kwargs))

def structure(
self, output_structure_type: Type[O], instruction: str | None = None, **kwargs: Unpack['StructureModelKwargs']
self, output_structure_type: Type[O], **kwargs: Unpack['StructureModelKwargs']
) -> 'StructureGenerateModel[Self, O]':
from generate.modifiers.structure import StructureGenerateModel

return StructureGenerateModel(
self,
instruction=instruction,
output_structure_type=output_structure_type,
**kwargs,
)
Expand Down
6 changes: 5 additions & 1 deletion generate/modifiers/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ class StructureModelOutput(ModelOutput, Generic[O]):


class StructureModelKwargs(TypedDict, Generic[O], total=False):
instruction: Optional[str]
examples: Optional[Iterable[Example[O]]]
system_template: str
max_num_reask: int
output_exclude_none: bool


class StructureGenerateModel(GenerateModel[str, StructureModelOutput[O]], Generic[M, O]):
Expand All @@ -127,6 +129,7 @@ def __init__(
instruction: str | None = None,
examples: Optional[Iterable[Example[O]]] = None,
system_template: str = system_template,
output_exclude_none: bool = True,
max_num_reask: int = 2,
) -> None:
self.model = model
Expand All @@ -135,6 +138,7 @@ def __init__(
self.examples = examples or []
self.system_template = system_template
self.max_num_reask = max_num_reask
self.output_exclude_none = output_exclude_none

self.model_type = self.model.model_type # type: ignore

Expand All @@ -152,7 +156,7 @@ def messages(self) -> List[UnionMessage]:
messages.append(self.system_message)
for example in self.examples:
messages.extend(ensure_messages(example.prompt))
messages.append(AssistantMessage(content=example.output.model_dump_json(exclude_none=True)))
messages.append(AssistantMessage(content=example.output.model_dump_json(exclude_none=self.output_exclude_none)))
return messages

@property
Expand Down

0 comments on commit 11f2bf1

Please sign in to comment.