Skip to content

Commit

Permalink
Merge pull request #94 from ROBERT-MCDOWELL/v2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
DrewThomasson authored Dec 19, 2024
2 parents 24ced3f + 11b6580 commit b8c8200
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 84 deletions.
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def main():
args = parser.parse_args()

# Check if the port is already in use to prevent multiple launches
if not args.headless and is_port_in_use(gradio_interface_port):
print(f'Error: Port {gradio_interface_port} is already in use. The web interface may already be running.')
if not args.headless and is_port_in_use(interface_port):
print(f'Error: Port {interface_port} is already in use. The web interface may already be running.')
sys.exit(1)

args.script_mode = args.script_mode if args.script_mode else NATIVE
Expand Down
12 changes: 9 additions & 3 deletions lib/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
requirements_file = os.path.abspath(os.path.join('.','requirements.txt'))

docker_utils_image = 'utils'
gradio_interface_port = 7860
gradio_shared_expire = 72 # hours
concurrency_limit = 8 # or None for unlimited
interface_port = 7860
interface_shared_expire = 72 # hours
interface_concurrency_limit = 8 # or None for unlimited
interface_component_options = {
"gr_tab_preferences": True,
"gr_voice_file": True,
"gr_custom_model_file": True,
"gr_custom_model_url": True
}

python_env_dir = os.path.abspath(os.path.join('.','python_env'))
models_dir = os.path.abspath(os.path.join('.','models'))
Expand Down
179 changes: 100 additions & 79 deletions lib/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_session(self, session_id):
}, manager=self.manager)
return self.sessions[session_id]

context = None
context = ConversionContext()
is_gui_process = False

class DependencyError(Exception):
Expand Down Expand Up @@ -185,50 +185,70 @@ def check_fine_tuned(fine_tuned, language):
except Exception as e:
raise RuntimeError(e)

def download_custom_model(url, dest_dir, session):
def download_custom_model(url, dest, session):
try:
parsed_url = urlparse(url)
fname = os.path.basename(parsed_url.path)
if not os.path.exists(dest_dir):
os.makedirs(dest_dir, exist_ok=True)
if not os.path.exists(dest):
os.makedirs(dest, exist_ok=True)
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad responses
f_path = os.path.join(dest_dir,fname)
with open(f_path, 'wb') as file:
file_src = os.path.join(dest,fname)
with open(file_src, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f'File saved at: {f_path}')
return extract_custom_model(f_path, dest_dir, session)
print(f'File saved at: {file_src}')
return file_src
except Exception as e:
raise RuntimeError(f'Error while downloading the file: {e}')

def extract_custom_model(f_path, dest_dir, session):
raise RuntimeError(f'download_custom_model(): {e}')

def analyze_uploaded_file(zip_path, required_files=None):
if required_files is None:
required_files = ['config.json', 'vocab.json', 'model.pth', 'ref.wav']
executable_extensions = {'.exe', '.bat', '.cmd', '.sh', '.msi', '.dll', '.com'}
try:
model_dir = os.path.join(dest_dir, f_path.replace('.zip',''))
os.makedirs(model_dir, exist_ok=True)
with zipfile.ZipFile(f_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, 'r') as zf:
files_in_zip = set()
executables_found = False
for file_info in zf.infolist():
file_name = file_info.filename
if file_info.is_dir():
continue # Skip directories
base_name = os.path.basename(file_name)
files_in_zip.add(base_name)
_, ext = os.path.splitext(base_name.lower())
if ext in executable_extensions:
executables_found = True
break
missing_files = [f for f in required_files if f not in files_in_zip]
is_valid = not executables_found and not missing_files
return is_valid
except zipfile.BadZipFile:
raise ValueError("error: The file is not a valid ZIP archive.")
except Exception as e:
raise RuntimeError(f'analyze_uploaded_file(): {e}')

