Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from huggingface_hub import snapshot_download | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import logging | |
from dotenv import load_dotenv | |
import base64 | |
import io | |
import re | |
# Set up logging to track application behavior and debug issues | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load Hugging Face API token from environment variables | |
API_TOKEN = os.getenv("HF_TOKEN") | |
if not API_TOKEN: | |
raise ValueError("HUGGINGFACE_API_TOKEN environment variable not set. Set it in Space secrets.") | |
# Initialize FastAPI application | |
app = FastAPI() | |
# Enable CORS to allow frontend-backend communication from any origin | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount the 'static' directory to serve frontend assets (e.g., index.html, script.js, style.css) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Configure the Hugging Face model for code generation | |
MODEL_NAME = "Qwen/Qwen2.5-Coder-1.5B-Instruct" | |
model_dir = "./qwen_model" | |
# Log the working directory and model path for debugging | |
logger.info(f"Current working directory: {os.getcwd()}") | |
logger.info(f"Model directory path: {os.path.abspath(model_dir)}") | |
# Download the model if it doesn't exist | |
try: | |
os.makedirs(model_dir, exist_ok=True) | |
if not os.listdir(model_dir): # Only download if directory is empty | |
logger.info(f"Downloading model {MODEL_NAME} to {model_dir}") | |
snapshot_download(repo_id=MODEL_NAME, token=API_TOKEN, local_dir=model_dir) | |
logger.info(f"Model downloaded. Directory contents: {os.listdir(model_dir)}") | |
else: | |
logger.info(f"Model directory {model_dir} already contains files: {os.listdir(model_dir)}") | |
except Exception as e: | |
logger.error(f"Failed to download model: {str(e)}") | |
raise ValueError(f"Model download failed: {str(e)}") | |
# Load the model and tokenizer for code generation | |
logger.info(f"Loading model from {model_dir}") | |
tokenizer = AutoTokenizer.from_pretrained(model_dir, token=API_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(model_dir, token=API_TOKEN) | |
# Set pad_token_id to eos_token_id to avoid tokenizer warnings | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# Create a text generation pipeline using the loaded model | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) # CPU for free Space, change to 0 for GPU | |
# Create directory for uploaded Excel files | |
UPLOAD_DIR = "uploads" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
# Endpoint to handle Excel file uploads | |
async def upload_file(file: UploadFile = File(...)): | |
# Ensure the uploaded file is an Excel file (.xlsx) | |
if not file.filename.endswith(".xlsx"): | |
raise HTTPException(status_code=400, detail="File must be an Excel file (.xlsx)") | |
# Save the file to the uploads directory | |
file_path = os.path.join(UPLOAD_DIR, file.filename) | |
with open(file_path, "wb") as buffer: | |
buffer.write(await file.read()) | |
logger.info(f"File uploaded: {file.filename}") | |
return {"filename": file.filename} | |
# Endpoint to generate a visualization based on a user prompt | |
async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)): | |
# Check if the uploaded file exists | |
file_path = os.path.join(UPLOAD_DIR, filename) | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="File not found on server.") | |
# Load the Excel file into a pandas DataFrame | |
try: | |
df = pd.read_excel(file_path) | |
if df.empty: | |
raise ValueError("Excel file is empty.") | |
logger.info(f"DataFrame columns: {df.columns.tolist()}") | |
logger.info(f"DataFrame preview:\n{df.head().to_string()}") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error reading Excel file: {str(e)}") | |
# Create a prompt for the model, specifying the DataFrame and visualization requirements | |
input_text = f""" | |
Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview: | |
{df.head().to_string()} | |
Write Python code to: {prompt} | |
- Use ONLY 'df =pd.read_excel({filename})' (no external data loading like pd.read_csv, pd.read_excel, or creating a new DataFrame). | |
- Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns). | |
- Include axis labels and a title. | |
- Output ONLY executable Python code. Do NOT include triple quotes, prose, Markdown, or text like 'Hint', 'Solution', or 'Here is the code'. | |
""" | |
# Generate code using the model | |
try: | |
generated = generator(input_text, max_new_tokens=500, num_return_sequences=1) | |
generated_code = generated[0]["generated_text"].replace(input_text, "").strip() | |
logger.info(f"Generated code:\n{generated_code}") | |
except Exception as e: | |
logger.error(f"Error querying model: {str(e)}") | |
return { | |
"plot_base64": None, | |
"generated_code": "", | |
"error": f"Error querying model: {str(e)}" | |
} | |
# Handle empty generated code | |
if not generated_code.strip(): | |
return { | |
"plot_base64": None, | |
"generated_code": "", | |
"error": "No code generated by the AI model." | |
} | |
# Extract code block between ```python and ```, strictly requiring a valid code block | |
code_block_pattern = r"```python\n(.*?)\n```" | |
matches = list(re.finditer(code_block_pattern, generated_code, re.DOTALL)) | |
if matches: | |
# Take the first code block for execution and display | |
raw_code_block = matches[0].group(1).strip() # Raw code for display | |
executable_code = raw_code_block # Will be cleaned for execution | |
logger.info(f"Raw code block:\n{raw_code_block}") | |
else: | |
logger.error("No valid Python code block found in generated output.") | |
return { | |
"plot_base64": None, | |
"generated_code": generated_code, | |
"error": "No valid Python code block found in generated output." | |
} | |
# Clean the code for execution: remove comments, empty lines, disallowed data loading, plt.show(), and df redefinition | |
executable_code = "\n".join( | |
line.strip() for line in executable_code.splitlines() | |
if line.strip() and | |
not line.strip().startswith('#') and | |
not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "plt.show", "df ="]) | |
).strip() | |
# Clean the raw code block for display: remove comments and empty lines, but keep other lines like pd.read_excel and plt.show() | |
display_code = "\n".join( | |
line.strip() for line in raw_code_block.splitlines() | |
if line.strip() and | |
not line.strip().startswith('#') | |
).strip() | |
logger.info(f"Display code (comments removed):\n{display_code}") | |
# Handle empty code after cleaning for execution | |
if not executable_code: | |
logger.error("No valid executable code after cleaning.") | |
return { | |
"plot_base64": None, | |
"generated_code": display_code, | |
"error": "Generated code was invalid (e.g., included data loading, df redefinition, or was empty)." | |
} | |
logger.info(f"Executable code:\n{executable_code}") | |
# Execute the code and generate the plot | |
try: | |
exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df} | |
exec(executable_code, exec_globals) | |
# Save the plot to a BytesIO buffer (no disk storage) | |
buffer = io.BytesIO() | |
plt.savefig(buffer, format="png", bbox_inches="tight") | |
plt.close() | |
buffer.seek(0) | |
# Encode the plot as base64 for frontend display | |
plot_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
except Exception as e: | |
logger.error(f"Error executing code:\n{executable_code}\nException: {str(e)}") | |
return { | |
"plot_base64": None, | |
"generated_code": display_code, | |
"error": f"Error executing code: {str(e)}" | |
} | |
# Return the plot, display code (without comments), and any error message | |
return { | |
"plot_base64": plot_base64, | |
"generated_code": display_code, | |
"error": None | |
} | |
# Serve the frontend HTML | |
async def serve_frontend(): | |
with open("static/index.html", "r") as f: | |
return HTMLResponse(content=f.read()) |