diff --git a/generate/chat_completion/base.py b/generate/chat_completion/base.py index 812f732..431f803 100644 --- a/generate/chat_completion/base.py +++ b/generate/chat_completion/base.py @@ -39,13 +39,12 @@ def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatComplet return sync_aiter(self.async_stream_generate(prompt, **kwargs)) def structure( - self, output_structure_type: Type[O], instruction: str | None = None, **kwargs: Unpack['StructureModelKwargs'] + self, output_structure_type: Type[O], **kwargs: Unpack['StructureModelKwargs'] ) -> 'StructureGenerateModel[Self, O]': from generate.modifiers.structure import StructureGenerateModel return StructureGenerateModel( self, - instruction=instruction, output_structure_type=output_structure_type, **kwargs, ) diff --git a/generate/modifiers/structure.py b/generate/modifiers/structure.py index 65246fa..1d1e56d 100644 --- a/generate/modifiers/structure.py +++ b/generate/modifiers/structure.py @@ -112,9 +112,11 @@ class StructureModelOutput(ModelOutput, Generic[O]): class StructureModelKwargs(TypedDict, Generic[O], total=False): + instruction: Optional[str] examples: Optional[Iterable[Example[O]]] system_template: str max_num_reask: int + output_exclude_none: bool class StructureGenerateModel(GenerateModel[str, StructureModelOutput[O]], Generic[M, O]): @@ -127,6 +129,7 @@ def __init__( instruction: str | None = None, examples: Optional[Iterable[Example[O]]] = None, system_template: str = system_template, + output_exclude_none: bool = True, max_num_reask: int = 2, ) -> None: self.model = model @@ -135,6 +138,7 @@ def __init__( self.examples = examples or [] self.system_template = system_template self.max_num_reask = max_num_reask + self.output_exclude_none = output_exclude_none self.model_type = self.model.model_type # type: ignore @@ -152,7 +156,7 @@ def messages(self) -> List[UnionMessage]: messages.append(self.system_message) for example in self.examples: messages.extend(ensure_messages(example.prompt)) - messages.append(AssistantMessage(content=example.output.model_dump_json(exclude_none=True))) + messages.append(AssistantMessage(content=example.output.model_dump_json(exclude_none=self.output_exclude_none))) return messages @property