HARRISON_GPT / app.py
hackergeek's picture
Update app.py
7ac1ec8 verified
raw
history blame
4.81 kB
import os
import torch
import gradio as gr
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from inspect import signature
# =====================================================
# OPTION: Use ephemeral /tmp cache
# =====================================================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules"
# =====================================================
# 1️⃣ Model setup
# =====================================================
GEN_MODEL_PRIVATE = "hackergeek/qwen3-harrison-rag"
GEN_MODEL_PUBLIC = "Qwen/Qwen2.5-1.5B-Instruct"
EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("⚠️ No Hugging Face token found. Private models may fail to load.")
# --- Check if accelerate is available ---
try:
import accelerate
accelerate_available = True
except ImportError:
accelerate_available = False
print("⚠️ `accelerate` not installed. Large private models with device_map='auto' may fail.")
# --- Helper to load model safely ---
def load_model(model_name, token=None):
dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
try:
param_names = signature(AutoModelForCausalLM.from_pretrained).parameters
dtype_arg = "dtype" if "dtype" in param_names else "torch_dtype"
load_kwargs = {
dtype_arg: dtype_value,
"cache_dir": "/tmp/hf_cache",
"low_cpu_mem_usage": True,
}
if accelerate_available:
load_kwargs["device_map"] = "auto"
if token:
load_kwargs["token"] = token
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
return tokenizer, model
except Exception as e:
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
# --- Attempt to load private model, fallback to public ---
try:
tokenizer, model = load_model(GEN_MODEL_PRIVATE, token=HF_TOKEN)
print(f"✅ Loaded private model: {GEN_MODEL_PRIVATE}")
except Exception as e:
print(f"❌ {e}\n➡️ Falling back to public model: {GEN_MODEL_PUBLIC}")
tokenizer, model = load_model(GEN_MODEL_PUBLIC)
print(f"✅ Loaded public model: {GEN_MODEL_PUBLIC}")
# --- Load embedding model ---
embedder = SentenceTransformer(EMB_MODEL, cache_folder="/tmp/hf_cache")
# =====================================================
# 2️⃣ Retrieval + generation logic
# =====================================================
# Placeholder FAISS index and chunks (replace with your actual documents)
index = faiss.IndexFlatL2(384)
chunks = ["This is a sample context chunk. Replace with real documents."]
def retrieve_context(query, k=5):
q_emb = embedder.encode([query], convert_to_numpy=True)
if index.ntotal == 0:
return "No context available (index empty)."
D, I = index.search(q_emb, k)
return "\n\n".join([chunks[i] for i in I[0]])
def generate_response(query, history):
context = retrieve_context(query)
system_prompt = (
"You are a helpful assistant that uses the retrieved context to answer questions.\n\n"
f"Context:\n{context}\n\n"
f"User: {query}\nAssistant:"
)
inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=300)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output.split("Assistant:")[-1].strip()
def chat_fn(user_message, history):
response = generate_response(user_message, history)
history = history + [(user_message, response)]
return history, history
# =====================================================
# 3️⃣ Gradio UI
# =====================================================
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
gr.Markdown("""
# 🤖 Qwen3-Harrison-RAG Chatbot
Ask me anything — I’ll retrieve relevant context and answer!
""")
chatbot = gr.Chatbot(height=400)
with gr.Row():
msg = gr.Textbox(placeholder="Type your message here...", scale=4)
clear = gr.Button("Clear", scale=1)
msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# =====================================================
# 4️⃣ Launch
# =====================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))