DharavathSri's picture
Update app.py
9d3f976 verified
import streamlit as st
from streamlit_chat import message
import time
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Set page configuration
st.set_page_config(
page_title="ChatGPT-Style Chatbot",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for styling with beautiful colors
st.markdown("""
<style>
.stApp {
background-image: linear-gradient(135deg, #ffcccc 0%, #ff9999 100%);
}
.sidebar .sidebar-content {
background-image: linear-gradient(135deg, #6B73FF 0%, #000DFF 100%);
color: black;
}
.stTextInput>div>div>input {
border-radius: 20px;
padding: 10px 15px;
border: 1px solid #d1d5db;
}
.stButton>button {
border-radius: 20px;
padding: 10px 25px;
background-image: linear-gradient(to right, #6B73FF 0%, #000DFF 100%);
color: black;
border: none;
font-weight: 500;
transition: all 0.3s ease;
}
.stButton>button:hover {
background-image: linear-gradient(to right, #000DFF 0%, #6B73FF 100%);
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.chat-container {
background-color: rgba(255, 230, 230, 0.95);
border-radius: 15px;
padding: 20px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05);
border: 1px solid #e5e7eb;
}
.title {
color: #2d3748;
text-align: center;
margin-bottom: 30px;
font-weight: 600;
}
.stSelectbox>div>div>select {
border-radius: 12px;
padding: 8px 12px;
}
.stSlider>div>div>div>div {
background-color: #6B73FF;
}
.st-expander {
border-radius: 12px;
border: 1px solid #e5e7eb;
}
.stMarkdown h1 {
color: #2d3748;
}
</style>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.title("βš™οΈ Chatbot Settings")
st.markdown("""
### ✨ About
This is a ChatGPT-style chatbot powered by a fine-tuned LLM.
""")
# Model selection
model_name = st.selectbox(
"Choose a model",
["gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot-400M-distill"],
index=1
)
# Advanced settings
with st.expander("πŸ”§ Advanced Settings"):
max_length = st.slider("Max response length", 50, 500, 100)
temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
top_p = st.slider("Top-p", 0.1, 1.0, 0.9)
st.markdown("---")
st.markdown("πŸš€ Built with ❀️ using [Streamlit](https://streamlit.io/) and [Hugging Face](https://huggingface.co/)")
# Initialize chat history
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'model' not in st.session_state:
st.session_state['model'] = None
if 'tokenizer' not in st.session_state:
st.session_state['tokenizer'] = None
# Load model
@st.cache_resource
def load_model(model_name):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
# Generate response
def generate_response(prompt):
if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
return "Model not loaded. Please try again."
try:
# Create conversation history context
history = "\n".join([f"User: {p}\nBot: {g}" for p, g in zip(st.session_state['past'], st.session_state['generated'])])
full_prompt = f"{history}\nUser: {prompt}\nBot:"
# Generate response
inputs = st.session_state['tokenizer'].encode(full_prompt, return_tensors="pt")
outputs = st.session_state['model'].generate(
inputs,
max_length=max_length + len(inputs[0]),
temperature=temperature,
top_p=top_p,
pad_token_id=st.session_state['tokenizer'].eos_token_id
)
response = st.session_state['tokenizer'].decode(outputs[0], skip_special_tokens=True)
# Extract only the new response
return response.split("Bot:")[-1].strip()
except Exception as e:
return f"Error generating response: {e}"
# Main app
st.title("πŸ’¬ ChatGPT-Style Chatbot")
st.markdown("""
<div class='title'>
Experience a conversation with our fine-tuned LLM chatbot
</div>
""", unsafe_allow_html=True)
# Container for chat
chat_container = st.container()
# Load model button
if st.button("πŸš€ Load Model"):
with st.spinner(f"Loading {model_name}..."):
st.session_state['model'], st.session_state['tokenizer'] = load_model(model_name)
st.success(f"Model {model_name} loaded successfully!")
# Display chat
with chat_container:
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="identicon")
message(st.session_state['generated'][i], key=str(i), avatar_style="bottts")
# User input
with st.form(key='chat_form', clear_on_submit=True):
user_input = st.text_input("You:", key='input', placeholder="Type your message here...")
submit_button = st.form_submit_button(label='Send ➀')
if submit_button and user_input:
if st.session_state['model'] is None or st.session_state['tokenizer'] is None:
st.warning("⚠️ Please load the model first!")
else:
# Add user message to chat history
st.session_state['past'].append(user_input)
# Generate response
with st.spinner("πŸ€” Thinking..."):
response = generate_response(user_input)
# Add bot response to chat history
st.session_state['generated'].append(response)
# Rerun to update the chat display
st.experimental_rerun()