Skip to content

Commit

Permalink
Feat: add gpustack image tools
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcodelf committed Jan 25, 2025
1 parent 59b3e67 commit 47d45d7
Show file tree
Hide file tree
Showing 9 changed files with 565 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/core/tools/provider/_position.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- cogview
- comfyui
- getimgai
- gpustack
- siliconflow
- spark
- stepfun
Expand Down
14 changes: 14 additions & 0 deletions api/core/tools/provider/builtin/gpustack/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions api/core/tools/provider/builtin/gpustack/gpustack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import requests

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class GPUStackProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
base_url = credentials.get("base_url", "").removesuffix("/").removesuffix("/v1-openai")
api_key = credentials.get("api_key", "")
tls_verify = credentials.get("tls_verify", True)

if not base_url:
raise ToolProviderCredentialValidationError("GPUStack base_url is required")
if not api_key:
raise ToolProviderCredentialValidationError("GPUStack api_key is required")
headers = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
}

response = requests.get(f"{base_url}/v1-openai/models", headers=headers, verify=tls_verify)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(
f"Failed to validate GPUStack API key, status code: {response.status_code}-{response.text}"
)
44 changes: 44 additions & 0 deletions api/core/tools/provider/builtin/gpustack/gpustack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
identity:
author: gpustack
name: gpustack
label:
en_US: GPUStack
zh_Hans: GPUStack
description:
en_US: GPUStack is an open-source GPU cluster manager for running AI models, providing efficient resource management and model deployment capabilities.
zh_Hans: GPUStack 是一款开源的 GPU 集群管理工具,专为 AI 模型部署和运行而设计,提供高效的资源管理和模型部署能力。
icon: icon.svg
tags:
- image
credentials_for_provider:
base_url:
type: text-input
required: true
label:
en_US: Server URL
zh_Hans: 服务器 URL
placeholder:
en_US: http://your-server-address.com
help:
en_US: Please input GPUStack server's URL
zh_Hans: 请输入 GPUStack 服务器的 URL
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API Key
placeholder:
en_US: Please input your GPUStack API Key
zh_Hans: 请输入你的 GPUStack API Key
url: https://docs.gpustack.ai/latest/user-guide/api-key-management/
tls_verify:
type: boolean
required: false
label:
en_US: TLS Verify
zh_Hans: 证书验证
help:
en_US: Whether to verify the TLS certificate of the GPUStack server.
zh_Hans: 是否验证 GPUStack 服务器的 TLS 证书。
default: true
47 changes: 47 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/image_edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import io
from typing import Any, Union

import requests

from core.file.enums import FileType
from core.file.file_manager import download
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

from .utils import get_base_url, get_common_params, handle_api_error, handle_image_response


class ImageEditTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
image = tool_parameters.get("image")
if image.type != FileType.IMAGE:
return [self.create_text_message("Not a valid image file")]

try:
params = get_common_params(tool_parameters)
params["strength"] = tool_parameters.get("strength", 0.75)

image_binary = io.BytesIO(download(image))
files = {"image": ("image.png", image_binary, "image/png")}

base_url = get_base_url(self.runtime.credentials["base_url"])
response = requests.post(
f"{base_url}/v1-openai/images/edits",
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
data=params,
files=files,
verify=self.runtime.credentials.get("tls_verify", True),
)

if not response.ok:
return self.create_text_message(handle_api_error(response))

result = []
return handle_image_response(result, response, self)

