Skip to content

Commit

Permalink
Add SharedLoraLoader node (#154)
Browse files Browse the repository at this point in the history
* Add SharedLoraLoader node

* poc of use js front end to get loras list (#156)

* poc

* refine

* refine

---------

Co-authored-by: FengWen <[email protected]>

* fix SharedLoraLoader

* Remove unused comments

---------

Co-authored-by: Yao Chi <[email protected]>
  • Loading branch information
ccssu and doombeaker authored Sep 29, 2024
1 parent 64cad67 commit a113f3e
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 41 deletions.
43 changes: 43 additions & 0 deletions js/share_lora_loader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { api } from "../../../scripts/api.js";
import { app } from "../../scripts/app.js";
app.registerExtension({
name: "bizyair.siliconcloud.share.lora.loader",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAir_SharedLoraLoader") {
async function onTextChange(share_id, canvas, comfynode) {
console.log("share_id:", share_id);
const response = await api.fetchApi(`/bizyair/modelhost/${share_id}/models/files?type=bizyair/lora`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

const { data: loras_list } = await response.json();
const lora_name_widget = comfynode.widgets.find(widget => widget.name === "lora_name");
if (loras_list.length > 0) {
lora_name_widget.value = loras_list[0];
lora_name_widget.options.values = loras_list;
} else {
console.log("No loras found in the response");
lora_name_widget.value = "";
lora_name_widget.options.values = [];
}
}

function setWigetCallback(){
const shareid_widget = this.widgets.find(widget => widget.name === "share_id");
if (shareid_widget) {
shareid_widget.callback = onTextChange;
} else {
console.log("share_id widget not found");
}
}
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
onNodeCreated?.apply(this, arguments);
setWigetCallback.call(this, arguments);
};
}
},
})
50 changes: 50 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,53 @@ def INPUT_TYPES(s):
# FUNCTION = "encode"

CATEGORY = "conditioning/inpaint"


class SharedLoraLoader(BizyAir_LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"share_id": ("STRING", {"default": "share_id"}),
"lora_name": ([],),
"model": (data_types.MODEL,),
"clip": (data_types.CLIP,),
"strength_model": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"strength_clip": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
}
}

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")
FUNCTION = "shared_load_lora"
CATEGORY = f"{PREFIX}/loaders"
NODE_DISPLAY_NAME = "Shared Lora Loader"

@classmethod
def VALIDATE_INPUTS(cls, share_id: str, lora_name: str):
if lora_name in folder_paths.filename_path_mapping.get("loras", {}):
return True

outs = folder_paths.get_share_filename_list("loras", share_id=share_id)
if lora_name not in outs:
raise ValueError(
f"Lora {lora_name} not found in share {share_id} with {outs}"
)
return True

def shared_load_lora(
self, model, clip, lora_name, strength_model, strength_clip, **kwargs
):
return super().load_lora(
model=model,
clip=clip,
lora_name=lora_name,
strength_model=strength_model,
strength_clip=strength_clip,
)
59 changes: 29 additions & 30 deletions src/bizy_server/modelhost.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ async def submit_upload(request):
self.uploads[upload_id]["type"] = json_data["type"]
self.uploads[upload_id]["name"] = json_data["name"]
self.upload_queue.put(self.uploads[upload_id])

# enable refresh for lora
# TODO: enable refresh for other types
bizyair.path_utils.path_manager.enable_refresh_options("loras")
return OKResponse(None)

@prompt_server.routes.get(f"/{API_PREFIX}/models/files")
Expand Down Expand Up @@ -270,6 +274,7 @@ async def list_model_files(request):
@prompt_server.routes.get(f"/{API_PREFIX}" + "/{shareId}/models/files")
async def list_share_model_files(request):
shareId = request.match_info["shareId"]

if not self.is_string_valid(shareId):
return ErrResponse(INVALID_SHARE_ID)

Expand All @@ -286,12 +291,12 @@ async def list_share_model_files(request):

if "ext_name" in request.rel_url.query:
payload["ext_name"] = request.rel_url.query["ext_name"]

model_files, err = await self.get_share_model_files(
shareId=shareId, payload=payload
)
if err is not None:
return ErrResponse(err)

return OKResponse(model_files)

@prompt_server.routes.delete(f"/{API_PREFIX}/models")
Expand Down Expand Up @@ -524,41 +529,35 @@ async def get_model_files(self, payload) -> (dict, ErrorNo):
return result, None

async def get_share_model_files(self, shareId, payload) -> (dict, ErrorNo):
headers, err = self.auth_header()
if err is not None:
return None, err

server_url = f"{BIZYAIR_SERVER_ADDRESS}/{shareId}/models/files"
try:
resp = self.do_get(server_url, params=payload, headers=headers)
ret = json.loads(resp)
if ret["code"] != CODE_OK:
if ret["code"] == CODE_NO_MODEL_FOUND:
return [], None
else:
return None, ErrorNo(500, ret["code"], None, ret["message"])

if not ret["data"]:
return [], None
except Exception as e:
print(f"fail to list share model files: {str(e)}")
return None, LIST_SHARE_MODEL_FILE_ERR
def callback(ret: dict):
if ret["code"] != CODE_OK:
if ret["code"] == CODE_NO_MODEL_FOUND:
return [], None
else:
return [], ErrorNo(500, ret["code"], None, ret["message"])

files = ret["data"]["files"]
result = []
if len(files) > 0:
tree = defaultdict(lambda: {"name": "", "list": []})
if not ret or "data" not in ret or ret["data"] is None:
return [], None

for item in files:
parts = item["label_path"].split("/")
model_name = parts[0]
if model_name not in tree:
tree[model_name] = {"name": model_name, "list": [item]}
else:
tree[model_name]["list"].append(item)
result = list(tree.values())
outputs = [
x["label_path"] for x in ret["data"]["files"] if x["label_path"]
]
outputs = bizyair.path_utils.filter_files_extensions(
outputs,
extensions=bizyair.path_utils.path_manager.supported_pt_extensions,
)
return outputs, None

return result, None
ret = await bizyair.common.client.async_send_request(
method="GET", url=server_url, params=payload, callback=callback
)
return ret[0], ret[1]
except Exception as e:
print(f"fail to list share model files: {str(e)}")
return [], LIST_SHARE_MODEL_FILE_ERR

async def get_models(self, payload) -> (dict, ErrorNo):
headers, err = self.auth_header()
Expand Down
45 changes: 45 additions & 0 deletions src/bizyair/common/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import json
import pprint
import urllib.error
import urllib.request
import warnings

import aiohttp

__all__ = ["send_request"]

from dataclasses import dataclass, field
Expand Down Expand Up @@ -134,6 +137,48 @@ def send_request(
return json.loads(response_data)


async def async_send_request(
method: str = "POST",
url: str = None,
data: bytes = None,
verbose=False,
callback: callable = process_response_data,
**kwargs,
) -> dict:
headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method, url, data=data, headers=headers, **kwargs
) as response:
response_data = await response.text()
if response.status != 200:
error_message = f"HTTP Status {response.status}"
if verbose:
print(f"Error encountered: {error_message}")
if response.status == 401:
raise PermissionError(
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
)
else:
raise ConnectionError(
f"Failed to connect to the server: {error_message}.\n"
+ "Please check your API key and ensure the server is reachable.\n"
+ "Also, verify your network settings and disable any proxies if necessary.\n"
+ "After checking, please restart the ComfyUI service."
)
if callback:
return callback(json.loads(response_data))
return json.loads(response_data)
except aiohttp.ClientError as e:
print(f"Error fetching data: {e}")
return {}
except Exception as e:
print(f"Error fetching data: {str(e)}")
return {}


def fetch_models_by_type(
url: str, model_type: str, *, method="GET", verbose=False
) -> dict:
Expand Down
3 changes: 3 additions & 0 deletions src/bizyair/path_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .path_manager import (
convert_prompt_label_path_to_real_path,
disable_refresh_options,
enable_refresh_options,
get_filename_list,
guess_config,
guess_url_from_node,
)
from .utils import filter_files_extensions
73 changes: 62 additions & 11 deletions src/bizyair/path_utils/path_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pprint
import re
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from ..common import fetch_models_by_type
Expand All @@ -24,6 +25,34 @@
filename_path_mapping: dict[str, dict[str, str]] = {}


@dataclass
class RefreshSettings:
loras: bool = True

def get(self, folder_name: str, default: bool = True):
return getattr(self, folder_name, default)

def set(self, folder_name: str, value: bool):
setattr(self, folder_name, value)


refresh_settings = RefreshSettings()


def enable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, True)


def disable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, False)


def _get_config_path():
src_bizyair_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
configs_path = os.path.join(src_bizyair_path, "configs")
Expand Down Expand Up @@ -88,26 +117,37 @@ def get_config_file_list(base_path=None) -> list:


def cached_filename_list(
folder_name: str, *, verbose=False, refresh=False
folder_name: str, *, share_id: str = None, verbose=False, refresh=False
) -> list[str]:
global filename_path_mapping
if refresh or folder_name not in filename_path_mapping:
model_types: Dict[str, str] = models_config["model_types"]
url = get_service_route(models_config["model_hub"]["find_model"])
if share_id:
url = f"{BIZYAIR_SERVER_ADDRESS}/{share_id}/models/files"
else:
url = get_service_route(models_config["model_hub"]["find_model"])
msg = fetch_models_by_type(
url=url, method="GET", model_type=model_types[folder_name]
)
if verbose:
pprint.pprint({"cached_filename_list": msg})

if not msg or "data" not in msg or msg["data"] is None:
try:
if not msg or "data" not in msg or msg["data"] is None:
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
except Exception as e:
warnings.warn(f"Failed to get filename list: {e}")
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
finally:
# TODO fix share_id vaild refresh settings
if share_id is None:
disable_refresh_options(folder_name)

return list(
filter_files_extensions(
Expand Down Expand Up @@ -139,11 +179,23 @@ def convert_prompt_label_path_to_real_path(prompt: dict[str, dict[str, any]]) ->
return new_prompt


def get_share_filename_list(folder_name, share_id, *, verbose=BIZYAIR_DEBUG):
assert share_id is not None and isinstance(share_id, str)
# TODO fix share_id vaild refresh settings
return cached_filename_list(
folder_name, share_id=share_id, verbose=verbose, refresh=True
)


def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):

global folder_names_and_paths
results = []
if folder_name in models_config["model_types"]:
results.extend(cached_filename_list(folder_name, verbose=verbose, refresh=True))
refresh = refresh_settings.get(folder_name, True)
results.extend(
cached_filename_list(folder_name, verbose=verbose, refresh=refresh)
)
if folder_name in folder_names_and_paths:
results.extend(folder_names_and_paths[folder_name])
if BIZYAIR_DEBUG:
Expand All @@ -153,7 +205,6 @@ def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):
results.extend(folder_paths.get_filename_list(folder_name))
except:
pass

return results


Expand Down

0 comments on commit a113f3e

Please sign in to comment.