Skip to content

Commit

Permalink
updated custom_model now managed by session, fixed various bugs
Browse files Browse the repository at this point in the history
ROBERT-MCDOWELL committed Dec 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent b8c8200 commit bc38df7
Showing 3 changed files with 169 additions and 91 deletions.
34 changes: 12 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
@@ -95,11 +95,8 @@ def main():
options = [
'--script_mode', '--share', '--headless',
'--session', '--ebook', '--ebooks_dir',
'--voice', '--language', '--device',
#'--custom_model',
#'--custom_model_url',
'--temperature',
'--length_penalty', '--repetition_penalty',
'--voice', '--language', '--device', '--custom_model',
'--temperature', '--length_penalty', '--repetition_penalty',
'--top_k', '--top_p', '--speed',
'--enable_text_splitting', '--fine_tuned',
'--version', '--help'
@@ -122,32 +119,25 @@ def main():
help=f'Language for the audiobook conversion. Options: {lang_list_str}. Default to English (eng).')
parser.add_argument(options[8], type=str, default='cpu', choices=['cpu', 'gpu'],
help=f'Type of processor unit for the audiobook conversion. If not specified: check first if gpu available, if not cpu is selected.')
"""
parser.add_argument(options[9], type=str,
help='Path to the custom model file (.pth). Required if using a custom model.')
parser.add_argument(options[10], type=str,
help=("URL to download the custom model as a zip file. Optional, but will be used if provided. "
"Examples include David Attenborough's model: "
"'https://huggingface.co/drewThomasson/xtts_David_Attenborough_fine_tune/resolve/main/Finished_model_files.zip?download=true'. "
"More XTTS fine-tunes can be found on my Hugging Face at 'https://huggingface.co/drewThomasson'."))
"""
parser.add_argument(options[9], type=float, default=0.65,
help=f'Path to the custom model (.zip file containing {default_model_files}). Required if using a custom model.')
parser.add_argument(options[10], type=float, default=0.65,
help='Temperature for the model. Default to 0.65. Higher temperatures lead to more creative outputs.')
parser.add_argument(options[10], type=float, default=1.0,
parser.add_argument(options[11], type=float, default=1.0,
help='A length penalty applied to the autoregressive decoder. Default to 1.0. Not applied to custom models.')
parser.add_argument(options[11], type=float, default=2.5,
parser.add_argument(options[12], type=float, default=2.5,
help='A penalty that prevents the autoregressive decoder from repeating itself. Default to 2.5')
parser.add_argument(options[12], type=int, default=50,
parser.add_argument(options[13], type=int, default=50,
help='Top-k sampling. Lower values mean more likely outputs and increased audio generation speed. Default to 50')
parser.add_argument(options[13], type=float, default=0.8,
parser.add_argument(options[14], type=float, default=0.8,
help='Top-p sampling. Lower values mean more likely outputs and increased audio generation speed. Default to 0.8')
parser.add_argument(options[14], type=float, default=1.0,
parser.add_argument(options[15], type=float, default=1.0,
help='Speed factor for the speech generation. Default to 1.0')
parser.add_argument(options[15], type=str, default=default_fine_tuned,
parser.add_argument(options[16], type=str, default=default_fine_tuned,
help='Name of the fine tuned model. Optional, uses the standard model according to the TTS engine and language.')
parser.add_argument(options[16], action='store_true',
parser.add_argument(options[17], action='store_true',
help='Enable splitting text into sentences. Default to False.')
parser.add_argument(options[17], action='version',version=f'ebook2audiobook version {version}',
parser.add_argument(options[18], action='version',version=f'ebook2audiobook version {version}',
help='Show the version of the script and exit')

for arg in sys.argv:
8 changes: 6 additions & 2 deletions lib/conf.py
Original file line number Diff line number Diff line change
@@ -12,20 +12,23 @@
requirements_file = os.path.abspath(os.path.join('.','requirements.txt'))

docker_utils_image = 'utils'

interface_host = '0.0.0.0'
interface_port = 7860
interface_shared_expire = 72 # hours
interface_concurrency_limit = 8 # or None for unlimited
interface_component_options = {
"gr_tab_preferences": True,
"gr_voice_file": True,
"gr_custom_model_file": True,
"gr_custom_model_url": True
"gr_group_custom_model": True
}

python_env_dir = os.path.abspath(os.path.join('.','python_env'))

models_dir = os.path.abspath(os.path.join('.','models'))
ebooks_dir = os.path.abspath(os.path.join('.','ebooks'))
processes_dir = os.path.abspath(os.path.join('.','tmp'))

audiobooks_gradio_dir = os.path.abspath(os.path.join('.','audiobooks','gui','gradio'))
audiobooks_host_dir = os.path.abspath(os.path.join('.','audiobooks','gui','host'))
audiobooks_cli_dir = os.path.abspath(os.path.join('.','audiobooks','cli'))
@@ -52,6 +55,7 @@

default_tts_engine = 'xtts'
default_fine_tuned = 'std'
default_model_files = ['config.json', 'vocab.json', 'model.pth', 'ref.wav']

models = {
"xtts": {
218 changes: 151 additions & 67 deletions lib/functions.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import ebooklib
import gradio as gr
import hashlib
import json
import numpy as np
import os
import regex as re
@@ -80,6 +81,7 @@ def get_session(self, session_id):
self.sessions[session_id] = recursive_proxy({
"script_mode": NATIVE,
"client": None,
"language": default_language_code,
"audiobooks_dir": None,
"tmp_dir": None,
"src": None,
@@ -185,27 +187,43 @@ def check_fine_tuned(fine_tuned, language):
except Exception as e:
raise RuntimeError(e)

def download_custom_model(url, dest, session):
async def download_custom_model(url, dest, session):
try:
progress_bar = None
if is_gui_process == True:
progress_bar = gr.Progress(track_tqdm=True)
parsed_url = urlparse(url)
fname = os.path.basename(parsed_url.path)
if not os.path.exists(dest):
os.makedirs(dest, exist_ok=True)
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad responses
file_src = os.path.join(dest,fname)
with open(file_src, 'wb') as file:
response.raise_for_status()
total_size = int(response.headers.get('Content-Length', 0))
file_src = os.path.join(dest, fname)
downloaded = 0
if os.path.exists(file_src):
os.remove(file_src)
with open(file_src, 'wb') as file, tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading File', initial=downloaded) as t:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
if chunk: # Filter out keep-alive chunks
file.write(chunk)
chunk_length = len(chunk)
downloaded += chunk_length
t.update(chunk_length)
t.refresh()
if progress_bar is not None:
progress_bar(downloaded / total_size)
yield file_src, progress_bar
print(f'File saved at: {file_src}')
return file_src
yield file_src, progress_bar
return
except Exception as e:
raise RuntimeError(f'download_custom_model(): {e}')
raise RuntimeError(f'download_custom_model() failed: {e}')

def analyze_uploaded_file(zip_path, required_files=None):
if required_files is None:
required_files = ['config.json', 'vocab.json', 'model.pth', 'ref.wav']
executable_extensions = {'.exe', '.bat', '.cmd', '.sh', '.msi', '.dll', '.com'}
required_files = default_model_files
executable_extensions = {'.exe', '.bat', '.cmd', '.bash', '.bin', '.sh', '.msi', '.dll', '.com'}
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
files_in_zip = set()
@@ -222,30 +240,59 @@ def analyze_uploaded_file(zip_path, required_files=None):
break
missing_files = [f for f in required_files if f not in files_in_zip]
is_valid = not executables_found and not missing_files
return is_valid
return is_valid,
except zipfile.BadZipFile:
raise ValueError("error: The file is not a valid ZIP archive.")
except Exception as e:
raise RuntimeError(f'analyze_uploaded_file(): {e}')

def extract_custom_model(file_src, dest, session):
async def extract_custom_model(file_src, dest=None, session=None, required_files=None):
try:
progress_bar = None
if is_gui_process:
progress_bar = gr.Progress(track_tqdm=True)
if dest is None:
dest = session['custom_model_dir'] = os.path.join(models_dir, '__sessions', f"model-{session['id']}")
os.makedirs(dest, exist_ok=True)
if required_files is None:
required_files = default_model_files

dir_src = os.path.dirname(file_src)
dir_dest = os.path.join(dest, file_src.replace('.zip',''))
os.makedirs(dir_dest, exist_ok=True)
dir_name = os.path.basename(file_src).replace('.zip', '')

with zipfile.ZipFile(file_src, 'r') as zip_ref:
files = zip_ref.namelist()
with tqdm(total=len(files), unit='file', desc='Extracting Files') as t:
for file in files:
if session['cancellation_requested']:
msg = 'Cancel requested'
raise ValueError()
if os.path.isfile(file):
extract = zip_ref.extract(file, dir_dest)
t.update(1)
files_length = len(files)
dir_tts = 'fairseq'
xtts_config = 'config.json'

# Check the model type
config_data = {}
if xtts_config in zip_ref.namelist():
with zip_ref.open(xtts_config) as file:
config_data = json.load(file)
if config_data.get('model') == 'xtts':
dir_tts = 'xtts'

dir_dest = os.path.join(dest, dir_tts, dir_name)
os.makedirs(dir_dest, exist_ok=True)

# Initialize progress bar
with tqdm(total=100, unit='%') as t: # Track progress as a percentage
for i, file in enumerate(files):
if file in required_files:
zip_ref.extract(file, dir_dest)
progress_percentage = ((i + 1) / files_length) * 100
t.n = int(progress_percentage)
t.refresh()
if progress_bar is not None:
progress_bar(downloaded / total_size)
yield dir_name, progress_bar

os.remove(file_src)
print(f'Extracted files to {dir_dest}')
return dir_dest
yield dir_name, progress_bar
return
except Exception as e:
raise DependencyError(e)

@@ -513,7 +560,7 @@ def convert_chapters_to_audio(session):
params['sentence'] = sentence
print(f'Sentence: {sentence}...')
if convert_sentence_to_audio(params, session):
t.update(1) # Increment progress bar by 1
t.update(1)
percentage = (current_sentence / total_sentences) * 100
t.set_description(f'Processing {percentage:.2f}%')
t.refresh()
@@ -568,7 +615,6 @@ def convert_sentence_to_audio(params, session):
params['tts'].tts_with_vc_to_file(
text=params['sentence'],
file_path=params['sentence_audio_file'],
#language=session['language'], # can be used only if multilingual model
speaker_wav=params['voice_file'].replace('_24khz','_22khz'),
split_sentences=session['enable_text_splitting']
)
@@ -897,7 +943,6 @@ def convert_ebook(args):
speed = args['speed']
enable_text_splitting = args['enable_text_splitting'] if args['enable_text_splitting'] is not None else True
custom_model_file = args['custom_model']
custom_model_url = args['custom_model_url'] if custom_model_file is None else None
fine_tuned = args['fine_tuned'] if check_fine_tuned(args['fine_tuned'], args['language']) else False

if not fine_tuned:
@@ -923,15 +968,8 @@ def convert_ebook(args):
if not is_gui_process:
print(f'*********** Session: {session_id}', '************* Store it in case of interruption or crash you can resume the conversion')
session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}")
if custom_model_file or custom_model_url:
if custom_model_url:
print(f'Get custom model: {custom_model_url}')
file_src = download_custom_model(custom_model_url, session['custom_model_dir'], session)
if session['custom_model']:
if analyze_uploaded_file(file_src):
session['custom_model'] = extract_custom_model(file_src, session['custom_model_dir'], session)
else:
session['custom_model'] = extract_custom_model(custom_model_file, session['custom_model_dir'], session)
if custom_model_file:
session['custom_model'], progression_status = extract_custom_model(custom_model_file, session['custom_model_dir'])
if not session['custom_model']:
raise ValueError(f'{custom_model_file} could not be extracted or mandatory files are missing')

@@ -1022,6 +1060,7 @@ def web_interface(args):
)
for lang, details in language_mapping.items()
]
custom_model_options = None
fine_tuned_options = list(models['xtts'].keys())
default_language_name = next((name for name, key in language_options if key == default_language_code), None)

@@ -1075,10 +1114,10 @@ def web_interface(args):
padding: 0 !important;
margin: 0 !important;
}
#component-7, #component-19, #component-22 {
#component-7, #component-10, #component-20 {
height: 140px !important;
}
#component-46 {
#component-47 {
height: 100px !important;
}
</style>
@@ -1098,19 +1137,21 @@ def web_interface(args):
with gr.Column(scale=3):
with gr.Group():
gr_ebook_file = gr.File(label='EBook File (.epub, .mobi, .azw3, fb2, lrf, rb, snb, tcr, .pdf, .txt, .rtf, doc, .docx, .html, .odt, .azw)', file_types=['.epub', '.mobi', '.azw3', 'fb2', 'lrf', 'rb', 'snb', 'tcr', '.pdf', '.txt', '.rtf', 'doc', '.docx', '.html', '.odt', '.azw'])
gr_device = gr.Radio(label='Processor Unit', choices=['CPU', 'GPU'], value='CPU')
with gr.Group():
gr_session_status = gr.Textbox(label='Session')
gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for XTTS base model and 22050hz for FAIRSEQ base model, no more than 6 sec)', file_types=['.wav'], visible=interface_component_options['gr_voice_file'])
gr.Markdown('<p>&nbsp;&nbsp;* Optional</p>')
with gr.Group():
gr_device = gr.Radio(label='Processor Unit', choices=['CPU', 'GPU'], value='CPU')
with gr.Group():
gr_language = gr.Dropdown(label='Language', choices=[name for name, _ in language_options], value=default_language_name)
with gr.Column(scale=3):
with gr.Group():
gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for XTTS base model and 22050hz for FAIRSEQ base model, no more than 6 sec)', file_types=['.wav'], visible=interface_component_options['gr_voice_file'])
gr_group_custom_model = gr.Group(visible=interface_component_options['gr_group_custom_model'])
with gr_group_custom_model:
gr_custom_model_file = gr.File(label='*Custom XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip'])
gr_custom_model_list = gr.Dropdown(label='', choices=['none'], interactive=True)
gr.Markdown('<p>&nbsp;&nbsp;* Optional</p>')
with gr.Group():
gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip'], visible=interface_component_options['gr_custom_model_file'])
gr_custom_model_url = gr.Textbox(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=interface_component_options['gr_custom_model_url'])
gr.Markdown('<p>&nbsp;&nbsp;* Optional</p>')
gr_session_status = gr.Textbox(label='Session')
with gr.Group():
gr_tts_engine = gr.Markdown(f'&nbsp;&nbsp;&nbsp;&nbsp;TTS Base:&nbsp;&nbsp;{default_tts_engine.upper()}')
gr_fine_tuned = gr.Dropdown(label='Fine Tuned Models', choices=fine_tuned_options, value=default_fine_tuned, interactive=True)
@@ -1290,7 +1331,8 @@ async def change_gr_ebook_file(f, session_id):
yield hide_modal()
return

def change_gr_language(selected: str):
def change_gr_language(selected: str, session_id: str):
nonlocal custom_model_options
if selected == 'zzzz':
new_language_name = default_language_name
new_language_key = default_language_code
@@ -1302,33 +1344,66 @@ def change_gr_language(selected: str):
for model_name, model_details in models.get(tts_engine_value, {}).items()
if model_details.get('lang') == 'multi' or model_details.get('lang') == new_language_key
]
custom_model_options = ['none']
if context and session_id:
session = context.get_session(session_id)
session['language'] = new_language_key
custom_model_tts = check_custom_model_tts(session)
custom_model_tts_dir = os.path.join(session['custom_model_dir'], custom_model_tts)
if os.path.exists(custom_model_tts_dir):
custom_model_options += os.listdir(custom_model_tts_dir)
return (
gr.update(value=new_language_name),
gr.update(value=f'&nbsp;&nbsp;&nbsp;&nbsp;tts base: {tts_engine_value.upper()}'),
gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else None)
gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else 'none'),
gr.update(choices=custom_model_options, value=custom_model_options[0])
)

def check_custom_model_tts(session):
custom_model_tts = 'xtts'
if not language_xtts.get(session['language']):
custom_model_tts = 'fairseq'
custom_model_tts_dir = os.path.join(session['custom_model_dir'], custom_model_tts)
if not os.path.isdir(custom_model_tts_dir):
os.makedirs(custom_model_tts_dir, exist_ok=True)
return custom_model_tts

def change_gr_custom_model_list(custom_model_list):
if custom_model_list == 'none':
return gr.update(visible=True)
return gr.update(visible=False)

async def change_gr_custom_model_file(custom_model_file, session_id):
try:
nonlocal custom_model_options, gr_custom_model_file, gr_conversion_progress
if context and session_id:
session = context.get_session(session_id)
if custom_model_file is not None:
if analyze_uploaded_file(custom_model_file):
model_dir = extract_custom_model(f, session)
if model_dir:
yield gr.update(value=None), gr.update(value=None), gr.update(value='')
return
yield gr.update(value=None), gr.update(value=None), gr.update(value='Invalid file! Please upload a valid ZIP.')
if analyze_uploaded_file(custom_model_file):
session['custom_model'], progress_status = extract_custom_model(custom_model_file, None, session)
if session['custom_model']:
custom_model_tts_dir = check_custom_model_tts(session)
custom_model_options = ['none'] + os.listdir(os.path.join(session['custom_model_dir'], custom_model_tts_dir))
yield (
gr.update(visible=False),
gr.update(choices=custom_model_options, value=session['custom_model']),
gr.update(value=f"{session['custom_model']} added to the custom list")
)
gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', value=None, file_types=['.zip'])
return
yield gr.update(), gr.update(), gr.update(value='Invalid file! Please upload a valid ZIP.')
return
except Exception as e:
yield gr.update(value=None), gr.update(value=None), gr.update(value=e)
return
yield gr.update(), gr.update(), gr.update(value=f'Error: {str(e)}')
return

