From bc38df7ac12ff35fcce06e8dbb3649eadcc74723 Mon Sep 17 00:00:00 2001 From: unknown <admin@radsl.net> Date: Sat, 21 Dec 2024 15:45:47 -0800 Subject: [PATCH] updated custom_model now managed by session, fixed various bugs --- app.py | 34 +++----- lib/conf.py | 8 +- lib/functions.py | 218 ++++++++++++++++++++++++++++++++--------------- 3 files changed, 169 insertions(+), 91 deletions(-) diff --git a/app.py b/app.py index 7f6b3c85..e79a9797 100644 --- a/app.py +++ b/app.py @@ -95,11 +95,8 @@ def main(): options = [ '--script_mode', '--share', '--headless', '--session', '--ebook', '--ebooks_dir', - '--voice', '--language', '--device', - #'--custom_model', - #'--custom_model_url', - '--temperature', - '--length_penalty', '--repetition_penalty', + '--voice', '--language', '--device', '--custom_model', + '--temperature', '--length_penalty', '--repetition_penalty', '--top_k', '--top_p', '--speed', '--enable_text_splitting', '--fine_tuned', '--version', '--help' @@ -122,32 +119,25 @@ def main(): help=f'Language for the audiobook conversion. Options: {lang_list_str}. Default to English (eng).') parser.add_argument(options[8], type=str, default='cpu', choices=['cpu', 'gpu'], help=f'Type of processor unit for the audiobook conversion. If not specified: check first if gpu available, if not cpu is selected.') - """ parser.add_argument(options[9], type=str, - help='Path to the custom model file (.pth). Required if using a custom model.') - parser.add_argument(options[10], type=str, - help=("URL to download the custom model as a zip file. Optional, but will be used if provided. " - "Examples include David Attenborough's model: " - "'https://huggingface.co/drewThomasson/xtts_David_Attenborough_fine_tune/resolve/main/Finished_model_files.zip?download=true'. " - "More XTTS fine-tunes can be found on my Hugging Face at 'https://huggingface.co/drewThomasson'.")) - """ - parser.add_argument(options[9], type=float, default=0.65, + help=f'Path to the custom model (.zip file containing {default_model_files}). Required if using a custom model.') + parser.add_argument(options[10], type=float, default=0.65, help='Temperature for the model. Default to 0.65. Higher temperatures lead to more creative outputs.') - parser.add_argument(options[10], type=float, default=1.0, + parser.add_argument(options[11], type=float, default=1.0, help='A length penalty applied to the autoregressive decoder. Default to 1.0. Not applied to custom models.') - parser.add_argument(options[11], type=float, default=2.5, + parser.add_argument(options[12], type=float, default=2.5, help='A penalty that prevents the autoregressive decoder from repeating itself. Default to 2.5') - parser.add_argument(options[12], type=int, default=50, + parser.add_argument(options[13], type=int, default=50, help='Top-k sampling. Lower values mean more likely outputs and increased audio generation speed. Default to 50') - parser.add_argument(options[13], type=float, default=0.8, + parser.add_argument(options[14], type=float, default=0.8, help='Top-p sampling. Lower values mean more likely outputs and increased audio generation speed. Default to 0.8') - parser.add_argument(options[14], type=float, default=1.0, + parser.add_argument(options[15], type=float, default=1.0, help='Speed factor for the speech generation. Default to 1.0') - parser.add_argument(options[15], type=str, default=default_fine_tuned, + parser.add_argument(options[16], type=str, default=default_fine_tuned, help='Name of the fine tuned model. Optional, uses the standard model according to the TTS engine and language.') - parser.add_argument(options[16], action='store_true', + parser.add_argument(options[17], action='store_true', help='Enable splitting text into sentences. Default to False.') - parser.add_argument(options[17], action='version',version=f'ebook2audiobook version {version}', + parser.add_argument(options[18], action='version',version=f'ebook2audiobook version {version}', help='Show the version of the script and exit') for arg in sys.argv: diff --git a/lib/conf.py b/lib/conf.py index cddfd60f..95d9fa6a 100644 --- a/lib/conf.py +++ b/lib/conf.py @@ -12,20 +12,23 @@ requirements_file = os.path.abspath(os.path.join('.','requirements.txt')) docker_utils_image = 'utils' + +interface_host = '0.0.0.0' interface_port = 7860 interface_shared_expire = 72 # hours interface_concurrency_limit = 8 # or None for unlimited interface_component_options = { "gr_tab_preferences": True, "gr_voice_file": True, - "gr_custom_model_file": True, - "gr_custom_model_url": True + "gr_group_custom_model": True } python_env_dir = os.path.abspath(os.path.join('.','python_env')) + models_dir = os.path.abspath(os.path.join('.','models')) ebooks_dir = os.path.abspath(os.path.join('.','ebooks')) processes_dir = os.path.abspath(os.path.join('.','tmp')) + audiobooks_gradio_dir = os.path.abspath(os.path.join('.','audiobooks','gui','gradio')) audiobooks_host_dir = os.path.abspath(os.path.join('.','audiobooks','gui','host')) audiobooks_cli_dir = os.path.abspath(os.path.join('.','audiobooks','cli')) @@ -52,6 +55,7 @@ default_tts_engine = 'xtts' default_fine_tuned = 'std' +default_model_files = ['config.json', 'vocab.json', 'model.pth', 'ref.wav'] models = { "xtts": { diff --git a/lib/functions.py b/lib/functions.py index 4e49674b..3cba9a3e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -4,6 +4,7 @@ import ebooklib import gradio as gr import hashlib +import json import numpy as np import os import regex as re @@ -80,6 +81,7 @@ def get_session(self, session_id): self.sessions[session_id] = recursive_proxy({ "script_mode": NATIVE, "client": None, + "language": default_language_code, "audiobooks_dir": None, "tmp_dir": None, "src": None, @@ -185,27 +187,43 @@ def check_fine_tuned(fine_tuned, language): except Exception as e: raise RuntimeError(e) -def download_custom_model(url, dest, session): +async def download_custom_model(url, dest, session): try: + progress_bar = None + if is_gui_process == True: + progress_bar = gr.Progress(track_tqdm=True) parsed_url = urlparse(url) fname = os.path.basename(parsed_url.path) if not os.path.exists(dest): os.makedirs(dest, exist_ok=True) response = requests.get(url, stream=True) - response.raise_for_status() # Raise an error for bad responses - file_src = os.path.join(dest,fname) - with open(file_src, 'wb') as file: + response.raise_for_status() + total_size = int(response.headers.get('Content-Length', 0)) + file_src = os.path.join(dest, fname) + downloaded = 0 + if os.path.exists(file_src): + os.remove(file_src) + with open(file_src, 'wb') as file, tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading File', initial=downloaded) as t: for chunk in response.iter_content(chunk_size=8192): - file.write(chunk) + if chunk: # Filter out keep-alive chunks + file.write(chunk) + chunk_length = len(chunk) + downloaded += chunk_length + t.update(chunk_length) + t.refresh() + if progress_bar is not None: + progress_bar(downloaded / total_size) + yield file_src, progress_bar print(f'File saved at: {file_src}') - return file_src + yield file_src, progress_bar + return except Exception as e: - raise RuntimeError(f'download_custom_model(): {e}') + raise RuntimeError(f'download_custom_model() failed: {e}') def analyze_uploaded_file(zip_path, required_files=None): if required_files is None: - required_files = ['config.json', 'vocab.json', 'model.pth', 'ref.wav'] - executable_extensions = {'.exe', '.bat', '.cmd', '.sh', '.msi', '.dll', '.com'} + required_files = default_model_files + executable_extensions = {'.exe', '.bat', '.cmd', '.bash', '.bin', '.sh', '.msi', '.dll', '.com'} try: with zipfile.ZipFile(zip_path, 'r') as zf: files_in_zip = set() @@ -222,30 +240,59 @@ def analyze_uploaded_file(zip_path, required_files=None): break missing_files = [f for f in required_files if f not in files_in_zip] is_valid = not executables_found and not missing_files - return is_valid + return is_valid, except zipfile.BadZipFile: raise ValueError("error: The file is not a valid ZIP archive.") except Exception as e: raise RuntimeError(f'analyze_uploaded_file(): {e}') -def extract_custom_model(file_src, dest, session): +async def extract_custom_model(file_src, dest=None, session=None, required_files=None): try: + progress_bar = None + if is_gui_process: + progress_bar = gr.Progress(track_tqdm=True) + if dest is None: + dest = session['custom_model_dir'] = os.path.join(models_dir, '__sessions', f"model-{session['id']}") + os.makedirs(dest, exist_ok=True) + if required_files is None: + required_files = default_model_files + dir_src = os.path.dirname(file_src) - dir_dest = os.path.join(dest, file_src.replace('.zip','')) - os.makedirs(dir_dest, exist_ok=True) + dir_name = os.path.basename(file_src).replace('.zip', '') + with zipfile.ZipFile(file_src, 'r') as zip_ref: files = zip_ref.namelist() - with tqdm(total=len(files), unit='file', desc='Extracting Files') as t: - for file in files: - if session['cancellation_requested']: - msg = 'Cancel requested' - raise ValueError() - if os.path.isfile(file): - extract = zip_ref.extract(file, dir_dest) - t.update(1) + files_length = len(files) + dir_tts = 'fairseq' + xtts_config = 'config.json' + + # Check the model type + config_data = {} + if xtts_config in zip_ref.namelist(): + with zip_ref.open(xtts_config) as file: + config_data = json.load(file) + if config_data.get('model') == 'xtts': + dir_tts = 'xtts' + + dir_dest = os.path.join(dest, dir_tts, dir_name) + os.makedirs(dir_dest, exist_ok=True) + + # Initialize progress bar + with tqdm(total=100, unit='%') as t: # Track progress as a percentage + for i, file in enumerate(files): + if file in required_files: + zip_ref.extract(file, dir_dest) + progress_percentage = ((i + 1) / files_length) * 100 + t.n = int(progress_percentage) + t.refresh() + if progress_bar is not None: + progress_bar(downloaded / total_size) + yield dir_name, progress_bar + os.remove(file_src) print(f'Extracted files to {dir_dest}') - return dir_dest + yield dir_name, progress_bar + return except Exception as e: raise DependencyError(e) @@ -513,7 +560,7 @@ def convert_chapters_to_audio(session): params['sentence'] = sentence print(f'Sentence: {sentence}...') if convert_sentence_to_audio(params, session): - t.update(1) # Increment progress bar by 1 + t.update(1) percentage = (current_sentence / total_sentences) * 100 t.set_description(f'Processing {percentage:.2f}%') t.refresh() @@ -568,7 +615,6 @@ def convert_sentence_to_audio(params, session): params['tts'].tts_with_vc_to_file( text=params['sentence'], file_path=params['sentence_audio_file'], - #language=session['language'], # can be used only if multilingual model speaker_wav=params['voice_file'].replace('_24khz','_22khz'), split_sentences=session['enable_text_splitting'] ) @@ -897,7 +943,6 @@ def convert_ebook(args): speed = args['speed'] enable_text_splitting = args['enable_text_splitting'] if args['enable_text_splitting'] is not None else True custom_model_file = args['custom_model'] - custom_model_url = args['custom_model_url'] if custom_model_file is None else None fine_tuned = args['fine_tuned'] if check_fine_tuned(args['fine_tuned'], args['language']) else False if not fine_tuned: @@ -923,15 +968,8 @@ def convert_ebook(args): if not is_gui_process: print(f'*********** Session: {session_id}', '************* Store it in case of interruption or crash you can resume the conversion') session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}") - if custom_model_file or custom_model_url: - if custom_model_url: - print(f'Get custom model: {custom_model_url}') - file_src = download_custom_model(custom_model_url, session['custom_model_dir'], session) - if session['custom_model']: - if analyze_uploaded_file(file_src): - session['custom_model'] = extract_custom_model(file_src, session['custom_model_dir'], session) - else: - session['custom_model'] = extract_custom_model(custom_model_file, session['custom_model_dir'], session) + if custom_model_file: + session['custom_model'], progression_status = extract_custom_model(custom_model_file, session['custom_model_dir']) if not session['custom_model']: raise ValueError(f'{custom_model_file} could not be extracted or mandatory files are missing') @@ -1022,6 +1060,7 @@ def web_interface(args): ) for lang, details in language_mapping.items() ] + custom_model_options = None fine_tuned_options = list(models['xtts'].keys()) default_language_name = next((name for name, key in language_options if key == default_language_code), None) @@ -1075,10 +1114,10 @@ def web_interface(args): padding: 0 !important; margin: 0 !important; } - #component-7, #component-19, #component-22 { + #component-7, #component-10, #component-20 { height: 140px !important; } - #component-46 { + #component-47 { height: 100px !important; } </style> @@ -1098,19 +1137,21 @@ def web_interface(args): with gr.Column(scale=3): with gr.Group(): gr_ebook_file = gr.File(label='EBook File (.epub, .mobi, .azw3, fb2, lrf, rb, snb, tcr, .pdf, .txt, .rtf, doc, .docx, .html, .odt, .azw)', file_types=['.epub', '.mobi', '.azw3', 'fb2', 'lrf', 'rb', 'snb', 'tcr', '.pdf', '.txt', '.rtf', 'doc', '.docx', '.html', '.odt', '.azw']) - gr_device = gr.Radio(label='Processor Unit', choices=['CPU', 'GPU'], value='CPU') with gr.Group(): - gr_session_status = gr.Textbox(label='Session') + gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for XTTS base model and 22050hz for FAIRSEQ base model, no more than 6 sec)', file_types=['.wav'], visible=interface_component_options['gr_voice_file']) + gr.Markdown('<p> * Optional</p>') + with gr.Group(): + gr_device = gr.Radio(label='Processor Unit', choices=['CPU', 'GPU'], value='CPU') with gr.Group(): gr_language = gr.Dropdown(label='Language', choices=[name for name, _ in language_options], value=default_language_name) with gr.Column(scale=3): - with gr.Group(): - gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for XTTS base model and 22050hz for FAIRSEQ base model, no more than 6 sec)', file_types=['.wav'], visible=interface_component_options['gr_voice_file']) + gr_group_custom_model = gr.Group(visible=interface_component_options['gr_group_custom_model']) + with gr_group_custom_model: + gr_custom_model_file = gr.File(label='*Custom XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip']) + gr_custom_model_list = gr.Dropdown(label='', choices=['none'], interactive=True) gr.Markdown('<p> * Optional</p>') with gr.Group(): - gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip'], visible=interface_component_options['gr_custom_model_file']) - gr_custom_model_url = gr.Textbox(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=interface_component_options['gr_custom_model_url']) - gr.Markdown('<p> * Optional</p>') + gr_session_status = gr.Textbox(label='Session') with gr.Group(): gr_tts_engine = gr.Markdown(f' TTS Base: {default_tts_engine.upper()}') gr_fine_tuned = gr.Dropdown(label='Fine Tuned Models', choices=fine_tuned_options, value=default_fine_tuned, interactive=True) @@ -1290,7 +1331,8 @@ async def change_gr_ebook_file(f, session_id): yield hide_modal() return - def change_gr_language(selected: str): + def change_gr_language(selected: str, session_id: str): + nonlocal custom_model_options if selected == 'zzzz': new_language_name = default_language_name new_language_key = default_language_code @@ -1302,26 +1344,58 @@ def change_gr_language(selected: str): for model_name, model_details in models.get(tts_engine_value, {}).items() if model_details.get('lang') == 'multi' or model_details.get('lang') == new_language_key ] + custom_model_options = ['none'] + if context and session_id: + session = context.get_session(session_id) + session['language'] = new_language_key + custom_model_tts = check_custom_model_tts(session) + custom_model_tts_dir = os.path.join(session['custom_model_dir'], custom_model_tts) + if os.path.exists(custom_model_tts_dir): + custom_model_options += os.listdir(custom_model_tts_dir) return ( gr.update(value=new_language_name), gr.update(value=f' tts base: {tts_engine_value.upper()}'), - gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else None) + gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else 'none'), + gr.update(choices=custom_model_options, value=custom_model_options[0]) ) + def check_custom_model_tts(session): + custom_model_tts = 'xtts' + if not language_xtts.get(session['language']): + custom_model_tts = 'fairseq' + custom_model_tts_dir = os.path.join(session['custom_model_dir'], custom_model_tts) + if not os.path.isdir(custom_model_tts_dir): + os.makedirs(custom_model_tts_dir, exist_ok=True) + return custom_model_tts + + def change_gr_custom_model_list(custom_model_list): + if custom_model_list == 'none': + return gr.update(visible=True) + return gr.update(visible=False) + async def change_gr_custom_model_file(custom_model_file, session_id): try: + nonlocal custom_model_options, gr_custom_model_file, gr_conversion_progress if context and session_id: session = context.get_session(session_id) if custom_model_file is not None: - if analyze_uploaded_file(custom_model_file): - model_dir = extract_custom_model(f, session) - if model_dir: - yield gr.update(value=None), gr.update(value=None), gr.update(value='') - return - yield gr.update(value=None), gr.update(value=None), gr.update(value='Invalid file! Please upload a valid ZIP.') + if analyze_uploaded_file(custom_model_file): + session['custom_model'], progress_status = extract_custom_model(custom_model_file, None, session) + if session['custom_model']: + custom_model_tts_dir = check_custom_model_tts(session) + custom_model_options = ['none'] + os.listdir(os.path.join(session['custom_model_dir'], custom_model_tts_dir)) + yield ( + gr.update(visible=False), + gr.update(choices=custom_model_options, value=session['custom_model']), + gr.update(value=f"{session['custom_model']} added to the custom list") + ) + gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', value=None, file_types=['.zip']) + return + yield gr.update(), gr.update(), gr.update(value='Invalid file! Please upload a valid ZIP.') + return except Exception as e: - yield gr.update(value=None), gr.update(value=None), gr.update(value=e) - return + yield gr.update(), gr.update(), gr.update(value=f'Error: {str(e)}') + return def change_gr_data(data): data['event'] = 'change_data' @@ -1329,6 +1403,7 @@ def change_gr_data(data): def change_gr_read_data(data): nonlocal audiobooks_dir + nonlocal custom_model_options warning_text_extra = '' if not data: data = {'session_id': str(uuid.uuid4())} @@ -1339,18 +1414,23 @@ def change_gr_read_data(data): warning_text = data['session_id'] event = data.get('event', '') if event != 'load': - return [gr.update(), gr.update(), gr.update()] + return [gr.update(), gr.update(), gr.update(), gr.update(), gr.update()] + session = context.get_session(data['session_id']) + session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}") + os.makedirs(session['custom_model_dir'], exist_ok=True) + custom_model_tts_dir = check_custom_model_tts(session) + custom_model_options = ['none'] + os.listdir(os.path.join(session['custom_model_dir'],custom_model_tts_dir)) if is_gui_shared: warning_text_extra = f' Note: access limit time: {interface_shared_expire} hours' audiobooks_dir = os.path.join(audiobooks_gradio_dir, f"web-{data['session_id']}") delete_old_web_folders(audiobooks_gradio_dir) else: audiobooks_dir = os.path.join(audiobooks_host_dir, f"web-{data['session_id']}") - return [data, f'{warning_text}{warning_text_extra}', data['session_id'], update_audiobooks_ddn()] + return [data, f'{warning_text}{warning_text_extra}', data['session_id'], update_audiobooks_ddn(), gr.update(choices=custom_model_options, value='none')] - def process_conversion( + def submit_convert_btn( session, device, ebook_file, voice_file, language, - custom_model_file, custom_model_url, temperature, length_penalty, + custom_model_file, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, enable_text_splitting, fine_tuned ): nonlocal is_converting @@ -1364,8 +1444,7 @@ def process_conversion( "audiobooks_dir": audiobooks_dir, "voice": voice_file.name if voice_file else None, "language": next((key for name, key in language_options if name == language), None), - "custom_model": custom_model_file.name if custom_model_file else None, - "custom_model_url": custom_model_url if custom_model_file is None else None, + "custom_model": next((key for name, key in language_options if name != 'none'), None), "temperature": float(temperature), "length_penalty": float(length_penalty), "repetition_penalty": float(repetition_penalty), @@ -1404,19 +1483,24 @@ def process_conversion( outputs=[gr_modal_html] ) gr_language.change( - lambda selected: change_gr_language(dict(language_options).get(selected, 'Unknown')), - inputs=gr_language, - outputs=[gr_language, gr_tts_engine, gr_fine_tuned] + fn=lambda selected, session_id: change_gr_language(dict(language_options).get(selected, 'Unknown'), session_id), + inputs=[gr_language, gr_session], + outputs=[gr_language, gr_tts_engine, gr_fine_tuned, gr_custom_model_list] ) gr_audiobooks_ddn.change( fn=change_gr_audiobooks_ddn, inputs=gr_audiobooks_ddn, outputs=[gr_audiobook_link, gr_audio_player, gr_audio_player] ) + gr_custom_model_list.change( + fn=change_gr_custom_model_list, + inputs=[gr_custom_model_list], + outputs=[gr_fine_tuned] + ) gr_custom_model_file.change( fn=change_gr_custom_model_file, inputs=[gr_custom_model_file, gr_session], - outputs=[gr_custom_model_file, gr_custom_model_url, gr_conversion_progress] + outputs=[gr_fine_tuned, gr_custom_model_list, gr_conversion_progress] ) gr_session.change( fn=change_gr_data, @@ -1437,13 +1521,13 @@ def process_conversion( gr_read_data.change( fn=change_gr_read_data, inputs=gr_read_data, - outputs=[gr_data, gr_session_status, gr_session, gr_audiobooks_ddn] + outputs=[gr_data, gr_session_status, gr_session, gr_audiobooks_ddn, gr_custom_model_list] ) gr_convert_btn.click( - fn=process_conversion, + fn=submit_convert_btn, inputs=[ gr_session, gr_device, gr_ebook_file, gr_voice_file, gr_language, - gr_custom_model_file, gr_custom_model_url, gr_temperature, gr_length_penalty, + gr_custom_model_list, gr_temperature, gr_length_penalty, gr_repetition_penalty, gr_top_k, gr_top_p, gr_speed, gr_enable_text_splitting, gr_fine_tuned ], outputs=[gr_conversion_progress, gr_modal_html] @@ -1470,7 +1554,7 @@ def process_conversion( ) try: - interface.queue(default_concurrency_limit=interface_concurrency_limit).launch(server_name="0.0.0.0", server_port=interface_port, share=is_gui_shared) + interface.queue(default_concurrency_limit=interface_concurrency_limit).launch(server_name=interface_host, server_port=interface_port, share=is_gui_shared) except OSError as e: print(f'Connection error: {e}') except socket.error as e: