Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from datetime import datetime | |
# Custom CSS for UI | |
st.markdown(""" | |
<style> | |
.main { background-color: #f9f9f9; padding: 20px; } | |
.stTextArea textarea { | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 10px; | |
font-family: 'Roboto', sans-serif; | |
font-size: 16px; | |
background-color: #fff; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
.stButton button { | |
background-color: #4a90e2; | |
color: white; | |
border-radius: 8px; | |
padding: 10px 20px; | |
font-family: 'Roboto', sans-serif; | |
font-size: 14px; | |
} | |
.stButton button:hover { | |
background-color: #357abd; | |
} | |
.code-output { | |
background-color: #2b2b2b; | |
color: #f0f0f0; | |
padding: 15px; | |
border-radius: 8px; | |
font-family: 'Courier New', monospace; | |
font-size: 14px; | |
margin-top: 10px; | |
} | |
.title { | |
font-family: 'Roboto', sans-serif; | |
font-size: 28px; | |
font-weight: bold; | |
color: #333; | |
margin-bottom: 10px; | |
} | |
.subtitle { | |
font-family: 'Roboto', sans-serif; | |
font-size: 16px; | |
color: #666; | |
margin-bottom: 20px; | |
} | |
.chat-message { | |
font-family: 'Roboto', sans-serif; | |
font-size: 16px; | |
color: #333; | |
margin-bottom: 5px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Cache model and tokenizer to avoid reloading | |
def load_model_and_tokenizer(): | |
checkpoint = "Salesforce/codegen-350M-mono" | |
try: | |
st.write("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
st.write("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
st.write("Model and tokenizer loaded successfully!") | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Failed to load model/tokenizer: {e}") | |
return None, None | |
# Load model and tokenizer once | |
tokenizer, model = load_model_and_tokenizer() | |
if tokenizer is None or model is None: | |
st.stop() | |
# Function to generate code | |
def generate_code(description): | |
prompt = f"Generate Python code for the following task: {description}\n" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
try: | |
outputs = model.generate( | |
**inputs, | |
max_length=500, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return code[len(prompt):].strip() | |
except Exception as e: | |
st.error(f"Error generating code: {e}") | |
return "Error: Could not generate code." | |
# Initialize chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# UI Layout | |
st.markdown('<div class="title">Code Generation Bot</div>', unsafe_allow_html=True) | |
st.markdown('<div class="subtitle">Describe your task, and I’ll generate Python code for you!</div>', unsafe_allow_html=True) | |
with st.container(): | |
# Input area | |
description = st.text_area( | |
"Enter your description here", | |
placeholder="e.g., Write a function to calculate the factorial of a number", | |
height=150 | |
) | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
if st.button("Generate"): | |
if description.strip(): | |
with st.spinner("Thinking..."): | |
generated_code = generate_code(description) | |
st.session_state.chat_history.append({ | |
"input": description, | |
"output": generated_code, | |
"time": datetime.now().strftime("%H:%M:%S") | |
}) | |
else: | |
st.warning("Please enter a description first!") | |
with col2: | |
if st.button("Clear History"): | |
st.session_state.chat_history = [] | |
st.success("Chat history cleared!") | |
# Display chat history | |
if st.session_state.chat_history: | |
st.write("### Chat History") | |
for chat in st.session_state.chat_history: | |
st.markdown(f'<div class="chat-message"><strong>You ({chat["time"]}):</strong> {chat["input"]}</div>', unsafe_allow_html=True) | |
st.markdown(f'<div class="code-output">{chat["output"]}</div>', unsafe_allow_html=True) | |
st.markdown("---") | |
st.info("Tip: Check the generated code for accuracy before using it!") |