def change_gr_data(data):
data['event'] = 'change_data'
return data

def change_gr_read_data(data):
nonlocal audiobooks_dir
nonlocal custom_model_options
warning_text_extra = ''
if not data:
data = {'session_id': str(uuid.uuid4())}
@@ -1339,18 +1414,23 @@ def change_gr_read_data(data):
warning_text = data['session_id']
event = data.get('event', '')
if event != 'load':
return [gr.update(), gr.update(), gr.update()]
return [gr.update(), gr.update(), gr.update(), gr.update(), gr.update()]
session = context.get_session(data['session_id'])
session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}")
os.makedirs(session['custom_model_dir'], exist_ok=True)
custom_model_tts_dir = check_custom_model_tts(session)
custom_model_options = ['none'] + os.listdir(os.path.join(session['custom_model_dir'],custom_model_tts_dir))
if is_gui_shared:
warning_text_extra = f' Note: access limit time: {interface_shared_expire} hours'
audiobooks_dir = os.path.join(audiobooks_gradio_dir, f"web-{data['session_id']}")
delete_old_web_folders(audiobooks_gradio_dir)
else:
audiobooks_dir = os.path.join(audiobooks_host_dir, f"web-{data['session_id']}")
return [data, f'{warning_text}{warning_text_extra}', data['session_id'], update_audiobooks_ddn()]
return [data, f'{warning_text}{warning_text_extra}', data['session_id'], update_audiobooks_ddn(), gr.update(choices=custom_model_options, value='none')]