def extract_custom_model(file_src, dest, session):
try:
dir_src = os.path.dirname(file_src)
dir_dest = os.path.join(dest, file_src.replace('.zip',''))
os.makedirs(dir_dest, exist_ok=True)
with zipfile.ZipFile(file_src, 'r') as zip_ref:
files = zip_ref.namelist()
with tqdm(total=len(files), unit='file', desc='Extracting Files') as t:
for file in files:
if session['cancellation_requested']:
msg = 'Cancel requested'
raise ValueError()
if os.path.isfile(file):
extract = zip_ref.extract(file, model_dir)
extract = zip_ref.extract(file, dir_dest)
t.update(1)
os.remove(f_path)
print(f'Extracted files to {model_dir}')
return check_model_files(model_dir)
os.remove(file_src)
print(f'Extracted files to {dir_dest}')
return dir_dest
except Exception as e:
raise DependencyError(e)

def check_model_files(model_dir):
existing_files = ['config.json', 'model.pth', 'vocab.json', 'ref.wav']
missing_files = [file for file in os.listdir(model_dir) if not file in existing_files]
if missing_files:
return False
return model_dir

def calculate_hash(filepath, hash_algorithm='sha256'):
hash_func = hashlib.new(hash_algorithm)
with open(filepath, 'rb') as file:
Expand Down Expand Up @@ -821,7 +841,7 @@ def delete_old_web_folders(root_dir):
os.makedirs(root_dir)
print(f'Created missing directory: {root_dir}')
current_time = time.time()
age_limit = current_time - gradio_shared_expire * 60 * 60 # 24 hours in seconds
age_limit = current_time - interface_shared_expire * 60 * 60 # 24 hours in seconds
for folder_name in os.listdir(root_dir):
dir_path = os.path.join(root_dir, folder_name)
if os.path.isdir(dir_path) and folder_name.startswith('web-'):
Expand Down Expand Up @@ -859,7 +879,6 @@ def convert_ebook(args):
pass

if args['language'] is not None and args['language'] in language_mapping.keys():
context = ConversionContext()
session_id = args['session'] if args['session'] is not None else str(uuid.uuid4())
session = context.get_session(session_id)
session['id'] = session_id
Expand Down Expand Up @@ -900,22 +919,24 @@ def convert_ebook(args):
session['tmp_dir'] = os.path.join(processes_dir, f"ebook-{session['id']}")
session['chapters_dir'] = os.path.join(session['tmp_dir'], f'chapters_{hashlib.md5(args['ebook'].encode()).hexdigest()}')
session['chapters_dir_sentences'] = os.path.join(session['chapters_dir'], 'sentences')
session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}")

if not is_gui_process:
print(f'*********** Session: {session_id}', '************* Store it in case of interruption or crash you can resume the conversion')

if prepare_dirs(args['ebook'], session):
session['filename_noext'] = os.path.splitext(os.path.basename(session['src']))[0]
session['custom_model'] = None
session['custom_model_dir'] = os.path.join(models_dir,'__sessions',f"model-{session['id']}")
if custom_model_file or custom_model_url:
if custom_model_url:
print(f'Get custom model: {custom_model_url}')
session['custom_model'] = download_custom_model(custom_model_url, session['custom_model_dir'], session)
file_src = download_custom_model(custom_model_url, session['custom_model_dir'], session)
if session['custom_model']:
if analyze_uploaded_file(file_src):
session['custom_model'] = extract_custom_model(file_src, session['custom_model_dir'], session)
else:
session['custom_model'] = extract_custom_model(custom_model_file, session['custom_model_dir'], session)
if not session['custom_model']:
raise ValueError(f'{custom_model_file} could not be extracted or mandatory files are missing')

if prepare_dirs(args['ebook'], session):
session['filename_noext'] = os.path.splitext(os.path.basename(session['src']))[0]
if not torch.cuda.is_available() or device == 'cpu':
if device == 'gpu':
print('GPU is not available on your device!')
Expand Down Expand Up @@ -1011,7 +1032,7 @@ def web_interface(args):
radius_size='lg',
font_mono=['JetBrains Mono', 'monospace', 'Consolas', 'Menlo', 'Liberation Mono']
)

