Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
TOKEN_LIMIT = 2048 | |
TEMPERATURE = 0.3 | |
REPETITION_PENALTY = 1.05 | |
MAX_NEW_TOKENS = 500 | |
MODEL_NAME = "OEvortex/HelpingAI-Lite-chat" | |
# fmt: off | |
st.write("**💬 with [OEvortex/HelpingAI-Lite-chat](https://huggingface.co/OEvortex/HelpingAI-Lite-chat)**" ) | |
st.write("*The model operates on free-tier hardware, which may lead to slower performance during periods of high demand.*") | |
st.write("*I am using transformers in this space you can also use ⚡Inference API for this model*") | |
# fmt: on | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
torch.set_grad_enabled(False) | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16 | |
) | |
return tokenizer, model | |
def chat_func_stream(tokenizer, model, chat_history, streamer): | |
input_ids = tokenizer.apply_chat_template( | |
chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
).to(model.device) | |
# check input length | |
if len(input_ids[0]) > TOKEN_LIMIT: | |
st.warning( | |
f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens." | |
) | |
st.session_state.chat_history = st.session_state.chat_history[:-1] | |
return | |
model.generate( | |
input_ids, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
repetition_penalty=REPETITION_PENALTY, | |
max_new_tokens=MAX_NEW_TOKENS, | |
streamer=streamer, | |
) | |
return | |
def show_chat_message(contrainer, chat_message): | |
with contrainer: | |
with st.chat_message(chat_message["role"]): | |
st.write(chat_message["content"]) | |
class ResponseStreamer: | |
def __init__(self, tokenizer, container, chat_history): | |
self.tokenizer = tokenizer | |
self.container = container | |
self.chat_history = chat_history | |
self.first_call_to_put = True | |
self.current_response = "" | |
with self.container: | |
self.placeholder = st.empty() # placeholder to save streamed message | |
def put(self, new_token): | |
# do not write input tokens | |
if self.first_call_to_put: | |
self.first_call_to_put = False | |
return | |
# decode current token and accumulate current_response | |
decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True) | |
self.current_response += decoded | |
# display the stramed message | |
show_chat_message( | |
self.placeholder.container(), | |
{"role": "assistant", "content": self.current_response}, | |
) | |
def end(self): | |
# save assistant message | |
self.chat_history.append( | |
{"role": "assistant", "content": self.current_response} | |
) | |
# clean up states (actually not needed as the instance will get recreated) | |
self.first_call_to_put = True | |
self.current_response = "" | |
# rerun to unfreeze the chat_input | |
st.rerun() | |
tokenizer, model = load_model() | |
chat_messages_container = st.container() | |
for msg in st.session_state.chat_history: | |
show_chat_message(chat_messages_container, msg) | |
input_placeholder = st.empty() # use placeholder as a hack to disable input | |
user_input = input_placeholder.chat_input(key="user_input_original") | |
if user_input: | |
# disable chat_input while generating | |
input_placeholder.chat_input(key="user_input_disabled", disabled=True) | |
new_user_message = {"role": "user", "content": user_input} | |
st.session_state.chat_history.append(new_user_message) | |
show_chat_message(chat_messages_container, new_user_message) | |
streamer = ResponseStreamer( | |
tokenizer=tokenizer, | |
container=chat_messages_container, | |
chat_history=st.session_state.chat_history, | |
) | |
chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer) | |