Skip to content

Commit

Permalink
Image generation (#1)
Browse files Browse the repository at this point in the history
* add model output

* 移除 HTTP 模型

* add test

* bump version

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Nov 15, 2023
1 parent 065a12f commit 1bd33f2
Show file tree
Hide file tree
Showing 40 changed files with 1,431 additions and 1,010 deletions.
12 changes: 9 additions & 3 deletions generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
ZhipuCharacterChatParameters,
ZhipuChat,
ZhipuChatParameters,
generate_text,
load_chat_model,
)
from generate.chat_completion.function_call import function
from generate.chat_engine import ChatEngine
Expand All @@ -31,7 +29,13 @@
MinimaxSpeechParameters,
OpenAISpeech,
OpenAISpeechParameters,
)
from generate.utils import (
generate_image,
generate_speech,
generate_text,
load_chat_model,
load_image_generation_model,
load_speech_model,
)
from generate.version import __version__
Expand Down Expand Up @@ -66,8 +70,10 @@
'MinimaxProSpeechParameters',
'function',
'load_chat_model',
'generate_text',
'load_speech_model',
'load_image_generation_model',
'generate_text',
'generate_speech',
'generate_image',
'__version__',
]
34 changes: 5 additions & 29 deletions generate/chat_completion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

from typing import Any, Type
from typing import Type

from generate.chat_completion.base import ChatCompletionModel
from generate.chat_completion.function_call import function, get_json_schema
from generate.chat_completion.http_chat import HttpChatModel, HttpModelInitKwargs
from generate.chat_completion.message import Prompt
from generate.chat_completion.model_output import ChatCompletionModelOutput, ChatCompletionModelStreamOutput
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models import (
AzureChat,
BaichuanChat,
Expand All @@ -28,7 +26,7 @@
ZhipuChat,
ZhipuChatParameters,
)
from generate.chat_completion.printer import MessagePrinter, RichMessagePrinter, SimpleMessagePrinter
from generate.chat_completion.printer import MessagePrinter, SimpleMessagePrinter
from generate.parameters import ModelParameters

