Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| from inspect import signature | |
| # ===================================================== | |
| # OPTION: Use ephemeral /tmp cache | |
| # ===================================================== | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HF_HOME"] = "/tmp/hf_home" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets" | |
| os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules" | |
| # ===================================================== | |
| # 1️⃣ Model setup | |
| # ===================================================== | |
| GEN_MODEL_PRIVATE = "hackergeek/qwen3-harrison-rag" | |
| GEN_MODEL_PUBLIC = "Qwen/Qwen2.5-1.5B-Instruct" | |
| EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("⚠️ No Hugging Face token found. Private models may fail to load.") | |
| # --- Check if accelerate is available --- | |
| try: | |
| import accelerate | |
| accelerate_available = True | |
| except ImportError: | |
| accelerate_available = False | |
| print("⚠️ `accelerate` not installed. Large private models with device_map='auto' may fail.") | |
| # --- Helper to load model safely --- | |
| def load_model(model_name, token=None): | |
| dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| try: | |
| param_names = signature(AutoModelForCausalLM.from_pretrained).parameters | |
| dtype_arg = "dtype" if "dtype" in param_names else "torch_dtype" | |
| load_kwargs = { | |
| dtype_arg: dtype_value, | |
| "cache_dir": "/tmp/hf_cache", | |
| "low_cpu_mem_usage": True, | |
| } | |
| if accelerate_available: | |
| load_kwargs["device_map"] = "auto" | |
| if token: | |
| load_kwargs["token"] = token | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs) | |
| return tokenizer, model | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model '{model_name}': {e}") | |
| # --- Attempt to load private model, fallback to public --- | |
| try: | |
| tokenizer, model = load_model(GEN_MODEL_PRIVATE, token=HF_TOKEN) | |
| print(f"✅ Loaded private model: {GEN_MODEL_PRIVATE}") | |
| except Exception as e: | |
| print(f"❌ {e}\n➡️ Falling back to public model: {GEN_MODEL_PUBLIC}") | |
| tokenizer, model = load_model(GEN_MODEL_PUBLIC) | |
| print(f"✅ Loaded public model: {GEN_MODEL_PUBLIC}") | |
| # --- Load embedding model --- | |
| embedder = SentenceTransformer(EMB_MODEL, cache_folder="/tmp/hf_cache") | |
| # ===================================================== | |
| # 2️⃣ Retrieval + generation logic | |
| # ===================================================== | |
| # Placeholder FAISS index and chunks (replace with your actual documents) | |
| index = faiss.IndexFlatL2(384) | |
| chunks = ["This is a sample context chunk. Replace with real documents."] | |
| def retrieve_context(query, k=5): | |
| q_emb = embedder.encode([query], convert_to_numpy=True) | |
| if index.ntotal == 0: | |
| return "No context available (index empty)." | |
| D, I = index.search(q_emb, k) | |
| return "\n\n".join([chunks[i] for i in I[0]]) | |
| def generate_response(query, history): | |
| context = retrieve_context(query) | |
| system_prompt = ( | |
| "You are a helpful assistant that uses the retrieved context to answer questions.\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"User: {query}\nAssistant:" | |
| ) | |
| inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device) | |
| output_ids = model.generate(**inputs, max_new_tokens=300) | |
| output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return output.split("Assistant:")[-1].strip() | |
| def chat_fn(user_message, history): | |
| response = generate_response(user_message, history) | |
| history = history + [(user_message, response)] | |
| return history, history | |
| # ===================================================== | |
| # 3️⃣ Gradio UI | |
| # ===================================================== | |
| with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo: | |
| gr.Markdown(""" | |
| # 🤖 Qwen3-Harrison-RAG Chatbot | |
| Ask me anything — I’ll retrieve relevant context and answer! | |
| """) | |
| chatbot = gr.Chatbot(height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Type your message here...", scale=4) | |
| clear = gr.Button("Clear", scale=1) | |
| msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # ===================================================== | |
| # 4️⃣ Launch | |
| # ===================================================== | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |