hackergeek commited on
Commit
fb8bdfe
·
verified ·
1 Parent(s): 3a080d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -30
app.py CHANGED
@@ -8,7 +8,7 @@ from sentence_transformers import SentenceTransformer
8
  from inspect import signature
9
 
10
  # =====================================================
11
- # OPTION: Use ephemeral /tmp cache
12
  # =====================================================
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
14
  os.environ["HF_HOME"] = "/tmp/hf_home"
@@ -16,7 +16,7 @@ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
16
  os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules"
17
 
18
  # =====================================================
19
- # 1️⃣ Model setup
20
  # =====================================================
21
  GEN_MODEL_PRIVATE = "hackergeek/qwen3-harrison-rag"
22
  GEN_MODEL_PUBLIC = "Qwen/Qwen2.5-1.5B-Instruct"
@@ -33,7 +33,7 @@ except ImportError:
33
  accelerate_available = False
34
  print("⚠️ `accelerate` not installed. Large private models with device_map='auto' may fail.")
35
 
36
- # --- Helper to load model safely ---
37
  def load_model(model_name, token=None):
38
  dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
39
  try:
@@ -45,10 +45,8 @@ def load_model(model_name, token=None):
45
  "cache_dir": "/tmp/hf_cache",
46
  "low_cpu_mem_usage": True,
47
  }
48
-
49
  if accelerate_available:
50
  load_kwargs["device_map"] = "auto"
51
-
52
  if token:
53
  load_kwargs["token"] = token
54
 
@@ -58,7 +56,7 @@ def load_model(model_name, token=None):
58
  except Exception as e:
59
  raise RuntimeError(f"Failed to load model '{model_name}': {e}")
60
 
61
- # --- Attempt to load private model, fallback to public ---
62
  try:
63
  tokenizer, model = load_model(GEN_MODEL_PRIVATE, token=HF_TOKEN)
64
  print(f"✅ Loaded private model: {GEN_MODEL_PRIVATE}")
@@ -71,57 +69,81 @@ except Exception as e:
71
  embedder = SentenceTransformer(EMB_MODEL, cache_folder="/tmp/hf_cache")
72
 
73
  # =====================================================
74
- # 2️⃣ Retrieval + generation logic (deterministic)
75
  # =====================================================
76
  index = faiss.IndexFlatL2(384)
77
  chunks = ["This is a sample context chunk. Replace with real documents."]
78
 
79
- def retrieve_context(query, max_k=5):
80
  q_emb = embedder.encode([query], convert_to_numpy=True)
81
  if index.ntotal == 0:
82
  return "No context available."
83
  D, I = index.search(q_emb, max_k)
84
- # Sort by distance ascending to make retrieval deterministic
85
- sorted_idx = [i for _, i in sorted(zip(D[0], I[0]))]
86
- return "\n\n".join([chunks[i] for i in sorted_idx])
87
-
88
- def calculate_max_tokens(query, min_tokens=50, max_tokens=600, factor=3):
 
 
 
 
 
 
 
 
 
 
89
  query_tokens = len(tokenizer(query)["input_ids"])
90
  dynamic_tokens = query_tokens * factor
91
  return min(max(dynamic_tokens, min_tokens), max_tokens)
92
 
93
- def generate_response(query, history):
94
- # Set fixed seeds for reproducibility
95
  torch.manual_seed(42)
96
  np.random.seed(42)
97
 
98
  context = retrieve_context(query)
99
- system_prompt = (
100
  "You are a helpful assistant that uses the retrieved context to answer questions.\n\n"
101
  f"Context:\n{context}\n\n"
102
  f"User: {query}\nAssistant:"
103
  )
104
- inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
105
 
106
- max_new_tokens = calculate_max_tokens(query)
 
107
 
108
- # Deterministic generation: do_sample=False
109
- output_ids = model.generate(
110
- **inputs,
111
- max_new_tokens=max_new_tokens,
112
- do_sample=False, # deterministic
113
- pad_token_id=tokenizer.eos_token_id
114
- )
115
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
116
- return output.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def chat_fn(user_message, history):
119
- response = generate_response(user_message, history)
120
  history = history + [(user_message, response)]
121
  return history, history
122
 
123
  # =====================================================
124
- # 3️⃣ Gradio UI
125
  # =====================================================
126
  with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
