From e18c227c42b4137c12c5aabdf2acea61dce6c828 Mon Sep 17 00:00:00 2001 From: flux9665 Date: Mon, 7 Oct 2024 15:09:11 +0200 Subject: [PATCH] simplify advanced demo --- run_advanced_GUI_demo.py | 67 +++++++--------------------------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/run_advanced_GUI_demo.py b/run_advanced_GUI_demo.py index 63d4e5e8..48be159c 100644 --- a/run_advanced_GUI_demo.py +++ b/run_advanced_GUI_demo.py @@ -15,12 +15,10 @@ from PyQt5.QtWidgets import QComboBox from PyQt5.QtWidgets import QFileDialog from PyQt5.QtWidgets import QHBoxLayout -from PyQt5.QtWidgets import QLabel from PyQt5.QtWidgets import QLineEdit from PyQt5.QtWidgets import QMainWindow from PyQt5.QtWidgets import QMessageBox from PyQt5.QtWidgets import QPushButton -from PyQt5.QtWidgets import QSlider from PyQt5.QtWidgets import QVBoxLayout from PyQt5.QtWidgets import QWidget from huggingface_hub import hf_hub_download @@ -144,9 +142,6 @@ def __init__(self, tts_interface: ToucanTTSInterface): self.audio_file_path = None self.result_audio = None self.min_duration = 1 - self.slider_val = 100 - self.durations_are_scaled = False - self.prev_slider_val_for_denorm = 100 self.setWindowTitle("TTS Model Interface") self.setGeometry(100, 100, 1200, 900) @@ -177,7 +172,7 @@ def __init__(self, tts_interface: ToucanTTSInterface): # Initialize plots self.init_plots() - # Initialize slider and buttons + # Initialize buttons self.init_controls() # Initialize Timer for TTS Cooldown @@ -189,10 +184,6 @@ def __init__(self, tts_interface: ToucanTTSInterface): def clear_all_widgets(self): self.spectrogram_view.setParent(None) self.pitch_plot.setParent(None) - self.upper_row.setParent(None) - self.slider_label.setParent(None) - self.mod_slider.setParent(None) - self.slider_value_label.setParent(None) self.generate_button.setParent(None) self.load_audio_button.setParent(None) self.save_audio_button.setParent(None) @@ -218,6 +209,7 @@ def load_data(self, durations, pitch, spectrogram): self.durations = durations self.cumulative_durations = np.cumsum(self.durations) + self.pitch = pitch self.spectrogram = spectrogram # Display Spectrogram @@ -245,7 +237,7 @@ def load_data(self, durations, pitch, spectrogram): # Display Durations self.duration_lines = [] for i, cum_dur in enumerate(self.cumulative_durations): - line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=4)) + line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=2)) self.spectrogram_view.addItem(line) line.setMovable(True) # Use lambda with default argument to capture current i @@ -274,28 +266,6 @@ def init_controls(self): self.controls_layout = QVBoxLayout() self.main_layout.addLayout(self.controls_layout) - # Upper row layout for slider - self.upper_row = QHBoxLayout() - self.controls_layout.addLayout(self.upper_row) - - # Slider Label - self.slider_label = QLabel("Faster") - self.upper_row.addWidget(self.slider_label) - - # Slider - self.mod_slider = QSlider(Qt.Horizontal) - self.mod_slider.setMinimum(70) - self.mod_slider.setMaximum(130) - self.mod_slider.setValue(self.slider_val) - self.mod_slider.setTickPosition(QSlider.TicksBelow) - self.mod_slider.setTickInterval(10) - self.mod_slider.valueChanged.connect(self.on_slider_changed) - self.upper_row.addWidget(self.mod_slider) - - # Slider Value Display - self.slider_value_label = QLabel("Slower") - self.upper_row.addWidget(self.slider_value_label) - # Lower row layout for buttons self.lower_row = QHBoxLayout() self.controls_layout.addLayout(self.lower_row) @@ -406,18 +376,12 @@ def on_user_input_changed(self, text): # Mark that an update is required self.mark_tts_update() - def on_slider_changed(self, value): - # Update the slider label - # self.slider_value_label.setText(f"Durations at {value}%") - self.slider_val = value - # print(f"Slider changed to {scaling_factor * 100}% speed") - # Mark that an update is required - self.mark_tts_update() - def generate_new_prosody(self): """ Generate new prosody. """ + if self.text_input.text().strip() == "": + return wave, mel, durations, pitch = self.tts_backend(text=self.text_input.text(), view=False, duration_scaling_factor=1.0, @@ -433,9 +397,6 @@ def generate_new_prosody(self): prosody_creativity=0.8, return_everything=True) # reset and clear everything - self.slider_val = 100 - self.prev_slider_val_for_denorm = self.slider_val - self.durations_are_scaled = False self.clear_all_widgets() self.init_plots() self.init_controls() @@ -510,7 +471,8 @@ def save_audio_file(self): def play_audio(self): # print("playing current audio...") - sounddevice.play(self.result_audio, samplerate=24000) + if self.result_audio is not None: + sounddevice.play(self.result_audio, samplerate=24000) def update_result_audio(self, audio_array): """ @@ -525,7 +487,7 @@ def mark_tts_update(self): Marks that a TTS update is required and starts/resets the timer. """ self.tts_update_required = True - self.tts_timer.start(600) # 600 milliseconds + self.tts_timer.start(800) # 800 milliseconds delay before the model starts to compute something def run_tts(self): """ @@ -553,16 +515,12 @@ def run_tts(self): phonemes = self.tts_backend.text2phone.get_phone_string(text=text) self.phonemes = phonemes.replace(" ", "") - forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else insert_zeros_at_indexes(self.durations, self.word_boundaries) - if forced_durations is not None and self.durations_are_scaled: - forced_durations = torch.LongTensor([forced_duration / (self.prev_slider_val_for_denorm / 100) for forced_duration in forced_durations]).unsqueeze(0) # revert scaling - elif forced_durations is not None: - forced_durations = torch.LongTensor(forced_durations).unsqueeze(0) + forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else torch.LongTensor(insert_zeros_at_indexes(self.durations, self.word_boundaries)).unsqueeze(0) forced_pitch = None if self.pitch is None or len(self.pitch) != len(self.phonemes) else torch.tensor(insert_zeros_at_indexes(self.pitch, self.word_boundaries)).unsqueeze(0) wave, mel, durations, pitch = self.tts_backend(text, view=False, - duration_scaling_factor=self.slider_val / 100, + duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, @@ -576,9 +534,6 @@ def run_tts(self): return_everything=True) self.word_boundaries = find_zero_indexes(durations) - self.prev_slider_val_for_denorm = self.slider_val - if self.slider_val != 100: - self.durations_are_scaled = True self.load_data(durations=durations.cpu().numpy(), pitch=pitch.cpu().numpy(), spectrogram=mel.cpu().transpose(0, 1).numpy()) @@ -602,7 +557,7 @@ def main(): } QPushButton { - background-color: #808000; + background-color: #b9770e; border: 1px solid #ffffff; color: #ffffff; padding: 8px 16px;