hackergeek commited on
Commit
80f88de
·
verified ·
1 Parent(s): 3fbdc08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -144
app.py CHANGED
@@ -1,14 +1,10 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- import faiss
5
- import numpy as np
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from sentence_transformers import SentenceTransformer
8
- from inspect import signature
9
 
10
  # =====================================================
11
- # Cache setup
12
  # =====================================================
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
14
  os.environ["HF_HOME"] = "/tmp/hf_home"
@@ -16,174 +12,108 @@ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
16
  os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules"
17
 
18
  # =====================================================
19
- # Model setup
20
  # =====================================================
21
- GEN_MODEL_PRIVATE = "hackergeek/qwen3-harrison-rag"
22
- GEN_MODEL_PUBLIC = "Qwen/Qwen2.5-1.5B-Instruct"
23
- EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
24
-
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
- if not HF_TOKEN:
27
- print("⚠️ No Hugging Face token found. Private models may fail to load.")
28
 
29
- try:
30
- import accelerate
31
- accelerate_available = True
32
- except ImportError:
33
- accelerate_available = False
34
- print("⚠️ `accelerate` not installed. Large private models with device_map='auto' may fail.")
35
 
36
- # --- Load model helper ---
37
- def load_model(model_name, token=None):
 
 
38
  dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
39
  try:
40
- param_names = signature(AutoModelForCausalLM.from_pretrained).parameters
41
- dtype_arg = "dtype" if "dtype" in param_names else "torch_dtype"
42
-
43
- load_kwargs = {
44
- dtype_arg: dtype_value,
45
- "cache_dir": "/tmp/hf_cache",
46
- "low_cpu_mem_usage": True,
47
- }
48
- if accelerate_available:
49
- load_kwargs["device_map"] = "auto"
50
- if token:
51
- load_kwargs["token"] = token
52
 
53
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
54
- model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
55
- return tokenizer, model
56
- except Exception as e:
57
- raise RuntimeError(f"Failed to load model '{model_name}': {e}")
58
 
59
- # --- Attempt private model, fallback to public ---
60
  try:
61
- tokenizer, model = load_model(GEN_MODEL_PRIVATE, token=HF_TOKEN)
62
- print(f"✅ Loaded private model: {GEN_MODEL_PRIVATE}")
63
  except Exception as e:
64
- print(f"❌ {e}\n➡️ Falling back to public model: {GEN_MODEL_PUBLIC}")
65
- tokenizer, model = load_model(GEN_MODEL_PUBLIC)
66
- print(f"✅ Loaded public model: {GEN_MODEL_PUBLIC}")
67
-
68
- # --- Load embedding model ---
69
- embedder = SentenceTransformer(EMB_MODEL, cache_folder="/tmp/hf_cache")
70
-
71
- # =====================================================
72
- # FAISS index setup
73
- # =====================================================
74
- # Example medical text; replace with full dataset
75
- documents = [
76
- "Infliximab is a humanized monoclonal antibody used in rheumatoid arthritis. "
77
- "It is administered intravenously at 3–5 mg/kg every 6–8 weeks.",
78
- "Colitis ulcerum is a chronic inflammatory disorder of the colon characterized by ulcerated erosions.",
79
- "COPD is a chronic obstructive pulmonary disease with progressive airflow limitation."
80
- ]
81
-
82
- # Function to split documents into chunks
83
- def chunk_text(text, chunk_size=150):
84
- words = text.split()
85
- chunks = []
86
- for i in range(0, len(words), chunk_size):
87
- chunk = " ".join(words[i:i+chunk_size])
88
- chunks.append(chunk)
89
- return chunks
90
-
91
- # Create all chunks and embeddings
92
- chunks = []
93
- for doc in documents:
94
- chunks.extend(chunk_text(doc))
95
-
96
- chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True)
97
- index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
98
- index.add(np.array(chunk_embeddings))
99
 
100
  # =====================================================
101
- # Retrieval + generation
102
  # =====================================================