except ValueError as e:
return self.create_text_message(str(e))
except Exception as e:
return self.create_text_message(f"An error occurred: {str(e)}")
181 changes: 181 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/image_edit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
identity:
name: image_edit
author: gpustack
label:
en_US: Image Edit
zh_Hans: 图片编辑
icon: icon.svg
description:
human:
en_US: Edit images with GPUStack's image editing model.
zh_Hans: 使用 GPUStack 的图像编辑模型编辑图片。
llm: This tool is used to edit image.
parameters:
- name: image
type: file
required: true
label:
en_US: Image
zh_Hans: 图片
human_description:
en_US: The image to be edited.
zh_Hans: 要编辑的图片。
llm_description: The image to be edited.
form: llm
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to edit the image.
zh_Hans: 用于编辑图片的文字提示词
llm_description: this prompt text will be used to edit image.
form: llm
- name: model
type: string
required: true
label:
en_US: Model
zh_Hans: 模型
human_description:
en_US: image model name that running in GPUStack.
zh_Hans: 在 GPUStack 上运行的图像模型名称。
form: form
- name: cfg_scale
type: number
required: false
default: 4.5
label:
en_US: CFG Scale
human_description:
en_US: Classifier-free guidance scale, affecting the image's adherence to the prompt.
zh_Hans: 无分类器引导比例,影响图片的对 Prompt 的贴合度。
form: form
- name: n
type: number
required: false
default: 1
label:
en_US: Number
zh_Hans: 数量
human_description:
en_US: Number of images to generate.
zh_Hans: 生成图片数量。
form: form
- name: size
type: string
required: true
default: "512x512"
label:
en_US: Image Size
zh_Hans: 图片尺寸
human_description:
en_US: The maximum size of the generated image is controlled by the deployment parameters of the model.
zh_Hans: 图片生成的最大尺寸受控于模型的部署参数。
form: form
- name: sample_method
type: select
required: true
default: euler
options:
- value: euler_a
label:
en_US: euler_a
- value: euler
label:
en_US: euler
- value: heun
label:
en_US: heun
- value: dpm2
label:
en_US: dpm2
- value: dpm++2s_a
label:
en_US: dpm++2s_a
- value: dpm++2m
label:
en_US: dpm++2m
- value: dpm++2mv2
label:
en_US: dpm++2mv2
- value: ipndm
label:
en_US: ipndm
- value: ipndm_v
label:
en_US: ipndm_v
- value: icm
label:
en_US: icm
label:
en_US: Sample Method
zh_Hans: 采样方法
human_description:
en_US: The sample method for the image generation model.
zh_Hans: 图像生成模型的采样方法。
form: form
- name: sampling_steps
type: number
required: false
default: 20
label:
en_US: Sampling Steps
zh_Hans: 采样步数
human_description:
en_US: Number of sampling steps to generate the image.
zh_Hans: 生成图片所需的采样步数。
form: form
- name: guidance
type: number
required: false
default: 4.5
label:
en_US: Guidance
human_description:
en_US: Guidance scale, affecting the quality and diversity of the image.
zh_Hans: 引导比例,影响图片的质量和多样性
form: form
- name: schedule_method
type: select
required: true
default: discrete
options:
- value: discrete
label:
en_US: discrete
- value: karras
label:
en_US: karras
- value: exponential
label:
en_US: exponential
- value: ays
label:
en_US: ays
- value: gits
label:
en_US: gits
label:
en_US: Schedule Method
zh_Hans: 调度方法
form: form
- name: strength
type: number
required: false
default: 0.75
label:
en_US: Strength
zh_Hans: 强度
human_description:
en_US: The higher the value, the greater the modification to the original image.
zh_Hans: 值越高,它对原图的修改越大。
form: form
- name: seed
type: number
required: false
label:
en_US: Seed
form: form
34 changes: 34 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/text2image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Union

import requests

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

from .utils import get_base_url, get_common_params, handle_api_error, handle_image_response


class TextToImageTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
try:
params = get_common_params(tool_parameters)
base_url = get_base_url(self.runtime.credentials["base_url"])
response = requests.post(
f"{base_url}/v1-openai/images/generations",
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
json=params,
verify=self.runtime.credentials.get("tls_verify", True),
)

if not response.ok:
return self.create_text_message(handle_api_error(response))

result = []
return handle_image_response(result, response, self)

except ValueError as e:
return self.create_text_message(str(e))
except Exception as e:
return self.create_text_message(f"An error occurred: {str(e)}")
Loading

0 comments on commit 47d45d7

Please sign in to comment.