def process_conversion(
def submit_convert_btn(
session, device, ebook_file, voice_file, language,
custom_model_file, custom_model_url, temperature, length_penalty,
custom_model_file, temperature, length_penalty,
repetition_penalty, top_k, top_p, speed, enable_text_splitting, fine_tuned
):
nonlocal is_converting
@@ -1364,8 +1444,7 @@ def process_conversion(
"audiobooks_dir": audiobooks_dir,
"voice": voice_file.name if voice_file else None,
"language": next((key for name, key in language_options if name == language), None),
"custom_model": custom_model_file.name if custom_model_file else None,
"custom_model_url": custom_model_url if custom_model_file is None else None,
"custom_model": next((key for name, key in language_options if name != 'none'), None),
"temperature": float(temperature),
"length_penalty": float(length_penalty),
"repetition_penalty": float(repetition_penalty),
@@ -1404,19 +1483,24 @@ def process_conversion(
outputs=[gr_modal_html]
)
gr_language.change(
lambda selected: change_gr_language(dict(language_options).get(selected, 'Unknown')),
inputs=gr_language,
outputs=[gr_language, gr_tts_engine, gr_fine_tuned]
fn=lambda selected, session_id: change_gr_language(dict(language_options).get(selected, 'Unknown'), session_id),
inputs=[gr_language, gr_session],
outputs=[gr_language, gr_tts_engine, gr_fine_tuned, gr_custom_model_list]
)
gr_audiobooks_ddn.change(
fn=change_gr_audiobooks_ddn,
inputs=gr_audiobooks_ddn,
outputs=[gr_audiobook_link, gr_audio_player, gr_audio_player]
)
gr_custom_model_list.change(
fn=change_gr_custom_model_list,
inputs=[gr_custom_model_list],
outputs=[gr_fine_tuned]
)
gr_custom_model_file.change(
fn=change_gr_custom_model_file,
inputs=[gr_custom_model_file, gr_session],
outputs=[gr_custom_model_file, gr_custom_model_url, gr_conversion_progress]
outputs=[gr_fine_tuned, gr_custom_model_list, gr_conversion_progress]
)
gr_session.change(
fn=change_gr_data,
@@ -1437,13 +1521,13 @@ def process_conversion(
gr_read_data.change(
fn=change_gr_read_data,
inputs=gr_read_data,
outputs=[gr_data, gr_session_status, gr_session, gr_audiobooks_ddn]
outputs=[gr_data, gr_session_status, gr_session, gr_audiobooks_ddn, gr_custom_model_list]
)
gr_convert_btn.click(
fn=process_conversion,
fn=submit_convert_btn,
inputs=[
gr_session, gr_device, gr_ebook_file, gr_voice_file, gr_language,
gr_custom_model_file, gr_custom_model_url, gr_temperature, gr_length_penalty,
gr_custom_model_list, gr_temperature, gr_length_penalty,
gr_repetition_penalty, gr_top_k, gr_top_p, gr_speed, gr_enable_text_splitting, gr_fine_tuned
],
outputs=[gr_conversion_progress, gr_modal_html]
@@ -1470,7 +1554,7 @@ def process_conversion(
)

try:
interface.queue(default_concurrency_limit=interface_concurrency_limit).launch(server_name="0.0.0.0", server_port=interface_port, share=is_gui_shared)
interface.queue(default_concurrency_limit=interface_concurrency_limit).launch(server_name=interface_host, server_port=interface_port, share=is_gui_shared)
except OSError as e:
print(f'Connection error: {e}')
except socket.error as e:

0 comments on commit bc38df7

Please sign in to comment.