Skip to content

Commit

Permalink
Merge pull request #5011 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Dec 20, 2023
2 parents 5b791ca + fadb295 commit c1f78db
Show file tree
Hide file tree
Showing 20 changed files with 120 additions and 171 deletions.
18 changes: 2 additions & 16 deletions extensions/coqui_tts/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,13 @@
from pathlib import Path

import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer

from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.ui import create_refresh_button
from modules.utils import gradio

try:
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
"\n"
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
"\n"
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
)

raise

os.environ["COQUI_TOS_AGREED"] = "1"

params = {
Expand Down
2 changes: 2 additions & 0 deletions extensions/openai/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import logging
import os
import traceback
from threading import Thread
Expand Down Expand Up @@ -367,6 +368,7 @@ def on_start(public_url: str):
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')

logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)


Expand Down
2 changes: 1 addition & 1 deletion modules/GPTQ_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name)
if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit()
else:
logger.info(f"Found the following quantized model: {pt_path}")
Expand Down
2 changes: 1 addition & 1 deletion modules/LoRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):

# Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora)

Expand Down
3 changes: 2 additions & 1 deletion modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
else:
renderer = chat_renderer
if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']})
context = replace_character_names(state['context'], state['name1'], state['name2'])
messages.append({"role": "system", "content": context})

insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history):
Expand Down
9 changes: 7 additions & 2 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
logger.info(f'Loading the extension "{name}"...')
logger.info(f'Loading the extension "{name}"')
try:
exec(f"import extensions.{name}.script")
try:
exec(f"import extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
raise

extension = getattr(extensions, name).script
apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"):
Expand Down
176 changes: 63 additions & 113 deletions modules/logging_colors.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,67 @@
# Copied from https://stackoverflow.com/a/1336640

import logging
import platform

logging.basicConfig(
format='%(asctime)s %(levelname)s:%(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)


def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)

def _set_color(self, code):
import ctypes

# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)

setattr(logging.StreamHandler, '_set_color', _set_color)

def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12

# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.

# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.

levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)

ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new


def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new

logger = logging.getLogger('text-generation-webui')

if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())

logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)
def setup_logging():
'''
Copied from: https://github.com/vladmandic/automatic
All credits to vladmandic.
'''

class RingBuffer(logging.StreamHandler):
def __init__(self, capacity):
super().__init__()
self.capacity = capacity
self.buffer = []
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')

def emit(self, record):
msg = self.format(record)
# self.buffer.append(json.loads(msg))
self.buffer.append(msg)
if len(self.buffer) > self.capacity:
self.buffer.pop(0)

def get(self):
return self.buffer

from rich.console import Console
from rich.logging import RichHandler
from rich.pretty import install as pretty_install
from rich.theme import Theme
from rich.traceback import install as traceback_install

level = logging.DEBUG
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
while logger.hasHandlers() and len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[0])

# handlers
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
rh.setLevel(level)
logger.addHandler(rh)

rb = RingBuffer(100) # 100 entries default in log ring buffer
rb.setLevel(level)
logger.addHandler(rb)
logger.buffer = rb.buffer

# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("lycoris").handlers = logger.handlers


setup_logging()
10 changes: 7 additions & 3 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...")
logger.info(f"Loading {model_name}")
t0 = time.time()

shared.is_seq2seq = False
Expand Down Expand Up @@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):


def HQQ_loader(model_name):
from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import HQQLinear, HQQBackend
try:
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM
except ModuleNotFoundError:
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
return None

logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")

Expand Down
34 changes: 19 additions & 15 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,26 @@
if hasattr(args, arg):
provided_arguments.append(arg)

# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')

# Security warnings
if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')


def do_cmd_flags_warnings():

# Deprecation warnings
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')

# Security warnings
if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')


def fix_loader_name(name):
Expand Down
Loading

0 comments on commit c1f78db

Please sign in to comment.