Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
| model_id = 'Deci/DeciLM-6b-instruct' | |
| SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
| ### Instruction: | |
| {instruction} | |
| ### Response: | |
| """ | |
| DESCRIPTION = """ | |
| # <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p> | |
| <span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model; you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span> | |
| <p><span style='color: #292b47;'>Learn more about the base model <a href="https://deci.ai/blog/decilm-15-times-faster-than-llama2-nas-generated-llm-with-variable-gqa/" style="color: #3264ff;">DeciLM-6B.</a></span></p> | |
| """ | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb' | |
| if torch.cuda.is_available(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map='auto', | |
| trust_remote_code=True, | |
| use_auth_token=token | |
| ) | |
| else: | |
| model = None | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Function to construct the prompt using the new system prompt template | |
| def get_prompt_with_template(message: str) -> str: | |
| return SYSTEM_PROMPT_TEMPLATE.format(instruction=message) | |
| # Function to generate the model's response | |
| def generate_model_response(message: str) -> str: | |
| prompt = get_prompt_with_template(message) | |
| inputs = tokenizer(prompt, return_tensors='pt') | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to('cuda') | |
| # Include **generate_kwargs to include the user-defined options | |
| output = model.generate(**inputs, | |
| max_new_tokens=3000, | |
| num_beams=2, | |
| no_repeat_ngram_size=4, | |
| early_stopping=True, | |
| do_sample=True | |
| ) | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Function to extract the content after "### Response:" | |
| def extract_response_content(full_response: str, ) -> str: | |
| response_start_index = full_response.find("### Response:") | |
| if response_start_index != -1: | |
| return full_response[response_start_index + len("### Response:"):].strip() | |
| else: | |
| return full_response | |
| # The main function that uses the dynamic generate_kwargs | |
| def get_response_with_template(message: str) -> str: | |
| full_response = generate_model_response(message) | |
| return extract_response_content(full_response) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton(value='Duplicate Space for private use', | |
| elem_id='duplicate-button') | |
| with gr.Group(): | |
| chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:') | |
| with gr.Row(): | |
| textbox = gr.Textbox( | |
| container=False, | |
| show_label=False, | |
| placeholder='Type an instruction...', | |
| scale=10, | |
| elem_id="textbox" | |
| ) | |
| submit_button = gr.Button( | |
| '💬 Submit', | |
| variant='primary', | |
| scale=1, | |
| min_width=0, | |
| elem_id="submit_button" | |
| ) | |
| # Clear button to clear the chat history | |
| clear_button = gr.Button( | |
| '🗑️ Clear', | |
| variant='secondary', | |
| ) | |
| clear_button.click( | |
| fn=lambda: ('',''), | |
| outputs=[textbox, chatbot], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| submit_button.click( | |
| fn=get_response_with_template, | |
| inputs=textbox, | |
| outputs= chatbot, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| 'Write detailed instructions for making chocolate chip pancakes.', | |
| 'Write a 250-word article about your love of pancakes.', | |
| 'Explain the plot of Back to the Future in three sentences.', | |
| 'How do I make a trap beat?', | |
| 'A step-by-step guide to learning Python in one month.', | |
| ], | |
| inputs=textbox, | |
| outputs=chatbot, | |
| fn=get_response_with_template, | |
| cache_examples=True, | |
| elem_id="examples" | |
| ) | |
| gr.HTML(label="Keep in touch", value="<img src='https://huggingface.co/spaces/Deci/DeciLM-6b-instruct/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>") | |
| demo.launch() |