103
- def retrieve_context(query, max_k=5, distance_threshold=0.7, max_tokens=1500):
104
- q_emb = embedder.encode([query], convert_to_numpy=True)
105
- if index.ntotal == 0:
106
- return "No context available."
107
- D, I = index.search(q_emb, max_k)
108
- sorted_idx = [i for _, i in sorted(zip(D[0], I[0]))] # deterministic
109
- context = []
110
- total_tokens = 0
111
- for idx in sorted_idx:
112
- if D[0][list(sorted_idx).index(idx)] > distance_threshold:
113
- continue
114
- chunk_tokens = len(tokenizer(chunks[idx])["input_ids"])
115
- if total_tokens + chunk_tokens > max_tokens:
116
- break
117
- context.append(chunks[idx])
118
- total_tokens += chunk_tokens
119
- return "\n\n".join(context) if context else "No context available."
120
-
121
- def calculate_max_tokens(query, min_tokens=50, max_tokens=800, factor=3):
122
  query_tokens = len(tokenizer(query)["input_ids"])
123
  dynamic_tokens = query_tokens * factor
124
  return min(max(dynamic_tokens, min_tokens), max_tokens)
125
 
126
- def generate_full_answer(query, history, max_loops=3):
127
- torch.manual_seed(42)
128
- np.random.seed(42)
129
-
130
- context = retrieve_context(query)
131
- prompt = (
132
- "You are a helpful assistant. ONLY use the retrieved context to answer questions.\n\n"
133
- f"Context:\n{context}\n\n"
134
- f"User: {query}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
136
 
137
- full_response = ""
138
- remaining_prompt = prompt
139
- loop_count = 0
140
-
141
- while loop_count < max_loops:
142
- inputs = tokenizer(remaining_prompt, return_tensors="pt").to(model.device)
143
- max_new_tokens = calculate_max_tokens(query)
144
-
145
- output_ids = model.generate(
146
- **inputs,
147
- max_new_tokens=max_new_tokens,
148
- do_sample=False,
149
- pad_token_id=tokenizer.eos_token_id,
150
- no_repeat_ngram_size=4
151
- )
152
- partial_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
153
- partial_answer = partial_answer.split("Assistant:")[-1].strip()
154
-
155
- new_content = partial_answer[len(full_response):].strip()
156
- if not new_content:
157
- break
158
-
159
- full_response += new_content
160
-
161
- if full_response.endswith(('.', '!', '?')):
162
- break
163
-
164
- remaining_prompt = full_response
165
- loop_count += 1
166
-
167
- return full_response
168
 
169
- def chat_fn(user_message, history):
170
- response = generate_full_answer(user_message, history)
171
- history = history + [(user_message, response)]
172
  return history, history
173
 
174
  # =====================================================
175
- # Gradio UI
176
  # =====================================================
177
  with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
178
  gr.Markdown("""
179
- # 🤖 Qwen3-Harrison-RAG Chatbot
180
- Ask me anything — I’ll retrieve relevant context and answer!
181
  """)
182
- chatbot = gr.Chatbot(height=400)
183
  with gr.Row():
184
- msg = gr.Textbox(placeholder="Type your message here...", scale=4)
185
  clear = gr.Button("Clear", scale=1)
186
- msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
187
  clear.click(lambda: None, None, chatbot, queue=False)
188
 
189
  # =====================================================
 
1
  import os
2
  import torch
3
  import gradio as gr
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
5
 
6
  # =====================================================
7
+ # Environment setup
8
  # =====================================================
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_HOME"] = "/tmp/hf_home"
 
12
  os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules"
13
 
14
  # =====================================================
15
+ # Model configuration
16
  # =====================================================
17
+ GEN_MODEL = "hackergeek/qwen3-harrison-rag"
 
 
 
18
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
19
 
20
+ if not HF_TOKEN:
21
+ print("⚠️ No Hugging Face token found. Set one using:")
22
+ print(" export HF_TOKEN='your_hf_token_here'")
 
 
 
23
 
24
+ # =====================================================
25
+ # Load private RAG model
26
+ # =====================================================
27
+ def load_private_model(model_name, token):
28
  dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
29
+ load_kwargs = {
30
+ "dtype": dtype_value,
31
+ "cache_dir": "/tmp/hf_cache",
32
+ "low_cpu_mem_usage": True,
33
+ }
34
  try:
35
+ import accelerate
36
+ load_kwargs["device_map"] = "auto"
37
+ except ImportError:
38
+ print("⚠️ `accelerate` not installed — using default device placement.")
 
 
 
 
 
 
 
 
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
41
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs)
42
+ return tokenizer, model
 
 
43
 
 
44
  try:
45
+ tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN)
46
+ print(f"✅ Loaded private RAG model: {GEN_MODEL}")
47
  except Exception as e:
48
+ raise RuntimeError(f"❌ Failed to load {GEN_MODEL}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # =====================================================
51
+ # Dynamic token allocation
52
  # =====================================================
53
+ def calculate_max_tokens(query, min_tokens=100, max_tokens=600, factor=3):
54
+ """Dynamically scale output length to input length."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  query_tokens = len(tokenizer(query)["input_ids"])
56
  dynamic_tokens = query_tokens * factor
57
  return min(max(dynamic_tokens, min_tokens), max_tokens)
58
 
59
+ # =====================================================
60
+ # RAG-aware generation logic
61
+ # =====================================================
62
+ def generate_answer(query, history):
63
+ if not query.strip():
64
+ return history, history
65
+
66
+ # Step 1️⃣: Rephrase user query for optimal retrieval
67
+ rephrase_prompt = (
68
+ "You are a retrieval-augmented assistant.\n"
69
+ "Rephrase the following user query to maximize retrieval accuracy "
70
+ "by keeping key entities and medical terms intact:\n\n"
71
+ f"User query: {query}\n\n"
72
+ "Rephrased query:"
73
+ )
74
+ inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
75
+ rephrased_ids = model.generate(**inputs, max_new_tokens=80, do_sample=False)
76
+ rephrased_query = tokenizer.decode(rephrased_ids[0], skip_special_tokens=True).split("Rephrased query:")[-1].strip()
77
+
78
+ # Step 2️⃣: Main retrieval + generation
79
+ max_tokens = calculate_max_tokens(rephrased_query)
80
+ system_prompt = (
81
+ "You are a retrieval-augmented medical assistant. "
82
+ "You have access to internal knowledge and context retrieval. "
83
+ "Always provide clear, complete, and factual medical explanations.\n\n"
84
+ f"Optimized query for retrieval:\n{rephrased_query}\n\n"
85
+ "Answer using relevant retrieved context and your reasoning.\n\n"
86
+ "Assistant:"
87
  )
88
 
89
+ inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
90
+ output_ids = model.generate(
91
+ **inputs,
92
+ max_new_tokens=max_tokens,
93
+ do_sample=False,
94
+ pad_token_id=tokenizer.eos_token_id,
95
+ no_repeat_ngram_size=4,
96
+ temperature=0.0, # completely deterministic
97
+ )
98
+ output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
99
+ answer = output.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ history = history + [(query, answer)]
 
 
102
  return history, history
103
 
104
  # =====================================================
105
+ # Gradio interface
106
  # =====================================================
107
  with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
108
  gr.Markdown("""
109
+ # 🤖 Qwen3-Harrison-RAG Chatbot
110
+ Ask me anything — I’ll rephrase your question, retrieve the right context, and answer with complete reasoning.
111
  """)
112
+ chatbot = gr.Chatbot(height=420)
113
  with gr.Row():
114
+ msg = gr.Textbox(placeholder="Ask a medical or scientific question...", scale=4)
115
  clear = gr.Button("Clear", scale=1)
116
+ msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
117
  clear.click(lambda: None, None, chatbot, queue=False)
118
 
119
  # =====================================================