Spaces:
Sleeping
Sleeping
from optimum.intel.openvino import OVModelForCausalLM | |
from transformers import AutoTokenizer, AutoConfig | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
import streamlit as st | |
import warnings | |
warnings.filterwarnings(action='ignore') | |
import datetime | |
import random | |
import string | |
from time import sleep | |
import tiktoken | |
import asyncio # ๋น๋๊ธฐ ์ฒ๋ฆฌ๋ฅผ ์ํด asyncio ์ถ๊ฐ | |
# requirements.txt ํ์ผ ํ์: | |
# optimum[openvino] | |
# transformers | |
# streamlit | |
# tiktoken | |
# asyncio | |
# ํ ํฐ ์ ๊ณ์ฐ์ ์ํ ์ธ์ฝ๋ฉ ์ค์ | |
encoding = tiktoken.get_encoding("cl100k_base") | |
# ๋ชจ๋ธ ์ด๋ฆ ๋ฐ ID ์ค์ (๋ณ์ ํต์ผ) | |
model_name = "Gemma2-2B-it" | |
model_id = "AIFunOver/gemma-2-2b-it-openvino-4bit" # Hugging Face Hub ๋ชจ๋ธ ID | |
# ์นํ์ด์ง ๊ธฐ๋ณธ ์ค์ | |
st.set_page_config( | |
page_title=f"Your LocalGPT โจ with {model_name}", | |
page_icon="๐", | |
layout="wide") | |
# Session State ์ด๊ธฐํ (Hugging Face Space ์ฌ์คํ ์ ์ํ ์ ์ง) | |
if "hf_model" not in st.session_state: | |
st.session_state.hf_model = model_name | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "chatMessages" not in st.session_state: | |
st.session_state.chatMessages = [] | |
if "repeat" not in st.session_state: | |
st.session_state.repeat = 1.35 | |
if "temperature" not in st.session_state: | |
st.session_state.temperature = 0.1 | |
if "maxlength" not in st.session_state: | |
st.session_state.maxlength = 500 | |
if "speed" not in st.session_state: | |
st.session_state.speed = 0.0 | |
if "numOfTurns" not in st.session_state: | |
st.session_state.numOfTurns = 0 | |
if "maxTurns" not in st.session_state: | |
st.session_state.maxTurns = 5 # must be odd number, greater than equal to 5 | |
if "logfilename" not in st.session_state: | |
## Logger file | |
logfile = f'logs/Gemma2-2B_{genRANstring(5)}_log.txt' # Space ๋ฃจํธ์ logs ํด๋์ ์ ์ฅ | |
st.session_state.logfilename = logfile | |
# Write in the history the first 2 sessions | |
writehistory(st.session_state.logfilename,f'{str(datetime.datetime.now())}\n\nYour own LocalGPT with ๐ {model_name}\n---\n๐ง ๐ซก: You are a helpful assistant.') | |
writehistory(st.session_state.logfilename,f'๐: How may I help you today?') | |
def writehistory(filename,text): | |
try: | |
with open(filename, 'a', encoding='utf-8') as f: | |
f.write(text) | |
f.write('\n') | |
f.close() | |
except Exception as e: | |
print(f"Error writing to log file: {e}") # Log error to console | |
def genRANstring(n): | |
""" | |
n = int number of char to randomize | |
""" | |
N = n | |
res = ''.join(random.choices(string.ascii_uppercase + | |
string.digits, k=N)) | |
return res | |
# | |
def create_chat(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
ov_model = OVModelForCausalLM.from_pretrained( | |
model_id = model_id, | |
device='CPU', | |
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}, # OpenVINO config | |
config=AutoConfig.from_pretrained(model_id) | |
) | |
#Credit to https://github.com/openvino-dev-samples/chatglm3.openvino/blob/main/chat.py | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
return tokenizer, ov_model, streamer | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
return None, None, None # Return None values to indicate failure | |
def countTokens(text): | |
encoding = tiktoken.get_encoding("cl100k_base") # context_count = len(encoding.encode(yourtext)) | |
numoftokens = len(encoding.encode(text)) | |
return numoftokens | |
#AVATARS - using emojis instead of images | |
av_us = "๐ค" # User avatar emoji | |
av_ass = "๐ค" # Assistant avatar emoji | |
nCTX = 8192 | |
### START STREAMLIT UI | |
# Create a header element - using markdown instead of image | |
st.header(f"๐ {model_name} Chatbot") | |
st.markdown(f"> *๐ {model_name} with {nCTX} tokens Context window* - Turn based Chat available with max capacity of :orange[**{st.session_state.maxTurns} messages**].", unsafe_allow_html=True) | |
st.markdown(f"#### Powered by OpenVINO") | |
# CREATE THE SIDEBAR - using markdown and text instead of images | |
with st.sidebar: | |
st.subheader("Configuration") # Sidebar header | |
# st.image('images/banner.png', use_column_width=True) # Removed image | |
st.markdown("---") | |
st.markdown("**Model Parameters**") | |
st.session_state.temperature = st.slider('Temperature:', min_value=0.0, max_value=1.0, value=0.65, step=0.01) | |
st.session_state.maxlength = st.slider('Length reply:', min_value=150, max_value=2000, | |
value=550, step=50) | |
st.session_state.repeat = st.slider('Repeat Penalty:', min_value=0.0, max_value=2.0, value=1.176, step=0.02) | |
st.markdown("---") | |
st.markdown("**Chat Options**") | |
st.session_state.turns = st.toggle('Turn based', value=False, help='Activate Conversational Turn Chat with History', | |
disabled=False, label_visibility="visible") | |
st.markdown(f"*Number of Max Turns*: {st.session_state.maxTurns}") | |
actualTurns = st.markdown(f"*Chat History Lenght*: :green[Good]") | |
statspeed = st.markdown(f'๐ซ speed: {st.session_state.speed} t/s') | |
btnClear = st.button("Clear History",type="primary", use_container_width=True) | |
st.markdown("---") | |
st.markdown("**Logs**") | |
st.markdown(f"**Logfile**: {st.session_state.logfilename}") | |
tokenizer, ov_model, streamer = create_chat() | |
if tokenizer and ov_model and streamer: # Only proceed if model loading was successful | |
# Display chat messages from history on app rerun | |
for message in st.session_state.chatMessages: | |
if message["role"] == "user": | |
with st.chat_message(message["role"],avatar=av_us): | |
st.markdown(message["content"]) | |
else: | |
with st.chat_message(message["role"],avatar=av_ass): | |
st.markdown(message["content"]) | |
# Accept user input using text_area and form for more dynamic updates | |
with st.form(key='chat_form', clear_on_submit=False): # clear_on_submit=False ์ค์! ํผ ๋ด์ฉ ์ ์ง, ์ ์ถ ๋ฒํผ ์ ๊ฑฐ | |
myprompt = st.text_area("What is an AI model?", key="prompt_input", height=100) # text_area ์ฌ์ฉ | |
if myprompt: # myprompt ๊ฐ ์ ๋ ฅ๋๋ฉด (text_area ๋ด์ฉ์ด ๋ณ๊ฒฝ๋๋ฉด) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": myprompt}) | |
st.session_state.chatMessages.append({"role": "user", "content": myprompt}) | |
st.session_state.numOfTurns = len(st.session_state.messages) | |
# Display user message in chat message container | |
with st.chat_message("user", avatar=av_us): | |
st.markdown(myprompt) | |
usertext = f"user: {myprompt}" | |
writehistory(st.session_state.logfilename,usertext) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant",avatar=av_ass): | |
message_placeholder = st.empty() | |
with st.spinner("Thinking..."): | |
start = datetime.datetime.now() | |
response = '' | |
conv_messages = [] | |
if st.session_state.turns: | |
if st.session_state.numOfTurns > st.session_state.maxTurns: | |
conv_messages = st.session_state.messages[-st.session_state.maxTurns:] | |
actualTurns.markdown(f"*Chat History Lenght*: :red[Trimmed]") | |
else: | |
conv_messages = st.session_state.messages | |
else: | |
conv_messages.append(st.session_state.messages[-1]) | |
full_response = "" | |
model_inputs = tokenizer.apply_chat_template(conv_messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_tensors="pt") | |
generate_kwargs = dict(input_ids=model_inputs, | |
max_new_tokens=st.session_state.maxlength, | |
temperature=st.session_state.temperature, | |
do_sample=True, | |
top_p=0.5, | |
repetition_penalty=st.session_state.repeat, | |
streamer=streamer) | |
# ๋น๋๊ธฐ์ ์ผ๋ก ๋ชจ๋ธ ์์ฑ ์คํ (asyncio ์ฌ์ฉ) | |
async def generate_response(): | |
t1 = Thread(target=ov_model.generate, kwargs=generate_kwargs) | |
t1.start() | |
start_time = datetime.datetime.now() | |
partial_text = "" | |
first_token = 0 | |
for chunk in streamer: | |
if first_token == 0: | |
ttft = datetime.datetime.now() - start_time | |
first_token = 1 | |
for char in chunk: | |
partial_text += char | |
message_placeholder.markdown(partial_text + "๐ก") | |
sleep(0.005) # ๋ ๋น ๋ฅธ ํ์๊ธฐ ํจ๊ณผ (0.005์ด๋ก ๊ฐ์, ํ์์ ๋ฐ๋ผ ์กฐ์ ) | |
full_response += chunk | |
delta_time = datetime.datetime.now() - start_time | |
total_seconds = delta_time.total_seconds() | |
prompt_tokens = len(encoding.encode(myprompt)) | |
assistant_tokens = len(encoding.encode(full_response)) | |
total_tokens = prompt_tokens + assistant_tokens | |
st.session_state.speed = total_tokens / total_seconds | |
statspeed.markdown(f'๐ซ speed: {st.session_state.speed:.2f} t/s') | |
delta_time = datetime.datetime.now() - start_time | |
prompt_tokens = len(encoding.encode(myprompt)) | |
assistant_tokens = len(encoding.encode(full_response)) | |
message_placeholder.markdown(full_response) # Display only the response, without stats | |
asstext = f"assistant: {full_response}" | |
writehistory(st.session_state.logfilename, asstext) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
st.session_state.chatMessages.append({"role": "assistant", "content": full_response}) # Store just the response | |
st.session_state.numOfTurns = len(st.session_state.messages) | |
asyncio.run(generate_response()) # ๋น๋๊ธฐ ํจ์ ์คํ | |
if btnClear: # Clear History ๋ฒํผ ํด๋ฆญ ์ | |
st.session_state.messages = [] | |
st.session_state.chatMessages = [] | |
st.session_state.numOfTurns = 0 | |
st.rerun() # Streamlit ์ฑ ๋ค์ ์คํ | |
else: | |
st.error("Model initialization failed. Please check the logs for details.") |