Skip to content

Commit

Permalink
Update dependencies and add new models (#52)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Mar 13, 2024
1 parent 34ede16 commit caed54b
Show file tree
Hide file tree
Showing 8 changed files with 828 additions and 700 deletions.
13 changes: 13 additions & 0 deletions generate/chat_completion/message/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import mimetypes
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

from generate.utils import fetch_data


class Message(BaseModel):
Expand Down Expand Up @@ -38,6 +43,14 @@ class ImagePart(BaseModel):
image: bytes
image_format: Optional[str] = None

@classmethod
def from_url_or_path(cls, url_or_path: str | Path) -> Self:
image_data = fetch_data(str(url_or_path))
mimetype = mimetypes.guess_type(url=str(url_or_path))[0]
if mimetype is not None:
image_format = mimetype.split('/')[1]
return cls(image=image_data, image_format=image_format)


class UserMultiPartMessage(Message):
role: Literal['user'] = 'user'
Expand Down
10 changes: 8 additions & 2 deletions generate/chat_completion/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ class AnthropicParametersDict(RemoteModelParametersDict, total=False):

class AnthropicChat(RemoteChatCompletionModel):
model_type: ClassVar[str] = 'anthropic'
available_models: ClassVar[List[str]] = ['claude-2.1', 'claude-2.0', 'claude-instant-1.2']
available_models: ClassVar[List[str]] = [
'claude-2.1',
'claude-2.0',
'claude-instant-1.2',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
]

parameters: AnthropicChatParameters
settings: AnthropicSettings
Expand Down Expand Up @@ -107,7 +113,7 @@ def _convert_message(self, message: Message) -> dict[str, str]:

if isinstance(part, ImagePart):
data = base64.b64encode(part.image).decode()
media_type = part.image_format or 'image/jpeg'
media_type = 'image/jpeg' if part.image_format is None else f'image/{part.image_format}'
message_dict['content'].append(
{'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}}
)
Expand Down
2 changes: 1 addition & 1 deletion generate/image_generation/models/baidu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def custom_model_dump(self) -> dict[str, Any]:
n = self.n or 1
output_data['image_num'] = n
if self.reference_image:
if isinstance(self.reference_image, HttpUrl):
if isinstance(self.reference_image, str):
output_data['url'] = self.reference_image
else:
output_data['image'] = self.reference_image
Expand Down
32 changes: 16 additions & 16 deletions generate/ui.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import Any, List, Optional, cast
from typing import Any, List, Optional, Type, cast

from pydantic import BaseModel

from generate.chat_completion.base import RemoteChatCompletionModel
from generate.chat_completion.models.dashscope_multimodal import DashScopeMultiModalChat
from generate.model import ModelInfo

try:
import chainlit as cl
Expand All @@ -14,7 +16,7 @@
except ImportError as e:
raise ImportError('Please install chainlit with "pip install chainlit typer"') from e

from generate import ChatCompletionModel, load_chat_model
from generate import ChatCompletionModel, ChatModelRegistry, load_chat_model
from generate.chat_completion.message import (
ImagePart,
ImageUrl,
Expand Down Expand Up @@ -64,26 +66,24 @@ def get_avatars() -> List[Avatar]:


def get_generate_settings() -> List[Any]:
available_model_ids = []
for model_cls, _ in ChatModelRegistry.values():
model_cls = cast(Type[RemoteChatCompletionModel], model_cls)
for model_name in model_cls.available_models:
model_info = ModelInfo(
task=model_cls.model_task,
type=model_cls.model_type,
name=model_name,
)
available_model_ids.append(model_info.model_id)
model_select = Select(
id='Model',
label='Model',
values=[
'openai',
'openai/gpt-4-vision-preview',
'dashscope',
'dashscope_multimodal',
'zhipu',
'zhipu/glm-4v',
'wenxin',
'baichuan',
'minimax_pro',
'moonshot',
'deepseek',
],
values=available_model_ids,
)
model_id = TextInput(
id='ModelId',
label='Model ID',
label='Custom Model ID',
initial='',
description='如 openai/gpt-4-turbo-preview,此设置会覆盖 Model 选项。',
)
Expand Down
22 changes: 22 additions & 0 deletions generate/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
from pathlib import Path
from typing import Any, AsyncIterator, Awaitable, Generator, Generic, Iterable, TypeVar
from urllib.parse import urlparse

from generate.types import OrIterable

Expand Down Expand Up @@ -60,3 +62,23 @@ def unwrap_model(model: Any) -> Any:
return unwrap_model(model.model)
return model
return model


def fetch_data(url_or_file: str) -> bytes:
parsed_url = urlparse(url_or_file)
if not parsed_url.scheme and Path(url_or_file).exists():
return Path(url_or_file).read_bytes()

if parsed_url.scheme == 'file':
if Path(parsed_url.path).exists():
return Path(parsed_url.path).read_bytes()
raise FileNotFoundError(f'File {parsed_url.path} not found')

if parsed_url.scheme in ('http', 'https'):
import httpx

response = httpx.get(url_or_file)
response.raise_for_status()
return response.content

raise ValueError(f'Unsupported URL scheme {parsed_url.scheme}')
Loading

0 comments on commit caed54b

Please sign in to comment.