ChatModels: list[tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = [
Expand All @@ -49,31 +47,11 @@
}


def load_chat_model(model_id: str, **kwargs: Any) -> ChatCompletionModel:
if '/' not in model_id:
model_type = model_id
return ChatModelRegistry[model_type][0](**kwargs) # type: ignore
model_type, name = model_id.split('/')
model_cls = ChatModelRegistry[model_type][0]
return model_cls.from_name(name, **kwargs)


def list_chat_model_types() -> list[str]:
return list(ChatModelRegistry.keys())


def generate_text(prompt: Prompt, model_id: str = 'openai/gpt-3.5-turbo', **kwargs: Any) -> ChatCompletionModelOutput:
model = load_chat_model(model_id, **kwargs)
return model.generate(prompt, **kwargs)


__all__ = [
'ChatCompletionModel',
'ChatCompletionModelOutput',
'ChatCompletionModelStreamOutput',
'ChatCompletionOutput',
'ChatCompletionStreamOutput',
'ModelParameters',
'HttpChatModel',
'HttpModelInitKwargs',
'AzureChat',
'MinimaxProChat',
'MinimaxProChatParameters',
Expand All @@ -95,8 +73,6 @@ def generate_text(prompt: Prompt, model_id: str = 'openai/gpt-3.5-turbo', **kwar
'BailianChatParameters',
'MessagePrinter',
'SimpleMessagePrinter',
'RichMessagePrinter',
'generate_text',
'get_json_schema',
'function',
]
57 changes: 24 additions & 33 deletions generate/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, ClassVar, Generic, Iterator, TypeVar

from typing_extensions import Self, TypeGuard

from generate.chat_completion.message import Messages, Prompt, ensure_messages
from generate.chat_completion.model_output import ChatCompletionModelOutput, ChatCompletionModelStreamOutput
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.model import ModelInfo
from generate.parameters import ModelParameters

P = TypeVar('P', bound=ModelParameters)
logger = logging.getLogger(__name__)


class ChatCompletionModel(Generic[P], ABC):
Expand All @@ -29,66 +32,54 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
...

@abstractmethod
def _completion(self, messages: Messages, parameters: P) -> ChatCompletionModelOutput:
def _completion(self, messages: Messages, parameters: P) -> ChatCompletionOutput:
...

@abstractmethod
async def _async_completion(self, messages: Messages, parameters: P) -> ChatCompletionModelOutput:
async def _async_completion(self, messages: Messages, parameters: P) -> ChatCompletionOutput:
...

@abstractmethod
def _stream_completion(self, messages: Messages, parameters: P) -> Iterator[ChatCompletionModelStreamOutput]:
def _stream_completion(self, messages: Messages, parameters: P) -> Iterator[ChatCompletionStreamOutput]:
...

@abstractmethod
def _async_stream_completion(self, messages: Messages, parameters: P) -> AsyncIterator[ChatCompletionModelStreamOutput]:
def _async_stream_completion(self, messages: Messages, parameters: P) -> AsyncIterator[ChatCompletionStreamOutput]:
...

@property
def model_id(self) -> str:
return f'{self.model_type}/{self.name}'
def model_info(self) -> ModelInfo:
return ModelInfo(task='chat_completion', type=self.model_type, name=self.name)

def generate(self, prompt: Prompt, **override_parameters: Any) -> ChatCompletionModelOutput:
def generate(self, prompt: Prompt, **override_parameters: Any) -> ChatCompletionOutput:
parameters = self._merge_parameters(**override_parameters)
messages = ensure_messages(prompt)
model_output = self._completion(messages, parameters)
model_output.debug['input_messages'] = messages
model_output.debug['parameters'] = parameters
return model_output
logger.debug(f'{messages=}, {parameters=}')
return self._completion(messages, parameters)

async def async_generate(self, prompt: Prompt, **override_parameters: Any) -> ChatCompletionModelOutput:
async def async_generate(self, prompt: Prompt, **override_parameters: Any) -> ChatCompletionOutput:
parameters = self._merge_parameters(**override_parameters)
messages = ensure_messages(prompt)
model_output = await self._async_completion(messages, parameters)
model_output.debug['input_messages'] = messages
model_output.debug['parameters'] = parameters
return model_output
logger.debug(f'{messages=}, {parameters=}')
return await self._async_completion(messages, parameters)

def stream_generate(self, prompt: Prompt, **override_parameters: Any) -> Iterator[ChatCompletionModelStreamOutput]:
def stream_generate(self, prompt: Prompt, **override_parameters: Any) -> Iterator[ChatCompletionStreamOutput]:
parameters = self._merge_parameters(**override_parameters)
messages = ensure_messages(prompt)
for stream_output in self._stream_completion(messages, parameters):
if stream_output.is_finish:
stream_output.debug['input_messages'] = messages
stream_output.debug['parameters'] = parameters
yield stream_output

async def async_stream_generate(
self, prompt: Prompt, **override_parameters: Any
) -> AsyncIterator[ChatCompletionModelStreamOutput]:
logger.debug(f'{messages=}, {parameters=}')
return self._stream_completion(messages, parameters)

def async_stream_generate(self, prompt: Prompt, **override_parameters: Any) -> AsyncIterator[ChatCompletionStreamOutput]:
parameters = self._merge_parameters(**override_parameters)
messages = ensure_messages(prompt)
async for stream_output in self._async_stream_completion(messages, parameters):
if stream_output.is_finish:
stream_output.debug['input_messages'] = messages
stream_output.debug['parameters'] = parameters
yield stream_output
logger.debug(f'{messages=}, {parameters=}')
return self._async_stream_completion(messages, parameters)

def _merge_parameters(self, **override_parameters: Any) -> P:
return self.parameters.__class__.model_validate(
{**self.parameters.model_dump(exclude_unset=True), **override_parameters}
)


def is_stream_model_output(model_output: ChatCompletionModelOutput) -> TypeGuard[ChatCompletionModelStreamOutput]:
def is_stream_model_output(model_output: ChatCompletionOutput) -> TypeGuard[ChatCompletionStreamOutput]:
return getattr(model_output, 'stream', None) is not None
2 changes: 1 addition & 1 deletion generate/chat_completion/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class FunctionJsonSchema(TypedDict):

def get_json_schema(function: Callable) -> FunctionJsonSchema:
function_name = function.__name__
docstring = parse(function.__doc__ or '')
docstring = parse(text=function.__doc__ or '')
parameters = TypeAdapter(function).json_schema()
for param in docstring.params:
if (arg_name := param.arg_name) in parameters['properties'] and (description := param.description):
Expand Down
Loading

0 comments on commit 1bd33f2

Please sign in to comment.