Skip to content

Commit

Permalink
Fix asynchronous service lora (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu authored Dec 18, 2024
1 parent df7ef21 commit 52f77cc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/bizyair/commands/processors/prompt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import deque
from typing import Any, Dict, List

from bizyair.common import client
from bizyair.common import client, get_api_key
from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.env_var import (
BIZYAIR_DEBUG,
Expand Down Expand Up @@ -88,7 +88,9 @@ class PromptProcessor(Processor):
def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
exec_info = {
"model_version_ids": [],
"api_key": get_api_key(),
}

model_version_id_prefix = config_manager.get_model_version_id_prefix()
for node_id, node_data in prompt.items():
for k, v in node_data.get("inputs", {}).items():
Expand Down
14 changes: 12 additions & 2 deletions src/bizyair/commands/servers/prompt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from bizyair.common.caching import BizyAirTaskCache, CacheConfig
from bizyair.common.client import send_request
from bizyair.common.env_var import BIZYAIR_DEBUG, BIZYAIR_SERVER_ADDRESS
from bizyair.common.env_var import (
BIZYAIR_DEBUG,
BIZYAIR_DEV_GET_TASK_RESULT_SERVER,
BIZYAIR_SERVER_ADDRESS,
)
from bizyair.common.utils import truncate_long_strings
from bizyair.configs.conf import config_manager
from bizyair.image_utils import decode_data, encode_data
Expand All @@ -25,7 +29,13 @@ def get_task_result(task_id: str, offset: int = 0) -> dict:
import requests

task_api = config_manager.get_task_api()
url = f"{BIZYAIR_SERVER_ADDRESS}/{task_api.task_result_endpoint}/{task_id}"
if BIZYAIR_DEV_GET_TASK_RESULT_SERVER:
url = f"{BIZYAIR_DEV_GET_TASK_RESULT_SERVER}{task_api.task_result_endpoint}/{task_id}"
else:
url = f"{BIZYAIR_SERVER_ADDRESS}{task_api.task_result_endpoint}/{task_id}"

if BIZYAIR_DEBUG:
print(f"Debug: get task result url: {url}")
response_json = send_request(
method="GET", url=url, data=json.dumps({"offset": offset}).encode("utf-8")
)
Expand Down
3 changes: 3 additions & 0 deletions src/bizyair/common/env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ def create_api_key_file(api_key):
# Development Settings
BIZYAIR_DEV_REQUEST_URL = env("BIZYAIR_DEV_REQUEST_URL", str, None)
BIZYAIR_DEBUG = env("BIZYAIR_DEBUG", bool, False)
BIZYAIR_DEV_GET_TASK_RESULT_SERVER = env(
"BIZYAIR_DEV_GET_TASK_RESULT_SERVER", str, None
)
2 changes: 1 addition & 1 deletion src/bizyair/configs/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model_types:

task_api:
# Base URL for task-related API calls
task_result_endpoint: bizy_task
task_result_endpoint: /bizy_task


model_rules:
Expand Down

0 comments on commit 52f77cc

Please sign in to comment.