Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import torch | |
| import gradio as gr | |
| from strings import TITLE, ABSTRACT, EXAMPLES | |
| from gen import get_pretrained_models, get_output, setup_model_parallel | |
| os.environ["RANK"] = "0" | |
| os.environ["WORLD_SIZE"] = "1" | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "50505" | |
| local_rank, world_size = setup_model_parallel() | |
| generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size) | |
| history = [] | |
| def chat( | |
| user_input, | |
| include_input, | |
| truncate, | |
| top_p, | |
| temperature, | |
| max_gen_len, | |
| state_chatbot | |
| ): | |
| bot_response = get_output( | |
| generator=generator, | |
| prompt=user_input, | |
| max_gen_len=max_gen_len, | |
| temperature=temperature, | |
| top_p=top_p)[0] | |
| # remove the first phrase identical to user prompt | |
| if not include_input: | |
| bot_response = bot_response[len(user_input):] | |
| bot_response = bot_response.replace("\n", "<br>") | |
| # trip the last phrase | |
| if truncate: | |
| try: | |
| bot_response = bot_response[:bot_response.rfind(".")+1] | |
| except: | |
| pass | |
| history.append({ | |
| "role": "user", | |
| "content": user_input | |
| }) | |
| history.append({ | |
| "role": "system", | |
| "content": bot_response | |
| }) | |
| state_chatbot = state_chatbot + [(user_input, None)] | |
| response = "" | |
| for word in bot_response.split(" "): | |
| time.sleep(0.1) | |
| response += word + " " | |
| current_pair = (user_input, response) | |
| state_chatbot[-1] = current_pair | |
| yield state_chatbot, state_chatbot | |
| def reset_textbox(): | |
| return gr.update(value='') | |
| with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;} | |
| #chatbot {height: 400px; overflow: auto;}""") as demo: | |
| state_chatbot = gr.State([]) | |
| with gr.Column(elem_id='col_container'): | |
| gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") | |
| with gr.Accordion("Example prompts", open=False): | |
| example_str = "\n" | |
| for example in EXAMPLES: | |
| example_str += f"- {example}\n" | |
| gr.Markdown(example_str) | |
| chatbot = gr.Chatbot(elem_id='chatbot') | |
| textbox = gr.Textbox(placeholder="Enter a prompt") | |
| with gr.Accordion("Parameters", open=False): | |
| include_input = gr.Checkbox(value=True, label="Do you want to include the input in the generated text?") | |
| truncate = gr.Checkbox(value=True, label="Truncate the unfinished last words?") | |
| max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",) | |
| top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) | |
| temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",) | |
| textbox.submit( | |
| chat, | |
| [textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot], | |
| [state_chatbot, chatbot] | |
| ) | |
| textbox.submit(reset_textbox, [], [textbox]) | |
| demo.queue(api_open=False).launch() |