Skip to content

Commit

Permalink
Mini refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Mar 15, 2023
1 parent 66256ac commit ffb8986
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def create_settings_menus(default_preset):
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''

if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
Expand Down Expand Up @@ -276,9 +276,6 @@ def create_settings_menus(default_preset):

create_settings_menus(default_preset)

if shared.args.extensions is not None:
extensions_module.create_extensions_block()

function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]

Expand Down Expand Up @@ -325,8 +322,7 @@ def create_settings_menus(default_preset):
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)

elif shared.args.notebook:
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
Expand All @@ -344,18 +340,14 @@ def create_settings_menus(default_preset):
with gr.Tab("Settings", elem_id="settings"):
create_settings_menus(default_preset)

if shared.args.extensions is not None:
extensions_module.create_extensions_block()

shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")

else:
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
else:
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column():
Expand All @@ -380,9 +372,6 @@ def create_settings_menus(default_preset):
with gr.Tab("Settings", elem_id="settings"):
create_settings_menus(default_preset)

if shared.args.extensions is not None:
extensions_module.create_extensions_block()

shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
Expand All @@ -391,6 +380,9 @@ def create_settings_menus(default_preset):
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")

if shared.args.extensions is not None:
extensions_module.create_extensions_block()

shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
Expand Down

0 comments on commit ffb8986

Please sign in to comment.