127
  gr.Markdown("""
@@ -136,7 +158,7 @@ with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
136
  clear.click(lambda: None, None, chatbot, queue=False)
137
 
138
  # =====================================================
139
- # 4️⃣ Launch
140
  # =====================================================
141
  if __name__ == "__main__":
142
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
8
  from inspect import signature
9
 
10
  # =====================================================
11
+ # Cache setup
12
  # =====================================================
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
14
  os.environ["HF_HOME"] = "/tmp/hf_home"
 
16
  os.environ["HF_MODULES_CACHE"] = "/tmp/hf_modules"
17
 
18
  # =====================================================
19
+ # Model setup
20
  # =====================================================
21
  GEN_MODEL_PRIVATE = "hackergeek/qwen3-harrison-rag"
22
  GEN_MODEL_PUBLIC = "Qwen/Qwen2.5-1.5B-Instruct"
 
33
  accelerate_available = False
34
  print("⚠️ `accelerate` not installed. Large private models with device_map='auto' may fail.")
35
 
36
+ # --- Load model helper ---
37
  def load_model(model_name, token=None):
38
  dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
39
  try:
 
45
  "cache_dir": "/tmp/hf_cache",
46
  "low_cpu_mem_usage": True,
47
  }
 
48
  if accelerate_available:
49
  load_kwargs["device_map"] = "auto"
 
50
  if token:
51
  load_kwargs["token"] = token
52
 
 
56
  except Exception as e:
57
  raise RuntimeError(f"Failed to load model '{model_name}': {e}")
58
 
59
+ # --- Attempt private model, fallback to public ---
60
  try:
61
  tokenizer, model = load_model(GEN_MODEL_PRIVATE, token=HF_TOKEN)
62
  print(f"✅ Loaded private model: {GEN_MODEL_PRIVATE}")
 
69
  embedder = SentenceTransformer(EMB_MODEL, cache_folder="/tmp/hf_cache")
70
 
71
  # =====================================================
72
+ # Retrieval + generation logic
73
  # =====================================================
74
  index = faiss.IndexFlatL2(384)
75
  chunks = ["This is a sample context chunk. Replace with real documents."]
76
 
77
+ def retrieve_context(query, max_k=5, distance_threshold=0.5, max_tokens=1500):
78
  q_emb = embedder.encode([query], convert_to_numpy=True)
79
  if index.ntotal == 0:
80
  return "No context available."
81
  D, I = index.search(q_emb, max_k)
82
+ sorted_idx = [i for _, i in sorted(zip(D[0], I[0]))] # deterministic
83
+ context = []
84
+ total_tokens = 0
85
+ for idx in sorted_idx:
86
+ # skip distant chunks
87
+ if D[0][list(sorted_idx).index(idx)] > distance_threshold:
88
+ continue
89
+ chunk_tokens = len(tokenizer(chunks[idx])["input_ids"])
90
+ if total_tokens + chunk_tokens > max_tokens:
91
+ break
92
+ context.append(chunks[idx])
93
+ total_tokens += chunk_tokens
94
+ return "\n\n".join(context) if context else chunks[sorted_idx[0]]
95
+
96
+ def calculate_max_tokens(query, min_tokens=50, max_tokens=800, factor=3):
97
  query_tokens = len(tokenizer(query)["input_ids"])
98
  dynamic_tokens = query_tokens * factor
99
  return min(max(dynamic_tokens, min_tokens), max_tokens)
100
 
101
+ def generate_full_answer(query, history):
 
102
  torch.manual_seed(42)
103
  np.random.seed(42)
104
 
105
  context = retrieve_context(query)
106
+ prompt = (
107
  "You are a helpful assistant that uses the retrieved context to answer questions.\n\n"
108
  f"Context:\n{context}\n\n"
109
  f"User: {query}\nAssistant:"
110
  )
 
111
 
112
+ full_response = ""
113
+ remaining_prompt = prompt
114
 
115
+ while True:
116
+ inputs = tokenizer(remaining_prompt, return_tensors="pt").to(model.device)
117
+ max_new_tokens = calculate_max_tokens(query)
118
+
119
+ output_ids = model.generate(
120
+ **inputs,
121
+ max_new_tokens=max_new_tokens,
122
+ do_sample=False, # deterministic
123
+ pad_token_id=tokenizer.eos_token_id
124
+ )
125
+ partial_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
126
+ partial_answer = partial_answer.split("Assistant:")[-1].strip()
127
+
128
+ # Append partial answer
129
+ full_response += partial_answer
130
+
131
+ # Stop if last character is sentence-ending punctuation
132
+ if full_response.endswith(('.', '!', '?')):
133
+ break
134
+
135
+ # Continue generating by feeding back the last output
136
+ remaining_prompt = full_response
137
+
138
+ return full_response
139
 
140
  def chat_fn(user_message, history):
141
+ response = generate_full_answer(user_message, history)
142
  history = history + [(user_message, response)]
143
  return history, history
144
 
145
  # =====================================================
146
+ # Gradio UI
147
  # =====================================================
148
  with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
149
  gr.Markdown("""
 
158
  clear.click(lambda: None, None, chatbot, queue=False)
159
 
160
  # =====================================================
161
+ # Launch
162
  # =====================================================
163
  if __name__ == "__main__":
164
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))