Skip to content

Commit

Permalink
Merge pull request #4349 from liunux4odoo/fix
Browse files Browse the repository at this point in the history
修复文生图工具
  • Loading branch information
liunux4odoo authored Jun 27, 2024
2 parents 755d889 + 500e26d commit 2704e7f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 33 deletions.
2 changes: 2 additions & 0 deletions libs/chatchat-server/chatchat/configs/_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(self):
},
"text2images": {
"use": False,
"model": "sd-turbo",
"size": "256*256",
},
# text2sql使用建议
# 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,31 @@
import base64
import json
from datetime import datetime
import os
import uuid
from typing import List
from typing import List, Literal

import openai
from PIL import Image

from chatchat.configs import MEDIA_PATH
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import MsgType, get_tool_config
from chatchat.server.utils import MsgType, get_tool_config, get_model_info

from .tools_registry import BaseToolOutput, regist_tool


def get_image_model_config() -> dict:
# from chatchat.configs import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
# TODO ONLINE_LLM_MODEL的配置被删除,此处业务需要修改
# model = LLM_MODEL_CONFIG.get("image_model")
# if model:
# name = list(model.keys())[0]
# if config := ONLINE_LLM_MODEL.get(name):
# config = {**list(model.values())[0], **config}
# config.setdefault("model_name", name)
# return config
pass


@regist_tool(title="文生图", return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: int = Field(512, description="生成图片的宽度"),
height: int = Field(512, description="生成图片的高度"),
width: Literal[256, 512, 1024] = Field(512, description="生成图片的宽度"),
height: Literal[256, 512, 1024] = Field(512, description="生成图片的高度"),
) -> List[str]:
"""根据用户的描述生成图片"""

model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
tool_config = get_tool_config("text2images")
model_config = get_model_info(tool_config["model"])
assert model_config, "请正确配置文生图模型"

client = openai.Client(
base_url=model_config["api_base_url"],
Expand All @@ -54,7 +42,10 @@ def text2images(
images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
today = datetime.now().strftime("%Y-%m-%d")
path = os.path.join(MEDIA_PATH, "image", today)
os.makedirs(path, exist_ok=True)
filename = f"image/{today}/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
Expand Down
24 changes: 16 additions & 8 deletions libs/chatchat-server/chatchat/server/api_server/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from sse_starlette import EventSourceResponse

from chatchat.server.api_server.api_schemas import AgentStatus, MsgType, OpenAIChatInput
from chatchat.server.chat.chat import chat
Expand All @@ -17,7 +18,8 @@
get_tool_config,
)

from .openai_routes import openai_request
from .openai_routes import openai_request, OpenAIChatOutput


chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])

Expand Down Expand Up @@ -124,16 +126,22 @@ async def chat_completions(
{
**extra_json,
"content": f"{tool_result}",
"tool_call": tool.get_name(),
"tool_output": tool_result.data,
"is_ref": True,
"is_ref": False if tool.return_direct else True,
}
]
return await openai_request(
client.chat.completions.create,
body,
extra_json=extra_json,
header=header,
)
if tool.return_direct:
def temp_gen():
yield OpenAIChatOutput(**header[0]).model_dump_json()
return EventSourceResponse(temp_gen())
else:
return await openai_request(
client.chat.completions.create,
body,
extra_json=extra_json,
header=header,
)

# agent chat with tool calls
if body.tools:
Expand Down
11 changes: 7 additions & 4 deletions libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def on_conv_change():
tool_choice=tool_choice,
extra_body=extra_body,
):
# from pprint import pprint
# pprint(d)
# import rich
# rich.print(d)
message_id = d.message_id
metadata = {
"message_id": message_id,
Expand Down Expand Up @@ -421,7 +421,7 @@ def on_conv_change():
if getattr(d, "is_ref", False):
context = str(d.tool_output)
if isinstance(d.tool_output, dict):
docs = d.tool_output.get("docs")
docs = d.tool_output.get("docs", [])
source_documents = []
for inum, doc in enumerate(docs):
doc = DocumentWithVSId.parse_obj(doc)
Expand Down Expand Up @@ -450,6 +450,9 @@ def on_conv_change():
)
)
chat_box.insert_msg("")
elif getattr(d, "tool_call") == "text2images": # TODO:特定工具特别处理,需要更通用的处理方式
for img in d.tool_output.get("images", []):
chat_box.insert_msg(Image(f"{api.base_url}/media/{img}"), pos=-2)
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(
Expand Down Expand Up @@ -514,4 +517,4 @@ def on_conv_change():
use_container_width=True,
)

# st.write(chat_box.context)
# st.write(chat_box.history)

0 comments on commit 2704e7f

Please sign in to comment.