Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_chat import message | |
| import time | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="ChatGPT-Style Chatbot", | |
| page_icon="π€", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for styling with beautiful colors | |
| st.markdown(""" | |
| <style> | |
| .stApp { | |
| background-image: linear-gradient(135deg, #ffcccc 0%, #ff9999 100%); | |
| } | |
| .sidebar .sidebar-content { | |
| background-image: linear-gradient(135deg, #6B73FF 0%, #000DFF 100%); | |
| color: black; | |
| } | |
| .stTextInput>div>div>input { | |
| border-radius: 20px; | |
| padding: 10px 15px; | |
| border: 1px solid #d1d5db; | |
| } | |
| .stButton>button { | |
| border-radius: 20px; | |
| padding: 10px 25px; | |
| background-image: linear-gradient(to right, #6B73FF 0%, #000DFF 100%); | |
| color: black; | |
| border: none; | |
| font-weight: 500; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton>button:hover { | |
| background-image: linear-gradient(to right, #000DFF 0%, #6B73FF 100%); | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| .chat-container { | |
| background-color: rgba(255, 230, 230, 0.95); | |
| border-radius: 15px; | |
| padding: 20px; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05); | |
| border: 1px solid #e5e7eb; | |
| } | |
| .title { | |
| color: #2d3748; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| font-weight: 600; | |
| } | |
| .stSelectbox>div>div>select { | |
| border-radius: 12px; | |
| padding: 8px 12px; | |
| } | |
| .stSlider>div>div>div>div { | |
| background-color: #6B73FF; | |
| } | |
| .st-expander { | |
| border-radius: 12px; | |
| border: 1px solid #e5e7eb; | |
| } | |
| .stMarkdown h1 { | |
| color: #2d3748; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Sidebar | |
| with st.sidebar: | |
| st.title("βοΈ Chatbot Settings") | |
| st.markdown(""" | |
| ### β¨ About | |
| This is a ChatGPT-style chatbot powered by a fine-tuned LLM. | |
| """) | |
| # Model selection | |
| model_name = st.selectbox( | |
| "Choose a model", | |
| ["gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot-400M-distill"], | |
| index=1 | |
| ) | |
| # Advanced settings | |
| with st.expander("π§ Advanced Settings"): | |
| max_length = st.slider("Max response length", 50, 500, 100) | |
| temperature = st.slider("Temperature", 0.1, 1.0, 0.7) | |
| top_p = st.slider("Top-p", 0.1, 1.0, 0.9) | |
| st.markdown("---") | |
| st.markdown("π Built with β€οΈ using [Streamlit](https://streamlit.io/) and [Hugging Face](https://huggingface.co/)") | |
| # Initialize chat history | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'model' not in st.session_state: | |
| st.session_state['model'] = None | |
| if 'tokenizer' not in st.session_state: | |
| st.session_state['tokenizer'] = None | |
| # Load model | |
| def load_model(model_name): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return None, None | |
| # Generate response | |
| def generate_response(prompt): | |
| if st.session_state['model'] is None or st.session_state['tokenizer'] is None: | |
| return "Model not loaded. Please try again." | |
| try: | |
| # Create conversation history context | |
| history = "\n".join([f"User: {p}\nBot: {g}" for p, g in zip(st.session_state['past'], st.session_state['generated'])]) | |
| full_prompt = f"{history}\nUser: {prompt}\nBot:" | |
| # Generate response | |
| inputs = st.session_state['tokenizer'].encode(full_prompt, return_tensors="pt") | |
| outputs = st.session_state['model'].generate( | |
| inputs, | |
| max_length=max_length + len(inputs[0]), | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=st.session_state['tokenizer'].eos_token_id | |
| ) | |
| response = st.session_state['tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the new response | |
| return response.split("Bot:")[-1].strip() | |
| except Exception as e: | |
| return f"Error generating response: {e}" | |
| # Main app | |
| st.title("π¬ ChatGPT-Style Chatbot") | |
| st.markdown(""" | |
| <div class='title'> | |
| Experience a conversation with our fine-tuned LLM chatbot | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Container for chat | |
| chat_container = st.container() | |
| # Load model button | |
| if st.button("π Load Model"): | |
| with st.spinner(f"Loading {model_name}..."): | |
| st.session_state['model'], st.session_state['tokenizer'] = load_model(model_name) | |
| st.success(f"Model {model_name} loaded successfully!") | |
| # Display chat | |
| with chat_container: | |
| if st.session_state['generated']: | |
| for i in range(len(st.session_state['generated'])): | |
| message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="identicon") | |
| message(st.session_state['generated'][i], key=str(i), avatar_style="bottts") | |
| # User input | |
| with st.form(key='chat_form', clear_on_submit=True): | |
| user_input = st.text_input("You:", key='input', placeholder="Type your message here...") | |
| submit_button = st.form_submit_button(label='Send β€') | |
| if submit_button and user_input: | |
| if st.session_state['model'] is None or st.session_state['tokenizer'] is None: | |
| st.warning("β οΈ Please load the model first!") | |
| else: | |
| # Add user message to chat history | |
| st.session_state['past'].append(user_input) | |
| # Generate response | |
| with st.spinner("π€ Thinking..."): | |
| response = generate_response(user_input) | |
| # Add bot response to chat history | |
| st.session_state['generated'].append(response) | |
| # Rerun to update the chat display | |
| st.experimental_rerun() |