Skip to content

Commit

Permalink
feat: interruptible infer (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent 5222976 commit c23e514
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ChatTTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .core import Chat
from .core import Chat
8 changes: 8 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self, logger=logging.getLogger(__name__)):
logger,
)

self.context = GPT.Context()

def has_loaded(self, use_decoder = False):
not_finish = False
check_list = ["vocos", "_vocos_decode", 'gpt', 'tokenizer']
Expand Down Expand Up @@ -155,6 +157,7 @@ def infer(
params_refine_text = RefineTextParams(),
params_infer_code = InferCodeParams(),
):
self.context.set(False)
res_gen = self._infer(
text,
stream,
Expand All @@ -171,6 +174,9 @@ def infer(
return res_gen
else:
return next(res_gen)

def interrupt(self):
self.context.set(True)

def _load(
self,
Expand Down Expand Up @@ -422,6 +428,7 @@ def _infer_code(
infer_text = False,
return_hidden=return_hidden,
stream = stream,
context=self.context,
)

del_all(text_token)
Expand Down Expand Up @@ -467,6 +474,7 @@ def _refine_text(
logits_processors = logits_processors,
infer_text = True,
stream = False,
context=self.context,
)

del_all(text_token)
Expand Down
21 changes: 19 additions & 2 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def __init__(
super().__init__()

self.logger = logger

self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")

self.num_vq = num_vq
self.num_audio_tokens = num_audio_tokens

Expand All @@ -76,6 +78,17 @@ def __init__(
name='weight',
) for _ in range(self.num_vq)],
)

class Context:
def __init__(self):
self._interrupt = False

def set(self, v: bool):
self._interrupt = v

def get(self) -> bool:
return self._interrupt


def _build_llama(self, config: omegaconf.DictConfig, device: torch.device) -> LlamaModel:

Expand Down Expand Up @@ -266,6 +279,7 @@ def generate(
return_attn=False,
return_hidden=False,
stream=False,
context=Context(),
):

with torch.no_grad():
Expand Down Expand Up @@ -407,12 +421,15 @@ def generate(
)
del minus_prev_end_index

if finish.all(): break
if finish.all() or context.get(): break

pbar.update(1)

if not finish.all():
self.logger.warning(f'incomplete result. hit max_new_token: {max_new_token}')
if context.get():
self.logger.warning('generation is interrupted')
else:
self.logger.warning(f'incomplete result. hit max_new_token: {max_new_token}')

del finish

Expand Down
57 changes: 41 additions & 16 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

custom_path: Optional[str] = None

has_interrupted = False

# 音色选项:用于预置合适的音色
voices = {
"Default": {"seed": 2},
Expand Down Expand Up @@ -83,24 +85,31 @@ def reload_chat(coef: Optional[str]) -> str:
gr.Info("Reload succeess.")
return chat.coef

def set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(value=generate_button, visible=is_reset, interactive=is_reset), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)

def refine_text(text, text_seed_input, refine_text_flag):
if not refine_text_flag:
return text
def refine_text(text, text_seed_input, refine_text_flag, generate_button, interrupt_button):
global chat, has_interrupted
has_interrupted = False

global chat
if not refine_text_flag:
return text, *set_generate_buttons(generate_button, interrupt_button, is_reset=True)

with TorchSeedContext(text_seed_input):
text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
)
return text[0] if isinstance(text, list) else text
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
)
return text[0] if isinstance(text, list) else text, *set_generate_buttons(generate_button, interrupt_button, is_reset=True)

def text_output_listener(generate_button, interrupt_button):
return set_generate_buttons(generate_button, interrupt_button)

def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
if not text: return None
global chat, has_interrupted

global chat
if not text or text == "𝕃𝕠𝕒𝕕𝕚𝕟𝕘..." or has_interrupted: return None

with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
Expand All @@ -119,10 +128,26 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
params_infer_code=params_infer_code,
stream=stream,
)

if stream:
for gen in wav:
yield 24000, unsafe_float_to_int16(gen[0][0])
return
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield 24000, unsafe_float_to_int16(audio[0])
del audio
return

yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())

def interrupt_generate():
global chat, has_interrupted

has_interrupted = True
chat.interrupt()

def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted

return set_generate_buttons(
generate_button, interrupt_button,
audio_output is not None or has_interrupted,
)
12 changes: 8 additions & 4 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main():
auto_play_checkbox = gr.Checkbox(label="Auto Play", value=False, scale=1)
stream_mode_checkbox = gr.Checkbox(label="Stream Mode", value=False, scale=1)
generate_button = gr.Button("Generate", scale=2, variant="primary")
interrupt_button = gr.Button("Interrupt", scale=2, variant="stop", visible=False, interactive=False)

text_output = gr.Textbox(label="Output Text", interactive=False)

Expand All @@ -64,10 +65,12 @@ def main():

reload_chat_button.click(reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text)

generate_button.click(fn=lambda: "", outputs=text_output)
generate_button.click(fn=lambda: "𝕃𝕠𝕒𝕕𝕚𝕟𝕘...", outputs=text_output)
generate_button.click(refine_text,
inputs=[text_input, text_seed_input, refine_text_checkbox],
outputs=text_output)
inputs=[text_input, text_seed_input, refine_text_checkbox, generate_button, interrupt_button],
outputs=[text_output, generate_button, interrupt_button])

interrupt_button.click(interrupt_generate)

@gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox])
def make_audio(autoplay, stream):
Expand All @@ -79,9 +82,10 @@ def make_audio(autoplay, stream):
interactive=False,
show_label=True,
)
text_output.change(text_output_listener, inputs=[generate_button, interrupt_button], outputs=[generate_button, interrupt_button])
text_output.change(generate_audio,
inputs=[text_output, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, stream_mode_checkbox],
outputs=audio_output)
outputs=audio_output).then(fn=set_buttons_after_generate, inputs=[generate_button, interrupt_button, audio_output], outputs=[generate_button, interrupt_button])

gr.Examples(
examples=[
Expand Down

0 comments on commit c23e514

Please sign in to comment.