with gr.Blocks(theme=theme) as interface:
gr.HTML(
'''
Expand Down Expand Up @@ -1071,7 +1092,8 @@ def web_interface(args):
'''
)
with gr.Tabs():
with gr.TabItem('Input Options'):
gr_tab_main = gr.TabItem('Input Options')
with gr_tab_main:
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
Expand All @@ -1083,16 +1105,17 @@ def web_interface(args):
gr_language = gr.Dropdown(label='Language', choices=[name for name, _ in language_options], value=default_language_name)
with gr.Column(scale=3):
with gr.Group():
gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for major language and 22050hz for others, no more than 6 sec)', file_types=['.wav'])
gr_voice_file = gr.File(label='*Cloning Voice (a .wav 24000hz for XTTS base model and 22050hz for FAIRSEQ base model, no more than 6 sec)', file_types=['.wav'], visible=interface_component_options['gr_voice_file'])
gr.Markdown('<p>&nbsp;&nbsp;* Optional</p>')
with gr.Group():
gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip'], visible=True)
gr_custom_model_url = gr.Textbox(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=False)
gr_custom_model_file = gr.File(label='*XTTS Model (a .zip containing config.json, vocab.json, model.pth, ref.wav)', file_types=['.zip'], visible=interface_component_options['gr_custom_model_file'])
gr_custom_model_url = gr.Textbox(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=interface_component_options['gr_custom_model_url'])
gr.Markdown('<p>&nbsp;&nbsp;* Optional</p>')
with gr.Group():
gr_tts_engine = gr.Markdown(f'&nbsp;&nbsp;&nbsp;&nbsp;TTS Base:&nbsp;&nbsp;{default_tts_engine.upper()}')
gr_fine_tuned = gr.Dropdown(label='Fine Tuned Models', choices=fine_tuned_options, value=default_fine_tuned, interactive=True)
with gr.TabItem('Audio Generation Preferences'):
gr_tab_preferences = gr.TabItem('Audio Generation Preferences', visible=interface_component_options['gr_tab_preferences'])
with gr_tab_preferences:
gr.Markdown(
'''
### Customize Audio Generation Parameters
Expand Down Expand Up @@ -1242,32 +1265,29 @@ def change_gr_audiobooks_ddn(audiobook):
return link, link, gr.update(visible=True)
return None, None, gr.update(visible=False)

def disable_convert_btn(f):
if not hasattr(f, 'name') or f.name == "":
return "File is still uploading. Please wait."
def update_convert_btn(upload_file, custom_model_file, session_id):
session = context.get_session(session_id)
if hasattr(upload_file, 'name') and not hasattr(custom_model_file, 'name'):
yield gr.update(variant='primary', interactive=True)
else:
return gr.update('Convert', variant='primary', interactive=False)
yield gr.update(variant='primary', interactive=False)
return

def update_audiobooks_ddn():
files = refresh_audiobook_list()
return gr.update(choices=files, label='Audiobooks', value=files[0] if files else None)

async def change_gr_ebook_file(btn, f, session_id):
async def change_gr_ebook_file(f, session_id):
nonlocal is_converting
if context and session_id:
session = context.get_session(session_id)
if f is None:
if is_converting:
session['cancellation_requested'] = True
yield gr.update(interactive=False), show_modal('Cancellation requested, please wait...')
return
else:
session['cancellation_requested'] = False
yield gr.update(interactive=False), hide_modal()
yield show_modal('Cancellation requested, please wait...')
return
else:
session['cancellation_requested'] = False
yield gr.update(interactive=bool(f)), hide_modal()
session['cancellation_requested'] = False
yield hide_modal()
return

def change_gr_language(selected: str):
Expand All @@ -1276,28 +1296,32 @@ def change_gr_language(selected: str):
new_language_key = default_language_code
else:
new_language_name, new_language_key = next(((name, key) for name, key in language_options if key == selected), (None, None))

# Determine the TTS engine to use
tts_engine_value = 'xtts' if language_xtts.get(new_language_key, False) else 'fairseq'

# Get fine-tuned options filtered by language
fine_tuned_options = [
model_name
for model_name, model_details in models.get(tts_engine_value, {}).items()
if model_details.get('lang') == 'multi' or model_details.get('lang') == new_language_key
]

# Update the dropdown and other elements
return (
gr.update(value=new_language_name), # Update the language dropdown
gr.update(value=f'&nbsp;&nbsp;&nbsp;&nbsp;tts base: {tts_engine_value.upper()}'), # Update the TTS engine display
gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else None) # Update fine-tuned options
gr.update(value=new_language_name),
gr.update(value=f'&nbsp;&nbsp;&nbsp;&nbsp;tts base: {tts_engine_value.upper()}'),
gr.update(choices=fine_tuned_options, value=fine_tuned_options[0] if fine_tuned_options else None)
)

def change_gr_custom_model_file(f):
if f is not None:
return gr.update(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=False)
return gr.update(placeholder='https://www.example.com/model.zip', label='Model from URL*', visible=True)
async def change_gr_custom_model_file(custom_model_file, session_id):
try:
if context and session_id:
session = context.get_session(session_id)
if custom_model_file is not None:
if analyze_uploaded_file(custom_model_file):
model_dir = extract_custom_model(f, session)
if model_dir:
yield gr.update(value=None), gr.update(value=None), gr.update(value='')
return
yield gr.update(value=None), gr.update(value=None), gr.update(value='Invalid file! Please upload a valid ZIP.')
except Exception as e:
yield gr.update(value=None), gr.update(value=None), gr.update(value=e)
return

def change_gr_data(data):
data['event'] = 'change_data'
Expand All @@ -1306,8 +1330,6 @@ def change_gr_data(data):
def change_gr_read_data(data):
nonlocal audiobooks_dir
warning_text_extra = ''
if is_gui_shared:
warning_text_extra = f' Note: access limit time: {gradio_shared_expire} hours'
if not data:
data = {'session_id': str(uuid.uuid4())}
warning_text = f"Session: {data['session_id']}"
Expand All @@ -1319,6 +1341,7 @@ def change_gr_read_data(data):
if event != 'load':
return [gr.update(), gr.update(), gr.update()]
if is_gui_shared:
warning_text_extra = f' Note: access limit time: {interface_shared_expire} hours'
audiobooks_dir = os.path.join(audiobooks_gradio_dir, f"web-{data['session_id']}")
delete_old_web_folders(audiobooks_gradio_dir)
else:
Expand Down Expand Up @@ -1360,7 +1383,6 @@ def process_conversion(
is_converting = True
progress_status, audiobook_file = convert_ebook(args)
is_converting = False

if audiobook_file is None:
if is_converting:
return 'Conversion cancelled.', hide_modal()
Expand All @@ -1373,9 +1395,13 @@ def process_conversion(
return DependencyError(e)

gr_ebook_file.change(
fn=update_convert_btn,
inputs=[gr_ebook_file, gr_custom_model_file, gr_session],
outputs=gr_convert_btn
).then(
fn=change_gr_ebook_file,
inputs=[gr_convert_btn, gr_ebook_file, gr_session],
outputs=[gr_convert_btn, gr_modal_html]
inputs=[gr_ebook_file, gr_session],
outputs=[gr_modal_html]
)
gr_language.change(
lambda selected: change_gr_language(dict(language_options).get(selected, 'Unknown')),
Expand All @@ -1387,12 +1413,11 @@ def process_conversion(
inputs=gr_audiobooks_ddn,
outputs=[gr_audiobook_link, gr_audio_player, gr_audio_player]
)
"""
gr_custom_model_file.change(
fn=change_gr_custom_model_file,
inputs=gr_custom_model_file,
outputs=gr_custom_model_url
"""
inputs=[gr_custom_model_file, gr_session],
outputs=[gr_custom_model_file, gr_custom_model_url, gr_conversion_progress]
)
gr_session.change(
fn=change_gr_data,
inputs=gr_data,
Expand All @@ -1415,10 +1440,6 @@ def process_conversion(
outputs=[gr_data, gr_session_status, gr_session, gr_audiobooks_ddn]
)
gr_convert_btn.click(
fn=disable_convert_btn,
inputs=gr_custom_model_file,
outputs=gr_convert_btn
).then(
fn=process_conversion,
inputs=[
gr_session, gr_device, gr_ebook_file, gr_voice_file, gr_language,
Expand Down Expand Up @@ -1449,7 +1470,7 @@ def process_conversion(
)

try:
interface.queue(default_concurrency_limit=concurrency_limit).launch(server_name="0.0.0.0", server_port=gradio_interface_port, share=is_gui_shared)
interface.queue(default_concurrency_limit=interface_concurrency_limit).launch(server_name="0.0.0.0", server_port=interface_port, share=is_gui_shared)
except OSError as e:
print(f'Connection error: {e}')
except socket.error as e:
Expand Down

0 comments on commit b8c8200

Please sign in to comment.