-
Notifications
You must be signed in to change notification settings - Fork 1
/
nodes.py
88 lines (75 loc) · 3.01 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
import google.generativeai as genai
from torch import Tensor
from .utils import images_to_pillow, temporary_env_var
class GeminiNode:
@classmethod
def INPUT_TYPES(cls): # noqa
return {
"required": {
"prompt": ("STRING", {"default": "Why number 42 is important?", "multiline": True}),
"safety_settings": (["BLOCK_NONE", "BLOCK_ONLY_HIGH", "BLOCK_MEDIUM_AND_ABOVE"],),
"response_type": (["text", "json"],),
"model": (["gemini-1.5-flash-002", "gemini-1.5-pro-002"],),
},
"optional": {
"api_key": ("STRING", {}),
"proxy": ("STRING", {}),
"image_1": ("IMAGE",),
"image_2": ("IMAGE",),
"image_3": ("IMAGE",),
"system_instruction": ("STRING", {}),
"error_fallback_value": ("STRING", {"lazy": True}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("text",)
FUNCTION = "ask_gemini"
CATEGORY = "Gemini"
def __init__(self):
self.text_output: str | None = None
def ask_gemini(self, **kwargs):
return (kwargs["error_fallback_value"] if self.text_output is None else self.text_output,)
def check_lazy_status(
self,
prompt: str,
safety_settings: str,
response_type: str,
model: str,
api_key: str | None = None,
proxy: str | None = None,
image_1: Tensor | list[Tensor] | None = None,
image_2: Tensor | list[Tensor] | None = None,
image_3: Tensor | list[Tensor] | None = None,
system_instruction: str | None = None,
error_fallback_value: str | None = None,
):
self.text_output = None
if not system_instruction:
system_instruction = None
images_to_send = []
for image in [image_1, image_2, image_3]:
if image is not None:
images_to_send.extend(images_to_pillow(image))
genai.configure(api_key=api_key, transport="rest")
model = genai.GenerativeModel(model, safety_settings=safety_settings, system_instruction=system_instruction)
generation_config = genai.GenerationConfig(
response_mime_type="application/json" if response_type == "json" else "text/plain"
)
try:
with temporary_env_var("HTTP_PROXY", proxy), temporary_env_var("HTTPS_PROXY", proxy):
response = model.generate_content([prompt, *images_to_send], generation_config=generation_config)
self.text_output = response.text
except Exception:
if error_fallback_value is None:
logging.getLogger("ComfyUI-Gemini").debug("ComfyUI-Gemini: exception occurred:", exc_info=True)
return ["error_fallback_value"]
if error_fallback_value == "":
raise
return []
NODE_CLASS_MAPPINGS = {
"Ask_Gemini": GeminiNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Ask_Gemini": "Ask Gemini",
}