File size: 5,394 Bytes
67dd891
799ca27
 
 
98b95ab
90461f2
80f88de
799ca27
98b95ab
 
 
799ca27
80f88de
799ca27
80f88de
799ca27
aed4587
80f88de
 
 
7ac1ec8
80f88de
e2734d1
80f88de
 
81619a9
80f88de
 
 
 
 
7ac1ec8
80f88de
 
 
98dfe1e
81619a9
80f88de
 
 
7ac1ec8
078216d
aed4587
3fbdc08
80f88de
3fbdc08
e2734d1
c2f959e
 
 
aed4587
80f88de
e2734d1
80f88de
 
 
 
 
078216d
98dfe1e
 
e2734d1
80f88de
98dfe1e
 
80f88de
 
e2734d1
 
 
 
80f88de
078216d
80f88de
98dfe1e
e2734d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67dd891
c2f959e
078216d
80f88de
 
 
98dfe1e
e2734d1
 
 
98dfe1e
80f88de
98dfe1e
80f88de
078216d
67dd891
e2734d1
 
 
98dfe1e
80f88de
67dd891
aed4587
799ca27
e2734d1
799ca27
98b95ab
799ca27
e2734d1
 
 
799ca27
e2734d1
67dd891
e2734d1
67dd891
80f88de
67dd891
aed4587
799ca27
fb8bdfe
799ca27
aed4587
e2734d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# =====================================================
# Environment setup
# =====================================================
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_home"

# =====================================================
# Model configuration
# =====================================================
GEN_MODEL = "hackergeek/qwen3-harrison-rag"
HF_TOKEN = os.getenv("HF_TOKEN")

if not HF_TOKEN:
    print("⚠️ No Hugging Face token found. Set one using:")
    print("   export HF_TOKEN='your_hf_token_here'")

# =====================================================
# Load private model
# =====================================================
def load_private_model(model_name, token):
    dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
    load_kwargs = {
        "dtype": dtype_value,
        "cache_dir": "/tmp/hf_cache",
        "low_cpu_mem_usage": True,
    }
    try:
        import accelerate
        load_kwargs["device_map"] = "auto"
    except ImportError:
        print("⚠️ `accelerate` not installed — default device placement used.")

    tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs)
    return tokenizer, model

tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN)

# =====================================================
# Dynamic token allocation
# =====================================================
def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8):
    query_tokens = len(tokenizer(query)["input_ids"])
    dynamic_tokens = query_tokens * factor
    return min(max(dynamic_tokens, min_tokens), max_tokens)

# =====================================================
# Generate long, complete, structured answers
# =====================================================
def generate_answer(query, history):
    if not query.strip():
        return history, history

    # Correct common typos
    corrected_query = query.replace("COPP", "COPD")

    # Step 1: Rephrase for precise retrieval
    rephrase_prompt = (
        "You are a medical assistant. Rephrase this query for precise retrieval:\n\n"
        f"Query: {corrected_query}\n\nRephrased query:"
    )
    inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
    rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
    rephrased_query = tokenizer.decode(
        rephrased_ids[0], skip_special_tokens=True
    ).split("Rephrased query:")[-1].strip()

    # Step 2: Generate detailed structured answer
    max_tokens = calculate_max_tokens(rephrased_query)
    prompt = (
        "You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer "
        "as if writing a concise clinical guideline. Use markdown headings and bullet points. "
        "Each section should include multiple complete sentences and clear explanations.\n\n"
        "Follow this structure:\n"
        "### Definition / Description\n"
        "### Epidemiology / Causes\n"
        "### Symptoms & Signs\n"
        "### Diagnosis / Investigations\n"
        "### Complications\n"
        "### Treatment & Management\n"
        "### Prognosis / Prevention\n"
        "### Key Notes / References\n\n"
        "At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English "
        "that a non-medical reader could understand.\n\n"
        f"User query: {rephrased_query}\n\nAnswer:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        repetition_penalty=1.2,
        pad_token_id=tokenizer.eos_token_id,
    )

    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    answer = output.split("Answer:")[-1].strip()

    # Clean up potential triple breaks
    while "\n\n\n" in answer:
        answer = answer.replace("\n\n\n", "\n\n")

    history = history + [(query, answer)]
    return history, history

# =====================================================
# Gradio interface
# =====================================================
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
    gr.Markdown("""
    # 🧠 Qwen3-Harrison-RAG Medical Chatbot  
    This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**.  
    *For educational and informational purposes only — not a substitute for professional medical advice.*
    """)
    chatbot = gr.Chatbot(height=480, show_label=False)
    with gr.Row():
        msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4)
        clear = gr.Button("Clear", scale=1)
    msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

# =====================================================
# Launch
# =====================================================
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)