Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asynchronous service lora #281

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/bizyair/commands/processors/prompt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import deque
from typing import Any, Dict, List

from bizyair.common import client
from bizyair.common import client, get_api_key
from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.env_var import (
BIZYAIR_DEBUG,
Expand Down Expand Up @@ -88,7 +88,9 @@ class PromptProcessor(Processor):
def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
exec_info = {
"model_version_ids": [],
"api_key": get_api_key(),
}

model_version_id_prefix = config_manager.get_model_version_id_prefix()
for node_id, node_data in prompt.items():
for k, v in node_data.get("inputs", {}).items():
Expand Down
14 changes: 12 additions & 2 deletions src/bizyair/commands/servers/prompt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.client import send_request
from bizyair.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
from bizyair.common.env_var import (
BIZYAIR_DEBUG,
BIZYAIR_DEV_GET_TASK_RESULT_SERVER,
BIZYAIR_SERVER_ADDRESS,
)
from bizyair.common.utils import truncate_long_strings
from bizyair.configs.conf import config_manager
from bizyair.image_utils import decode_data, encode_data
Expand All @@ -25,7 +29,13 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
import requests

task_api = config_manager.get_task_api()
url = f"{BIZYAIR_SERVER_ADDRESS}/{task_api.task_result_endpoint}/{task_id}"
if BIZYAIR_DEV_GET_TASK_RESULT_SERVER:
url = f"{BIZYAIR_DEV_GET_TASK_RESULT_SERVER}{task_api.task_result_endpoint}/{task_id}"
else:
url = f"{BIZYAIR_SERVER_ADDRESS}{task_api.task_result_endpoint}/{task_id}"

if BIZYAIR_DEBUG:
print(f"Debug: get task result url: {url}")
response_json = send_request(
method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8")
)
Expand Down
3 changes: 3 additions & 0 deletions src/bizyair/common/env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def create_api_key_file(api_key):
# Development Settings
BIZYAIR_DEV_REQUEST_URL = env("BIZYAIR_DEV_REQUEST_URL", str, None)
BIZYAIR_DEBUG = env("BIZYAIR_DEBUG", bool, False)
BIZYAIR_DEV_GET_TASK_RESULT_SERVER = env(
"BIZYAIR_DEV_GET_TASK_RESULT_SERVER", str, None
)
2 changes: 1 addition & 1 deletion src/bizyair/configs/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model_types:

task_api:
# Base URL for task-related API calls
task_result_endpoint: bizy_task
task_result_endpoint: /bizy_task


model_rules:
Expand Down
Loading