HARRISON_GPT / app.py
hackergeek's picture
Update app.py
e2734d1 verified
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# =====================================================
# Environment setup
# =====================================================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"
# =====================================================
# Model configuration
# =====================================================
GEN_MODEL = "hackergeek/qwen3-harrison-rag"
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("⚠️ No Hugging Face token found. Set one using:")
print(" export HF_TOKEN='your_hf_token_here'")
# =====================================================
# Load private model
# =====================================================
def load_private_model(model_name, token):
dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
load_kwargs = {
"dtype": dtype_value,
"cache_dir": "/tmp/hf_cache",
"low_cpu_mem_usage": True,
}
try:
import accelerate
load_kwargs["device_map"] = "auto"
except ImportError:
print("⚠️ `accelerate` not installed — default device placement used.")
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs)
return tokenizer, model
tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN)
# =====================================================
# Dynamic token allocation
# =====================================================
def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8):
query_tokens = len(tokenizer(query)["input_ids"])
dynamic_tokens = query_tokens * factor
return min(max(dynamic_tokens, min_tokens), max_tokens)
# =====================================================
# Generate long, complete, structured answers
# =====================================================
def generate_answer(query, history):
if not query.strip():
return history, history
# Correct common typos
corrected_query = query.replace("COPP", "COPD")
# Step 1: Rephrase for precise retrieval
rephrase_prompt = (
"You are a medical assistant. Rephrase this query for precise retrieval:\n\n"
f"Query: {corrected_query}\n\nRephrased query:"
)
inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
rephrased_query = tokenizer.decode(
rephrased_ids[0], skip_special_tokens=True
).split("Rephrased query:")[-1].strip()
# Step 2: Generate detailed structured answer
max_tokens = calculate_max_tokens(rephrased_query)
prompt = (
"You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer "
"as if writing a concise clinical guideline. Use markdown headings and bullet points. "
"Each section should include multiple complete sentences and clear explanations.\n\n"
"Follow this structure:\n"
"### Definition / Description\n"
"### Epidemiology / Causes\n"
"### Symptoms & Signs\n"
"### Diagnosis / Investigations\n"
"### Complications\n"
"### Treatment & Management\n"
"### Prognosis / Prevention\n"
"### Key Notes / References\n\n"
"At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English "
"that a non-medical reader could understand.\n\n"
f"User query: {rephrased_query}\n\nAnswer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer = output.split("Answer:")[-1].strip()
# Clean up potential triple breaks
while "\n\n\n" in answer:
answer = answer.replace("\n\n\n", "\n\n")
history = history + [(query, answer)]
return history, history
# =====================================================
# Gradio interface
# =====================================================
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
gr.Markdown("""
# 🧠 Qwen3-Harrison-RAG Medical Chatbot
This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**.
*For educational and informational purposes only — not a substitute for professional medical advice.*
""")
chatbot = gr.Chatbot(height=480, show_label=False)
with gr.Row():
msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4)
clear = gr.Button("Clear", scale=1)
msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# =====================================================
# Launch
# =====================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)