Skip to content

Commit

Permalink
Pylance (#8)
Browse files Browse the repository at this point in the history
* make pylance happy

* 更新文档

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Nov 17, 2023
1 parent 4631cec commit 83e1123
Show file tree
Hide file tree
Showing 31 changed files with 269 additions and 83 deletions.
108 changes: 106 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,107 @@
# genreate
# Generate

文本生成,图像生成,语音生成
> One API to Access World-Class Generative Models.
## 简介

Generate Package 允许用户通过统一的 api 访问跨平台的生成式模型,当前支持:

* [OpenAI](https://platform.openai.com/docs/introduction)
* [Azure](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt?tabs=python&amp;pivots=programming-language-chat-completions)
* [阿里云-百炼](https://bailian.console.aliyun.com/)
* [百川智能](https://platform.baichuan-ai.com/docs/api)
* [腾讯云-混元](https://cloud.tencent.com/document/product/1729)
* [Minimax](https://api.minimax.chat/document/guides/chat)
* [百度智能云](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t)
* [智谱](https://open.bigmodel.cn/dev/api)
* ...

## Features

* **多模态**,支持文本生成,图像生成以及语音生成
* **跨平台**,支持 OpenAI,Azure,Minimax 在内的多家平台
* **One API**,统一了不同平台的消息格式,推理参数,接口封装,返回解析
* **异步和流式**,提供流式调用,非流式调用,同步调用,异步调用,适配不同的应用场景
* **自带电池**,提供输入检查,参数检查,计费,*ChatEngine*, *CompletionEngine*, *function* 等功能
* **高质量代码**,100% typehints,pylance strict, ruff lint & format, test coverage > 85% ...

## 基础使用

### 安装

```bash
pip install generate-core
```

### 文本生成

```python
from generate import OpenAIChat

model = OpenAIChat()
model.generate('你好,GPT!', temperature=0, seed=2023)

# ----- Output -----
ChatCompletionOutput(
model_info=ModelInfo(task='chat_completion', type='openai', name='gpt-3.5-turbo-0613'),
cost=0.000343,
extra={'usage': {'prompt_tokens': 13, 'completion_tokens': 18, 'total_tokens': 31}},
messages=[
AssistantMessage(
content='你好!有什么我可以帮助你的吗?',
role='assistant',
name=None,
content_type='text'
)
],
finish_reason='stop'
)
```

### 图像生成

```python
from generate import OpenAIImageGeneration

model = OpenAIImageGeneration()
model.generate('black hole')

# ----- Output -----
ImageGenerationOutput(
model_info=ModelInfo(task='image_generation', type='openai', name='dall-e-3'),
cost=0.56,
extra={},
images=[
GeneratedImage(
url='https://oaidalleapiprodscus.blob.core.windows.net/...',
prompt='Visualize an astronomical illustration featuring a black hole at its core. The black hole
should be portrayed with strong gravitational lensing effect that distorts the light around it. Include a
surrounding accretion disk, glowing brightly with blue and white hues, streaked with shades of red and orange,
indicating heat and intense energy. The cosmos in the background should be filled with distant stars, galaxies, and
nebulas, illuminating the vast, infinite space with specks of light.',
image_format='png',
content=b'<image bytes>'
)
]
)
```

### 语音生成

```python
from generate import MinimaxSpeech

model = MinimaxSpeech()
model.generate('你好,世界!')

# ----- Output -----
TextToSpeechOutput(
model_info=ModelInfo(task='text_to_speech', type='minimax', name='speech-01'),
cost=0.01,
extra={},
audio=b'<audio bytes>',
audio_format='mp3'
)
```

14 changes: 14 additions & 0 deletions generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
from generate.chat_completion.function_call import function
from generate.chat_engine import ChatEngine
from generate.completion_engine import CompletionEngine
from generate.image_generation import (
BaiduImageGeneration,
BaiduImageGenerationParameters,
OpenAIImageGeneration,
OpenAIImageGenerationParameters,
QianfanImageGeneration,
QianfanImageGenerationParameters,
)
from generate.text_to_speech import (
MinimaxProSpeech,
MinimaxProSpeechParameters,
Expand Down Expand Up @@ -68,6 +76,12 @@
'MinimaxSpeechParameters',
'MinimaxProSpeech',
'MinimaxProSpeechParameters',
'OpenAIImageGeneration',
'OpenAIImageGenerationParameters',
'BaiduImageGeneration',
'BaiduImageGenerationParameters',
'QianfanImageGeneration',
'QianfanImageGenerationParameters',
'function',
'load_chat_model',
'load_speech_model',
Expand Down
6 changes: 3 additions & 3 deletions generate/chat_completion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Type
from typing import Any, Type

from generate.chat_completion.base import ChatCompletionModel
from generate.chat_completion.function_call import function, get_json_schema
Expand Down Expand Up @@ -29,7 +29,7 @@
from generate.chat_completion.printer import MessagePrinter, SimpleMessagePrinter
from generate.model import ModelParameters

ChatModels: list[tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = [
ChatModels: list[tuple[Type[ChatCompletionModel[Any]], Type[ModelParameters]]] = [
(AzureChat, OpenAIChatParameters),
(OpenAIChat, OpenAIChatParameters),
(MinimaxProChat, MinimaxProChatParameters),
Expand All @@ -42,7 +42,7 @@
(BailianChat, BailianChatParameters),
]

ChatModelRegistry: dict[str, tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = {
ChatModelRegistry: dict[str, tuple[Type[ChatCompletionModel[Any]], Type[ModelParameters]]] = {
model_cls.model_type: (model_cls, parameter_cls) for model_cls, parameter_cls in ChatModels
}

Expand Down
15 changes: 8 additions & 7 deletions generate/chat_completion/function_call.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Callable, Generic, TypeVar
from typing import Any, Callable, Dict, Generic, TypeVar, cast

from docstring_parser import parse
from pydantic import TypeAdapter, validate_call
Expand All @@ -20,7 +20,7 @@ class FunctionJsonSchema(TypedDict):
description: NotRequired[str]


def get_json_schema(function: Callable) -> FunctionJsonSchema:
def get_json_schema(function: Callable[..., Any]) -> FunctionJsonSchema:
function_name = function.__name__
docstring = parse(text=function.__doc__ or '')
parameters = TypeAdapter(function).json_schema()
Expand Down Expand Up @@ -71,7 +71,7 @@ def call_with_message(self, message: Message) -> T:
raise ValueError(f'message is not a function call: {message}')


def recusive_remove(dictionary: dict, remove_key: str) -> None:
def recusive_remove(obj: Any, remove_key: str) -> None:
"""
Recursively removes a key from a dictionary and all its nested dictionaries.
Expand All @@ -82,9 +82,10 @@ def recusive_remove(dictionary: dict, remove_key: str) -> None:
Returns:
None
"""
if isinstance(dictionary, dict):
for key in list(dictionary.keys()):
if isinstance(obj, dict):
obj = cast(Dict[str, Any], obj)
for key in list(obj.keys()):
if key == remove_key:
del dictionary[key]
del obj[key]
else:
recusive_remove(dictionary[key], remove_key)
recusive_remove(obj[key], remove_key)
7 changes: 5 additions & 2 deletions generate/chat_completion/message/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@


class Message(BaseModel):
role: str
name: Optional[str] = None
content: Any


Expand All @@ -19,6 +17,7 @@ class SystemMessage(Message):

class UserMessage(Message):
role: Literal['user'] = 'user'
name: Optional[str] = None
content_type: Literal['text'] = 'text'
content: str

Expand All @@ -43,6 +42,7 @@ class ImageUrlPart(BaseModel):

class UserMultiPartMessage(Message):
role: Literal['user'] = 'user'
name: Optional[str] = None
content_type: Literal['multi_part'] = 'multi_part'
content: List[UserPartTypes]

Expand All @@ -62,6 +62,7 @@ class ToolMessage(Message):

class AssistantMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content_type: Literal['text'] = 'text'
content: str

Expand All @@ -74,6 +75,7 @@ class FunctionCall(BaseModel):

class FunctionCallMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content_type: Literal['function_call'] = 'function_call'
content: FunctionCall

Expand All @@ -86,6 +88,7 @@ class ToolCall(BaseModel):

class ToolCallsMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content_type: Literal['tool_calls'] = 'tool_calls'
content: List[ToolCall]

Expand Down
8 changes: 2 additions & 6 deletions generate/chat_completion/message/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,11 @@ def infer_assistant_message_content_type(message_content: Any) -> Literal['text'
return 'text'
if isinstance(obj, FunctionCall):
return 'function_call'
if isinstance(obj, list):
return 'tool_calls'
raise ValueError(f'Unknown content type: {obj}')
return 'tool_calls'


def infer_user_message_content_type(message_content: Any) -> Literal['text', 'multi_part']:
obj = user_content_validator.validate_python(message_content)
if isinstance(obj, str):
return 'text'
if isinstance(obj, list):
return 'multi_part'
raise ValueError(f'Unknown content type: {obj}')
return 'multi_part'
2 changes: 1 addition & 1 deletion generate/chat_completion/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _stream_completion(self, messages: Messages, parameters: OpenAIChatParameter
raise NotImplementedError('Azure does not support streaming')

@override
async def _async_stream_completion(
def _async_stream_completion(
self, messages: Messages, parameters: OpenAIChatParameters
) -> AsyncIterator[ChatCompletionStreamOutput]:
raise NotImplementedError('Azure does not support streaming')
Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(

def _get_request_parameters(self, messages: Messages, parameters: BaichuanChatParameters) -> HttpxPostKwargs:
baichuan_messages: list[BaichuanMessage] = [convert_to_baichuan_message(message) for message in messages]
data = {
data: dict[str, Any] = {
'model': self.model,
'messages': baichuan_messages,
}
Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_default_request_id() -> str:


def convert_to_bailian_chat_qa_pair(messages: Messages) -> list[BailianChatQAPair]:
pairs = []
pairs: list[BailianChatQAPair] = []
for user_message, assistant_message in zip(messages[::2], messages[1::2]):
if not isinstance(user_message, UserMessage):
raise MessageTypeError(user_message, allowed_message_type=(UserMessage,))
Expand Down
2 changes: 0 additions & 2 deletions generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def convert_to_minimax_pro_message(
}

if isinstance(message, FunctionMessage):
if message.name is None:
raise MessageValueError(message, 'function name is required')
return {
'sender_type': 'FUNCTION',
'sender_name': message.name,
Expand Down
13 changes: 4 additions & 9 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
FunctionCall,
FunctionCallMessage,
FunctionMessage,
ImageUrlPart,
Message,
Messages,
MessageTypeError,
MessageValueError,
SystemMessage,
TextPart,
ToolCall,
Expand Down Expand Up @@ -65,7 +63,7 @@ class OpenAIToolCall(TypedDict):

class OpenAIMessage(TypedDict):
role: str
content: Union[str, None, List[Dict]]
content: Union[str, None, List[Dict[str, Any]]]
name: NotRequired[str]
function_call: NotRequired[OpenAIFunctionCall]
tool_call_id: NotRequired[str]
Expand Down Expand Up @@ -101,12 +99,12 @@ def _to_text_message_dict(role: str, message: Message) -> OpenAIMessage:


def _to_user_multipart_message_dict(message: UserMultiPartMessage) -> OpenAIMessage:
content = []
content: list[dict[str, Any]] = []
for part in message.content:
if isinstance(part, TextPart):
content.append({'type': 'text', 'text': part.text})
elif isinstance(part, ImageUrlPart):
image_url_part_dict = {
else:
image_url_part_dict: dict[str, Any] = {
'type': 'image_url',
'image_url': {
'url': part.image_url.url,
Expand All @@ -115,9 +113,6 @@ def _to_user_multipart_message_dict(message: UserMultiPartMessage) -> OpenAIMess
if part.image_url.detail:
image_url_part_dict['image_url']['detail'] = part.image_url.detail
content.append(image_url_part_dict)
else:
raise MessageValueError(message, f'OpenAI does not support {type(part)} ')

return {
'role': 'user',
'content': content,
Expand Down
10 changes: 4 additions & 6 deletions generate/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class ChatEngineKwargs(TypedDict, total=False):
functions: Union[List[function], dict[str, Callable], None]
functions: Union[List[function[Any, Any]], dict[str, Callable[..., Any]], None]
call_raise_error: bool
max_calls_per_turn: int

Expand All @@ -51,7 +51,7 @@ class ChatEngine(Generic[P]):
def __init__(
self,
chat_model: ChatCompletionModel[P],
functions: Union[List[function], dict[str, Callable], None] = None,
functions: Union[List[function[Any, Any]], dict[str, Callable[..., Any]], None] = None,
call_raise_error: bool = False,
max_calls_per_turn: int = 5,
stream: bool | Literal['auto'] = 'auto',
Expand All @@ -60,10 +60,8 @@ def __init__(
self._chat_model = chat_model

if isinstance(functions, list):
self._function_map = {}
self._function_map: dict[str, Callable[..., Any]] = {}
for i in functions:
if not isinstance(i, function):
raise TypeError(f'Invalid function type: {type(i)}')
self._function_map[i.json_schema['name']] = i
elif isinstance(functions, dict):
self._function_map = functions
Expand Down Expand Up @@ -127,7 +125,7 @@ def handle_model_output(self, model_output: ChatCompletionOutput, **override_par
self.history.extend(model_output.messages)
if self.printer:
for message in model_output.messages:
if self.stream and message.role == 'assistant':
if self.stream and isinstance(message, AssistantMessage):
continue
self.printer.print_message(message)

Expand Down
Loading

0 comments on commit 83e1123

Please sign in to comment.