Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #52: Add Environment Variables to Disable Model Functionalities #54

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
36 changes: 31 additions & 5 deletions Globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,44 @@


def getenv(var_name: str):

# Define new environment variables for model functionalities
default_values.update({
"LLM_ENABLED": "true",
"VISION_ENABLED": "true",
"IMG_ENABLED": "true",
"TTS_ENABLED": "true",
"STT_ENABLED": "true",
})

default_values = {
"EZLOCALAI_URL": "http://localhost:8091",
"ALLOWED_DOMAINS": "*",
"DEFAULT_MODEL": "TheBloke/phi-2-dpo-GGUF",
"WHISPER_MODEL": "base.en",
"VISION_MODEL": "",
"SD_MODEL": "",
"IMG_DEVICE": "cpu",
"NGROK_TOKEN": "",
"LOG_LEVEL": "INFO",
"LOG_FORMAT": "%(asctime)s | %(levelname)s | %(message)s",
"UVICORN_WORKERS": 10,
"GPU_LAYERS": "0",
"MAIN_GPU": "0",
"TENSOR_SPLIT": "",
"QUANT_TYPE": "Q4_K_M",
"LLM_MAX_TOKENS": "2048",
"LLM_BATCH_SIZE": "16",
}
default_value = default_values[var_name] if var_name in default_values else ""
return os.getenv(var_name, default_value)
default_values = {
"EZLOCALAI_URL": "http://localhost:8091",
"EZLOCALAI_API_KEY": "none",
"ALLOWED_DOMAINS": "*",
"DEFAULT_MODEL": "TheBloke/phi-2-dpo-GGUF",
"WHISPER_MODEL": "base.en",
"VISION_MODEL": "",
"SD_MODEL": "",
"EMBEDDING_ENABLED": "true",
"IMG_ENABLED": "false",
"TTS_ENABLED": "true",
"STT_ENABLED": "true",
"IMG_DEVICE": "cpu",
"NGROK_TOKEN": "",
"LOG_LEVEL": "INFO",
Expand Down
28 changes: 20 additions & 8 deletions Pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ def __init__(self):
self.ctts = None
self.stt = None
self.embedder = None
if self.current_llm.lower() != "none":
if (
self.current_llm.lower() != "none"
and getenv("LLM_ENABLED").lower() == "true"
):
logging.info(f"[LLM] {self.current_llm} model loading. Please wait...")
self.llm = LLM(model=self.current_llm)
logging.info(f"[LLM] {self.current_llm} model loaded successfully.")
if getenv("EMBEDDING_ENABLED").lower() == "true":
if getenv("STT_ENABLED").lower() == "true":
if (
getenv("EMBEDDING_ENABLED").lower() == "true"
and getenv("LLM_ENABLED").lower() == "true"
):
self.embedder = Embedding()
if self.current_vlm != "":
if getenv("VISION_ENABLED").lower() == "true":
logging.info(f"[VLM] {self.current_vlm} model loading. Please wait...")
try:
self.vlm = VLM(model=self.current_vlm)
Expand All @@ -48,11 +54,17 @@ def __init__(self):
self.vlm = None
if self.vlm is not None:
logging.info(f"[ezlocalai] Vision is enabled with {self.current_vlm}.")
if getenv("TTS_ENABLED").lower() == "true":
if (
getenv("TTS_ENABLED").lower() == "true"
and getenv("TTS_ENABLED").lower() == "true"
):
logging.info(f"[CTTS] xttsv2_2.0.2 model loading. Please wait...")
self.ctts = CTTS()
logging.info(f"[CTTS] xttsv2_2.0.2 model loaded successfully.")
if getenv("STT_ENABLED").lower() == "true":
if (
getenv("STT_ENABLED").lower() == "true"
and getenv("STT_ENABLED").lower() == "true"
):
self.current_stt = getenv("WHISPER_MODEL")
logging.info(f"[STT] {self.current_stt} model loading. Please wait...")
self.stt = STT(model=self.current_stt)
Expand All @@ -70,7 +82,7 @@ def __init__(self):
self.local_uri = public_url.public_url
else:
self.local_uri = getenv("EZLOCALAI_URL")
self.img_enabled = getenv("IMG_ENABLED").lower() == "true"
if getenv("TTS_ENABLED").lower() == "true":
self.img = None
if img_import_success:
logging.info(f"[IMG] Image generation is enabled.")
Expand Down Expand Up @@ -219,7 +231,7 @@ async def get_response(self, data, completion_type="chat"):
data["temperature"] = 0.5
if "top_p" not in data:
data["top_p"] = 0.9
if self.img_enabled and img_import_success and self.img:
if (self.img_enabled and img_import_success):
user_message = (
data["messages"][-1]["content"]
if completion_type == "chat"
Expand Down
Loading