Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ===================================================== | |
| # Environment setup | |
| # ===================================================== | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["HF_HOME"] = "/tmp/hf_home" | |
| # ===================================================== | |
| # Model configuration | |
| # ===================================================== | |
| GEN_MODEL = "hackergeek/qwen3-harrison-rag" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("⚠️ No Hugging Face token found. Set one using:") | |
| print(" export HF_TOKEN='your_hf_token_here'") | |
| # ===================================================== | |
| # Load private model | |
| # ===================================================== | |
| def load_private_model(model_name, token): | |
| dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| load_kwargs = { | |
| "dtype": dtype_value, | |
| "cache_dir": "/tmp/hf_cache", | |
| "low_cpu_mem_usage": True, | |
| } | |
| try: | |
| import accelerate | |
| load_kwargs["device_map"] = "auto" | |
| except ImportError: | |
| print("⚠️ `accelerate` not installed — default device placement used.") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs) | |
| return tokenizer, model | |
| tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN) | |
| # ===================================================== | |
| # Dynamic token allocation | |
| # ===================================================== | |
| def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8): | |
| query_tokens = len(tokenizer(query)["input_ids"]) | |
| dynamic_tokens = query_tokens * factor | |
| return min(max(dynamic_tokens, min_tokens), max_tokens) | |
| # ===================================================== | |
| # Generate long, complete, structured answers | |
| # ===================================================== | |
| def generate_answer(query, history): | |
| if not query.strip(): | |
| return history, history | |
| # Correct common typos | |
| corrected_query = query.replace("COPP", "COPD") | |
| # Step 1: Rephrase for precise retrieval | |
| rephrase_prompt = ( | |
| "You are a medical assistant. Rephrase this query for precise retrieval:\n\n" | |
| f"Query: {corrected_query}\n\nRephrased query:" | |
| ) | |
| inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device) | |
| rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) | |
| rephrased_query = tokenizer.decode( | |
| rephrased_ids[0], skip_special_tokens=True | |
| ).split("Rephrased query:")[-1].strip() | |
| # Step 2: Generate detailed structured answer | |
| max_tokens = calculate_max_tokens(rephrased_query) | |
| prompt = ( | |
| "You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer " | |
| "as if writing a concise clinical guideline. Use markdown headings and bullet points. " | |
| "Each section should include multiple complete sentences and clear explanations.\n\n" | |
| "Follow this structure:\n" | |
| "### Definition / Description\n" | |
| "### Epidemiology / Causes\n" | |
| "### Symptoms & Signs\n" | |
| "### Diagnosis / Investigations\n" | |
| "### Complications\n" | |
| "### Treatment & Management\n" | |
| "### Prognosis / Prevention\n" | |
| "### Key Notes / References\n\n" | |
| "At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English " | |
| "that a non-medical reader could understand.\n\n" | |
| f"User query: {rephrased_query}\n\nAnswer:" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| answer = output.split("Answer:")[-1].strip() | |
| # Clean up potential triple breaks | |
| while "\n\n\n" in answer: | |
| answer = answer.replace("\n\n\n", "\n\n") | |
| history = history + [(query, answer)] | |
| return history, history | |
| # ===================================================== | |
| # Gradio interface | |
| # ===================================================== | |
| with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo: | |
| gr.Markdown(""" | |
| # 🧠 Qwen3-Harrison-RAG Medical Chatbot | |
| This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**. | |
| *For educational and informational purposes only — not a substitute for professional medical advice.* | |
| """) | |
| chatbot = gr.Chatbot(height=480, show_label=False) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4) | |
| clear = gr.Button("Clear", scale=1) | |
| msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # ===================================================== | |
| # Launch | |
| # ===================================================== | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True) |