Spaces:
Running
Running
File size: 5,394 Bytes
67dd891 799ca27 98b95ab 90461f2 80f88de 799ca27 98b95ab 799ca27 80f88de 799ca27 80f88de 799ca27 aed4587 80f88de 7ac1ec8 80f88de e2734d1 80f88de 81619a9 80f88de 7ac1ec8 80f88de 98dfe1e 81619a9 80f88de 7ac1ec8 078216d aed4587 3fbdc08 80f88de 3fbdc08 e2734d1 c2f959e aed4587 80f88de e2734d1 80f88de 078216d 98dfe1e e2734d1 80f88de 98dfe1e 80f88de e2734d1 80f88de 078216d 80f88de 98dfe1e e2734d1 67dd891 c2f959e 078216d 80f88de 98dfe1e e2734d1 98dfe1e 80f88de 98dfe1e 80f88de 078216d 67dd891 e2734d1 98dfe1e 80f88de 67dd891 aed4587 799ca27 e2734d1 799ca27 98b95ab 799ca27 e2734d1 799ca27 e2734d1 67dd891 e2734d1 67dd891 80f88de 67dd891 aed4587 799ca27 fb8bdfe 799ca27 aed4587 e2734d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# =====================================================
# Environment setup
# =====================================================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
# =====================================================
# Model configuration
# =====================================================
GEN_MODEL = "hackergeek/qwen3-harrison-rag"
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("⚠️ No Hugging Face token found. Set one using:")
print(" export HF_TOKEN='your_hf_token_here'")
# =====================================================
# Load private model
# =====================================================
def load_private_model(model_name, token):
dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
load_kwargs = {
"dtype": dtype_value,
"cache_dir": "/tmp/hf_cache",
"low_cpu_mem_usage": True,
}
try:
import accelerate
load_kwargs["device_map"] = "auto"
except ImportError:
print("⚠️ `accelerate` not installed — default device placement used.")
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs)
return tokenizer, model
tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN)
# =====================================================
# Dynamic token allocation
# =====================================================
def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8):
query_tokens = len(tokenizer(query)["input_ids"])
dynamic_tokens = query_tokens * factor
return min(max(dynamic_tokens, min_tokens), max_tokens)
# =====================================================
# Generate long, complete, structured answers
# =====================================================
def generate_answer(query, history):
if not query.strip():
return history, history
# Correct common typos
corrected_query = query.replace("COPP", "COPD")
# Step 1: Rephrase for precise retrieval
rephrase_prompt = (
"You are a medical assistant. Rephrase this query for precise retrieval:\n\n"
f"Query: {corrected_query}\n\nRephrased query:"
)
inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
rephrased_query = tokenizer.decode(
rephrased_ids[0], skip_special_tokens=True
).split("Rephrased query:")[-1].strip()
# Step 2: Generate detailed structured answer
max_tokens = calculate_max_tokens(rephrased_query)
prompt = (
"You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer "
"as if writing a concise clinical guideline. Use markdown headings and bullet points. "
"Each section should include multiple complete sentences and clear explanations.\n\n"
"Follow this structure:\n"
"### Definition / Description\n"
"### Epidemiology / Causes\n"
"### Symptoms & Signs\n"
"### Diagnosis / Investigations\n"
"### Complications\n"
"### Treatment & Management\n"
"### Prognosis / Prevention\n"
"### Key Notes / References\n\n"
"At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English "
"that a non-medical reader could understand.\n\n"
f"User query: {rephrased_query}\n\nAnswer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer = output.split("Answer:")[-1].strip()
# Clean up potential triple breaks
while "\n\n\n" in answer:
answer = answer.replace("\n\n\n", "\n\n")
history = history + [(query, answer)]
return history, history
# =====================================================
# Gradio interface
# =====================================================
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
gr.Markdown("""
# 🧠 Qwen3-Harrison-RAG Medical Chatbot
This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**.
*For educational and informational purposes only — not a substitute for professional medical advice.*
""")
chatbot = gr.Chatbot(height=480, show_label=False)
with gr.Row():
msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4)
clear = gr.Button("Clear", scale=1)
msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# =====================================================
# Launch
# =====================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True) |