shreyess commited on
Commit
85d096e
·
verified ·
1 Parent(s): d89eaa3

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .inference.re_call import ReCall
2
+ from .inference.r1_searcher import R1Searcher, R1SearchConfig
3
+ from .inference.zerosearch import ZeroSearchInference, ZeroSearchConfig
4
+ from .inference.o1_searcher import O1Cfg, O1Searcher
5
+ from .inference.simpledeepsearch import SDSCfg, SDSearcher
6
+ __all__ = ["ReCall", "R1Searcher", "ZeroSearchInference", "ZeroSearchConfig", "R1SearchConfig", "O1Cfg", "O1Searcher", "SDSCfg", "SDSearcher"]
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (613 Bytes). View file
 
inference/__init__.py ADDED
File without changes
inference/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
inference/__pycache__/o1_searcher.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
inference/__pycache__/r1_searcher.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
inference/__pycache__/re_call.cpython-310.pyc ADDED
Binary file (27.6 kB). View file
 
inference/__pycache__/simpledeepsearch.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
inference/__pycache__/zerosearch.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
inference/o1_searcher.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """o1_searcher_inference.py — Serper‑based Search‑o1 re‑implementation
3
+ with *original* in‑house summarisation workflow, step‑replacement logic and
4
+ bug‑fixes for duplicate queries / ValueError.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os, re, json, time, string, pathlib
9
+ from dataclasses import dataclass
10
+ from typing import List, Dict, Optional, Tuple
11
+ import requests, trafilatura
12
+ import threading
13
+ from openai import OpenAI, APIStatusError
14
+ # -----------------------------------------------------------------------------
15
+ # Optional NLTK sentence tokenizer (fallback to regex) -------------------------
16
+ try:
17
+ from nltk.tokenize import sent_tokenize # type: ignore
18
+ except Exception: # ImportError *or* missing punkt data
19
+ def sent_tokenize(x: str):
20
+ return re.split(r"(?<=[.!?]) +", x)
21
+
22
+ def _oa() -> OpenAI:
23
+ th = threading.current_thread()
24
+ if not hasattr(th, "_oa"):
25
+ th._oa = OpenAI()
26
+ return th._oa
27
+
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # Special tags & constants -----------------------------------------------------
31
+ BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
32
+ END_SEARCH_QUERY = "<|end_search_query|>"
33
+ BEGIN_DOCUMENT_QUERY = "<|begin_of_document|>"
34
+ END_DOCUMENT_QUERY = "<|end_of_document|>"
35
+ THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
36
+ EOS_TOKEN = "<|im_end|>"
37
+ ANSWER_OPEN, ANSWER_CLOSE = "<answer>", "</answer>"
38
+ STOP_STRINGS = [END_SEARCH_QUERY, ANSWER_CLOSE, EOS_TOKEN, "<|endoftext|>"]
39
+ ALLOWED_DATASETS = {"musique", "frames", "simpleqa", "browsercomp"}
40
+ # tokenizer =
41
+ TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
42
+ # ───────────────────────── BASIC UTILS ──────────────────────────────
43
+ def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
44
+ """Tiny retry decorator with fixed back‑off."""
45
+
46
+ def decorator(func):
47
+ def wrapper(*args, **kwargs):
48
+ for i in range(max_attempts):
49
+ try:
50
+ return func(*args, **kwargs)
51
+ except Exception as exc:
52
+ if i == max_attempts - 1:
53
+ #print(f"[retry] {func.__name__} failed – giving up: {exc}")
54
+ return fallback
55
+ #print(f"[retry] {func.__name__}: attempt {i+1}/{max_attempts} → {exc}")
56
+ time.sleep(sleep)
57
+
58
+ return wrapper
59
+
60
+ return decorator
61
+
62
+
63
+ # ───────────────────────── tokenizer ────────────────────────────────────────
64
+ try:
65
+ from transformers import AutoTokenizer
66
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
67
+ except Exception as e:
68
+ import sys
69
+ sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # Helper functions -------------------------------------------------------------
73
+
74
+ def remove_punc(t: str) -> str:
75
+ return t.translate(str.maketrans("", "", string.punctuation))
76
+
77
+ # legacy aliases for older checkpoints ---------------------------------------
78
+ _nopunc = remove_punc
79
+
80
+
81
+ def f1(a: set, b: set) -> float:
82
+ inter = len(a & b)
83
+ return 0.0 if inter == 0 else 2 * inter / (len(a) + len(b))
84
+
85
+ # legacy alias
86
+ _f1 = f1
87
+
88
+
89
+ def extract_snippet_ctx(text: str, snippet: str, win: int = 2500) -> str:
90
+ """Return *window*‑sized context around the sentence most similar to snippet."""
91
+ text = text[:50_000]
92
+ sn_set = set(remove_punc(snippet.lower()).split())
93
+ best, best_score = None, 0.20
94
+ for sent in sent_tokenize(text):
95
+ score = f1(sn_set, set(remove_punc(sent.lower()).split()))
96
+ if score > best_score:
97
+ best, best_score = sent, score
98
+ if best:
99
+ pos = text.find(best)
100
+ return text[max(0, pos - win): pos + len(best) + win]
101
+ return text[: 2 * win]
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Config dataclass -------------------------------------------------------------
105
+ @dataclass
106
+ class O1Cfg:
107
+ serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
108
+ gl: str = "us"; hl: str = "en"
109
+ top_k: int = 10; max_doc_len: int = 3000
110
+ max_search: int = 10; max_turn: int = 15
111
+ use_jina: bool = True
112
+ jina_tpl: str = "https://r.jina.ai/http://{}"
113
+ # generation params
114
+ temperature: float = 0.7; top_p: float = 0.8; top_k_sampling: int = 20
115
+ rep_pen: float = 1.05; thinker_max_tokens: int = 32768
116
+ summariser_model: str = "gpt-4o-mini"
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Serper search + page fetch ---------------------------------------------------
120
+
121
+ def serper_search(q: str, num: int, key: str, gl="us", hl="en") -> List[Dict]:
122
+ hdr = {"X-API-KEY": key, "Content-Type": "application/json"}
123
+ body = {"q": q, "num": num, "gl": gl, "hl": hl}
124
+ r = requests.post("https://google.serper.dev/search", json=body, headers=hdr, timeout=20)
125
+ r.raise_for_status(); return r.json().get("organic", [])
126
+
127
+
128
+ def fetch_page(url: str, cfg: O1Cfg, snippet: str = "") -> str:
129
+ try:
130
+ txt = ""
131
+ if cfg.use_jina:
132
+ r = requests.get(cfg.jina_tpl.format(url), timeout=15)
133
+ if r.ok and len(r.text.strip()) > 100:
134
+ txt = r.text.strip()
135
+ if txt == "":
136
+ r = requests.get(url, timeout=15); r.raise_for_status()
137
+ txt = trafilatura.extract(r.text, output_format="txt") or ""
138
+ if snippet:
139
+ txt = extract_snippet_ctx(txt, snippet, cfg.max_doc_len)
140
+
141
+ return txt
142
+ except Exception:
143
+ return ""
144
+
145
+ # -----------------------------------------------------------------------------
146
+ # replace_recent_steps --------------------------------------------------------
147
+
148
+ def replace_recent_steps(origin: str, patch: str) -> str:
149
+ """Apply *patch* (containing numbered `Step N:` lines) to *origin*."""
150
+ step_re = re.compile(r"Step\s+(\d+):\s*")
151
+
152
+ def parse(block: str) -> Dict[int, str]:
153
+ cur, buf, out = None, [], {}
154
+ for line in block.splitlines():
155
+ m = step_re.match(line)
156
+ if m:
157
+ if cur is not None:
158
+ out[cur] = "\n".join(buf).strip()
159
+ cur, buf = int(m.group(1)), [line[m.end():].strip()]
160
+ elif cur is not None:
161
+ buf.append(line)
162
+ if cur is not None:
163
+ out[cur] = "\n".join(buf).strip()
164
+ return out
165
+
166
+ base = parse(origin); mod = parse(patch)
167
+ for k, v in mod.items():
168
+ if "DELETE THIS STEP" in v:
169
+ base.pop(k, None)
170
+ else:
171
+ base[k] = v
172
+ return "\n\n".join(base[k] for k in sorted(base))
173
+
174
+ # -----------------------------------------------------------------------------
175
+ # Prompts ----------------------------------------------------------------------
176
+ # from prompts import get_webpage_to_reasonchain_instruction # keep original helper
177
+
178
+ # -----------------------------------------------------------------------------
179
+ # Main agent -------------------------------------------------------------------
180
+ class O1Searcher:
181
+ # STOP_TOKENS = [
182
+ # "<|im_end|>",
183
+ # "<|endoftext|>",
184
+ # "<|end_of_query|>",
185
+ # " <|end_of_query|>",
186
+ # "<|end_of_query|>\n",
187
+ # "<|end_of_query|>\n\n",
188
+ # " <|end_of_query|>\n",
189
+ # " <|end_of_query|>\n\n",
190
+ # ]
191
+ get_webpage_to_reasonchain_instruction = """**Task Instruction:**
192
+
193
+ You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
194
+
195
+ **Guidelines:**
196
+
197
+ 1. **Analyze the Searched Web Pages:**
198
+ - Carefully review the content of each searched web page.
199
+ - Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
200
+
201
+ 2. **Extract Relevant Information:**
202
+ - Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
203
+ - Ensure that the extracted information is accurate and relevant.
204
+
205
+ 3. **Output Format:**
206
+ - **If the web pages provide helpful information for current search query:** Present the information beginning with **Final Information** as shown below.
207
+ **Final Information**
208
+
209
+ [Helpful information]
210
+
211
+ - **If the web pages do not provide any helpful information for current search query:** Output the following text.
212
+
213
+ **Final Information**
214
+
215
+ No helpful information found.
216
+
217
+ **Inputs:**
218
+ - **Previous Reasoning Steps:**
219
+ {prev_reasoning}
220
+
221
+ - **Current Search Query:**
222
+ {search_query}
223
+
224
+ - **Searched Web Pages:**
225
+ {document}
226
+
227
+ Now you should analyze each web page and find helpful information based on the current search query {search_query} and previous reasoning steps.
228
+ Return the Helpful information in the <information></information> tags
229
+ """
230
+ SUMMARY_PROMPT = (
231
+ """## Task Description:\n"
232
+ "Given the search query and the content of the searched webpage, "
233
+ "extract information relevant to the query and write one summary paragraph."\n\n"
234
+ "## Guidelines:\n"
235
+ "(1) The extracted content should be relevant to the query.\n"
236
+ "(2) The form of the extracted content **must be a summary paragraph** rather than a direct answer.\n"
237
+ "(3) If the webpage content is unrelated to the query, output \"None\".\n\n"
238
+ "## Output Format:\n"
239
+ "[Exacted Content]: <summary‑paragraph‑or‑None>\n\n"
240
+ "## Inputs:\n"
241
+ "[Search Query]\n{search_query}\n\n"
242
+ "[Webpage Content]\n{document}\n\n"
243
+ "## Output:\n"""
244
+ )
245
+
246
+ sys_prompt_multiqa = (
247
+ "You are a reasoning assistant with the ability to perform web searches to help "
248
+ "you answer the user's question accurately. You have special tools:\n\n"
249
+ "- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\n"
250
+ "Then, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\n"
251
+ f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to 16.\n\n"
252
+ "Once you have all the information you need, continue your reasoning.\n\n"
253
+ "Example:\n"
254
+ "Question: \"Alice David is the voice of Lara Croft in a video game developed by which company?\"\n"
255
+ "Assistant thinking steps:\n"
256
+ "- I need to find out who voices Lara Croft in the video game.\n"
257
+ "- Then, I need to determine which company developed that video game.\n\n"
258
+ "Assistant:\n"
259
+ "<|begin_search_query|>Alice David Lara Croft voice<|end_search_query|>\n\n"
260
+ "(System returns processed information from relevant web pages)\n\n"
261
+ "Assistant thinks: The search results indicate that Alice David is the voice of Lara Croft in a specific video game. Now, I need to find out which company developed that game.\n\n"
262
+ "Assistant:\n"
263
+ "<|begin_search_query|>video game developed by Alice David Lara Croft<|end_search_query|>\n\n"
264
+ "(System returns processed information from relevant web pages)\n\n"
265
+ "Assistant continues reasoning with the new information...\n\n"
266
+ "Remember:\n"
267
+ "- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n"
268
+ "- When done searching, continue your reasoning.\n\n"
269
+ "Always give you final answer between <answer></answer> tags"
270
+ )
271
+
272
+ def __init__(self, cfg: O1Cfg, thinker_url: str):
273
+ if not cfg.serper_api_key:
274
+ raise ValueError("SERPER_API_KEY required")
275
+ self.cfg, self.model_url = cfg, thinker_url.rstrip("/")
276
+ self.search_cache: Dict[str, List[Dict]] = {}
277
+ self.page_cache: Dict[Tuple[str, str], str] = {}
278
+ self.openai = _oa()
279
+
280
+ # --- low‑level generation call ------------------------------------------
281
+ @retry(4,1)
282
+ def _generate(self, prompt: str) -> str:
283
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
284
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
285
+ resp = requests.post(
286
+ f"{self.model_url}/generate",
287
+ json={
288
+ "text": prompt,
289
+ "sampling_params": {
290
+ "temperature": self.cfg.temperature,
291
+ "top_p": self.cfg.top_p,
292
+ "max_new_tokens": max_tokens_left,
293
+ "repetition_penalty": self.cfg.rep_pen,
294
+ "stop": STOP_STRINGS,
295
+ },
296
+
297
+ },
298
+ timeout=60,
299
+ ).json()
300
+ # resp.raise_for_status()
301
+ generated = resp["text"] # what you have now
302
+ matched = resp["meta_info"]["finish_reason"].get("matched")
303
+ reason = resp["meta_info"]["finish_reason"].get("type")
304
+
305
+ # ⇢ append the tag back only if it was removed
306
+ if reason == "stop" and matched in STOP_STRINGS:
307
+ if not "<|end_of_query|>" in generated:
308
+ generated += matched
309
+ if reason == "stop" and matched == 151645:
310
+ if not generated.endswith("<|im_end|>"):
311
+ generated += "<|im_end|>"
312
+ if reason == "stop" and matched == 151643:
313
+ if not generated.endswith("<|endoftext|>"):
314
+ generated += "<|endoftext|>"
315
+
316
+ return generated
317
+
318
+ # @retry(fallback="None")
319
+ def _summarise_openai(self, query: str, doc: str) -> str:
320
+ prompt = self.SUMMARY_PROMPT.format(search_query=query, document=doc)
321
+ resp = self.openai.chat.completions.create(
322
+ model=self.cfg.summariser_model,
323
+ messages=[{"role": "user", "content": prompt}],
324
+ max_tokens=1024,
325
+ temperature=0.0,
326
+ )
327
+ # print(resp)
328
+ text = resp.choices[0].message.content
329
+ return text.split("[Exacted Content]:")[-1].strip()
330
+
331
+
332
+
333
+ def _generate_summary(self, prompt: str) -> str:
334
+ summary_url = "http://0.0.0.0:1241"
335
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
336
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
337
+ resp = requests.post(
338
+ f"{summary_url}/generate",
339
+ json={
340
+ "text": prompt,
341
+ "sampling_params": {
342
+ "temperature": self.cfg.temperature,
343
+ "max_new_tokens": 8192,#max_tokens_left,
344
+ "stop": STOP_STRINGS,
345
+ },
346
+
347
+ },
348
+ timeout=60,
349
+ ).json()
350
+ generated = resp["text"] # what you have now
351
+ matched = resp["meta_info"]["finish_reason"].get("matched")
352
+ reason = resp["meta_info"]["finish_reason"].get("type")
353
+ # ##print("-"*100)
354
+ # ##print(resp)
355
+ # ##print(matched)
356
+ # ##print("-"*100)
357
+ # ⇢ append the tag back only if it was removed
358
+ if reason == "stop" and matched in STOP_STRINGS:
359
+ if not "<|end_of_query|>" in generated:
360
+ generated += matched + EOS_TOKEN
361
+ if reason == "stop" and matched == 151645:
362
+ if not generated.endswith("<|im_end|>"):
363
+ generated += "<|im_end|>"
364
+ if reason == "stop" and matched == 151643:
365
+ if not generated.endswith("<|endoftext|>"):
366
+ generated += "<|endoftext|>"
367
+ return generated
368
+ # --- public entry -------------------------------------------------------
369
+ def run(self, question: str):
370
+ prompt = (
371
+ f"<|im_start|>system\n{self.sys_prompt_multiqa}<|im_end|>\n"
372
+ f"<|im_start|>user\n{question}<|im_end|>\n"
373
+ f"<|im_start|>assistant\n{THINK_OPEN}"
374
+ )
375
+ full_trace = prompt # <-- Track full trace
376
+ queries: List[str] = []
377
+ seen_queries: set[str] = set()
378
+
379
+ for i in range(self.cfg.max_turn):
380
+ chunk = self._generate(prompt)
381
+ prompt += chunk
382
+
383
+ if ANSWER_CLOSE in chunk:
384
+ break
385
+
386
+ ##print(f"step-{i}")
387
+ ##print(chunk)
388
+
389
+ query = self._extract_query(chunk)
390
+ ##print(query)
391
+ if not query or len(queries) >= self.cfg.max_search:
392
+ break
393
+ if query in seen_queries:
394
+ continue
395
+ queries.append(query)
396
+ seen_queries.add(query)
397
+
398
+ doc = self._retrieve_doc(query)
399
+ prev_reasoning = self._extract_reasoning(prompt)
400
+ # summary = "\n<|im_start|>user" + self._summarise_openai(query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
401
+ summary = "\n<|im_start|>user" + self._summarise(prev_reasoning, query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
402
+ ##print("summary")
403
+ # print(summary)
404
+ prompt += summary # <-- Log summary to trace
405
+
406
+ # new_reasoning = replace_recent_steps(prev_reasoning, summary)
407
+
408
+ # if prev_reasoning:
409
+ # prompt = prompt.rsplit(prev_reasoning, 1)[0] + new_reasoning + THINK_CLOSE + THINK_OPEN
410
+ # else:
411
+ # prompt += new_reasoning + THINK_CLOSE + THINK_OPEN
412
+
413
+ # full_trace += + THINK_CLOSE + THINK_OPEN + "\n" # <-- Log reasoning to trace
414
+ else:
415
+ final = f"{ANSWER_OPEN}I don't know.{ANSWER_CLOSE}"
416
+ prompt += final
417
+ # full_trace += final
418
+
419
+ return prompt, queries
420
+
421
+
422
+ # ---------------------------------------------------------------------
423
+ # helpers --------------------------------------------------------------
424
+ def _extract_query(self, txt: str) -> Optional[str]:
425
+ if BEGIN_SEARCH_QUERY not in txt or END_SEARCH_QUERY not in txt:
426
+ return None
427
+ frag = txt.split(BEGIN_SEARCH_QUERY)[-1].split(END_SEARCH_QUERY)[0]
428
+ # strip quotes / ellipsis / tabs
429
+ return re.sub(r"[\"'…\t]", " ", frag.split("<|")[0]).strip()
430
+
431
+ def _retrieve_doc(self, query: str) -> str:
432
+ if query not in self.search_cache:
433
+ self.search_cache[query] = serper_search(query, self.cfg.top_k, self.cfg.serper_api_key,
434
+ gl=self.cfg.gl, hl=self.cfg.hl)
435
+ for hit in self.search_cache[query]:
436
+ # ##print("hit")
437
+ # ##print(hit)
438
+ url, sn = hit.get("link", ""), hit.get("snippet", "")
439
+ if not url:
440
+ continue
441
+ key = (url, sn)
442
+ if key not in self.page_cache:
443
+ self.page_cache[key] = fetch_page(url, self.cfg, sn)
444
+ if self.page_cache[key]:
445
+ return self.page_cache[key]
446
+ return ""
447
+
448
+ def _summarise(self, prev: str, query: str, doc: str) -> str:
449
+ rid_prompt = self.get_webpage_to_reasonchain_instruction.format(prev_reasoning = prev, search_query = query, document = doc)
450
+ chat = f"<|im_start|>user\\n{rid_prompt}\\n<|im_end|>\\n<|im_start|>assistant\\n"
451
+ resp = self._generate_summary(chat)
452
+ # ##print("summarization out \n", resp)
453
+ return BEGIN_DOCUMENT_QUERY + self._extract_summary(resp) + END_DOCUMENT_QUERY
454
+ # ##print("summary")
455
+ # ##print(resp)
456
+ # match = re.search(r"Final Information\*\*\s*\n(.+?)<\|im_end\|>", resp)
457
+ # if match:
458
+ # final_info = match.group(1).strip()
459
+ # ##print(final_info)
460
+ # return final_info
461
+
462
+ def _extract_summary(self, prompt: str) -> str:
463
+ if "<information>" in prompt:
464
+ summary = prompt.split("<information>")[-1].split("</information>")[0] if THINK_OPEN in prompt else ""
465
+ return summary
466
+ else:
467
+ match = re.search(r"\*\*Final Information\*\*\s*\n(.+?)<\|im_end\|>", prompt)
468
+ if match:
469
+ final_info = match.group(1).strip()
470
+ return final_info
471
+ return prompt
472
+
473
+ def _extract_reasoning(self, prompt: str) -> str:
474
+ return prompt.split(THINK_OPEN)[-1].split(THINK_CLOSE)[0] if THINK_OPEN in prompt else ""
475
+
476
+ # -----------------------------------------------------------------------------
477
+ # CLI -------------------------------------------------------------------------
478
+ # if __name__ == "__main__":
479
+ # # import argparse, json
480
+ # # parser = argparse.ArgumentParser()
481
+ # # parser.add_argument("question"); parser.add_argument("--dataset", required=True, choices=sorted(ALLOWED_DATASETS)); parser.add_argument("--model-url", required
inference/oss.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # recall_oss_harmony.py
2
+ # ReCall loop using the Harmony chat format renderer, calling vLLM's
3
+ # OpenAI-compatible **/v1/completions** endpoint with a rendered prompt.
4
+
5
+ import json
6
+ import re
7
+ import time
8
+ from functools import wraps
9
+ from typing import List, Optional
10
+
11
+ import requests
12
+ from openai_harmony import (
13
+ load_harmony_encoding,
14
+ HarmonyEncodingName,
15
+ Role,
16
+ Message,
17
+ Conversation,
18
+ )
19
+
20
+ def retry(max: int = 5, sleep: float = 1.0, fallback=None):
21
+ def decorator(fn):
22
+ @wraps(fn)
23
+ def wrapper(*args, **kwargs):
24
+ for i in range(max):
25
+ try:
26
+ return fn(*args, **kwargs)
27
+ except Exception as e:
28
+ print(f"[retry] attempt {i+1}/{max} failed: {e}")
29
+ if i + 1 == max:
30
+ print(f"[retry] giving up – returning {fallback!r}")
31
+ return fallback
32
+ if sleep:
33
+ time.sleep(sleep)
34
+ return wrapper
35
+ return decorator
36
+
37
+
38
+ class ReCallOSSHarmony:
39
+ TOOL_CALL_RE = re.compile(r"<tool_call>((?:(?!</tool_call>).)*)</tool_call>", re.DOTALL)
40
+
41
+ def __init__(
42
+ self,
43
+ executor_url: str,
44
+ base_url: str,
45
+ model_name: str,
46
+ api_key: Optional[str] = None,
47
+ request_timeout: int = 120,
48
+ ):
49
+ self.executor_url = executor_url.rstrip("/")
50
+ self.base_url = base_url.rstrip("/")
51
+ self.model_name = model_name
52
+ self.api_key = api_key
53
+ self.timeout = request_timeout
54
+ self.enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
55
+
56
+ # ---------------- HTTP ----------------
57
+ def _headers(self):
58
+ h = {"Content-Type": "application/json"}
59
+ if self.api_key:
60
+ h["Authorization"] = f"Bearer {self.api_key}"
61
+ return h
62
+
63
+ @retry(max=5, sleep=1, fallback={"choices": [{"text": ""}]})
64
+ def _complete(self, prompt: str, temperature: float, max_tokens: int, stop: Optional[List[str]] = None):
65
+ payload = {
66
+ "model": self.model_name,
67
+ "prompt": prompt,
68
+ "max_tokens": max_tokens,
69
+ "temperature": temperature,
70
+ }
71
+ if stop:
72
+ payload["stop"] = stop
73
+ resp = requests.post(
74
+ f"{self.base_url}/completions",
75
+ headers=self._headers(),
76
+ json=payload,
77
+ timeout=self.timeout,
78
+ )
79
+ if resp.status_code != 200:
80
+ raise RuntimeError(f"completions HTTP {resp.status_code}: {resp.text}")
81
+ return resp.json()
82
+
83
+ # ------------- tool plumbing ----------
84
+ @staticmethod
85
+ def _validate_tool_tags(s: str) -> bool:
86
+ starts = [m.start() for m in re.finditer(r"<tool_call>", s)]
87
+ ends = [m.start() for m in re.finditer(r"</tool_call>", s)]
88
+ if len(starts) != len(ends):
89
+ return False
90
+ return all(st < en for st, en in zip(starts, ends))
91
+
92
+ def extract_tool_calls(self, text: str) -> List[str]:
93
+ if not self._validate_tool_tags(text):
94
+ return []
95
+ return [m.group(1).strip() for m in self.TOOL_CALL_RE.finditer(text)]
96
+
97
+ @staticmethod
98
+ def _format_tool_call(call_json_str: str) -> str:
99
+ try:
100
+ spec = json.loads(call_json_str)
101
+ fname = spec["name"]
102
+ args = spec.get("arguments", {}) or {}
103
+ args_str = ", ".join(f"{k}={repr(v)}" for k, v in args.items())
104
+ return f"{fname}({args_str})"
105
+ except Exception as e:
106
+ return f"error: parse tool call failed: {e}"
107
+
108
+ def _exec_one_call(self, env: str, call_json_str: str) -> str:
109
+ call_src = self._format_tool_call(call_json_str)
110
+ if call_src.startswith("error:"):
111
+ return call_src
112
+ try:
113
+ response = requests.post(
114
+ f"{self.executor_url}/execute",
115
+ json={"env": env, "call": call_src},
116
+ timeout=self.timeout,
117
+ )
118
+ if response.status_code != 200:
119
+ return f"error: executor HTTP {response.status_code}"
120
+ payload = response.json()
121
+ out = []
122
+ if payload.get("result"):
123
+ out.append(f"result:\n{payload['result']}")
124
+ if payload.get("output"):
125
+ out.append(f"output:\n{payload['output']}")
126
+ if payload.get("error"):
127
+ out.append(f"error:\n{payload['error']}")
128
+ return "\n".join(out).strip() or "ok"
129
+ except requests.exceptions.Timeout:
130
+ return "error: execution timed out"
131
+ except Exception as e:
132
+ return f"error: executor exception: {e}"
133
+
134
+ def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]:
135
+ return [self._exec_one_call(env, c) for c in tool_calls]
136
+
137
+ # ------------- harmony helpers --------
138
+ def _render(self, messages: List[Message]) -> str:
139
+ convo = Conversation.from_messages(messages)
140
+ # Render for a completion (assistant is the next speaker)
141
+ return self.enc.render_conversation_for_completion(convo, Role.ASSISTANT)
142
+
143
+ # ------------- main run loop ----------
144
+ @retry(max=5, sleep=1, fallback=("", []))
145
+ def run(
146
+ self,
147
+ env: str,
148
+ func_schemas,
149
+ question: str,
150
+ system_prompt: str = "",
151
+ temperature: float = 0.2,
152
+ max_tokens: int = 2048,
153
+ max_turns: int = 16,
154
+ stop: Optional[List[str]] = None,
155
+ ):
156
+ # Build the initial harmony conversation.
157
+ # Paste your full system prompt into `system_prompt` before calling.
158
+ # If you want to include func_schemas in your system content, do:
159
+ try:
160
+ sys_msg = system_prompt.format(func_schemas=json.dumps(func_schemas, ensure_ascii=False))
161
+ except Exception:
162
+ sys_msg = system_prompt
163
+
164
+ messages: List[Message] = [
165
+ Message.from_role_and_content(Role.SYSTEM, sys_msg),
166
+ Message.from_role_and_content(Role.USER, question),
167
+ ]
168
+
169
+ transcript_chunks: List[str] = []
170
+ all_tool_calls: List[str] = []
171
+
172
+ for _ in range(max_turns):
173
+ prompt = self._render(messages)
174
+ resp = self._complete(prompt=prompt, temperature=temperature, max_tokens=max_tokens, stop=stop)
175
+ assistant_text = resp["choices"][0]["text"]
176
+ transcript_chunks.append(assistant_text)
177
+ messages.append(Message.from_role_and_content(Role.ASSISTANT, assistant_text))
178
+
179
+ if "<answer>" in assistant_text:
180
+ break
181
+
182
+ tool_calls = self.extract_tool_calls(assistant_text)
183
+ all_tool_calls.extend(tool_calls)
184
+ if not tool_calls:
185
+ continue
186
+
187
+ results = self.execute_tool_calls(env, tool_calls)
188
+ tool_resp_block = "".join(
189
+ f"<tool_response>{tc}\n{res}\n</tool_response>\n"
190
+ for tc, res in zip(tool_calls, results)
191
+ )
192
+ messages.append(Message.from_role_and_content(Role.USER, tool_resp_block))
193
+
194
+ transcript = "".join(transcript_chunks)
195
+ return transcript, all_tool_calls
inference/r1_searcher.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+ #!/usr/bin/env python3
8
+ # r1_searcher_inference.py
9
+ """
10
+ Faithful re‑implementation of the **R1‑Searcher** loop described in
11
+ ‘Reasoning with Retrieval 1’ (2024).
12
+
13
+ Key properties ─────────────────────────────────────────────────────────
14
+ • The policy LLM (“thinker”) reasons inside <think> … </think>.
15
+ • When it needs external knowledge it emits exactly **one** single‑triple
16
+ query inside <|begin_of_query|> … <|end_of_query|>.
17
+ • The wrapper searches *English Wikipedia only* (via Serper.dev).
18
+ • It summarises the *first* retrieved article that contains the query terms
19
+ and injects that summary between
20
+ <|begin_of_documents|> … <|end_of_documents|>
21
+ before handing control back to the thinker.
22
+ • The loop stops when the thinker outputs </answer> or when the configurable
23
+ round‑limit is reached.
24
+
25
+ The class is API‑compatible with the user’s existing `ReCall` wrapper so the
26
+ same benchmarking harness can swap between them with a single flag.
27
+ """
28
+ from __future__ import annotations
29
+
30
+ import os
31
+ import time
32
+ from dataclasses import dataclass
33
+ from typing import List, Optional
34
+
35
+ import requests
36
+ from bs4 import BeautifulSoup
37
+ import trafilatura
38
+ import wikipedia
39
+ from urllib.parse import unquote
40
+ from openai import OpenAI
41
+
42
+ client = OpenAI(api_key = "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")
43
+
44
+ TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
45
+
46
+ # ───────────────────────── tokenizer ────────────────────────────────────────
47
+ try:
48
+ from transformers import AutoTokenizer
49
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
50
+ except Exception as e:
51
+ import sys
52
+ sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
53
+
54
+ # ───────────────────────── BASIC UTILS ──────────────────────────────
55
+ def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
56
+ """Tiny retry decorator with fixed back‑off."""
57
+
58
+ def decorator(func):
59
+ def wrapper(*args, **kwargs):
60
+ for i in range(max_attempts):
61
+ try:
62
+ return func(*args, **kwargs)
63
+ except Exception as exc:
64
+ if i == max_attempts - 1:
65
+ #print(f"[retry] {func.__name__} failed – giving up: {exc}")
66
+ return fallback
67
+ #print(f"[retry] {func.__name__}: attempt {i+1}/{max_attempts} → {exc}")
68
+ time.sleep(sleep)
69
+
70
+ return wrapper
71
+
72
+ return decorator
73
+
74
+
75
+ # ────────────────────────── CONFIG ──────────────────────────────────
76
+ @dataclass
77
+ class R1SearchConfig:
78
+ # Serper.dev parameters
79
+ serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
80
+ serper_url: str = "https://google.serper.dev/search"
81
+ gl: str = "us"
82
+ hl: str = "en"
83
+
84
+ # Policy model endpoint (vLLM)
85
+ thinker_temperature: float = 0.0
86
+ thinker_max_tokens: int = 40960
87
+
88
+ # Loop / misc
89
+ max_rounds: int = 16
90
+ summariser_model: str = "gpt-4o-mini"
91
+
92
+
93
+ # ───────────────────────── R1‑Searcher ──────────────────────────────
94
+ class R1Searcher:
95
+ SYSTEM_PROMPT = """
96
+ You are a helpful assistant.
97
+ Given a question, you should answer it by first thinking about the reasoning
98
+ process in the mind and then providing the final answer.
99
+
100
+ The output format of reasoning process and final answer are enclosed within
101
+ <think> </think> and <answer> </answer> tags, respectively, i.e.,
102
+ "<think> reasoning process here </think>
103
+
104
+ <answer> final answer here </answer>".
105
+
106
+ During the thinking process, **you can perform searching for uncertain
107
+ knowledge** if necessary with the format of
108
+ "<|begin_of_query|> keyword_1 keyword_2 ... <|end_of_query|>".
109
+ **A query must involve only a single triple**.
110
+
111
+ Then, the search system will provide you with the retrieval information with
112
+ the format of "<|begin_of_documents|> ...search results... <|end_of_documents|>".
113
+ """.strip()
114
+
115
+ SUMMARY_PROMPT = (
116
+ """## Task Description:\n"
117
+ "Given the search query and the content of the searched webpage, "
118
+ "extract information relevant to the query and write one summary paragraph."\n\n"
119
+ "## Guidelines:\n"
120
+ "(1) The extracted content should be relevant to the query.\n"
121
+ "(2) The form of the extracted content **must be a summary paragraph** rather than a direct answer.\n"
122
+ "(3) If the webpage content is unrelated to the query, output \"None\".\n\n"
123
+ "## Output Format:\n"
124
+ "[Exacted Content]: <summary‑paragraph‑or‑None>\n\n"
125
+ "## Inputs:\n"
126
+ "[Search Query]\n{search_query}\n\n"
127
+ "[Webpage Content]\n{document}\n\n"
128
+ "## Output:\n"""
129
+ )
130
+
131
+ # Tag constants
132
+ EOS_TOKEN = "<|im_end|>"
133
+ THINK_OPEN = "<think>"
134
+ ANSWER_CLOSE = "</answer>"
135
+ Q_OPEN, Q_CLOSE = "<|begin_of_query|>", "<|end_of_query|>"
136
+ DOC_OPEN, DOC_CLOSE = "<|begin_of_documents|>", "<|end_of_documents|>"
137
+
138
+ # Stop strings – must match *exact* token sequences vLLM will see
139
+ STOP_TOKENS = [
140
+ "<|im_end|>",
141
+ "<|endoftext|>",
142
+ "<|end_of_query|>",
143
+ " <|end_of_query|>",
144
+ "<|end_of_query|>\n",
145
+ "<|end_of_query|>\n\n",
146
+ " <|end_of_query|>\n",
147
+ " <|end_of_query|>\n\n",
148
+ ]
149
+ # STOP_TOKENS = []
150
+
151
+
152
+
153
+ def __init__(self, cfg: R1SearchConfig, model_url):
154
+ self.cfg = cfg
155
+ self.openai = client
156
+ self._wiki = wikipedia
157
+ self._wiki.set_lang("en")
158
+
159
+ # Patch wikipedia lib to use a session with proper UA
160
+ sess = requests.Session()
161
+ sess.headers.update({"User-Agent": "r1-searcher-bot/1.0"})
162
+ self._wiki._http = sess
163
+ self.model_url=model_url
164
+
165
+ # ── public entry ─────────────────────────────────────────────────
166
+ def run(self, question: str) -> tuple[str, List[str]]:
167
+ prompt = (
168
+ f"<|im_start|>system\n{self.SYSTEM_PROMPT}<|im_end|>\n"
169
+ f"<|im_start|>user\n{question}<|im_end|>\n"
170
+ f"<|im_start|>assistant\n{self.THINK_OPEN}"
171
+ )
172
+ queries: List[str] = []
173
+
174
+ for _ in range(self.cfg.max_rounds):
175
+ model_out = self._call_thinker(prompt)
176
+ prompt += model_out
177
+
178
+ if self.ANSWER_CLOSE in model_out:
179
+ break
180
+
181
+ query = self._extract_query(model_out)
182
+ if not query:
183
+ break
184
+ queries.append(query)
185
+
186
+ doc_block = self._retrieve_block(query)
187
+ prompt += "<|im_start|>user\n" + doc_block + self.EOS_TOKEN + "<|im_start|>assistant\n" + self.THINK_OPEN # continue loop
188
+
189
+ else: # exceeded round cap
190
+ prompt += "<answer>I don't know.</answer><|im_end|>"
191
+
192
+ return prompt, queries
193
+
194
+ # ── thinker call ────────────────────────────────────────────────
195
+ # @retry()
196
+ def _call_thinker(self, prompt: str) -> str:
197
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
198
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
199
+ resp = requests.post(
200
+ f"{self.model_url}/generate",
201
+ json={
202
+ "text": prompt,
203
+ "sampling_params": {
204
+ "temperature": self.cfg.thinker_temperature,
205
+ "max_new_tokens": 2048,#max_tokens_left,
206
+ "stop": self.STOP_TOKENS,
207
+ "repetition_penalty": 1.05,
208
+ },
209
+ },
210
+ timeout=60,
211
+ ).json()
212
+ generated = resp["text"] # what you have now
213
+ matched = resp["meta_info"]["finish_reason"].get("matched")
214
+ reason = resp["meta_info"]["finish_reason"].get("type")
215
+ #print("-"*100)
216
+ #print(resp)
217
+ #print(matched)
218
+ #print("-"*100)
219
+ # ⇢ append the tag back only if it was removed
220
+ if reason == "stop" and matched in self.STOP_TOKENS:
221
+ if not "<|end_of_query|>" in generated:
222
+ generated += matched + self.EOS_TOKEN
223
+ if reason == "stop" and matched == 151645:
224
+ if not generated.endswith("<|im_end|>"):
225
+ generated += "<|im_end|>"
226
+ if reason == "stop" and matched == 151643:
227
+ if not generated.endswith("<|endoftext|>"):
228
+ generated += "<|endoftext|>"
229
+ return generated
230
+
231
+ # ── query helpers ───────────────────────────────────────────────
232
+ @staticmethod
233
+ def _extract_query(text: str) -> Optional[str]:
234
+ if R1Searcher.Q_OPEN not in text or R1Searcher.Q_CLOSE not in text:
235
+ return None
236
+ fragment = text.split(R1Searcher.Q_OPEN)[-1].split(R1Searcher.Q_CLOSE)[0]
237
+ #print("*"*10)
238
+ #print(fragment)
239
+ fragment = fragment.split("<|")[0] #handle end_of_query slipping
240
+ return (
241
+ fragment.replace("\t", " ")
242
+ .replace("\"", "")
243
+ .replace("'", "")
244
+ .replace("…", "")
245
+ .strip()
246
+ ) or None
247
+
248
+ # ── retrieval & summary ───────────────────────────────────────
249
+ def _retrieve_block(self, query: str) -> str:
250
+ wiki_links = self._serper_wiki_links(query)
251
+
252
+ for url in wiki_links[:3]:
253
+ text = self._get_wiki_text(url)
254
+ if not text:
255
+ continue
256
+ summary = self._summarise(query, text[:35000])
257
+ if summary.lower() != "none":
258
+ return f"{self.DOC_OPEN}\n{summary}\n{self.DOC_CLOSE}\n\n"
259
+
260
+ return f"{self.DOC_OPEN}\nNone\n{self.DOC_CLOSE}\n\n"
261
+
262
+ # --- Serper ------------------------------------------------------
263
+ @retry()
264
+ def _serper_wiki_links(self, q: str) -> List[str]:
265
+ headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"}
266
+ payload = {"q": f"{q} site:en.wikipedia.org", "num": 10, "gl": self.cfg.gl, "hl": self.cfg.hl}
267
+ r = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20)
268
+ r.raise_for_status()
269
+ links = [
270
+ item.get("link")
271
+ for item in r.json().get("organic", [])
272
+ if item.get("link", "").startswith("https://en.wikipedia.org")
273
+ ]
274
+ return links
275
+
276
+ def extract_main_text(self, html: str) -> str:
277
+ txt = trafilatura.extract(html, output_format="txt") or ""
278
+ if len(txt) >= 500:
279
+ return txt
280
+ from readability import Document
281
+ soup = BeautifulSoup(Document(html).summary(), "lxml")
282
+ txt = soup.get_text(" ", strip=True)
283
+ if len(txt) >= 400:
284
+ return txt
285
+ for tag in soup(["script", "style", "noscript"]):
286
+ tag.decompose()
287
+ return re.sub(r"\s+", " ", soup.get_text(" ").strip())
288
+
289
+
290
+ # --- fetch article ----------------------------------------------
291
+ def _get_wiki_text(self, url: str) -> str | None:
292
+ try:
293
+ # 1. Download
294
+ r = requests.get(url, timeout=10)
295
+ r.raise_for_status()
296
+
297
+ # 2. Extract main text
298
+ txt = self.extract_main_text(r.text).strip()
299
+ if not txt:
300
+ return None
301
+
302
+ # 3. Prepend article slug if it isn’t already in the body
303
+ slug = unquote(url.rsplit("/", 1)[-1]).replace("_", " ")
304
+ if slug.lower() not in txt.lower():
305
+ txt = f"{slug}\n\n{txt}"
306
+
307
+ # 4. Return the final value
308
+ return "[Retrieved from Wikipedia] " + txt
309
+
310
+ except Exception as e:
311
+ #print("Failed to fetch Wikipedia page %s: %s", url, e)
312
+ return None
313
+
314
+ # --- call OpenAI to summarise -----------------------------------
315
+ @retry(fallback="None")
316
+ def _summarise(self, query: str, doc: str) -> str:
317
+ prompt = self.SUMMARY_PROMPT.format(search_query=query, document=doc)
318
+ resp = self.openai.chat.completions.create(
319
+ model=self.cfg.summariser_model,
320
+ messages=[{"role": "user", "content": prompt}],
321
+ max_tokens=1024,
322
+ temperature=0.0,
323
+ )
324
+ text = resp.choices[0].message.content
325
+ return text.split("[Exacted Content]:")[-1].strip()
326
+
327
+
328
+ # ─────────────────────────── CLI ────────────────────────────────────
329
+ if __name__ == "__main__":
330
+ import argparse, json
331
+
332
+ ap = argparse.ArgumentParser()
333
+ ap.add_argument("question", type=str, help="Natural‑language question")
334
+ ap.add_argument("--serper-key", type=str, help="Override SERPER_API_KEY env")
335
+ args = ap.parse_args()
336
+
337
+ cfg = R1SearchConfig(serper_api_key=args.serper_key or os.getenv("SERPER_API_KEY", ""))
338
+ agent = R1Searcher(cfg, OpenAI())
339
+ final_prompt, issued_queries = agent.run(args.question)
340
+
341
+ answer = final_prompt.split("<answer>")[-1].split("</answer>")[0]
342
+ #print("\nANSWER:", answer)
343
+ #print("\nQUERIES:", json.dumps(issued_queries, indent=2))
344
+
inference/re_call.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import requests
4
+ import time
5
+ from typing import List
6
+ from functools import wraps
7
+ from together import Together # pip install together
8
+ from datetime import datetime # only needed for retries / logging
9
+
10
+
11
+
12
+ # return decorator
13
+ def retry(max: int = 10, sleep: int = 1, fallback=None):
14
+ """
15
+ Retry `max` times and, if still failing, return `fallback`
16
+ instead of raising. This keeps outer loops alive.
17
+ """
18
+ def decorator(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ for i in range(max):
22
+ try:
23
+ return func(*args, **kwargs)
24
+ except Exception as e:
25
+ print(f"[retry] attempt {i+1}/{max} failed: {e}")
26
+ if i == max - 1: # last try exhausted
27
+ print(f"[retry] giving up – returning {fallback!r}")
28
+ return fallback # ← swallow the error
29
+ if sleep:
30
+ time.sleep(sleep)
31
+ return wrapper
32
+ return decorator
33
+
34
+ class ReCall():
35
+ sys_prompt_websailor = """
36
+ You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information.
37
+ In this environment you have access to a set of tools you can use to assist with the user query.
38
+ You may perform multiple rounds of function calls. In each round, you can call one or more functions.
39
+
40
+ As you proceed, adhere to the following principles:
41
+
42
+ 1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found.
43
+
44
+ 2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability.
45
+
46
+ 3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins.
47
+
48
+
49
+
50
+ Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
51
+
52
+ In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
53
+ The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \
54
+ The results of the function calls will be given back to you after execution, \
55
+ and you can continue to call functions until you get the final answer for the user's question. \
56
+ Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \
57
+ i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\].
58
+
59
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
60
+ <tool_call>
61
+ {{"name": <function-name>, "arguments": <args-json-object>}}
62
+ </tool_call>
63
+ For Multiple Choice Question always give the final answer as one of the options whichever fits the best.s
64
+ Always give your answer as option id. and answer.
65
+ Example:
66
+ What is the Captial of India ?
67
+ \\[ \\boxed{{A. India}} \\]
68
+ """
69
+
70
+ sys_prompt_websailor_deepseek = """
71
+ You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information.
72
+ In this environment you have access to a set of tools you can use to assist with the user query.
73
+ You may perform multiple rounds of function calls. In each round, you can call one or more functions.
74
+
75
+ As you proceed, adhere to the following principles:
76
+
77
+ 1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found.
78
+
79
+ 2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability.
80
+
81
+ 3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins.
82
+
83
+
84
+
85
+ Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
86
+
87
+ In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
88
+ The reasoning process and function calling are enclosed within <think> </think> and <tool_calls_begin> <tool_calls_end> tags. \
89
+ The results of the function calls will be given back to you after execution, \
90
+ and you can continue to call functions until you get the final answer for the user's question. \
91
+ Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \
92
+ i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\].
93
+ """
94
+
95
+ # sys_prompt_websailor_deepseek = """
96
+ # You are a Web Information Seeking Master. Seek the internet thoroughly and provide accurate answers. You may use tools multiple times.
97
+
98
+ # Principles:
99
+ # 1) Persistent Actions for Answers: explore deeply until you find satisfactory information.
100
+ # 2) Repeated Verification: cross-check and validate before the final answer.
101
+ # 3) Attention to Detail: ensure sources are current, relevant, and credible.
102
+
103
+ # You have the following tools (JSONSchema):
104
+ # ```json
105
+ # {func_schemas}
106
+ # Follow this EXACT tool-call I/O protocol.
107
+
108
+ # TO CALL ONE OR MORE TOOLS:
109
+ # Respond only with this block (no extra text before/after):
110
+ # <|tool▁call▁begin|>function<|tool▁sep|>{tool_name}{args_json}
111
+ # <|tool▁call▁end|>
112
+ # ... (repeat <|tool▁call▁begin|>…<|tool▁call▁end|> for multiple tools)
113
+ # <|tool▁calls▁end|><|end▁of▁sentence|>
114
+
115
+ # HOW TOOL RESULTS ARRIVE:
116
+ # I will send tool outputs back embedded inside a single user message, each wrapped like:
117
+ # <tool_response>{one_tool_call_you_made}
118
+ # {tool_return_text_or_json}
119
+ # </tool_response>
120
+
121
+ # WHAT TO DO NEXT:
122
+
123
+ # If you still need info, emit another tool-calls block (same exact format).
124
+
125
+ # If you have the final answer, output:
126
+ # <answer> …your final answer… </answer>
127
+ # and DO NOT call any more tools.
128
+
129
+ # Important:
130
+
131
+ # Do not expose your internal reasoning; keep thoughts private.
132
+
133
+ # When emitting a tool-calls block, do not include any explanations, only the block specified above.
134
+
135
+ # Arguments must be valid JSON.
136
+
137
+ # Stop tokens to respect: <|end▁of▁sentence|>
138
+ # """
139
+
140
+ system_prompt = """In this environment you have access to a set of tools you can use to assist with the user query. \
141
+ You may perform multiple rounds of function calls. \
142
+ In each round, you can call one or more functions. \
143
+
144
+ Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
145
+
146
+ In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
147
+ The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \
148
+ The results of the function calls will be given back to you after execution, \
149
+ and you can continue to call functions until you get the final answer for the user's question. You are encouraged to utilize as many function calls as possible. \
150
+ Finally, if you have got the answer, wrap it in <answer> </answer> **and do not call any more functions**, \
151
+ e.g. <think> Based on the tool results … </think> <answer>20 °C</answer>.
152
+
153
+ For each function call, return a JSON object with function name and arguments within <tool_call></tool_call> XML tags:
154
+ <tool_call>
155
+ {{"name": <function-name-1>, "arguments": <args-json-object>}}
156
+ </tool_call>"""
157
+
158
+ system_prompt_budget = """
159
+ You are an autonomous reasoning agent with access to external tools.
160
+
161
+ The conversation will retain only the *most-recent* <tool_response> block; older ones disappear.
162
+ As soon as you receive tool results, extract the *essential facts tables links etc* that might be needed for later and restate them inside your <think> section.
163
+  **Never copy large bodies of text** or raw JSON from tool output into your visible reply; summarise instead.
164
+
165
+ ◎ **Workflow**
166
+ 1. In every round, start with <think> … </think> to lay out your short reasoning.
167
+ 2. If you need external information or an action, emit one or more <tool_call> … </tool_call> blocks (JSON spec below).
168
+ 3. When the environment returns <tool_response>, continue reasoning; you may call more tools.
169
+ 4. Once you can answer the user, wrap the final result in <answer> … </answer> and STOP calling tools.
170
+
171
+ ◎ **Tool call format** (do **not** restate the schema or any explanations):
172
+ <tool_call>
173
+ {{"name": <function-name-1>, "arguments": <args-json-object>}}
174
+ </tool_call>
175
+
176
+ Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
177
+ """
178
+
179
+
180
+
181
+ system_prompt_forcing_tool_call = """
182
+ In this environment you have access to a set of tools you can use to assist with the user query.
183
+ You may perform multiple rounds of function calls upto ten. In each round, you can call upto three functions.
184
+
185
+ ──────────────────────── AVAILABLE TOOLS ────────────────────────
186
+ ```json
187
+ [
188
+ {
189
+ "type": "function",
190
+ "function": {
191
+ "name": "pubmed_search",
192
+ "description": "Search PubMed for Medical related queries.",
193
+ "parameters": {
194
+ "type": "object",
195
+ "properties": {
196
+ "query": { "type": "string", "description": "Query to search for." },
197
+ "top_n": { "type": "integer", "description": "Number of hits", "default": 3 }
198
+ },
199
+ "required": ["query"]
200
+ }
201
+ }
202
+ }
203
+ ]
204
+ ```
205
+
206
+ ────────────────────────────── RULES ──────────────────────────────
207
+ 1. You MUST issue one pubmed_search tool call for each answer choice. Each query must relate the clinical context to that option.
208
+ 2. You MAY NOT skip any option or decide based only on internal reasoning. Evidence must be retrieved for all choices.
209
+ 3. You MAY issue follow-up tool calls if your reasoning leads you to need more evidence.
210
+ 4. You MUST wrap all reasoning in <think> </think> tags and all tool usage in <tool_call> </tool_call> tags. Number of <tool_call> and </tool_call> tokens in the entire trace MUST always match.
211
+ 5. Do NOT casually emit the <tool_call> </tool_call> during reasoning unless explicitly calling a tool in the proper format.
212
+ 5. Your final answer must be enclosed a single letter corresponding to the correct option enclosed in the <answer> </answer> tags. Do not output anything else inside these tags.
213
+ 6. DO NOT use any other confusing tags like <thiking> or </thinking>.
214
+ 7. Each <think> </think> block MUST be followed by a <tool_call> </tool_call> or <answer> </answer> or else the program will break without an answer.
215
+
216
+ ───────────────────── DUMMY EXAMPLE INTERLEAVED SKELETON ─────────────────────
217
+ <think>
218
+ We are presented with a 54-year-old woman with invasive ductal carcinoma of the breast and osteolytic lesions in the thoracic spine. This strongly suggests metastatic spread. Our task is to determine the most likely anatomical route of metastasis to the spine.
219
+
220
+ Let’s examine the given options:
221
+ A. Hemiazygos vein
222
+ B. Posterior intercostal veins
223
+ C. Batson’s vertebral venous plexus
224
+ D. Internal mammary lymphatics
225
+
226
+ We'll evaluate each option in turn using available literature and known anatomical pathways.
227
+ **Option A: Hemiazygos vein**
228
+ We begin by evaluating whether the hemiazygos vein could be involved in metastatic spread from breast cancer to the spine.
229
+ </think>
230
+ <tool_call>
231
+ {"name": "pubmed_search", "arguments": {"query": "breast cancer metastasis hemiazygos vein", "top_n": 2}}
232
+ </tool_call>
233
+ <tool_response>
234
+ ...
235
+ </tool_response>
236
+ <think>
237
+ There is limited or no strong evidence suggesting the hemiazygos vein is a common or primary route for vertebral metastasis from breast cancer.
238
+ Lets explore **Option B: Posterior intercostal veins** and **Option C: Batson’s vertebral venous plexus** and **Option D:Internal mammary lymphatics**
239
+ </think>
240
+ <tool_call>
241
+ {"name": "pubmed_search", "arguments": {"query": "posterior intercostal veins breast cancer spinal metastasis", "top_n": 3}}
242
+ </tool_call>
243
+ <tool_call>
244
+ {"name": "pubmed_search", "arguments": {"query": "Batson vertebral venous plexus breast cancer metastasis", "top_n": 3}}
245
+ </tool_call>
246
+ <tool_call>
247
+ {"name": "pubmed_search", "arguments": {"query": "Internal mammary lymphatics breast cancer metastasis", "top_n": 3}}
248
+ </tool_call>
249
+ <tool_response>
250
+ ...
251
+ </tool_response>
252
+ <think>
253
+ While the posterior intercostal veins may be involved in venous drainage, there is insufficient evidence to support them as a primary route for metastasis to the vertebral column.
254
+ where as Batson’s vertebral venous plexus — a valveless venous network that connects the thoracic and abdominal veins directly to the spine. I to find more specific information about option C.
255
+ </think>
256
+ <tool_call>
257
+ {"name": "pubmed_search", "arguments": {"query": ""Batson vertebral venous plexus breast cancer metastasis in people over 50", "top_n": 1}}
258
+ </tool_call>
259
+ <think>
260
+ After evaluating all four options, the most plausible route for breast cancer metastasis to the thoracic spine is clearly via Batson’s vertebral venous plexus:
261
+ </think>
262
+ <answer>C</answer>
263
+ """
264
+ # STOP_TOKENS =STOP_TOKENS = ["<|im_end|>", "<|endoftext|>"
265
+
266
+
267
+ def __init__(self, executor_url):
268
+ self.executor_url = executor_url
269
+
270
+ def init_prompt(self, func_schemas, question):
271
+ system_prompt = f"<|im_start|>system\n{self.sys_prompt_websailor.format(func_schemas=func_schemas)}<|im_end|>"
272
+ user_prompt = f"<|im_start|>user\n{question}<|im_end|>"
273
+ assistant_prefix = f"<|im_start|>assistant\n<think>"
274
+ return system_prompt + "\n" + user_prompt + "\n" + assistant_prefix
275
+
276
+
277
+
278
+ def _strip_old_tool_responses(self, prompt: str) -> str:
279
+ TOOL_RESPONSE_RE = re.compile(r"<tool_response>.*?</tool_response>\s*", re.DOTALL)
280
+ """Remove every existing <tool_response> … </tool_response> block."""
281
+ return TOOL_RESPONSE_RE.sub("", prompt)
282
+
283
+ def cat_assistant_response(self, curr_prompt, assistant_response):
284
+ return curr_prompt + assistant_response + "<|im_end|>"
285
+
286
+ def cat_tool_results(self, curr_prompt, tool_calls, results):
287
+ tool_response_str = ""
288
+ for tool_call, result in zip(tool_calls, results):
289
+ tool_response_str += f"<tool_response>{tool_call}\n{result}\n</tool_response>\n"
290
+ tool_response_str = f"<|im_start|>user\n{tool_response_str}<|im_end|>"
291
+ assistant_prefix = f"<|im_start|>assistant\n<think>"
292
+ return curr_prompt + "\n" + tool_response_str + "\n" + assistant_prefix
293
+
294
+ def format_tool_call(self, tool_call_str: str):
295
+ """Convert JSON function call description to Python executable code string."""
296
+ try:
297
+ call_json = json.loads(tool_call_str)
298
+ func_name = call_json['name']
299
+ arguments = call_json.get('arguments', {})
300
+
301
+ args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items())
302
+ return f"{func_name}({args_str})"
303
+ except Exception as e:
304
+ return f"Parse tool call failed: {e}"
305
+
306
+ def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]:
307
+ def exe_tool_call(env, call):
308
+ url = self.executor_url + '/execute'
309
+
310
+ call_str = self.format_tool_call(call)
311
+ # print(call_str)
312
+ if call_str.startswith("error: parse tool call failed"):
313
+ return call_str
314
+
315
+ try:
316
+ data = {
317
+ 'env': env,
318
+ 'call': call_str
319
+ }
320
+ response = requests.post(url, json=data, timeout=60)
321
+ if response.status_code != 200:
322
+ return f"error: {response.status_code}"
323
+ response = response.json()
324
+ ret_str = ''
325
+ if response['result']:
326
+ ret_str += f'result: \n{response["result"]}\n'
327
+ if response['output']:
328
+ ret_str += f'output: \n{response["output"]}\n'
329
+ if response['error']:
330
+ ret_str += f'error: \n{response["error"]}\n'
331
+ return ret_str.strip()
332
+ except requests.exceptions.Timeout:
333
+ return "error: execution timed out"
334
+ except Exception as e:
335
+ return str(e)
336
+
337
+ results = []
338
+ for tool_call in tool_calls:
339
+ result = exe_tool_call(env, tool_call)
340
+ results.append(result)
341
+ return results
342
+
343
+ def validate_tool_calls(self, output_str):
344
+ start_tags = re.findall(r'<tool_call>', output_str)
345
+ end_tags = re.findall(r'</tool_call>', output_str)
346
+
347
+ if len(start_tags) != len(end_tags):
348
+ return False
349
+
350
+ start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)]
351
+ end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)]
352
+
353
+ for start, end in zip(start_positions, end_positions):
354
+ if start >= end:
355
+ return False
356
+
357
+ return True
358
+
359
+ def extract_tool_calls(self, output_str):
360
+ if not self.validate_tool_calls(output_str):
361
+ return []
362
+
363
+ try:
364
+ pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>'
365
+ matches = re.finditer(pattern, output_str, re.DOTALL)
366
+
367
+ return [match.group(1).strip() for match in matches]
368
+ except Exception as e:
369
+ return []
370
+
371
+ def extract_tool_calls_deepseek(self, output_str):
372
+ if not self.validate_tool_calls(output_str):
373
+ return []
374
+
375
+ try:
376
+ pattern = r'<tool_calls_begin>((?:(?!</tool_calls_end>).)*)<tool_calls_end>'
377
+ matches = re.finditer(pattern, output_str, re.DOTALL)
378
+
379
+ return [match.group(1).strip() for match in matches]
380
+ except Exception as e:
381
+ return []
382
+
383
+
384
+
385
+ @retry(max=5, sleep=1, fallback={"score": 0})
386
+ def run_ii_searcher(
387
+ self,
388
+ env: str,
389
+ func_schemas: str,
390
+ question: str,
391
+ tokenizer,
392
+ model_url="http://0.0.0.0:1214",
393
+ temperature: float = 0.0,
394
+ max_new_tokens: int = 40960,
395
+ ):
396
+ curr_prompt = self.init_prompt(func_schemas, question)
397
+ all_tool_calls= []
398
+
399
+ for _ in range(16):
400
+ prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
401
+ max_tokens_left = max_new_tokens - len(prompt_tokens) - 100
402
+ # for oss model served via vllm
403
+ # response = requests.post(
404
+ # f'{model_url}/v1/chat/completions',
405
+ # json={
406
+ # "text": curr_prompt,
407
+ # # "reasoning": "medium"
408
+ # },
409
+ # ).json()
410
+ # for sglang served models hf models
411
+ response = requests.post(
412
+ f'{model_url}/generate',
413
+ json={
414
+ "text": curr_prompt,
415
+ "sampling_params": {
416
+ "temperature": temperature,
417
+ "max_new_tokens": max_tokens_left,
418
+ "repetition_penalty": 1.05
419
+ },
420
+
421
+ }
422
+ ).json()
423
+ if "error" in response.keys():
424
+ print("resp",response)
425
+ curr_prompt = self.cat_assistant_response(curr_prompt, response['text'])
426
+
427
+ tool_calls: List[str] = self.extract_tool_calls(response['text'])
428
+ all_tool_calls += tool_calls
429
+
430
+ if len(tool_calls) == 0:
431
+ break
432
+
433
+ else:
434
+ results: List[str] = self.execute_tool_calls(env, tool_calls)
435
+ curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
436
+
437
+ return curr_prompt, all_tool_calls
438
+
439
+ # @retry(max=5, sleep=1, fallback={"score": 0})
440
+ # def run(
441
+ # self,
442
+ # env: str,
443
+ # func_schemas: str,
444
+ # question: str,
445
+ # tokenizer,
446
+ # model_url="http://0.0.0.0:1214",
447
+ # temperature: float = 0.0,
448
+ # max_new_tokens: int = 40960,
449
+ # ):
450
+ # curr_prompt = self.init_prompt(func_schemas, question)
451
+ # all_tool_calls= []
452
+
453
+ # for i in range(32):
454
+ # prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
455
+ # max_tokens_left = max_new_tokens - len(prompt_tokens) - 100
456
+ # # for oss model served via vllm
457
+ # # response = requests.post(
458
+ # # f'{model_url}/v1/chat/completions',
459
+ # # json={
460
+ # # "text": curr_prompt,
461
+ # # # "reasoning": "medium"
462
+ # # },
463
+ # # ).json()
464
+ # # for sglang served models hf models
465
+ # response = requests.post(
466
+ # f'{model_url}/generate',
467
+ # json={
468
+ # "text": curr_prompt,
469
+ # "sampling_params": {
470
+ # "temperature": temperature,
471
+ # "max_new_tokens": max_tokens_left,
472
+ # "repetition_penalty": 1.05
473
+ # },
474
+
475
+ # }
476
+ # ).json()
477
+ # if "error" in response.keys():
478
+ # print("resp",response)
479
+ # curr_prompt = self.cat_assistant_response(curr_prompt, response['text'])
480
+
481
+ # tool_calls: List[str] = self.extract_tool_calls(response['text'])
482
+ # all_tool_calls += tool_calls
483
+
484
+ # if len(tool_calls) == 0:
485
+ # break
486
+
487
+ # else:
488
+ # # print(f"Step-{i+1}")
489
+ # results: List[str] = self.execute_tool_calls(env, tool_calls)
490
+ # curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
491
+
492
+ # return curr_prompt, all_tool_calls
493
+ from typing import List, Dict, Any, Tuple
494
+ import requests
495
+
496
+ @retry(max=5, sleep=1, fallback={"score": 0})
497
+ def run(
498
+ self,
499
+ env: str,
500
+ func_schemas: str,
501
+ question: str,
502
+ tokenizer,
503
+ model_url: str = "http://0.0.0.0:1214",
504
+ temperature: float = 0.0,
505
+ max_new_tokens: int = 40960,
506
+ ) -> Tuple[str, List[str], List[Dict[str, str]]]:
507
+ """
508
+ Returns:
509
+ curr_prompt: the final prompt buffer (with assistant/tool traces you maintain internally)
510
+ all_tool_calls: flat list of all tool call strings extracted across steps
511
+ chat: a lightweight chat transcript list[{"role": "...", "content": "..."}]
512
+ • 'user' items = the original question + aggregated tool responses
513
+ • 'assistant' items = model responses (and a compact line-list of tool calls)
514
+ """
515
+ # Build runtime prompt and initialize accumulators
516
+ curr_prompt = self.init_prompt(func_schemas, question)
517
+ all_tool_calls: List[str] = []
518
+ chat: List[Dict[str, str]] = []
519
+
520
+ # Seed transcript with JUST the question (no system prompt)
521
+ chat.append({"role": "user", "content": question})
522
+
523
+ for i in range(32):
524
+ # Budget tokens for this step
525
+ prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
526
+ max_tokens_left = max(1, max_new_tokens - len(prompt_tokens) - 100)
527
+
528
+ # ---- Model call (sglang/vLLM-style JSON) ----
529
+ # If you switch to /v1/chat/completions, adjust accordingly.
530
+ response = requests.post(
531
+ f"{model_url}/generate",
532
+ json={
533
+ "text": curr_prompt,
534
+ "sampling_params": {
535
+ "temperature": temperature,
536
+ "max_new_tokens": max_tokens_left,
537
+ "repetition_penalty": 1.05,
538
+ },
539
+ },
540
+ timeout=300,
541
+ ).json()
542
+
543
+ if isinstance(response, dict) and "error" in response:
544
+ # Log the error as assistant text for visibility and break
545
+ err_msg = f"[model_error] {response.get('error')}"
546
+ chat.append({"role": "assistant", "content": err_msg})
547
+ break
548
+
549
+ assistant_text = response.get("text", "")
550
+ # Append assistant's raw text to chat
551
+ chat.append({"role": "assistant", "content": assistant_text})
552
+
553
+ # Update your running prompt with assistant text
554
+ curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text)
555
+
556
+ # Extract tool calls from the assistant text
557
+ tool_calls: List[str] = self.extract_tool_calls(assistant_text)
558
+ if tool_calls:
559
+ all_tool_calls.extend(tool_calls)
560
+
561
+ # Log tool calls as an assistant message (newline-joined)
562
+ chat.append({"role": "assistant", "content": "\n".join(tool_calls)})
563
+
564
+ # Execute tools and collect results
565
+ results: List[str] = self.execute_tool_calls(env, tool_calls)
566
+
567
+ # Feed tool results back into prompt
568
+ curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
569
+
570
+ # Aggregate tool responses into a single user message
571
+ tool_res_blocks = []
572
+ for idx, (call, res) in enumerate(zip(tool_calls, results), 1):
573
+ tool_res_blocks.append(f"[Tool {idx}] Result:\n{res}")
574
+ chat.append({"role": "user", "content": "\n\n".join(tool_res_blocks)})
575
+
576
+ else:
577
+ # No tool calls → model produced a final answer; stop.
578
+ break
579
+
580
+ # Return the original outputs plus the chat-style transcript
581
+ return curr_prompt, all_tool_calls, chat
582
+
583
+
584
+
585
+
586
+
587
+ @retry(max=5, sleep=1, fallback={"score": 0})
588
+ def run_deepseek(
589
+ self,
590
+ env: str,
591
+ func_schemas: str,
592
+ question: str,
593
+ model_name: str,
594
+ temperature: float = 0.0,
595
+ top_p: float = 0.95,
596
+ max_tokens: int = 32768,
597
+ ):
598
+ # print("AA"* 100)
599
+ """
600
+ Chat-based ReCall loop for DeepSeek-R1 on Together.
601
+ """
602
+ sys_content = self.sys_prompt_websailor_deepseek.format(func_schemas=func_schemas)
603
+ # sys_content = self.init_prompt(func_schemas, question)
604
+
605
+ messages = [
606
+ {"role": "system", "content": sys_content},
607
+ {"role": "user", "content": question},
608
+ ]
609
+
610
+ # client = Together(api_key="")
611
+ client = Together(api_key="bcc761f7a821a80c9c5166171ebb36756cd16d505cec226c3b2259b846364000")
612
+ all_tool_calls = []
613
+ for turn in range(32): # up to 10 reasoning turns
614
+ resp = client.chat.completions.create(
615
+ model=model_name,
616
+ # model="Qwen/Qwen3-235B-A22B-fp8-tput",
617
+ messages=messages,
618
+ temperature=temperature,
619
+ top_p=top_p,
620
+ max_tokens=39000,
621
+ stop=["<|end▁of▁sentence|>", "<|im_end|>"]
622
+ )
623
+ # print(resp)
624
+
625
+
626
+ assistant_text = resp.choices[0].message.content
627
+ # print(assistant_text)
628
+ messages.append({"role": "assistant", "content": assistant_text})
629
+ # print(f"assistant_output: {assistant_text}")
630
+
631
+ # ⛑ Safe tool call extraction with diagnostic
632
+ # try:
633
+ # print("Extracting tool calls")
634
+ tool_calls = self.extract_tool_calls_deepseek(assistant_text)
635
+ print(tool_calls)
636
+ all_tool_calls += tool_calls
637
+ # except Exception as e:
638
+ # print(f"Extraction failed with exception {e}")
639
+ # err_msg = f"<tool_response>Tool call extraction failed on turn {turn+1}: {str(e)}</tool_response>"
640
+ # messages.append({"role": "user", "content": err_msg})
641
+ # continue # continue to next turn instead of breaking
642
+ if "<answer>" in assistant_text:
643
+ break
644
+
645
+ if len(tool_calls) != 0:
646
+ results = self.execute_tool_calls(env, tool_calls)
647
+ tool_resp_block = "".join(
648
+ f"<tool_response>{c}\n{r}\n</tool_response>\n"
649
+ for c, r in zip(tool_calls, results)
650
+ )
651
+ messages.append({"role": "user", "content": tool_resp_block})
652
+ # print(f"Tool Response {tool_resp_block}")
653
+ else:
654
+ print("no answer or tool call")
655
+ break
656
+
657
+ trajectory = "\n".join(
658
+ f"<{m['role']}>\n{m['content']}" for m in messages
659
+ if m["role"] != "system"
660
+ )
661
+ return trajectory, all_tool_calls
662
+
663
+
664
+ # ────────────────────────────────────────────────────────────────
665
+ # HF-endpoint version of “retrieve → inject → tool loop”
666
+ # ────────────────────────────────────────────────────────────────
667
+ @retry(max=5, sleep=1, fallback=None)
668
+ def run_with_prompt_injection(
669
+ self,
670
+ env: str,
671
+ func_schemas: str,
672
+ question: str,
673
+ model_url: str = "http://0.0.0.0:1214",
674
+ temperature: float = 0.0,
675
+ max_new_tokens: int = 512,
676
+ top_n: int = 5,
677
+ ):
678
+ """
679
+ 0) call pubmed_search(question, top_n) once via the sandbox
680
+ 1) inject those snippets into the very first user message
681
+ 2) continue with the normal multi-turn ReCall loop against *model_url*
682
+ """
683
+
684
+ # 0️⃣ do a single retrieval tool call
685
+ retrieve_call = json.dumps({
686
+ "name": "pubmed_search",
687
+ "arguments": {"query": question, "top_n": top_n}
688
+ })
689
+ retrieval_raw = self.execute_tool_calls(env, [retrieve_call])[0]
690
+ try:
691
+ snippets_block = retrieval_raw.split("result:", 1)[-1].strip()
692
+ except Exception:
693
+ snippets_block = ""
694
+
695
+ # 1️⃣ build initial prompt with injected snippets
696
+ user_msg = (
697
+ f"Question: {question}\n\n"
698
+ "Here are some relevant PubMed snippets:\n"
699
+ f"{snippets_block}"
700
+ ) if snippets_block else f"Question: {question}"
701
+
702
+ sys_prompt = self.init_prompt(func_schemas, question)
703
+ system_prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>"
704
+ user_prompt = f"<|im_start|>user\n{user_msg}<|im_end|>"
705
+ assistant_pref= f"<|im_start|>assistant\n<think>"
706
+ curr_prompt = system_prompt + "\n" + user_prompt + "\n" + assistant_pref
707
+
708
+ # 2️⃣ normal ReCall loop hitting the HF inference endpoint
709
+ for _ in range(10):
710
+ resp = requests.post(
711
+ f"{model_url}/generate",
712
+ json={
713
+ "text": curr_prompt,
714
+ "sampling_params": {
715
+ "temperature": temperature,
716
+ "max_new_tokens": max_new_tokens,
717
+ }
718
+ },
719
+ timeout=120,
720
+ ).json()
721
+ if "error" in resp.keys():
722
+ print("resp",response)
723
+ assistant_txt = resp["text"]
724
+ curr_prompt = self.cat_assistant_response(curr_prompt, assistant_txt)
725
+
726
+ tool_calls = self.extract_tool_calls(assistant_txt)
727
+ if len(tool_calls) != 0:
728
+ # break # model produced an answer → done
729
+
730
+ results = self.execute_tool_calls(env, tool_calls)
731
+ curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
732
+
733
+ else:
734
+ continue
735
+ return curr_prompt
736
+
737
+
738
+
739
+ @retry(max=5, sleep=1, fallback={"score": 0})
740
+ def run_budget(
741
+ self,
742
+ env: str,
743
+ func_schemas: str,
744
+ question: str,
745
+ model_url: str = "http://0.0.0.0:1214",
746
+ temperature: float = 0.0,
747
+ max_new_tokens: int = 2048,
748
+ ) -> str:
749
+ """
750
+ Execute an agentic dialogue with external tools while *pruning* previous
751
+ <tool_response> blocks to prevent context-length explosion.
752
+ """
753
+ curr_prompt = self.init_prompt(func_schemas, question)
754
+
755
+ for _ in range(16): # hard loop-limit
756
+ # ── 1. Call the model
757
+ rsp = requests.post(
758
+ f"{model_url}/generate",
759
+ json={
760
+ "text": curr_prompt,
761
+ "sampling_params": {
762
+ "temperature": temperature,
763
+ "max_new_tokens": max_new_tokens,
764
+ "stop": ["<|im_end|>", "</think>", "</think>\n" "</think>\n\n"],
765
+ },
766
+
767
+ },
768
+ timeout=120,
769
+ ).json()
770
+ generated = rsp["text"] # what you have now
771
+ matched = rsp["meta_info"]["finish_reason"].get("matched")
772
+
773
+ # ⇢ append the tag back only if it was removed
774
+ if matched and not generated.endswith(matched):
775
+ generated += matched
776
+
777
+ # Fail fast on server error
778
+ if "error" in rsp:
779
+ raise RuntimeError(rsp["error"])
780
+
781
+ assistant_text: str = rsp["text"]
782
+ curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text)
783
+
784
+ # ── 2. Check for final answer ────────────────────────────────────
785
+ if "<answer>" in assistant_text:
786
+ break
787
+
788
+ # ── 3. Extract & execute tool calls ──────────────────────────────
789
+ tool_calls: List[str] = self.extract_tool_calls(assistant_text)
790
+ if not tool_calls: # continue reasoning without calling a tool
791
+ continue
792
+
793
+ results: List[str] = self.execute_tool_calls(env, tool_calls)
794
+
795
+
796
+ # ── 4. BEFORE appending new tool output, drop all old ones ───────
797
+ curr_prompt =self. _strip_old_tool_responses(curr_prompt)
798
+
799
+ # ── 5. Append *only* the fresh tool_response block ───────────────
800
+ curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
801
+
802
+ return curr_prompt
803
+
804
+
805
+
806
+
807
+ def _strip_old_tool_responses_msgs(self, messages: list[dict]) -> list[dict]:
808
+ """
809
+ Return a copy of `messages` with every *user* message that starts with
810
+ <tool_response> removed. Keeps assistant turns untouched.
811
+ """
812
+ return [
813
+ m for m in messages
814
+ if not (m["role"] == "user" and m["content"].lstrip().startswith("<tool_response>"))
815
+ ]
816
+ # ────────── budget version ──────────
817
+ @retry(max=5, sleep=1, fallback={"score": 0})
818
+ def run_deepseek_budget(
819
+ self,
820
+ env: str,
821
+ func_schemas: str,
822
+ question: str,
823
+ api_key: str,
824
+ model_name: str,
825
+ temperature: float = 0.0,
826
+ top_p: float = 0.95,
827
+ max_tokens: int = 32768,
828
+ max_turns: int = 10,
829
+ ):
830
+ """
831
+ Chat-based ReCall loop for DeepSeek-R1 **with context-budget pruning**.
832
+ Keeps only the *latest* <tool_response> block to avoid prompt bloat.
833
+ """
834
+ sys_content = self.system_prompt_budget.format(func_schemas=func_schemas)
835
+
836
+ messages = [
837
+ {"role": "system", "content": sys_content},
838
+ {"role": "user", "content": question},
839
+ ]
840
+
841
+ client = Together(api_key=api_key)
842
+
843
+ for turn in range(max_turns):
844
+ # ── 1. model call ───────────────────────────────────────────────
845
+ resp = client.chat.completions.create(
846
+ model=model_name,
847
+ messages=messages,
848
+ temperature=temperature,
849
+ top_p=top_p,
850
+ max_tokens=max_tokens,
851
+ stop=["</tool_call>", "<|end▁of▁sentence|>"],
852
+ )
853
+ assistant_text = resp.choices[0].message.content
854
+ messages.append({"role": "assistant", "content": assistant_text})
855
+
856
+ print(f"**assistant** \n {assistant_text}")
857
+
858
+ # ── 2. finished? ────────────────────────────────────────────────
859
+ if "<answer>" in assistant_text:
860
+ break
861
+
862
+ # ── 3. parse tool calls ────────────────────────────────────────
863
+ tool_calls = self.extract_tool_calls(assistant_text)
864
+ print(f"**tool_calls** \n {tool_calls}")
865
+ if not tool_calls:
866
+ continue # keep reasoning without tools
867
+
868
+ # ── 4. execute tools ───────────────────────────────────────────
869
+ results = self.execute_tool_calls(env, tool_calls)
870
+ print(f"**tool_response** \n {results}")
871
+
872
+ # ── 5. prune & append fresh tool_response ──────────────────────
873
+ messages = self._strip_old_tool_responses_msgs(messages)
874
+
875
+ tool_resp_block = "".join(
876
+ f"<tool_response>{c}\n{r}\n</tool_response>\n"
877
+ for c, r in zip(tool_calls, results)
878
+ )
879
+ messages.append({"role": "user", "content": tool_resp_block})
880
+
881
+ # ── 6. flatten & return trajectory (sans system for readability) ───
882
+ trajectory = "\n".join(
883
+ f"<{m['role']}>\n{m['content']}" for m in messages if m["role"] != "system"
884
+ )
885
+ return trajectory
886
+
887
+
888
+ @retry(max=5, sleep=1, fallback=None)
889
+ def run_deepseek_with_prompt_injection(
890
+ self,
891
+ env: str,
892
+ func_schemas: str,
893
+ question: str,
894
+ api_key: str,
895
+ model_name: str,
896
+ temperature: float = 0.0,
897
+ top_p: float = 0.95,
898
+ max_tokens: int = 32768,
899
+ ):
900
+ """
901
+ 1) Call pubmed_search(question, top_n=5) as a tool to get snippets.
902
+ 2) Inject them into the first user message.
903
+ 3) Proceed with the usual DeepSeek-R1 tool‐based rollout.
904
+ """
905
+
906
+ # ── Step 0: prepare the single‐tool call for retrieval ───────────────
907
+ retrieve_call = json.dumps({
908
+ "name": "pubmed_search",
909
+ "arguments": {
910
+ "query": question,
911
+ "top_n": 5
912
+ }
913
+ })
914
+
915
+ # Execute it once via your helper
916
+ # note: `env` must include whatever import / client‐setup
917
+ # your sandbox needs to run pubmed_search(...)
918
+ raw_retrieval_results = self.execute_tool_calls(env, [retrieve_call])[0]
919
+ # print("AAAAA"*100)
920
+ try:
921
+ snippets = raw_retrieval_results[9:] #"remove result: str"
922
+ # print(snippets)
923
+ except:
924
+ snippets = ""
925
+ # print(f"[ReCall] Retriever call failed to parse JSON, got:\n{raw_retrieval_results!r}")
926
+
927
+ # ── Step 1: build the injected user prompt ────────────────────────────
928
+ if snippets:
929
+
930
+ user_content = (
931
+ f"Question: {question}\n\n"
932
+ "Here are some relevant PubMed snippets:\n"
933
+ f"{snippets}"
934
+ )
935
+ else:
936
+ user_content = f"Question: {question}"
937
+
938
+ # ── Step 2: start the chat history ────────────────────────────────────
939
+ sys_content = self.system_prompt_forcing_tool_call
940
+ messages = [
941
+ {"role": "system", "content": sys_content},
942
+ {"role": "user", "content": user_content},
943
+ ]
944
+ client = Together(api_key=api_key)
945
+
946
+ # ── Step 3: your normal ReCall tool‐calling loop ─────────────────────
947
+ for turn in range(10):
948
+ resp = client.chat.completions.create(
949
+ model = model_name,
950
+ messages = messages,
951
+ temperature = temperature,
952
+ top_p = top_p,
953
+ max_tokens = max_tokens,
954
+ stop = ["</tool_call>", "<|end▁of▁sentence|>"]
955
+ )
956
+
957
+ assistant_text = resp.choices[0].message.content
958
+ messages.append({"role": "assistant", "content": assistant_text})
959
+
960
+ tool_calls = self.extract_tool_calls(assistant_text)
961
+ if not tool_calls:
962
+ break
963
+
964
+ # Execute all of the tool calls in one go
965
+ results = self.execute_tool_calls(env, tool_calls)
966
+ # and append them back in the required <tool_response> format
967
+ tool_resp_block = "".join(
968
+ f"<tool_response>{call}\n{out}\n</tool_response>\n"
969
+ for call, out in zip(tool_calls, results)
970
+ )
971
+ messages.append({"role": "user", "content": tool_resp_block})
972
+
973
+ # ── Step 4: flatten to a single trajectory ────────────────────────────
974
+ trajectory = "\n".join(
975
+ f"<{m['role']}>\n{m['content']}"
976
+ for m in messages
977
+ if m["role"] != "system"
978
+ )
979
+ return trajectory
980
+
inference/simpledeepsearch.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """o1_searcher_inference.py — Serper‑based Search‑o1 re‑implementation
3
+ with *original* in‑house summarisation workflow, step‑replacement logic and
4
+ bug‑fixes for duplicate queries / ValueError.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os, re, json, time, string, pathlib
9
+ from dataclasses import dataclass
10
+ from typing import List, Dict, Optional, Tuple
11
+ import requests, trafilatura
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # Optional NLTK sentence tokenizer (fallback to regex) -------------------------
15
+ try:
16
+ from nltk.tokenize import sent_tokenize # type: ignore
17
+ except Exception: # ImportError *or* missing punkt data
18
+ def sent_tokenize(x: str):
19
+ return re.split(r"(?<=[.!?]) +", x)
20
+
21
+ # -----------------------------------------------------------------------------
22
+ # Special tags & constants -----------------------------------------------------
23
+ BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
24
+ END_SEARCH_QUERY = "<|end_search_query|>"
25
+ BEGIN_DOCUMENT_QUERY = "<|begin_of_document|>"
26
+ END_DOCUMENT_QUERY = "<|end_of_document|>"
27
+ THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
28
+ EOS_TOKEN = "<|im_end|>"
29
+ ANSWER_OPEN, ANSWER_CLOSE = "<answer>", "</answer>"
30
+ STOP_STRINGS = [END_SEARCH_QUERY, ANSWER_CLOSE, EOS_TOKEN, "<|endoftext|>"]
31
+ ALLOWED_DATASETS = {"musique", "frames", "simpleqa", "browsercomp"}
32
+ # tokenizer =
33
+ TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
34
+
35
+ # ───────────────────────── tokenizer ────────────────────────────────────────
36
+ try:
37
+ from transformers import AutoTokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
39
+ except Exception as e:
40
+ import sys
41
+ sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Helper functions -------------------------------------------------------------
45
+
46
+ def remove_punc(t: str) -> str:
47
+ return t.translate(str.maketrans("", "", string.punctuation))
48
+
49
+ # legacy aliases for older checkpoints ---------------------------------------
50
+ _nopunc = remove_punc
51
+
52
+
53
+ def f1(a: set, b: set) -> float:
54
+ inter = len(a & b)
55
+ return 0.0 if inter == 0 else 2 * inter / (len(a) + len(b))
56
+
57
+ # legacy alias
58
+ _f1 = f1
59
+
60
+
61
+ def extract_snippet_ctx(text: str, snippet: str, win: int = 2500) -> str:
62
+ """Return *window*‑sized context around the sentence most similar to snippet."""
63
+ text = text[:50_000]
64
+ sn_set = set(remove_punc(snippet.lower()).split())
65
+ best, best_score = None, 0.20
66
+ for sent in sent_tokenize(text):
67
+ score = f1(sn_set, set(remove_punc(sent.lower()).split()))
68
+ if score > best_score:
69
+ best, best_score = sent, score
70
+ if best:
71
+ pos = text.find(best)
72
+ return text[max(0, pos - win): pos + len(best) + win]
73
+ return text[: 2 * win]
74
+
75
+ # -----------------------------------------------------------------------------
76
+ # Config dataclass -------------------------------------------------------------
77
+ @dataclass
78
+ class SDSCfg:
79
+ serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
80
+ gl: str = "us"; hl: str = "en"
81
+ top_k: int = 10; max_doc_len: int = 3000
82
+ max_search: int = 10; max_turn: int = 15
83
+ use_jina: bool = True
84
+ jina_tpl: str = "https://r.jina.ai/http://{}"
85
+ # generation params
86
+ temperature: float = 0.7; top_p: float = 0.8; top_k_sampling: int = 20
87
+ rep_pen: float = 1.05; thinker_max_tokens: int = 40960
88
+
89
+ # -----------------------------------------------------------------------------
90
+ # Serper search + page fetch ---------------------------------------------------
91
+
92
+ def serper_search(q: str, num: int, key: str, gl="us", hl="en") -> List[Dict]:
93
+ hdr = {"X-API-KEY": key, "Content-Type": "application/json"}
94
+ body = {"q": q, "num": num, "gl": gl, "hl": hl}
95
+ r = requests.post("https://google.serper.dev/search", json=body, headers=hdr, timeout=20)
96
+ r.raise_for_status(); return r.json().get("organic", [])
97
+
98
+
99
+ def fetch_page(url: str, cfg: O1Cfg, snippet: str = "") -> str:
100
+ try:
101
+ txt = ""
102
+ if cfg.use_jina:
103
+ r = requests.get(cfg.jina_tpl.format(url), timeout=15)
104
+ if r.ok and len(r.text.strip()) > 100:
105
+ txt = r.text.strip()
106
+ if txt == "":
107
+ r = requests.get(url, timeout=15); r.raise_for_status()
108
+ txt = trafilatura.extract(r.text, output_format="txt") or ""
109
+ if snippet:
110
+ txt = extract_snippet_ctx(txt, snippet, cfg.max_doc_len)
111
+
112
+ return txt
113
+ except Exception:
114
+ return ""
115
+
116
+ # -----------------------------------------------------------------------------
117
+ # replace_recent_steps --------------------------------------------------------
118
+
119
+ def replace_recent_steps(origin: str, patch: str) -> str:
120
+ """Apply *patch* (containing numbered `Step N:` lines) to *origin*."""
121
+ step_re = re.compile(r"Step\s+(\d+):\s*")
122
+
123
+ def parse(block: str) -> Dict[int, str]:
124
+ cur, buf, out = None, [], {}
125
+ for line in block.splitlines():
126
+ m = step_re.match(line)
127
+ if m:
128
+ if cur is not None:
129
+ out[cur] = "\n".join(buf).strip()
130
+ cur, buf = int(m.group(1)), [line[m.end():].strip()]
131
+ elif cur is not None:
132
+ buf.append(line)
133
+ if cur is not None:
134
+ out[cur] = "\n".join(buf).strip()
135
+ return out
136
+
137
+ base = parse(origin); mod = parse(patch)
138
+ for k, v in mod.items():
139
+ if "DELETE THIS STEP" in v:
140
+ base.pop(k, None)
141
+ else:
142
+ base[k] = v
143
+ return "\n\n".join(base[k] for k in sorted(base))
144
+
145
+ # -----------------------------------------------------------------------------
146
+ # Prompts ----------------------------------------------------------------------
147
+ # from prompts import get_webpage_to_reasonchain_instruction # keep original helper
148
+
149
+ # -----------------------------------------------------------------------------
150
+ # Main agent -------------------------------------------------------------------
151
+ class SDSearcher:
152
+ # STOP_TOKENS = [
153
+ # "<|im_end|>",
154
+ # "<|endoftext|>",
155
+ # "<|end_of_query|>",
156
+ # " <|end_of_query|>",
157
+ # "<|end_of_query|>\n",
158
+ # "<|end_of_query|>\n\n",
159
+ # " <|end_of_query|>\n",
160
+ # " <|end_of_query|>\n\n",
161
+ # ]
162
+ get_webpage_to_reasonchain_instruction = """**Task Instruction:**
163
+
164
+ You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
165
+
166
+ **Guidelines:**
167
+
168
+ 1. **Analyze the Searched Web Pages:**
169
+ - Carefully review the content of each searched web page.
170
+ - Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
171
+
172
+ 2. **Extract Relevant Information:**
173
+ - Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
174
+ - Ensure that the extracted information is accurate and relevant.
175
+
176
+ 3. **Output Format:**
177
+ - **If the web pages provide helpful information for current search query:** Present the information beginning with **Final Information** as shown below.
178
+ **Final Information**
179
+
180
+ [Helpful information]
181
+
182
+ - **If the web pages do not provide any helpful information for current search query:** Output the following text.
183
+
184
+ **Final Information**
185
+
186
+ No helpful information found.
187
+
188
+ **Inputs:**
189
+ - **Previous Reasoning Steps:**
190
+ {prev_reasoning}
191
+
192
+ - **Current Search Query:**
193
+ {search_query}
194
+
195
+ - **Searched Web Pages:**
196
+ {document}
197
+
198
+ Now you should analyze each web page and find helpful information based on the current search query {search_query} and previous reasoning steps.
199
+ Return the Helpful information in the <information></information> tags
200
+ """
201
+
202
+ sys_prompt_multiqa = (
203
+ "You are a reasoning assistant with the ability to perform web searches to help "
204
+ "you answer the user's question accurately. You have special tools:\n\n"
205
+ "- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\n"
206
+ "Then, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\n"
207
+ f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to 16.\n\n"
208
+ "Once you have all the information you need, continue your reasoning.\n\n"
209
+ "Example:\n"
210
+ "Question: \"Alice David is the voice of Lara Croft in a video game developed by which company?\"\n"
211
+ "Assistant thinking steps:\n"
212
+ "- I need to find out who voices Lara Croft in the video game.\n"
213
+ "- Then, I need to determine which company developed that video game.\n\n"
214
+ "Assistant:\n"
215
+ "<|begin_search_query|>Alice David Lara Croft voice<|end_search_query|>\n\n"
216
+ "(System returns processed information from relevant web pages)\n\n"
217
+ "Assistant thinks: The search results indicate that Alice David is the voice of Lara Croft in a specific video game. Now, I need to find out which company developed that game.\n\n"
218
+ "Assistant:\n"
219
+ "<|begin_search_query|>video game developed by Alice David Lara Croft<|end_search_query|>\n\n"
220
+ "(System returns processed information from relevant web pages)\n\n"
221
+ "Assistant continues reasoning with the new information...\n\n"
222
+ "Remember:\n"
223
+ "- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n"
224
+ "- When done searching, continue your reasoning.\n\n",
225
+ "Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions"
226
+ )
227
+
228
+ def __init__(self, cfg: O1Cfg, thinker_url: str):
229
+ if not cfg.serper_api_key:
230
+ raise ValueError("SERPER_API_KEY required")
231
+ self.cfg, self.model_url = cfg, thinker_url.rstrip("/")
232
+ self.search_cache: Dict[str, List[Dict]] = {}
233
+ self.page_cache: Dict[Tuple[str, str], str] = {}
234
+
235
+ # --- low‑level generation call ------------------------------------------
236
+ def _generate(self, prompt: str) -> str:
237
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
238
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
239
+ resp = requests.post(
240
+ f"{self.model_url}/generate",
241
+ json={
242
+ "text": prompt,
243
+ "sampling_params": {
244
+ "temperature": self.cfg.temperature,
245
+ "top_p": self.cfg.top_p,
246
+ "max_new_tokens": max_tokens_left,
247
+ "repetition_penalty": self.cfg.rep_pen,
248
+ "stop": STOP_STRINGS,
249
+ },
250
+
251
+ },
252
+ timeout=60,
253
+ ).json()
254
+ generated = resp["text"] # what you have now
255
+ matched = resp["meta_info"]["finish_reason"].get("matched")
256
+ reason = resp["meta_info"]["finish_reason"].get("type")
257
+
258
+ # ⇢ append the tag back only if it was removed
259
+ if reason == "stop" and matched in STOP_STRINGS:
260
+ if not "<|end_of_query|>" in generated:
261
+ generated += matched
262
+ if reason == "stop" and matched == 151645:
263
+ if not generated.endswith("<|im_end|>"):
264
+ generated += "<|im_end|>"
265
+ if reason == "stop" and matched == 151643:
266
+ if not generated.endswith("<|endoftext|>"):
267
+ generated += "<|endoftext|>"
268
+ return generated
269
+
270
+ def _generate_summary(self, prompt: str) -> str:
271
+ summary_url = "http://0.0.0.0:1243"
272
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
273
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
274
+ resp = requests.post(
275
+ f"{summary_url}/generate",
276
+ json={
277
+ "text": prompt,
278
+ "sampling_params": {
279
+ "temperature": self.cfg.temperature,
280
+ "max_new_tokens": 8192,#max_tokens_left,
281
+ "stop": STOP_STRINGS,
282
+ },
283
+
284
+ },
285
+ timeout=60,
286
+ ).json()
287
+ generated = resp["text"] # what you have now
288
+ matched = resp["meta_info"]["finish_reason"].get("matched")
289
+ reason = resp["meta_info"]["finish_reason"].get("type")
290
+ # ##print("-"*100)
291
+ # ##print(resp)
292
+ # ##print(matched)
293
+ # ##print("-"*100)
294
+ # ⇢ append the tag back only if it was removed
295
+ if reason == "stop" and matched in STOP_STRINGS:
296
+ if not "<|end_of_query|>" in generated:
297
+ generated += matched + EOS_TOKEN
298
+ if reason == "stop" and matched == 151645:
299
+ if not generated.endswith("<|im_end|>"):
300
+ generated += "<|im_end|>"
301
+ if reason == "stop" and matched == 151643:
302
+ if not generated.endswith("<|endoftext|>"):
303
+ generated += "<|endoftext|>"
304
+ return generated
305
+ # --- public entry -------------------------------------------------------
306
+ def run(self, question: str):
307
+ prompt = (
308
+ f"<|im_start|>system\n{self.sys_prompt_multiqa}<|im_end|>\n"
309
+ f"<|im_start|>user\n{question}<|im_end|>\n"
310
+ f"<|im_start|>assistant\n{THINK_OPEN}"
311
+ )
312
+ full_trace = prompt # <-- Track full trace
313
+ queries: List[str] = []
314
+ seen_queries: set[str] = set()
315
+
316
+ for i in range(self.cfg.max_turn):
317
+ chunk = self._generate(prompt)
318
+ prompt += chunk
319
+
320
+ if ANSWER_CLOSE in chunk:
321
+ break
322
+
323
+ ##print(f"step-{i}")
324
+ ##print(chunk)
325
+
326
+ query = self._extract_query(chunk)
327
+ ##print(query)
328
+ if not query or len(queries) >= self.cfg.max_search:
329
+ break
330
+ if query in seen_queries:
331
+ continue
332
+ queries.append(query)
333
+ seen_queries.add(query)
334
+
335
+ doc = self._retrieve_doc(query)
336
+ prev_reasoning = self._extract_reasoning(prompt)
337
+ summary = "\n<|im_start|>user" + self._summarise(prev_reasoning, query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
338
+ ##print("summary")
339
+ ##print(summary)
340
+ prompt += summary # <-- Log summary to trace
341
+
342
+ # new_reasoning = replace_recent_steps(prev_reasoning, summary)
343
+
344
+ # if prev_reasoning:
345
+ # prompt = prompt.rsplit(prev_reasoning, 1)[0] + new_reasoning + THINK_CLOSE + THINK_OPEN
346
+ # else:
347
+ # prompt += new_reasoning + THINK_CLOSE + THINK_OPEN
348
+
349
+ # full_trace += + THINK_CLOSE + THINK_OPEN + "\n" # <-- Log reasoning to trace
350
+ else:
351
+ final = f"{ANSWER_OPEN}I don't know.{ANSWER_CLOSE}"
352
+ prompt += final
353
+ # full_trace += final
354
+
355
+ return prompt, queries
356
+
357
+
358
+ # ---------------------------------------------------------------------
359
+ # helpers --------------------------------------------------------------
360
+ def _extract_query(self, txt: str) -> Optional[str]:
361
+ if BEGIN_SEARCH_QUERY not in txt or END_SEARCH_QUERY not in txt:
362
+ return None
363
+ frag = txt.split(BEGIN_SEARCH_QUERY)[-1].split(END_SEARCH_QUERY)[0]
364
+ # strip quotes / ellipsis / tabs
365
+ return re.sub(r"[\"'…\t]", " ", frag.split("<|")[0]).strip()
366
+
367
+ def _retrieve_doc(self, query: str) -> str:
368
+ if query not in self.search_cache:
369
+ self.search_cache[query] = serper_search(query, self.cfg.top_k, self.cfg.serper_api_key,
370
+ gl=self.cfg.gl, hl=self.cfg.hl)
371
+ for hit in self.search_cache[query]:
372
+ # ##print("hit")
373
+ # ##print(hit)
374
+ url, sn = hit.get("link", ""), hit.get("snippet", "")
375
+ if not url:
376
+ continue
377
+ key = (url, sn)
378
+ if key not in self.page_cache:
379
+ self.page_cache[key] = fetch_page(url, self.cfg, sn)
380
+ if self.page_cache[key]:
381
+ return self.page_cache[key]
382
+ return ""
383
+
384
+ def _summarise(self, prev: str, query: str, doc: str) -> str:
385
+ rid_prompt = self.get_webpage_to_reasonchain_instruction.format(prev_reasoning = prev, search_query = query, document = doc)
386
+ chat = f"<|im_start|>user\\n{rid_prompt}\\n<|im_end|>\\n<|im_start|>assistant\\n"
387
+ resp = self._generate_summary(chat)
388
+ # ##print("summarization out \n", resp)
389
+ return BEGIN_DOCUMENT_QUERY + self._extract_summary(resp) + END_DOCUMENT_QUERY
390
+ # ##print("summary")
391
+ # ##print(resp)
392
+ # match = re.search(r"Final Information\*\*\s*\n(.+?)<\|im_end\|>", resp)
393
+ # if match:
394
+ # final_info = match.group(1).strip()
395
+ # ##print(final_info)
396
+ # return final_info
397
+
398
+ def _extract_summary(self, prompt: str) -> str:
399
+ if "<information>" in prompt:
400
+ summary = prompt.split("<information>")[-1].split("</information>")[0] if THINK_OPEN in prompt else ""
401
+ return summary
402
+ else:
403
+ match = re.search(r"\*\*Final Information\*\*\s*\n(.+?)<\|im_end\|>", prompt)
404
+ if match:
405
+ final_info = match.group(1).strip()
406
+ return final_info
407
+ return prompt
408
+
409
+ def _extract_reasoning(self, prompt: str) -> str:
410
+ return prompt.split(THINK_OPEN)[-1].split(THINK_CLOSE)[0] if THINK_OPEN in prompt else ""
411
+
412
+ # -----------------------------------------------------------------------------
413
+ # CLI -------------------------------------------------------------------------
414
+ # if __name__ == "__main__":
415
+ # # import argparse, json
416
+ # # parser = argparse.ArgumentParser()
417
+ # # parser.add_argument("question"); parser.add_argument("--dataset", required=True, choices=sorted(ALLOWED_DATASETS)); parser.add_argument("--model-url", required
inference/zerosearch.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # zero_search_inference.py
2
+ """End‑to‑end inference loop that emulates the ZeroSearch prompting style.
3
+
4
+ The policy model ("thinker") must:
5
+ • reason inside <think> … </think>
6
+ • place a query inside <search> … </search> whenever it needs external knowledge
7
+ • return the final short answer inside <answer> … </answer>
8
+
9
+ The wrapper intercepts each <search> request, fulfils it with either:
10
+
11
+ (a) a **simulated search engine** (another small LLM fine‑tuned as ZeroSearch
12
+ retriever) ‑‑ default; or
13
+ (b) a real search backend (e.g. Serper.dev, Bing) if `engine="real"`.
14
+
15
+ It then injects results between <information> … </information> and hands control
16
+ back to the policy model. The loop repeats until </answer> is produced or a
17
+ maximum number of retrieval rounds is reached.
18
+
19
+ The goal is to mirror the ergonomics of the user’s existing `ReCall` class so
20
+ that outer orchestration code can drop this in with minimal friction.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ import os
26
+ import re
27
+ import time
28
+ from dataclasses import dataclass
29
+ from typing import List, Optional
30
+
31
+ import requests
32
+ from openai import OpenAI
33
+
34
+ __all__ = ["ZeroSearchInference", "ZeroSearchConfig"]
35
+
36
+ TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
37
+
38
+ # ───────────────────────── tokenizer ────────────────────────────────────────
39
+ try:
40
+ from transformers import AutoTokenizer
41
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
42
+ except Exception as e:
43
+ import sys
44
+ sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Utility: retry decorator ---------------------------------------------------
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
52
+ def decorator(func):
53
+ def wrapper(*args, **kwargs):
54
+ for i in range(max_attempts):
55
+ try:
56
+ return func(*args, **kwargs)
57
+ except Exception as exc:
58
+ #print(f"[retry] {func.__name__}: attempt {i + 1}/{max_attempts} failed – {exc}")
59
+ if i == max_attempts - 1:
60
+ return fallback
61
+ time.sleep(sleep)
62
+ return wrapper
63
+ return decorator
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Configuration dataclass ----------------------------------------------------
68
+ # ---------------------------------------------------------------------------
69
+
70
+ @dataclass
71
+ class ZeroSearchConfig:
72
+ # thinker LLM endpoint
73
+ thinker_url: str = "http://0.0.0.0:1214"
74
+ thinker_temperature: float = 0.7
75
+ thinker_max_tokens: int = 40960
76
+
77
+ # retrieval engine mode: "sim" or "real"
78
+ engine: str = "real" # simulated search (LLM) by default
79
+
80
+ # simulated search model (only used if engine == "sim")
81
+ retriever_model: str = "gpt-4o-mini"
82
+ retriever_top_k: int = 5
83
+
84
+ # real search backend (engine == "real")
85
+ serper_api_key: Optional[str] = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
86
+ serper_url: str = "https://google.serper.dev/search"
87
+ serper_top_k: int = 5
88
+
89
+ # Loop limits
90
+ max_rounds: int = 16
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Main wrapper ---------------------------------------------------------------
95
+ # ---------------------------------------------------------------------------
96
+
97
+ class ZeroSearchInference:
98
+ SEARCH_OPEN = "<search>"
99
+ SEARCH_CLOSE = "</search>"
100
+ INFO_OPEN = "<information>"
101
+ INFO_CLOSE = "</information>"
102
+
103
+ ANSWER_CLOSE = "</answer>"
104
+ THINK_OPEN = "<think>"
105
+ THINK_CLOSE = "</think>"
106
+
107
+ STOP_TOKENS = ["<|im_end|>", "<|endoftext|>", "</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]#, "</think>", "</think>\n", " </think>\n", "</think>\n\n", " </think>\n\n"]
108
+
109
+ def __init__(self, cfg: ZeroSearchConfig):
110
+ self.cfg = cfg
111
+ # ------------------------------------------------------------------
112
+ # Public driver -----------------------------------------------------
113
+ # ------------------------------------------------------------------
114
+
115
+ def run(self, user_question: str) -> str:
116
+ tool_calls = []
117
+ prompt = self._build_initial_prompt(user_question)
118
+ for round_idx in range(self.cfg.max_rounds):
119
+ generated = self._call_thinker(prompt)
120
+ #print("-"*100)
121
+ #print(f"Round: {round_idx}")
122
+ #print(generated)
123
+ prompt += generated
124
+
125
+ if self.ANSWER_CLOSE in generated:
126
+ #print(f"[ZeroSearch] Done in {round_idx + 1} rounds")
127
+ break
128
+
129
+ query = self._extract_query(generated)
130
+
131
+ if not query:
132
+ #print("[ZeroSearch] Model failed to emit <search>; aborting")
133
+ break
134
+ tool_calls.append(query)
135
+ info_block = self._retrieve_and_format(query)
136
+ #print(f"retrived docs: \n{info_block}")
137
+ #print("-"*100)
138
+ prompt += info_block + self.THINK_OPEN # next turn
139
+
140
+ else: # exceeded rounds
141
+ prompt += "<answer>I don't know.</answer><|im_end|>"
142
+ return prompt, tool_calls
143
+
144
+ # ------------------------------------------------------------------
145
+ # Prompt construction helpers --------------------------------------
146
+ # ------------------------------------------------------------------
147
+
148
+ def _build_initial_prompt(self, question: str) -> str:
149
+ user_msg = f"""Answer the given question. \
150
+ You must conduct reasoning inside <think> and </think> first every time you get new information. \
151
+ After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \
152
+ You can search as many times as your want. \
153
+ If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n"""
154
+ return f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{self.THINK_OPEN}"
155
+
156
+ # ------------------------------------------------------------------
157
+ # Thinker model call ------------------------------------------------
158
+ # ------------------------------------------------------------------
159
+
160
+ @retry(fallback="")
161
+ def _call_thinker(self, prompt: str) -> str:
162
+ prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
163
+ max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
164
+ resp = requests.post(
165
+ f"{self.cfg.thinker_url}/generate",
166
+ json={
167
+ "text": prompt,
168
+ "sampling_params": {
169
+ "temperature": self.cfg.thinker_temperature,
170
+ "max_new_tokens": max_tokens_left,
171
+ "stop": self.STOP_TOKENS,
172
+ },
173
+
174
+ },
175
+ timeout=60,
176
+ ).json()
177
+ generated = resp["text"] # what you have now
178
+ matched = resp["meta_info"]["finish_reason"].get("matched")
179
+ reason = resp["meta_info"]["finish_reason"].get("type")
180
+ # ⇢ append the tag back only if it was removed
181
+ if reason == "stop" and matched in self.STOP_TOKENS:
182
+ if not generated.endswith(matched):
183
+ generated += matched
184
+ if reason == "stop" and matched == 151645:
185
+ if not generated.endswith("<|im_end|>"):
186
+ generated += "<|im_end|>"
187
+ return generated
188
+
189
+ # ------------------------------------------------------------------
190
+ # Query extraction --------------------------------------------------
191
+ # ------------------------------------------------------------------
192
+
193
+ def _extract_query(self, gen_text: str) -> Optional[str]:
194
+ if self.SEARCH_OPEN not in gen_text or self.SEARCH_CLOSE not in gen_text:
195
+ return None
196
+ query = gen_text.split(self.SEARCH_OPEN)[-1].split(self.SEARCH_CLOSE)[0].strip()
197
+ return query or None
198
+
199
+ # ------------------------------------------------------------------
200
+ # Retrieval path ----------------------------------------------------
201
+ # ------------------------------------------------------------------
202
+
203
+ def _retrieve_and_format(self, query: str) -> str:
204
+ if self.cfg.engine == "real":
205
+ docs = self._real_search(query)
206
+ #print("DOCS")
207
+ #print(docs)
208
+ else:
209
+ docs = self._simulated_search(query)
210
+ return f"{self.INFO_OPEN}\n{docs}\n{self.INFO_CLOSE}\n\n"
211
+
212
+ # --- simulated search with LLM ------------------------------------
213
+
214
+ @retry(fallback="No information available")
215
+ def _simulated_search(self, query: str) -> str:
216
+ messages = [
217
+ {
218
+ "role": "user",
219
+ "content": (
220
+ "You are a search engine. Return up to "
221
+ f"{self.cfg.retriever_top_k} short documents (titles + snippets) "
222
+ "most relevant to the query, each on a new line.\n\n"
223
+ f"Query: {query}"
224
+ ),
225
+ }
226
+ ]
227
+ resp = self.openai.chat.completions.create(
228
+ model=self.cfg.retriever_model,
229
+ messages=messages,
230
+ max_tokens=256,
231
+ )
232
+ return resp.choices[0].message.content.strip()
233
+
234
+ # --- real web search via Serper ----------------------------------
235
+
236
+ @retry(fallback="No information available")
237
+ def _real_search(self, query: str) -> str:
238
+ if not self.cfg.serper_api_key:
239
+ raise ValueError("serper_api_key must be set for real search mode")
240
+ headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"}
241
+ payload = {"q": query, "num": self.cfg.serper_top_k}
242
+ resp = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20)
243
+ resp.raise_for_status()
244
+ data = resp.json().get("organic", [])[: self.cfg.serper_top_k]
245
+ lines = []
246
+ for i, item in enumerate(data, 1):
247
+ snippet = f"Title: {item['title']}, \nSnippet{item['snippet']}"
248
+ lines.append(f"Doc {i}: {snippet}")
249
+ return "\n".join(lines) or "No information available"