Upload folder using huggingface_hub
Browse files- __init__.py +6 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- inference/__init__.py +0 -0
- inference/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/__pycache__/o1_searcher.cpython-310.pyc +0 -0
- inference/__pycache__/r1_searcher.cpython-310.pyc +0 -0
- inference/__pycache__/re_call.cpython-310.pyc +0 -0
- inference/__pycache__/simpledeepsearch.cpython-310.pyc +0 -0
- inference/__pycache__/zerosearch.cpython-310.pyc +0 -0
- inference/o1_searcher.py +481 -0
- inference/oss.py +195 -0
- inference/r1_searcher.py +344 -0
- inference/re_call.py +980 -0
- inference/simpledeepsearch.py +417 -0
- inference/zerosearch.py +249 -0
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .inference.re_call import ReCall
|
| 2 |
+
from .inference.r1_searcher import R1Searcher, R1SearchConfig
|
| 3 |
+
from .inference.zerosearch import ZeroSearchInference, ZeroSearchConfig
|
| 4 |
+
from .inference.o1_searcher import O1Cfg, O1Searcher
|
| 5 |
+
from .inference.simpledeepsearch import SDSCfg, SDSearcher
|
| 6 |
+
__all__ = ["ReCall", "R1Searcher", "ZeroSearchInference", "ZeroSearchConfig", "R1SearchConfig", "O1Cfg", "O1Searcher", "SDSCfg", "SDSearcher"]
|
__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (613 Bytes). View file
|
|
|
inference/__init__.py
ADDED
|
File without changes
|
inference/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (187 Bytes). View file
|
|
|
inference/__pycache__/o1_searcher.cpython-310.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
inference/__pycache__/r1_searcher.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
inference/__pycache__/re_call.cpython-310.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
inference/__pycache__/simpledeepsearch.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
inference/__pycache__/zerosearch.cpython-310.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
inference/o1_searcher.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""o1_searcher_inference.py — Serper‑based Search‑o1 re‑implementation
|
| 3 |
+
with *original* in‑house summarisation workflow, step‑replacement logic and
|
| 4 |
+
bug‑fixes for duplicate queries / ValueError.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os, re, json, time, string, pathlib
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Dict, Optional, Tuple
|
| 11 |
+
import requests, trafilatura
|
| 12 |
+
import threading
|
| 13 |
+
from openai import OpenAI, APIStatusError
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Optional NLTK sentence tokenizer (fallback to regex) -------------------------
|
| 16 |
+
try:
|
| 17 |
+
from nltk.tokenize import sent_tokenize # type: ignore
|
| 18 |
+
except Exception: # ImportError *or* missing punkt data
|
| 19 |
+
def sent_tokenize(x: str):
|
| 20 |
+
return re.split(r"(?<=[.!?]) +", x)
|
| 21 |
+
|
| 22 |
+
def _oa() -> OpenAI:
|
| 23 |
+
th = threading.current_thread()
|
| 24 |
+
if not hasattr(th, "_oa"):
|
| 25 |
+
th._oa = OpenAI()
|
| 26 |
+
return th._oa
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
# Special tags & constants -----------------------------------------------------
|
| 31 |
+
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
| 32 |
+
END_SEARCH_QUERY = "<|end_search_query|>"
|
| 33 |
+
BEGIN_DOCUMENT_QUERY = "<|begin_of_document|>"
|
| 34 |
+
END_DOCUMENT_QUERY = "<|end_of_document|>"
|
| 35 |
+
THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
|
| 36 |
+
EOS_TOKEN = "<|im_end|>"
|
| 37 |
+
ANSWER_OPEN, ANSWER_CLOSE = "<answer>", "</answer>"
|
| 38 |
+
STOP_STRINGS = [END_SEARCH_QUERY, ANSWER_CLOSE, EOS_TOKEN, "<|endoftext|>"]
|
| 39 |
+
ALLOWED_DATASETS = {"musique", "frames", "simpleqa", "browsercomp"}
|
| 40 |
+
# tokenizer =
|
| 41 |
+
TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
|
| 42 |
+
# ───────────────────────── BASIC UTILS ──────────────────────────────
|
| 43 |
+
def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
|
| 44 |
+
"""Tiny retry decorator with fixed back‑off."""
|
| 45 |
+
|
| 46 |
+
def decorator(func):
|
| 47 |
+
def wrapper(*args, **kwargs):
|
| 48 |
+
for i in range(max_attempts):
|
| 49 |
+
try:
|
| 50 |
+
return func(*args, **kwargs)
|
| 51 |
+
except Exception as exc:
|
| 52 |
+
if i == max_attempts - 1:
|
| 53 |
+
#print(f"[retry] {func.__name__} failed – giving up: {exc}")
|
| 54 |
+
return fallback
|
| 55 |
+
#print(f"[retry] {func.__name__}: attempt {i+1}/{max_attempts} → {exc}")
|
| 56 |
+
time.sleep(sleep)
|
| 57 |
+
|
| 58 |
+
return wrapper
|
| 59 |
+
|
| 60 |
+
return decorator
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ───────────────────────── tokenizer ────────────────────────────────────────
|
| 64 |
+
try:
|
| 65 |
+
from transformers import AutoTokenizer
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
import sys
|
| 69 |
+
sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
|
| 70 |
+
|
| 71 |
+
# -----------------------------------------------------------------------------
|
| 72 |
+
# Helper functions -------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
def remove_punc(t: str) -> str:
|
| 75 |
+
return t.translate(str.maketrans("", "", string.punctuation))
|
| 76 |
+
|
| 77 |
+
# legacy aliases for older checkpoints ---------------------------------------
|
| 78 |
+
_nopunc = remove_punc
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def f1(a: set, b: set) -> float:
|
| 82 |
+
inter = len(a & b)
|
| 83 |
+
return 0.0 if inter == 0 else 2 * inter / (len(a) + len(b))
|
| 84 |
+
|
| 85 |
+
# legacy alias
|
| 86 |
+
_f1 = f1
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def extract_snippet_ctx(text: str, snippet: str, win: int = 2500) -> str:
|
| 90 |
+
"""Return *window*‑sized context around the sentence most similar to snippet."""
|
| 91 |
+
text = text[:50_000]
|
| 92 |
+
sn_set = set(remove_punc(snippet.lower()).split())
|
| 93 |
+
best, best_score = None, 0.20
|
| 94 |
+
for sent in sent_tokenize(text):
|
| 95 |
+
score = f1(sn_set, set(remove_punc(sent.lower()).split()))
|
| 96 |
+
if score > best_score:
|
| 97 |
+
best, best_score = sent, score
|
| 98 |
+
if best:
|
| 99 |
+
pos = text.find(best)
|
| 100 |
+
return text[max(0, pos - win): pos + len(best) + win]
|
| 101 |
+
return text[: 2 * win]
|
| 102 |
+
|
| 103 |
+
# -----------------------------------------------------------------------------
|
| 104 |
+
# Config dataclass -------------------------------------------------------------
|
| 105 |
+
@dataclass
|
| 106 |
+
class O1Cfg:
|
| 107 |
+
serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
|
| 108 |
+
gl: str = "us"; hl: str = "en"
|
| 109 |
+
top_k: int = 10; max_doc_len: int = 3000
|
| 110 |
+
max_search: int = 10; max_turn: int = 15
|
| 111 |
+
use_jina: bool = True
|
| 112 |
+
jina_tpl: str = "https://r.jina.ai/http://{}"
|
| 113 |
+
# generation params
|
| 114 |
+
temperature: float = 0.7; top_p: float = 0.8; top_k_sampling: int = 20
|
| 115 |
+
rep_pen: float = 1.05; thinker_max_tokens: int = 32768
|
| 116 |
+
summariser_model: str = "gpt-4o-mini"
|
| 117 |
+
|
| 118 |
+
# -----------------------------------------------------------------------------
|
| 119 |
+
# Serper search + page fetch ---------------------------------------------------
|
| 120 |
+
|
| 121 |
+
def serper_search(q: str, num: int, key: str, gl="us", hl="en") -> List[Dict]:
|
| 122 |
+
hdr = {"X-API-KEY": key, "Content-Type": "application/json"}
|
| 123 |
+
body = {"q": q, "num": num, "gl": gl, "hl": hl}
|
| 124 |
+
r = requests.post("https://google.serper.dev/search", json=body, headers=hdr, timeout=20)
|
| 125 |
+
r.raise_for_status(); return r.json().get("organic", [])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def fetch_page(url: str, cfg: O1Cfg, snippet: str = "") -> str:
|
| 129 |
+
try:
|
| 130 |
+
txt = ""
|
| 131 |
+
if cfg.use_jina:
|
| 132 |
+
r = requests.get(cfg.jina_tpl.format(url), timeout=15)
|
| 133 |
+
if r.ok and len(r.text.strip()) > 100:
|
| 134 |
+
txt = r.text.strip()
|
| 135 |
+
if txt == "":
|
| 136 |
+
r = requests.get(url, timeout=15); r.raise_for_status()
|
| 137 |
+
txt = trafilatura.extract(r.text, output_format="txt") or ""
|
| 138 |
+
if snippet:
|
| 139 |
+
txt = extract_snippet_ctx(txt, snippet, cfg.max_doc_len)
|
| 140 |
+
|
| 141 |
+
return txt
|
| 142 |
+
except Exception:
|
| 143 |
+
return ""
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------------------------
|
| 146 |
+
# replace_recent_steps --------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
def replace_recent_steps(origin: str, patch: str) -> str:
|
| 149 |
+
"""Apply *patch* (containing numbered `Step N:` lines) to *origin*."""
|
| 150 |
+
step_re = re.compile(r"Step\s+(\d+):\s*")
|
| 151 |
+
|
| 152 |
+
def parse(block: str) -> Dict[int, str]:
|
| 153 |
+
cur, buf, out = None, [], {}
|
| 154 |
+
for line in block.splitlines():
|
| 155 |
+
m = step_re.match(line)
|
| 156 |
+
if m:
|
| 157 |
+
if cur is not None:
|
| 158 |
+
out[cur] = "\n".join(buf).strip()
|
| 159 |
+
cur, buf = int(m.group(1)), [line[m.end():].strip()]
|
| 160 |
+
elif cur is not None:
|
| 161 |
+
buf.append(line)
|
| 162 |
+
if cur is not None:
|
| 163 |
+
out[cur] = "\n".join(buf).strip()
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
base = parse(origin); mod = parse(patch)
|
| 167 |
+
for k, v in mod.items():
|
| 168 |
+
if "DELETE THIS STEP" in v:
|
| 169 |
+
base.pop(k, None)
|
| 170 |
+
else:
|
| 171 |
+
base[k] = v
|
| 172 |
+
return "\n\n".join(base[k] for k in sorted(base))
|
| 173 |
+
|
| 174 |
+
# -----------------------------------------------------------------------------
|
| 175 |
+
# Prompts ----------------------------------------------------------------------
|
| 176 |
+
# from prompts import get_webpage_to_reasonchain_instruction # keep original helper
|
| 177 |
+
|
| 178 |
+
# -----------------------------------------------------------------------------
|
| 179 |
+
# Main agent -------------------------------------------------------------------
|
| 180 |
+
class O1Searcher:
|
| 181 |
+
# STOP_TOKENS = [
|
| 182 |
+
# "<|im_end|>",
|
| 183 |
+
# "<|endoftext|>",
|
| 184 |
+
# "<|end_of_query|>",
|
| 185 |
+
# " <|end_of_query|>",
|
| 186 |
+
# "<|end_of_query|>\n",
|
| 187 |
+
# "<|end_of_query|>\n\n",
|
| 188 |
+
# " <|end_of_query|>\n",
|
| 189 |
+
# " <|end_of_query|>\n\n",
|
| 190 |
+
# ]
|
| 191 |
+
get_webpage_to_reasonchain_instruction = """**Task Instruction:**
|
| 192 |
+
|
| 193 |
+
You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
|
| 194 |
+
|
| 195 |
+
**Guidelines:**
|
| 196 |
+
|
| 197 |
+
1. **Analyze the Searched Web Pages:**
|
| 198 |
+
- Carefully review the content of each searched web page.
|
| 199 |
+
- Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
|
| 200 |
+
|
| 201 |
+
2. **Extract Relevant Information:**
|
| 202 |
+
- Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
|
| 203 |
+
- Ensure that the extracted information is accurate and relevant.
|
| 204 |
+
|
| 205 |
+
3. **Output Format:**
|
| 206 |
+
- **If the web pages provide helpful information for current search query:** Present the information beginning with **Final Information** as shown below.
|
| 207 |
+
**Final Information**
|
| 208 |
+
|
| 209 |
+
[Helpful information]
|
| 210 |
+
|
| 211 |
+
- **If the web pages do not provide any helpful information for current search query:** Output the following text.
|
| 212 |
+
|
| 213 |
+
**Final Information**
|
| 214 |
+
|
| 215 |
+
No helpful information found.
|
| 216 |
+
|
| 217 |
+
**Inputs:**
|
| 218 |
+
- **Previous Reasoning Steps:**
|
| 219 |
+
{prev_reasoning}
|
| 220 |
+
|
| 221 |
+
- **Current Search Query:**
|
| 222 |
+
{search_query}
|
| 223 |
+
|
| 224 |
+
- **Searched Web Pages:**
|
| 225 |
+
{document}
|
| 226 |
+
|
| 227 |
+
Now you should analyze each web page and find helpful information based on the current search query {search_query} and previous reasoning steps.
|
| 228 |
+
Return the Helpful information in the <information></information> tags
|
| 229 |
+
"""
|
| 230 |
+
SUMMARY_PROMPT = (
|
| 231 |
+
"""## Task Description:\n"
|
| 232 |
+
"Given the search query and the content of the searched webpage, "
|
| 233 |
+
"extract information relevant to the query and write one summary paragraph."\n\n"
|
| 234 |
+
"## Guidelines:\n"
|
| 235 |
+
"(1) The extracted content should be relevant to the query.\n"
|
| 236 |
+
"(2) The form of the extracted content **must be a summary paragraph** rather than a direct answer.\n"
|
| 237 |
+
"(3) If the webpage content is unrelated to the query, output \"None\".\n\n"
|
| 238 |
+
"## Output Format:\n"
|
| 239 |
+
"[Exacted Content]: <summary‑paragraph‑or‑None>\n\n"
|
| 240 |
+
"## Inputs:\n"
|
| 241 |
+
"[Search Query]\n{search_query}\n\n"
|
| 242 |
+
"[Webpage Content]\n{document}\n\n"
|
| 243 |
+
"## Output:\n"""
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
sys_prompt_multiqa = (
|
| 247 |
+
"You are a reasoning assistant with the ability to perform web searches to help "
|
| 248 |
+
"you answer the user's question accurately. You have special tools:\n\n"
|
| 249 |
+
"- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\n"
|
| 250 |
+
"Then, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\n"
|
| 251 |
+
f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to 16.\n\n"
|
| 252 |
+
"Once you have all the information you need, continue your reasoning.\n\n"
|
| 253 |
+
"Example:\n"
|
| 254 |
+
"Question: \"Alice David is the voice of Lara Croft in a video game developed by which company?\"\n"
|
| 255 |
+
"Assistant thinking steps:\n"
|
| 256 |
+
"- I need to find out who voices Lara Croft in the video game.\n"
|
| 257 |
+
"- Then, I need to determine which company developed that video game.\n\n"
|
| 258 |
+
"Assistant:\n"
|
| 259 |
+
"<|begin_search_query|>Alice David Lara Croft voice<|end_search_query|>\n\n"
|
| 260 |
+
"(System returns processed information from relevant web pages)\n\n"
|
| 261 |
+
"Assistant thinks: The search results indicate that Alice David is the voice of Lara Croft in a specific video game. Now, I need to find out which company developed that game.\n\n"
|
| 262 |
+
"Assistant:\n"
|
| 263 |
+
"<|begin_search_query|>video game developed by Alice David Lara Croft<|end_search_query|>\n\n"
|
| 264 |
+
"(System returns processed information from relevant web pages)\n\n"
|
| 265 |
+
"Assistant continues reasoning with the new information...\n\n"
|
| 266 |
+
"Remember:\n"
|
| 267 |
+
"- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n"
|
| 268 |
+
"- When done searching, continue your reasoning.\n\n"
|
| 269 |
+
"Always give you final answer between <answer></answer> tags"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def __init__(self, cfg: O1Cfg, thinker_url: str):
|
| 273 |
+
if not cfg.serper_api_key:
|
| 274 |
+
raise ValueError("SERPER_API_KEY required")
|
| 275 |
+
self.cfg, self.model_url = cfg, thinker_url.rstrip("/")
|
| 276 |
+
self.search_cache: Dict[str, List[Dict]] = {}
|
| 277 |
+
self.page_cache: Dict[Tuple[str, str], str] = {}
|
| 278 |
+
self.openai = _oa()
|
| 279 |
+
|
| 280 |
+
# --- low‑level generation call ------------------------------------------
|
| 281 |
+
@retry(4,1)
|
| 282 |
+
def _generate(self, prompt: str) -> str:
|
| 283 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 284 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 285 |
+
resp = requests.post(
|
| 286 |
+
f"{self.model_url}/generate",
|
| 287 |
+
json={
|
| 288 |
+
"text": prompt,
|
| 289 |
+
"sampling_params": {
|
| 290 |
+
"temperature": self.cfg.temperature,
|
| 291 |
+
"top_p": self.cfg.top_p,
|
| 292 |
+
"max_new_tokens": max_tokens_left,
|
| 293 |
+
"repetition_penalty": self.cfg.rep_pen,
|
| 294 |
+
"stop": STOP_STRINGS,
|
| 295 |
+
},
|
| 296 |
+
|
| 297 |
+
},
|
| 298 |
+
timeout=60,
|
| 299 |
+
).json()
|
| 300 |
+
# resp.raise_for_status()
|
| 301 |
+
generated = resp["text"] # what you have now
|
| 302 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 303 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 304 |
+
|
| 305 |
+
# ⇢ append the tag back only if it was removed
|
| 306 |
+
if reason == "stop" and matched in STOP_STRINGS:
|
| 307 |
+
if not "<|end_of_query|>" in generated:
|
| 308 |
+
generated += matched
|
| 309 |
+
if reason == "stop" and matched == 151645:
|
| 310 |
+
if not generated.endswith("<|im_end|>"):
|
| 311 |
+
generated += "<|im_end|>"
|
| 312 |
+
if reason == "stop" and matched == 151643:
|
| 313 |
+
if not generated.endswith("<|endoftext|>"):
|
| 314 |
+
generated += "<|endoftext|>"
|
| 315 |
+
|
| 316 |
+
return generated
|
| 317 |
+
|
| 318 |
+
# @retry(fallback="None")
|
| 319 |
+
def _summarise_openai(self, query: str, doc: str) -> str:
|
| 320 |
+
prompt = self.SUMMARY_PROMPT.format(search_query=query, document=doc)
|
| 321 |
+
resp = self.openai.chat.completions.create(
|
| 322 |
+
model=self.cfg.summariser_model,
|
| 323 |
+
messages=[{"role": "user", "content": prompt}],
|
| 324 |
+
max_tokens=1024,
|
| 325 |
+
temperature=0.0,
|
| 326 |
+
)
|
| 327 |
+
# print(resp)
|
| 328 |
+
text = resp.choices[0].message.content
|
| 329 |
+
return text.split("[Exacted Content]:")[-1].strip()
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _generate_summary(self, prompt: str) -> str:
|
| 334 |
+
summary_url = "http://0.0.0.0:1241"
|
| 335 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 336 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 337 |
+
resp = requests.post(
|
| 338 |
+
f"{summary_url}/generate",
|
| 339 |
+
json={
|
| 340 |
+
"text": prompt,
|
| 341 |
+
"sampling_params": {
|
| 342 |
+
"temperature": self.cfg.temperature,
|
| 343 |
+
"max_new_tokens": 8192,#max_tokens_left,
|
| 344 |
+
"stop": STOP_STRINGS,
|
| 345 |
+
},
|
| 346 |
+
|
| 347 |
+
},
|
| 348 |
+
timeout=60,
|
| 349 |
+
).json()
|
| 350 |
+
generated = resp["text"] # what you have now
|
| 351 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 352 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 353 |
+
# ##print("-"*100)
|
| 354 |
+
# ##print(resp)
|
| 355 |
+
# ##print(matched)
|
| 356 |
+
# ##print("-"*100)
|
| 357 |
+
# ⇢ append the tag back only if it was removed
|
| 358 |
+
if reason == "stop" and matched in STOP_STRINGS:
|
| 359 |
+
if not "<|end_of_query|>" in generated:
|
| 360 |
+
generated += matched + EOS_TOKEN
|
| 361 |
+
if reason == "stop" and matched == 151645:
|
| 362 |
+
if not generated.endswith("<|im_end|>"):
|
| 363 |
+
generated += "<|im_end|>"
|
| 364 |
+
if reason == "stop" and matched == 151643:
|
| 365 |
+
if not generated.endswith("<|endoftext|>"):
|
| 366 |
+
generated += "<|endoftext|>"
|
| 367 |
+
return generated
|
| 368 |
+
# --- public entry -------------------------------------------------------
|
| 369 |
+
def run(self, question: str):
|
| 370 |
+
prompt = (
|
| 371 |
+
f"<|im_start|>system\n{self.sys_prompt_multiqa}<|im_end|>\n"
|
| 372 |
+
f"<|im_start|>user\n{question}<|im_end|>\n"
|
| 373 |
+
f"<|im_start|>assistant\n{THINK_OPEN}"
|
| 374 |
+
)
|
| 375 |
+
full_trace = prompt # <-- Track full trace
|
| 376 |
+
queries: List[str] = []
|
| 377 |
+
seen_queries: set[str] = set()
|
| 378 |
+
|
| 379 |
+
for i in range(self.cfg.max_turn):
|
| 380 |
+
chunk = self._generate(prompt)
|
| 381 |
+
prompt += chunk
|
| 382 |
+
|
| 383 |
+
if ANSWER_CLOSE in chunk:
|
| 384 |
+
break
|
| 385 |
+
|
| 386 |
+
##print(f"step-{i}")
|
| 387 |
+
##print(chunk)
|
| 388 |
+
|
| 389 |
+
query = self._extract_query(chunk)
|
| 390 |
+
##print(query)
|
| 391 |
+
if not query or len(queries) >= self.cfg.max_search:
|
| 392 |
+
break
|
| 393 |
+
if query in seen_queries:
|
| 394 |
+
continue
|
| 395 |
+
queries.append(query)
|
| 396 |
+
seen_queries.add(query)
|
| 397 |
+
|
| 398 |
+
doc = self._retrieve_doc(query)
|
| 399 |
+
prev_reasoning = self._extract_reasoning(prompt)
|
| 400 |
+
# summary = "\n<|im_start|>user" + self._summarise_openai(query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
|
| 401 |
+
summary = "\n<|im_start|>user" + self._summarise(prev_reasoning, query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
|
| 402 |
+
##print("summary")
|
| 403 |
+
# print(summary)
|
| 404 |
+
prompt += summary # <-- Log summary to trace
|
| 405 |
+
|
| 406 |
+
# new_reasoning = replace_recent_steps(prev_reasoning, summary)
|
| 407 |
+
|
| 408 |
+
# if prev_reasoning:
|
| 409 |
+
# prompt = prompt.rsplit(prev_reasoning, 1)[0] + new_reasoning + THINK_CLOSE + THINK_OPEN
|
| 410 |
+
# else:
|
| 411 |
+
# prompt += new_reasoning + THINK_CLOSE + THINK_OPEN
|
| 412 |
+
|
| 413 |
+
# full_trace += + THINK_CLOSE + THINK_OPEN + "\n" # <-- Log reasoning to trace
|
| 414 |
+
else:
|
| 415 |
+
final = f"{ANSWER_OPEN}I don't know.{ANSWER_CLOSE}"
|
| 416 |
+
prompt += final
|
| 417 |
+
# full_trace += final
|
| 418 |
+
|
| 419 |
+
return prompt, queries
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# ---------------------------------------------------------------------
|
| 423 |
+
# helpers --------------------------------------------------------------
|
| 424 |
+
def _extract_query(self, txt: str) -> Optional[str]:
|
| 425 |
+
if BEGIN_SEARCH_QUERY not in txt or END_SEARCH_QUERY not in txt:
|
| 426 |
+
return None
|
| 427 |
+
frag = txt.split(BEGIN_SEARCH_QUERY)[-1].split(END_SEARCH_QUERY)[0]
|
| 428 |
+
# strip quotes / ellipsis / tabs
|
| 429 |
+
return re.sub(r"[\"'…\t]", " ", frag.split("<|")[0]).strip()
|
| 430 |
+
|
| 431 |
+
def _retrieve_doc(self, query: str) -> str:
|
| 432 |
+
if query not in self.search_cache:
|
| 433 |
+
self.search_cache[query] = serper_search(query, self.cfg.top_k, self.cfg.serper_api_key,
|
| 434 |
+
gl=self.cfg.gl, hl=self.cfg.hl)
|
| 435 |
+
for hit in self.search_cache[query]:
|
| 436 |
+
# ##print("hit")
|
| 437 |
+
# ##print(hit)
|
| 438 |
+
url, sn = hit.get("link", ""), hit.get("snippet", "")
|
| 439 |
+
if not url:
|
| 440 |
+
continue
|
| 441 |
+
key = (url, sn)
|
| 442 |
+
if key not in self.page_cache:
|
| 443 |
+
self.page_cache[key] = fetch_page(url, self.cfg, sn)
|
| 444 |
+
if self.page_cache[key]:
|
| 445 |
+
return self.page_cache[key]
|
| 446 |
+
return ""
|
| 447 |
+
|
| 448 |
+
def _summarise(self, prev: str, query: str, doc: str) -> str:
|
| 449 |
+
rid_prompt = self.get_webpage_to_reasonchain_instruction.format(prev_reasoning = prev, search_query = query, document = doc)
|
| 450 |
+
chat = f"<|im_start|>user\\n{rid_prompt}\\n<|im_end|>\\n<|im_start|>assistant\\n"
|
| 451 |
+
resp = self._generate_summary(chat)
|
| 452 |
+
# ##print("summarization out \n", resp)
|
| 453 |
+
return BEGIN_DOCUMENT_QUERY + self._extract_summary(resp) + END_DOCUMENT_QUERY
|
| 454 |
+
# ##print("summary")
|
| 455 |
+
# ##print(resp)
|
| 456 |
+
# match = re.search(r"Final Information\*\*\s*\n(.+?)<\|im_end\|>", resp)
|
| 457 |
+
# if match:
|
| 458 |
+
# final_info = match.group(1).strip()
|
| 459 |
+
# ##print(final_info)
|
| 460 |
+
# return final_info
|
| 461 |
+
|
| 462 |
+
def _extract_summary(self, prompt: str) -> str:
|
| 463 |
+
if "<information>" in prompt:
|
| 464 |
+
summary = prompt.split("<information>")[-1].split("</information>")[0] if THINK_OPEN in prompt else ""
|
| 465 |
+
return summary
|
| 466 |
+
else:
|
| 467 |
+
match = re.search(r"\*\*Final Information\*\*\s*\n(.+?)<\|im_end\|>", prompt)
|
| 468 |
+
if match:
|
| 469 |
+
final_info = match.group(1).strip()
|
| 470 |
+
return final_info
|
| 471 |
+
return prompt
|
| 472 |
+
|
| 473 |
+
def _extract_reasoning(self, prompt: str) -> str:
|
| 474 |
+
return prompt.split(THINK_OPEN)[-1].split(THINK_CLOSE)[0] if THINK_OPEN in prompt else ""
|
| 475 |
+
|
| 476 |
+
# -----------------------------------------------------------------------------
|
| 477 |
+
# CLI -------------------------------------------------------------------------
|
| 478 |
+
# if __name__ == "__main__":
|
| 479 |
+
# # import argparse, json
|
| 480 |
+
# # parser = argparse.ArgumentParser()
|
| 481 |
+
# # parser.add_argument("question"); parser.add_argument("--dataset", required=True, choices=sorted(ALLOWED_DATASETS)); parser.add_argument("--model-url", required
|
inference/oss.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# recall_oss_harmony.py
|
| 2 |
+
# ReCall loop using the Harmony chat format renderer, calling vLLM's
|
| 3 |
+
# OpenAI-compatible **/v1/completions** endpoint with a rendered prompt.
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
from functools import wraps
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
from openai_harmony import (
|
| 13 |
+
load_harmony_encoding,
|
| 14 |
+
HarmonyEncodingName,
|
| 15 |
+
Role,
|
| 16 |
+
Message,
|
| 17 |
+
Conversation,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def retry(max: int = 5, sleep: float = 1.0, fallback=None):
|
| 21 |
+
def decorator(fn):
|
| 22 |
+
@wraps(fn)
|
| 23 |
+
def wrapper(*args, **kwargs):
|
| 24 |
+
for i in range(max):
|
| 25 |
+
try:
|
| 26 |
+
return fn(*args, **kwargs)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"[retry] attempt {i+1}/{max} failed: {e}")
|
| 29 |
+
if i + 1 == max:
|
| 30 |
+
print(f"[retry] giving up – returning {fallback!r}")
|
| 31 |
+
return fallback
|
| 32 |
+
if sleep:
|
| 33 |
+
time.sleep(sleep)
|
| 34 |
+
return wrapper
|
| 35 |
+
return decorator
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ReCallOSSHarmony:
|
| 39 |
+
TOOL_CALL_RE = re.compile(r"<tool_call>((?:(?!</tool_call>).)*)</tool_call>", re.DOTALL)
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
executor_url: str,
|
| 44 |
+
base_url: str,
|
| 45 |
+
model_name: str,
|
| 46 |
+
api_key: Optional[str] = None,
|
| 47 |
+
request_timeout: int = 120,
|
| 48 |
+
):
|
| 49 |
+
self.executor_url = executor_url.rstrip("/")
|
| 50 |
+
self.base_url = base_url.rstrip("/")
|
| 51 |
+
self.model_name = model_name
|
| 52 |
+
self.api_key = api_key
|
| 53 |
+
self.timeout = request_timeout
|
| 54 |
+
self.enc = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
| 55 |
+
|
| 56 |
+
# ---------------- HTTP ----------------
|
| 57 |
+
def _headers(self):
|
| 58 |
+
h = {"Content-Type": "application/json"}
|
| 59 |
+
if self.api_key:
|
| 60 |
+
h["Authorization"] = f"Bearer {self.api_key}"
|
| 61 |
+
return h
|
| 62 |
+
|
| 63 |
+
@retry(max=5, sleep=1, fallback={"choices": [{"text": ""}]})
|
| 64 |
+
def _complete(self, prompt: str, temperature: float, max_tokens: int, stop: Optional[List[str]] = None):
|
| 65 |
+
payload = {
|
| 66 |
+
"model": self.model_name,
|
| 67 |
+
"prompt": prompt,
|
| 68 |
+
"max_tokens": max_tokens,
|
| 69 |
+
"temperature": temperature,
|
| 70 |
+
}
|
| 71 |
+
if stop:
|
| 72 |
+
payload["stop"] = stop
|
| 73 |
+
resp = requests.post(
|
| 74 |
+
f"{self.base_url}/completions",
|
| 75 |
+
headers=self._headers(),
|
| 76 |
+
json=payload,
|
| 77 |
+
timeout=self.timeout,
|
| 78 |
+
)
|
| 79 |
+
if resp.status_code != 200:
|
| 80 |
+
raise RuntimeError(f"completions HTTP {resp.status_code}: {resp.text}")
|
| 81 |
+
return resp.json()
|
| 82 |
+
|
| 83 |
+
# ------------- tool plumbing ----------
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _validate_tool_tags(s: str) -> bool:
|
| 86 |
+
starts = [m.start() for m in re.finditer(r"<tool_call>", s)]
|
| 87 |
+
ends = [m.start() for m in re.finditer(r"</tool_call>", s)]
|
| 88 |
+
if len(starts) != len(ends):
|
| 89 |
+
return False
|
| 90 |
+
return all(st < en for st, en in zip(starts, ends))
|
| 91 |
+
|
| 92 |
+
def extract_tool_calls(self, text: str) -> List[str]:
|
| 93 |
+
if not self._validate_tool_tags(text):
|
| 94 |
+
return []
|
| 95 |
+
return [m.group(1).strip() for m in self.TOOL_CALL_RE.finditer(text)]
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _format_tool_call(call_json_str: str) -> str:
|
| 99 |
+
try:
|
| 100 |
+
spec = json.loads(call_json_str)
|
| 101 |
+
fname = spec["name"]
|
| 102 |
+
args = spec.get("arguments", {}) or {}
|
| 103 |
+
args_str = ", ".join(f"{k}={repr(v)}" for k, v in args.items())
|
| 104 |
+
return f"{fname}({args_str})"
|
| 105 |
+
except Exception as e:
|
| 106 |
+
return f"error: parse tool call failed: {e}"
|
| 107 |
+
|
| 108 |
+
def _exec_one_call(self, env: str, call_json_str: str) -> str:
|
| 109 |
+
call_src = self._format_tool_call(call_json_str)
|
| 110 |
+
if call_src.startswith("error:"):
|
| 111 |
+
return call_src
|
| 112 |
+
try:
|
| 113 |
+
response = requests.post(
|
| 114 |
+
f"{self.executor_url}/execute",
|
| 115 |
+
json={"env": env, "call": call_src},
|
| 116 |
+
timeout=self.timeout,
|
| 117 |
+
)
|
| 118 |
+
if response.status_code != 200:
|
| 119 |
+
return f"error: executor HTTP {response.status_code}"
|
| 120 |
+
payload = response.json()
|
| 121 |
+
out = []
|
| 122 |
+
if payload.get("result"):
|
| 123 |
+
out.append(f"result:\n{payload['result']}")
|
| 124 |
+
if payload.get("output"):
|
| 125 |
+
out.append(f"output:\n{payload['output']}")
|
| 126 |
+
if payload.get("error"):
|
| 127 |
+
out.append(f"error:\n{payload['error']}")
|
| 128 |
+
return "\n".join(out).strip() or "ok"
|
| 129 |
+
except requests.exceptions.Timeout:
|
| 130 |
+
return "error: execution timed out"
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return f"error: executor exception: {e}"
|
| 133 |
+
|
| 134 |
+
def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]:
|
| 135 |
+
return [self._exec_one_call(env, c) for c in tool_calls]
|
| 136 |
+
|
| 137 |
+
# ------------- harmony helpers --------
|
| 138 |
+
def _render(self, messages: List[Message]) -> str:
|
| 139 |
+
convo = Conversation.from_messages(messages)
|
| 140 |
+
# Render for a completion (assistant is the next speaker)
|
| 141 |
+
return self.enc.render_conversation_for_completion(convo, Role.ASSISTANT)
|
| 142 |
+
|
| 143 |
+
# ------------- main run loop ----------
|
| 144 |
+
@retry(max=5, sleep=1, fallback=("", []))
|
| 145 |
+
def run(
|
| 146 |
+
self,
|
| 147 |
+
env: str,
|
| 148 |
+
func_schemas,
|
| 149 |
+
question: str,
|
| 150 |
+
system_prompt: str = "",
|
| 151 |
+
temperature: float = 0.2,
|
| 152 |
+
max_tokens: int = 2048,
|
| 153 |
+
max_turns: int = 16,
|
| 154 |
+
stop: Optional[List[str]] = None,
|
| 155 |
+
):
|
| 156 |
+
# Build the initial harmony conversation.
|
| 157 |
+
# Paste your full system prompt into `system_prompt` before calling.
|
| 158 |
+
# If you want to include func_schemas in your system content, do:
|
| 159 |
+
try:
|
| 160 |
+
sys_msg = system_prompt.format(func_schemas=json.dumps(func_schemas, ensure_ascii=False))
|
| 161 |
+
except Exception:
|
| 162 |
+
sys_msg = system_prompt
|
| 163 |
+
|
| 164 |
+
messages: List[Message] = [
|
| 165 |
+
Message.from_role_and_content(Role.SYSTEM, sys_msg),
|
| 166 |
+
Message.from_role_and_content(Role.USER, question),
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
transcript_chunks: List[str] = []
|
| 170 |
+
all_tool_calls: List[str] = []
|
| 171 |
+
|
| 172 |
+
for _ in range(max_turns):
|
| 173 |
+
prompt = self._render(messages)
|
| 174 |
+
resp = self._complete(prompt=prompt, temperature=temperature, max_tokens=max_tokens, stop=stop)
|
| 175 |
+
assistant_text = resp["choices"][0]["text"]
|
| 176 |
+
transcript_chunks.append(assistant_text)
|
| 177 |
+
messages.append(Message.from_role_and_content(Role.ASSISTANT, assistant_text))
|
| 178 |
+
|
| 179 |
+
if "<answer>" in assistant_text:
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
tool_calls = self.extract_tool_calls(assistant_text)
|
| 183 |
+
all_tool_calls.extend(tool_calls)
|
| 184 |
+
if not tool_calls:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
results = self.execute_tool_calls(env, tool_calls)
|
| 188 |
+
tool_resp_block = "".join(
|
| 189 |
+
f"<tool_response>{tc}\n{res}\n</tool_response>\n"
|
| 190 |
+
for tc, res in zip(tool_calls, results)
|
| 191 |
+
)
|
| 192 |
+
messages.append(Message.from_role_and_content(Role.USER, tool_resp_block))
|
| 193 |
+
|
| 194 |
+
transcript = "".join(transcript_chunks)
|
| 195 |
+
return transcript, all_tool_calls
|
inference/r1_searcher.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
#!/usr/bin/env python3
|
| 8 |
+
# r1_searcher_inference.py
|
| 9 |
+
"""
|
| 10 |
+
Faithful re‑implementation of the **R1‑Searcher** loop described in
|
| 11 |
+
‘Reasoning with Retrieval 1’ (2024).
|
| 12 |
+
|
| 13 |
+
Key properties ─────────────────────────────────────────────────────────
|
| 14 |
+
• The policy LLM (“thinker”) reasons inside <think> … </think>.
|
| 15 |
+
• When it needs external knowledge it emits exactly **one** single‑triple
|
| 16 |
+
query inside <|begin_of_query|> … <|end_of_query|>.
|
| 17 |
+
• The wrapper searches *English Wikipedia only* (via Serper.dev).
|
| 18 |
+
• It summarises the *first* retrieved article that contains the query terms
|
| 19 |
+
and injects that summary between
|
| 20 |
+
<|begin_of_documents|> … <|end_of_documents|>
|
| 21 |
+
before handing control back to the thinker.
|
| 22 |
+
• The loop stops when the thinker outputs </answer> or when the configurable
|
| 23 |
+
round‑limit is reached.
|
| 24 |
+
|
| 25 |
+
The class is API‑compatible with the user’s existing `ReCall` wrapper so the
|
| 26 |
+
same benchmarking harness can swap between them with a single flag.
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import os
|
| 31 |
+
import time
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from typing import List, Optional
|
| 34 |
+
|
| 35 |
+
import requests
|
| 36 |
+
from bs4 import BeautifulSoup
|
| 37 |
+
import trafilatura
|
| 38 |
+
import wikipedia
|
| 39 |
+
from urllib.parse import unquote
|
| 40 |
+
from openai import OpenAI
|
| 41 |
+
|
| 42 |
+
client = OpenAI(api_key = "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")
|
| 43 |
+
|
| 44 |
+
TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
|
| 45 |
+
|
| 46 |
+
# ───────────────────────── tokenizer ────────────────────────────────────────
|
| 47 |
+
try:
|
| 48 |
+
from transformers import AutoTokenizer
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
import sys
|
| 52 |
+
sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
|
| 53 |
+
|
| 54 |
+
# ───────────────────────── BASIC UTILS ──────────────────────────────
|
| 55 |
+
def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
|
| 56 |
+
"""Tiny retry decorator with fixed back‑off."""
|
| 57 |
+
|
| 58 |
+
def decorator(func):
|
| 59 |
+
def wrapper(*args, **kwargs):
|
| 60 |
+
for i in range(max_attempts):
|
| 61 |
+
try:
|
| 62 |
+
return func(*args, **kwargs)
|
| 63 |
+
except Exception as exc:
|
| 64 |
+
if i == max_attempts - 1:
|
| 65 |
+
#print(f"[retry] {func.__name__} failed – giving up: {exc}")
|
| 66 |
+
return fallback
|
| 67 |
+
#print(f"[retry] {func.__name__}: attempt {i+1}/{max_attempts} → {exc}")
|
| 68 |
+
time.sleep(sleep)
|
| 69 |
+
|
| 70 |
+
return wrapper
|
| 71 |
+
|
| 72 |
+
return decorator
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ────────────────────────── CONFIG ──────────────────────────────────
|
| 76 |
+
@dataclass
|
| 77 |
+
class R1SearchConfig:
|
| 78 |
+
# Serper.dev parameters
|
| 79 |
+
serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
|
| 80 |
+
serper_url: str = "https://google.serper.dev/search"
|
| 81 |
+
gl: str = "us"
|
| 82 |
+
hl: str = "en"
|
| 83 |
+
|
| 84 |
+
# Policy model endpoint (vLLM)
|
| 85 |
+
thinker_temperature: float = 0.0
|
| 86 |
+
thinker_max_tokens: int = 40960
|
| 87 |
+
|
| 88 |
+
# Loop / misc
|
| 89 |
+
max_rounds: int = 16
|
| 90 |
+
summariser_model: str = "gpt-4o-mini"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ───────────────────────── R1‑Searcher ──────────────────────────────
|
| 94 |
+
class R1Searcher:
|
| 95 |
+
SYSTEM_PROMPT = """
|
| 96 |
+
You are a helpful assistant.
|
| 97 |
+
Given a question, you should answer it by first thinking about the reasoning
|
| 98 |
+
process in the mind and then providing the final answer.
|
| 99 |
+
|
| 100 |
+
The output format of reasoning process and final answer are enclosed within
|
| 101 |
+
<think> </think> and <answer> </answer> tags, respectively, i.e.,
|
| 102 |
+
"<think> reasoning process here </think>
|
| 103 |
+
|
| 104 |
+
<answer> final answer here </answer>".
|
| 105 |
+
|
| 106 |
+
During the thinking process, **you can perform searching for uncertain
|
| 107 |
+
knowledge** if necessary with the format of
|
| 108 |
+
"<|begin_of_query|> keyword_1 keyword_2 ... <|end_of_query|>".
|
| 109 |
+
**A query must involve only a single triple**.
|
| 110 |
+
|
| 111 |
+
Then, the search system will provide you with the retrieval information with
|
| 112 |
+
the format of "<|begin_of_documents|> ...search results... <|end_of_documents|>".
|
| 113 |
+
""".strip()
|
| 114 |
+
|
| 115 |
+
SUMMARY_PROMPT = (
|
| 116 |
+
"""## Task Description:\n"
|
| 117 |
+
"Given the search query and the content of the searched webpage, "
|
| 118 |
+
"extract information relevant to the query and write one summary paragraph."\n\n"
|
| 119 |
+
"## Guidelines:\n"
|
| 120 |
+
"(1) The extracted content should be relevant to the query.\n"
|
| 121 |
+
"(2) The form of the extracted content **must be a summary paragraph** rather than a direct answer.\n"
|
| 122 |
+
"(3) If the webpage content is unrelated to the query, output \"None\".\n\n"
|
| 123 |
+
"## Output Format:\n"
|
| 124 |
+
"[Exacted Content]: <summary‑paragraph‑or‑None>\n\n"
|
| 125 |
+
"## Inputs:\n"
|
| 126 |
+
"[Search Query]\n{search_query}\n\n"
|
| 127 |
+
"[Webpage Content]\n{document}\n\n"
|
| 128 |
+
"## Output:\n"""
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Tag constants
|
| 132 |
+
EOS_TOKEN = "<|im_end|>"
|
| 133 |
+
THINK_OPEN = "<think>"
|
| 134 |
+
ANSWER_CLOSE = "</answer>"
|
| 135 |
+
Q_OPEN, Q_CLOSE = "<|begin_of_query|>", "<|end_of_query|>"
|
| 136 |
+
DOC_OPEN, DOC_CLOSE = "<|begin_of_documents|>", "<|end_of_documents|>"
|
| 137 |
+
|
| 138 |
+
# Stop strings – must match *exact* token sequences vLLM will see
|
| 139 |
+
STOP_TOKENS = [
|
| 140 |
+
"<|im_end|>",
|
| 141 |
+
"<|endoftext|>",
|
| 142 |
+
"<|end_of_query|>",
|
| 143 |
+
" <|end_of_query|>",
|
| 144 |
+
"<|end_of_query|>\n",
|
| 145 |
+
"<|end_of_query|>\n\n",
|
| 146 |
+
" <|end_of_query|>\n",
|
| 147 |
+
" <|end_of_query|>\n\n",
|
| 148 |
+
]
|
| 149 |
+
# STOP_TOKENS = []
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def __init__(self, cfg: R1SearchConfig, model_url):
|
| 154 |
+
self.cfg = cfg
|
| 155 |
+
self.openai = client
|
| 156 |
+
self._wiki = wikipedia
|
| 157 |
+
self._wiki.set_lang("en")
|
| 158 |
+
|
| 159 |
+
# Patch wikipedia lib to use a session with proper UA
|
| 160 |
+
sess = requests.Session()
|
| 161 |
+
sess.headers.update({"User-Agent": "r1-searcher-bot/1.0"})
|
| 162 |
+
self._wiki._http = sess
|
| 163 |
+
self.model_url=model_url
|
| 164 |
+
|
| 165 |
+
# ── public entry ─────────────────────────────────────────────────
|
| 166 |
+
def run(self, question: str) -> tuple[str, List[str]]:
|
| 167 |
+
prompt = (
|
| 168 |
+
f"<|im_start|>system\n{self.SYSTEM_PROMPT}<|im_end|>\n"
|
| 169 |
+
f"<|im_start|>user\n{question}<|im_end|>\n"
|
| 170 |
+
f"<|im_start|>assistant\n{self.THINK_OPEN}"
|
| 171 |
+
)
|
| 172 |
+
queries: List[str] = []
|
| 173 |
+
|
| 174 |
+
for _ in range(self.cfg.max_rounds):
|
| 175 |
+
model_out = self._call_thinker(prompt)
|
| 176 |
+
prompt += model_out
|
| 177 |
+
|
| 178 |
+
if self.ANSWER_CLOSE in model_out:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
query = self._extract_query(model_out)
|
| 182 |
+
if not query:
|
| 183 |
+
break
|
| 184 |
+
queries.append(query)
|
| 185 |
+
|
| 186 |
+
doc_block = self._retrieve_block(query)
|
| 187 |
+
prompt += "<|im_start|>user\n" + doc_block + self.EOS_TOKEN + "<|im_start|>assistant\n" + self.THINK_OPEN # continue loop
|
| 188 |
+
|
| 189 |
+
else: # exceeded round cap
|
| 190 |
+
prompt += "<answer>I don't know.</answer><|im_end|>"
|
| 191 |
+
|
| 192 |
+
return prompt, queries
|
| 193 |
+
|
| 194 |
+
# ── thinker call ────────────────────────────────────────────────
|
| 195 |
+
# @retry()
|
| 196 |
+
def _call_thinker(self, prompt: str) -> str:
|
| 197 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 198 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 199 |
+
resp = requests.post(
|
| 200 |
+
f"{self.model_url}/generate",
|
| 201 |
+
json={
|
| 202 |
+
"text": prompt,
|
| 203 |
+
"sampling_params": {
|
| 204 |
+
"temperature": self.cfg.thinker_temperature,
|
| 205 |
+
"max_new_tokens": 2048,#max_tokens_left,
|
| 206 |
+
"stop": self.STOP_TOKENS,
|
| 207 |
+
"repetition_penalty": 1.05,
|
| 208 |
+
},
|
| 209 |
+
},
|
| 210 |
+
timeout=60,
|
| 211 |
+
).json()
|
| 212 |
+
generated = resp["text"] # what you have now
|
| 213 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 214 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 215 |
+
#print("-"*100)
|
| 216 |
+
#print(resp)
|
| 217 |
+
#print(matched)
|
| 218 |
+
#print("-"*100)
|
| 219 |
+
# ⇢ append the tag back only if it was removed
|
| 220 |
+
if reason == "stop" and matched in self.STOP_TOKENS:
|
| 221 |
+
if not "<|end_of_query|>" in generated:
|
| 222 |
+
generated += matched + self.EOS_TOKEN
|
| 223 |
+
if reason == "stop" and matched == 151645:
|
| 224 |
+
if not generated.endswith("<|im_end|>"):
|
| 225 |
+
generated += "<|im_end|>"
|
| 226 |
+
if reason == "stop" and matched == 151643:
|
| 227 |
+
if not generated.endswith("<|endoftext|>"):
|
| 228 |
+
generated += "<|endoftext|>"
|
| 229 |
+
return generated
|
| 230 |
+
|
| 231 |
+
# ── query helpers ───────────────────────────────────────────────
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _extract_query(text: str) -> Optional[str]:
|
| 234 |
+
if R1Searcher.Q_OPEN not in text or R1Searcher.Q_CLOSE not in text:
|
| 235 |
+
return None
|
| 236 |
+
fragment = text.split(R1Searcher.Q_OPEN)[-1].split(R1Searcher.Q_CLOSE)[0]
|
| 237 |
+
#print("*"*10)
|
| 238 |
+
#print(fragment)
|
| 239 |
+
fragment = fragment.split("<|")[0] #handle end_of_query slipping
|
| 240 |
+
return (
|
| 241 |
+
fragment.replace("\t", " ")
|
| 242 |
+
.replace("\"", "")
|
| 243 |
+
.replace("'", "")
|
| 244 |
+
.replace("…", "")
|
| 245 |
+
.strip()
|
| 246 |
+
) or None
|
| 247 |
+
|
| 248 |
+
# ── retrieval & summary ───────────────────────────────────────
|
| 249 |
+
def _retrieve_block(self, query: str) -> str:
|
| 250 |
+
wiki_links = self._serper_wiki_links(query)
|
| 251 |
+
|
| 252 |
+
for url in wiki_links[:3]:
|
| 253 |
+
text = self._get_wiki_text(url)
|
| 254 |
+
if not text:
|
| 255 |
+
continue
|
| 256 |
+
summary = self._summarise(query, text[:35000])
|
| 257 |
+
if summary.lower() != "none":
|
| 258 |
+
return f"{self.DOC_OPEN}\n{summary}\n{self.DOC_CLOSE}\n\n"
|
| 259 |
+
|
| 260 |
+
return f"{self.DOC_OPEN}\nNone\n{self.DOC_CLOSE}\n\n"
|
| 261 |
+
|
| 262 |
+
# --- Serper ------------------------------------------------------
|
| 263 |
+
@retry()
|
| 264 |
+
def _serper_wiki_links(self, q: str) -> List[str]:
|
| 265 |
+
headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"}
|
| 266 |
+
payload = {"q": f"{q} site:en.wikipedia.org", "num": 10, "gl": self.cfg.gl, "hl": self.cfg.hl}
|
| 267 |
+
r = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20)
|
| 268 |
+
r.raise_for_status()
|
| 269 |
+
links = [
|
| 270 |
+
item.get("link")
|
| 271 |
+
for item in r.json().get("organic", [])
|
| 272 |
+
if item.get("link", "").startswith("https://en.wikipedia.org")
|
| 273 |
+
]
|
| 274 |
+
return links
|
| 275 |
+
|
| 276 |
+
def extract_main_text(self, html: str) -> str:
|
| 277 |
+
txt = trafilatura.extract(html, output_format="txt") or ""
|
| 278 |
+
if len(txt) >= 500:
|
| 279 |
+
return txt
|
| 280 |
+
from readability import Document
|
| 281 |
+
soup = BeautifulSoup(Document(html).summary(), "lxml")
|
| 282 |
+
txt = soup.get_text(" ", strip=True)
|
| 283 |
+
if len(txt) >= 400:
|
| 284 |
+
return txt
|
| 285 |
+
for tag in soup(["script", "style", "noscript"]):
|
| 286 |
+
tag.decompose()
|
| 287 |
+
return re.sub(r"\s+", " ", soup.get_text(" ").strip())
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# --- fetch article ----------------------------------------------
|
| 291 |
+
def _get_wiki_text(self, url: str) -> str | None:
|
| 292 |
+
try:
|
| 293 |
+
# 1. Download
|
| 294 |
+
r = requests.get(url, timeout=10)
|
| 295 |
+
r.raise_for_status()
|
| 296 |
+
|
| 297 |
+
# 2. Extract main text
|
| 298 |
+
txt = self.extract_main_text(r.text).strip()
|
| 299 |
+
if not txt:
|
| 300 |
+
return None
|
| 301 |
+
|
| 302 |
+
# 3. Prepend article slug if it isn’t already in the body
|
| 303 |
+
slug = unquote(url.rsplit("/", 1)[-1]).replace("_", " ")
|
| 304 |
+
if slug.lower() not in txt.lower():
|
| 305 |
+
txt = f"{slug}\n\n{txt}"
|
| 306 |
+
|
| 307 |
+
# 4. Return the final value
|
| 308 |
+
return "[Retrieved from Wikipedia] " + txt
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
#print("Failed to fetch Wikipedia page %s: %s", url, e)
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
# --- call OpenAI to summarise -----------------------------------
|
| 315 |
+
@retry(fallback="None")
|
| 316 |
+
def _summarise(self, query: str, doc: str) -> str:
|
| 317 |
+
prompt = self.SUMMARY_PROMPT.format(search_query=query, document=doc)
|
| 318 |
+
resp = self.openai.chat.completions.create(
|
| 319 |
+
model=self.cfg.summariser_model,
|
| 320 |
+
messages=[{"role": "user", "content": prompt}],
|
| 321 |
+
max_tokens=1024,
|
| 322 |
+
temperature=0.0,
|
| 323 |
+
)
|
| 324 |
+
text = resp.choices[0].message.content
|
| 325 |
+
return text.split("[Exacted Content]:")[-1].strip()
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ─────────────────────────── CLI ────────────────────────────────────
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
import argparse, json
|
| 331 |
+
|
| 332 |
+
ap = argparse.ArgumentParser()
|
| 333 |
+
ap.add_argument("question", type=str, help="Natural‑language question")
|
| 334 |
+
ap.add_argument("--serper-key", type=str, help="Override SERPER_API_KEY env")
|
| 335 |
+
args = ap.parse_args()
|
| 336 |
+
|
| 337 |
+
cfg = R1SearchConfig(serper_api_key=args.serper_key or os.getenv("SERPER_API_KEY", ""))
|
| 338 |
+
agent = R1Searcher(cfg, OpenAI())
|
| 339 |
+
final_prompt, issued_queries = agent.run(args.question)
|
| 340 |
+
|
| 341 |
+
answer = final_prompt.split("<answer>")[-1].split("</answer>")[0]
|
| 342 |
+
#print("\nANSWER:", answer)
|
| 343 |
+
#print("\nQUERIES:", json.dumps(issued_queries, indent=2))
|
| 344 |
+
|
inference/re_call.py
ADDED
|
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import requests
|
| 4 |
+
import time
|
| 5 |
+
from typing import List
|
| 6 |
+
from functools import wraps
|
| 7 |
+
from together import Together # pip install together
|
| 8 |
+
from datetime import datetime # only needed for retries / logging
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# return decorator
|
| 13 |
+
def retry(max: int = 10, sleep: int = 1, fallback=None):
|
| 14 |
+
"""
|
| 15 |
+
Retry `max` times and, if still failing, return `fallback`
|
| 16 |
+
instead of raising. This keeps outer loops alive.
|
| 17 |
+
"""
|
| 18 |
+
def decorator(func):
|
| 19 |
+
@wraps(func)
|
| 20 |
+
def wrapper(*args, **kwargs):
|
| 21 |
+
for i in range(max):
|
| 22 |
+
try:
|
| 23 |
+
return func(*args, **kwargs)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"[retry] attempt {i+1}/{max} failed: {e}")
|
| 26 |
+
if i == max - 1: # last try exhausted
|
| 27 |
+
print(f"[retry] giving up – returning {fallback!r}")
|
| 28 |
+
return fallback # ← swallow the error
|
| 29 |
+
if sleep:
|
| 30 |
+
time.sleep(sleep)
|
| 31 |
+
return wrapper
|
| 32 |
+
return decorator
|
| 33 |
+
|
| 34 |
+
class ReCall():
|
| 35 |
+
sys_prompt_websailor = """
|
| 36 |
+
You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information.
|
| 37 |
+
In this environment you have access to a set of tools you can use to assist with the user query.
|
| 38 |
+
You may perform multiple rounds of function calls. In each round, you can call one or more functions.
|
| 39 |
+
|
| 40 |
+
As you proceed, adhere to the following principles:
|
| 41 |
+
|
| 42 |
+
1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found.
|
| 43 |
+
|
| 44 |
+
2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability.
|
| 45 |
+
|
| 46 |
+
3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
|
| 51 |
+
|
| 52 |
+
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
|
| 53 |
+
The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \
|
| 54 |
+
The results of the function calls will be given back to you after execution, \
|
| 55 |
+
and you can continue to call functions until you get the final answer for the user's question. \
|
| 56 |
+
Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \
|
| 57 |
+
i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\].
|
| 58 |
+
|
| 59 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 60 |
+
<tool_call>
|
| 61 |
+
{{"name": <function-name>, "arguments": <args-json-object>}}
|
| 62 |
+
</tool_call>
|
| 63 |
+
For Multiple Choice Question always give the final answer as one of the options whichever fits the best.s
|
| 64 |
+
Always give your answer as option id. and answer.
|
| 65 |
+
Example:
|
| 66 |
+
What is the Captial of India ?
|
| 67 |
+
\\[ \\boxed{{A. India}} \\]
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
sys_prompt_websailor_deepseek = """
|
| 71 |
+
You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information.
|
| 72 |
+
In this environment you have access to a set of tools you can use to assist with the user query.
|
| 73 |
+
You may perform multiple rounds of function calls. In each round, you can call one or more functions.
|
| 74 |
+
|
| 75 |
+
As you proceed, adhere to the following principles:
|
| 76 |
+
|
| 77 |
+
1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found.
|
| 78 |
+
|
| 79 |
+
2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability.
|
| 80 |
+
|
| 81 |
+
3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins.
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
|
| 86 |
+
|
| 87 |
+
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
|
| 88 |
+
The reasoning process and function calling are enclosed within <think> </think> and <tool_calls_begin> <tool_calls_end> tags. \
|
| 89 |
+
The results of the function calls will be given back to you after execution, \
|
| 90 |
+
and you can continue to call functions until you get the final answer for the user's question. \
|
| 91 |
+
Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \
|
| 92 |
+
i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\].
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# sys_prompt_websailor_deepseek = """
|
| 96 |
+
# You are a Web Information Seeking Master. Seek the internet thoroughly and provide accurate answers. You may use tools multiple times.
|
| 97 |
+
|
| 98 |
+
# Principles:
|
| 99 |
+
# 1) Persistent Actions for Answers: explore deeply until you find satisfactory information.
|
| 100 |
+
# 2) Repeated Verification: cross-check and validate before the final answer.
|
| 101 |
+
# 3) Attention to Detail: ensure sources are current, relevant, and credible.
|
| 102 |
+
|
| 103 |
+
# You have the following tools (JSONSchema):
|
| 104 |
+
# ```json
|
| 105 |
+
# {func_schemas}
|
| 106 |
+
# Follow this EXACT tool-call I/O protocol.
|
| 107 |
+
|
| 108 |
+
# TO CALL ONE OR MORE TOOLS:
|
| 109 |
+
# Respond only with this block (no extra text before/after):
|
| 110 |
+
# <|tool▁call▁begin|>function<|tool▁sep|>{tool_name}{args_json}
|
| 111 |
+
# <|tool▁call▁end|>
|
| 112 |
+
# ... (repeat <|tool▁call▁begin|>…<|tool▁call▁end|> for multiple tools)
|
| 113 |
+
# <|tool▁calls▁end|><|end▁of▁sentence|>
|
| 114 |
+
|
| 115 |
+
# HOW TOOL RESULTS ARRIVE:
|
| 116 |
+
# I will send tool outputs back embedded inside a single user message, each wrapped like:
|
| 117 |
+
# <tool_response>{one_tool_call_you_made}
|
| 118 |
+
# {tool_return_text_or_json}
|
| 119 |
+
# </tool_response>
|
| 120 |
+
|
| 121 |
+
# WHAT TO DO NEXT:
|
| 122 |
+
|
| 123 |
+
# If you still need info, emit another tool-calls block (same exact format).
|
| 124 |
+
|
| 125 |
+
# If you have the final answer, output:
|
| 126 |
+
# <answer> …your final answer… </answer>
|
| 127 |
+
# and DO NOT call any more tools.
|
| 128 |
+
|
| 129 |
+
# Important:
|
| 130 |
+
|
| 131 |
+
# Do not expose your internal reasoning; keep thoughts private.
|
| 132 |
+
|
| 133 |
+
# When emitting a tool-calls block, do not include any explanations, only the block specified above.
|
| 134 |
+
|
| 135 |
+
# Arguments must be valid JSON.
|
| 136 |
+
|
| 137 |
+
# Stop tokens to respect: <|end▁of▁sentence|>
|
| 138 |
+
# """
|
| 139 |
+
|
| 140 |
+
system_prompt = """In this environment you have access to a set of tools you can use to assist with the user query. \
|
| 141 |
+
You may perform multiple rounds of function calls. \
|
| 142 |
+
In each round, you can call one or more functions. \
|
| 143 |
+
|
| 144 |
+
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
|
| 145 |
+
|
| 146 |
+
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \
|
| 147 |
+
The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \
|
| 148 |
+
The results of the function calls will be given back to you after execution, \
|
| 149 |
+
and you can continue to call functions until you get the final answer for the user's question. You are encouraged to utilize as many function calls as possible. \
|
| 150 |
+
Finally, if you have got the answer, wrap it in <answer> </answer> **and do not call any more functions**, \
|
| 151 |
+
e.g. <think> Based on the tool results … </think> <answer>20 °C</answer>.
|
| 152 |
+
|
| 153 |
+
For each function call, return a JSON object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 154 |
+
<tool_call>
|
| 155 |
+
{{"name": <function-name-1>, "arguments": <args-json-object>}}
|
| 156 |
+
</tool_call>"""
|
| 157 |
+
|
| 158 |
+
system_prompt_budget = """
|
| 159 |
+
You are an autonomous reasoning agent with access to external tools.
|
| 160 |
+
|
| 161 |
+
The conversation will retain only the *most-recent* <tool_response> block; older ones disappear.
|
| 162 |
+
As soon as you receive tool results, extract the *essential facts tables links etc* that might be needed for later and restate them inside your <think> section.
|
| 163 |
+
**Never copy large bodies of text** or raw JSON from tool output into your visible reply; summarise instead.
|
| 164 |
+
|
| 165 |
+
◎ **Workflow**
|
| 166 |
+
1. In every round, start with <think> … </think> to lay out your short reasoning.
|
| 167 |
+
2. If you need external information or an action, emit one or more <tool_call> … </tool_call> blocks (JSON spec below).
|
| 168 |
+
3. When the environment returns <tool_response>, continue reasoning; you may call more tools.
|
| 169 |
+
4. Once you can answer the user, wrap the final result in <answer> … </answer> and STOP calling tools.
|
| 170 |
+
|
| 171 |
+
◎ **Tool call format** (do **not** restate the schema or any explanations):
|
| 172 |
+
<tool_call>
|
| 173 |
+
{{"name": <function-name-1>, "arguments": <args-json-object>}}
|
| 174 |
+
</tool_call>
|
| 175 |
+
|
| 176 |
+
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n```
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
system_prompt_forcing_tool_call = """
|
| 182 |
+
In this environment you have access to a set of tools you can use to assist with the user query.
|
| 183 |
+
You may perform multiple rounds of function calls upto ten. In each round, you can call upto three functions.
|
| 184 |
+
|
| 185 |
+
──────────────────────── AVAILABLE TOOLS ────────────────────────
|
| 186 |
+
```json
|
| 187 |
+
[
|
| 188 |
+
{
|
| 189 |
+
"type": "function",
|
| 190 |
+
"function": {
|
| 191 |
+
"name": "pubmed_search",
|
| 192 |
+
"description": "Search PubMed for Medical related queries.",
|
| 193 |
+
"parameters": {
|
| 194 |
+
"type": "object",
|
| 195 |
+
"properties": {
|
| 196 |
+
"query": { "type": "string", "description": "Query to search for." },
|
| 197 |
+
"top_n": { "type": "integer", "description": "Number of hits", "default": 3 }
|
| 198 |
+
},
|
| 199 |
+
"required": ["query"]
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
]
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
────────────────────────────── RULES ──────────────────────────────
|
| 207 |
+
1. You MUST issue one pubmed_search tool call for each answer choice. Each query must relate the clinical context to that option.
|
| 208 |
+
2. You MAY NOT skip any option or decide based only on internal reasoning. Evidence must be retrieved for all choices.
|
| 209 |
+
3. You MAY issue follow-up tool calls if your reasoning leads you to need more evidence.
|
| 210 |
+
4. You MUST wrap all reasoning in <think> </think> tags and all tool usage in <tool_call> </tool_call> tags. Number of <tool_call> and </tool_call> tokens in the entire trace MUST always match.
|
| 211 |
+
5. Do NOT casually emit the <tool_call> </tool_call> during reasoning unless explicitly calling a tool in the proper format.
|
| 212 |
+
5. Your final answer must be enclosed a single letter corresponding to the correct option enclosed in the <answer> </answer> tags. Do not output anything else inside these tags.
|
| 213 |
+
6. DO NOT use any other confusing tags like <thiking> or </thinking>.
|
| 214 |
+
7. Each <think> </think> block MUST be followed by a <tool_call> </tool_call> or <answer> </answer> or else the program will break without an answer.
|
| 215 |
+
|
| 216 |
+
───────────────────── DUMMY EXAMPLE INTERLEAVED SKELETON ─────────────────────
|
| 217 |
+
<think>
|
| 218 |
+
We are presented with a 54-year-old woman with invasive ductal carcinoma of the breast and osteolytic lesions in the thoracic spine. This strongly suggests metastatic spread. Our task is to determine the most likely anatomical route of metastasis to the spine.
|
| 219 |
+
|
| 220 |
+
Let’s examine the given options:
|
| 221 |
+
A. Hemiazygos vein
|
| 222 |
+
B. Posterior intercostal veins
|
| 223 |
+
C. Batson’s vertebral venous plexus
|
| 224 |
+
D. Internal mammary lymphatics
|
| 225 |
+
|
| 226 |
+
We'll evaluate each option in turn using available literature and known anatomical pathways.
|
| 227 |
+
**Option A: Hemiazygos vein**
|
| 228 |
+
We begin by evaluating whether the hemiazygos vein could be involved in metastatic spread from breast cancer to the spine.
|
| 229 |
+
</think>
|
| 230 |
+
<tool_call>
|
| 231 |
+
{"name": "pubmed_search", "arguments": {"query": "breast cancer metastasis hemiazygos vein", "top_n": 2}}
|
| 232 |
+
</tool_call>
|
| 233 |
+
<tool_response>
|
| 234 |
+
...
|
| 235 |
+
</tool_response>
|
| 236 |
+
<think>
|
| 237 |
+
There is limited or no strong evidence suggesting the hemiazygos vein is a common or primary route for vertebral metastasis from breast cancer.
|
| 238 |
+
Lets explore **Option B: Posterior intercostal veins** and **Option C: Batson’s vertebral venous plexus** and **Option D:Internal mammary lymphatics**
|
| 239 |
+
</think>
|
| 240 |
+
<tool_call>
|
| 241 |
+
{"name": "pubmed_search", "arguments": {"query": "posterior intercostal veins breast cancer spinal metastasis", "top_n": 3}}
|
| 242 |
+
</tool_call>
|
| 243 |
+
<tool_call>
|
| 244 |
+
{"name": "pubmed_search", "arguments": {"query": "Batson vertebral venous plexus breast cancer metastasis", "top_n": 3}}
|
| 245 |
+
</tool_call>
|
| 246 |
+
<tool_call>
|
| 247 |
+
{"name": "pubmed_search", "arguments": {"query": "Internal mammary lymphatics breast cancer metastasis", "top_n": 3}}
|
| 248 |
+
</tool_call>
|
| 249 |
+
<tool_response>
|
| 250 |
+
...
|
| 251 |
+
</tool_response>
|
| 252 |
+
<think>
|
| 253 |
+
While the posterior intercostal veins may be involved in venous drainage, there is insufficient evidence to support them as a primary route for metastasis to the vertebral column.
|
| 254 |
+
where as Batson’s vertebral venous plexus — a valveless venous network that connects the thoracic and abdominal veins directly to the spine. I to find more specific information about option C.
|
| 255 |
+
</think>
|
| 256 |
+
<tool_call>
|
| 257 |
+
{"name": "pubmed_search", "arguments": {"query": ""Batson vertebral venous plexus breast cancer metastasis in people over 50", "top_n": 1}}
|
| 258 |
+
</tool_call>
|
| 259 |
+
<think>
|
| 260 |
+
After evaluating all four options, the most plausible route for breast cancer metastasis to the thoracic spine is clearly via Batson’s vertebral venous plexus:
|
| 261 |
+
</think>
|
| 262 |
+
<answer>C</answer>
|
| 263 |
+
"""
|
| 264 |
+
# STOP_TOKENS =STOP_TOKENS = ["<|im_end|>", "<|endoftext|>"
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def __init__(self, executor_url):
|
| 268 |
+
self.executor_url = executor_url
|
| 269 |
+
|
| 270 |
+
def init_prompt(self, func_schemas, question):
|
| 271 |
+
system_prompt = f"<|im_start|>system\n{self.sys_prompt_websailor.format(func_schemas=func_schemas)}<|im_end|>"
|
| 272 |
+
user_prompt = f"<|im_start|>user\n{question}<|im_end|>"
|
| 273 |
+
assistant_prefix = f"<|im_start|>assistant\n<think>"
|
| 274 |
+
return system_prompt + "\n" + user_prompt + "\n" + assistant_prefix
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _strip_old_tool_responses(self, prompt: str) -> str:
|
| 279 |
+
TOOL_RESPONSE_RE = re.compile(r"<tool_response>.*?</tool_response>\s*", re.DOTALL)
|
| 280 |
+
"""Remove every existing <tool_response> … </tool_response> block."""
|
| 281 |
+
return TOOL_RESPONSE_RE.sub("", prompt)
|
| 282 |
+
|
| 283 |
+
def cat_assistant_response(self, curr_prompt, assistant_response):
|
| 284 |
+
return curr_prompt + assistant_response + "<|im_end|>"
|
| 285 |
+
|
| 286 |
+
def cat_tool_results(self, curr_prompt, tool_calls, results):
|
| 287 |
+
tool_response_str = ""
|
| 288 |
+
for tool_call, result in zip(tool_calls, results):
|
| 289 |
+
tool_response_str += f"<tool_response>{tool_call}\n{result}\n</tool_response>\n"
|
| 290 |
+
tool_response_str = f"<|im_start|>user\n{tool_response_str}<|im_end|>"
|
| 291 |
+
assistant_prefix = f"<|im_start|>assistant\n<think>"
|
| 292 |
+
return curr_prompt + "\n" + tool_response_str + "\n" + assistant_prefix
|
| 293 |
+
|
| 294 |
+
def format_tool_call(self, tool_call_str: str):
|
| 295 |
+
"""Convert JSON function call description to Python executable code string."""
|
| 296 |
+
try:
|
| 297 |
+
call_json = json.loads(tool_call_str)
|
| 298 |
+
func_name = call_json['name']
|
| 299 |
+
arguments = call_json.get('arguments', {})
|
| 300 |
+
|
| 301 |
+
args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items())
|
| 302 |
+
return f"{func_name}({args_str})"
|
| 303 |
+
except Exception as e:
|
| 304 |
+
return f"Parse tool call failed: {e}"
|
| 305 |
+
|
| 306 |
+
def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]:
|
| 307 |
+
def exe_tool_call(env, call):
|
| 308 |
+
url = self.executor_url + '/execute'
|
| 309 |
+
|
| 310 |
+
call_str = self.format_tool_call(call)
|
| 311 |
+
# print(call_str)
|
| 312 |
+
if call_str.startswith("error: parse tool call failed"):
|
| 313 |
+
return call_str
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
data = {
|
| 317 |
+
'env': env,
|
| 318 |
+
'call': call_str
|
| 319 |
+
}
|
| 320 |
+
response = requests.post(url, json=data, timeout=60)
|
| 321 |
+
if response.status_code != 200:
|
| 322 |
+
return f"error: {response.status_code}"
|
| 323 |
+
response = response.json()
|
| 324 |
+
ret_str = ''
|
| 325 |
+
if response['result']:
|
| 326 |
+
ret_str += f'result: \n{response["result"]}\n'
|
| 327 |
+
if response['output']:
|
| 328 |
+
ret_str += f'output: \n{response["output"]}\n'
|
| 329 |
+
if response['error']:
|
| 330 |
+
ret_str += f'error: \n{response["error"]}\n'
|
| 331 |
+
return ret_str.strip()
|
| 332 |
+
except requests.exceptions.Timeout:
|
| 333 |
+
return "error: execution timed out"
|
| 334 |
+
except Exception as e:
|
| 335 |
+
return str(e)
|
| 336 |
+
|
| 337 |
+
results = []
|
| 338 |
+
for tool_call in tool_calls:
|
| 339 |
+
result = exe_tool_call(env, tool_call)
|
| 340 |
+
results.append(result)
|
| 341 |
+
return results
|
| 342 |
+
|
| 343 |
+
def validate_tool_calls(self, output_str):
|
| 344 |
+
start_tags = re.findall(r'<tool_call>', output_str)
|
| 345 |
+
end_tags = re.findall(r'</tool_call>', output_str)
|
| 346 |
+
|
| 347 |
+
if len(start_tags) != len(end_tags):
|
| 348 |
+
return False
|
| 349 |
+
|
| 350 |
+
start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)]
|
| 351 |
+
end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)]
|
| 352 |
+
|
| 353 |
+
for start, end in zip(start_positions, end_positions):
|
| 354 |
+
if start >= end:
|
| 355 |
+
return False
|
| 356 |
+
|
| 357 |
+
return True
|
| 358 |
+
|
| 359 |
+
def extract_tool_calls(self, output_str):
|
| 360 |
+
if not self.validate_tool_calls(output_str):
|
| 361 |
+
return []
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>'
|
| 365 |
+
matches = re.finditer(pattern, output_str, re.DOTALL)
|
| 366 |
+
|
| 367 |
+
return [match.group(1).strip() for match in matches]
|
| 368 |
+
except Exception as e:
|
| 369 |
+
return []
|
| 370 |
+
|
| 371 |
+
def extract_tool_calls_deepseek(self, output_str):
|
| 372 |
+
if not self.validate_tool_calls(output_str):
|
| 373 |
+
return []
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
pattern = r'<tool_calls_begin>((?:(?!</tool_calls_end>).)*)<tool_calls_end>'
|
| 377 |
+
matches = re.finditer(pattern, output_str, re.DOTALL)
|
| 378 |
+
|
| 379 |
+
return [match.group(1).strip() for match in matches]
|
| 380 |
+
except Exception as e:
|
| 381 |
+
return []
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
@retry(max=5, sleep=1, fallback={"score": 0})
|
| 386 |
+
def run_ii_searcher(
|
| 387 |
+
self,
|
| 388 |
+
env: str,
|
| 389 |
+
func_schemas: str,
|
| 390 |
+
question: str,
|
| 391 |
+
tokenizer,
|
| 392 |
+
model_url="http://0.0.0.0:1214",
|
| 393 |
+
temperature: float = 0.0,
|
| 394 |
+
max_new_tokens: int = 40960,
|
| 395 |
+
):
|
| 396 |
+
curr_prompt = self.init_prompt(func_schemas, question)
|
| 397 |
+
all_tool_calls= []
|
| 398 |
+
|
| 399 |
+
for _ in range(16):
|
| 400 |
+
prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 401 |
+
max_tokens_left = max_new_tokens - len(prompt_tokens) - 100
|
| 402 |
+
# for oss model served via vllm
|
| 403 |
+
# response = requests.post(
|
| 404 |
+
# f'{model_url}/v1/chat/completions',
|
| 405 |
+
# json={
|
| 406 |
+
# "text": curr_prompt,
|
| 407 |
+
# # "reasoning": "medium"
|
| 408 |
+
# },
|
| 409 |
+
# ).json()
|
| 410 |
+
# for sglang served models hf models
|
| 411 |
+
response = requests.post(
|
| 412 |
+
f'{model_url}/generate',
|
| 413 |
+
json={
|
| 414 |
+
"text": curr_prompt,
|
| 415 |
+
"sampling_params": {
|
| 416 |
+
"temperature": temperature,
|
| 417 |
+
"max_new_tokens": max_tokens_left,
|
| 418 |
+
"repetition_penalty": 1.05
|
| 419 |
+
},
|
| 420 |
+
|
| 421 |
+
}
|
| 422 |
+
).json()
|
| 423 |
+
if "error" in response.keys():
|
| 424 |
+
print("resp",response)
|
| 425 |
+
curr_prompt = self.cat_assistant_response(curr_prompt, response['text'])
|
| 426 |
+
|
| 427 |
+
tool_calls: List[str] = self.extract_tool_calls(response['text'])
|
| 428 |
+
all_tool_calls += tool_calls
|
| 429 |
+
|
| 430 |
+
if len(tool_calls) == 0:
|
| 431 |
+
break
|
| 432 |
+
|
| 433 |
+
else:
|
| 434 |
+
results: List[str] = self.execute_tool_calls(env, tool_calls)
|
| 435 |
+
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
|
| 436 |
+
|
| 437 |
+
return curr_prompt, all_tool_calls
|
| 438 |
+
|
| 439 |
+
# @retry(max=5, sleep=1, fallback={"score": 0})
|
| 440 |
+
# def run(
|
| 441 |
+
# self,
|
| 442 |
+
# env: str,
|
| 443 |
+
# func_schemas: str,
|
| 444 |
+
# question: str,
|
| 445 |
+
# tokenizer,
|
| 446 |
+
# model_url="http://0.0.0.0:1214",
|
| 447 |
+
# temperature: float = 0.0,
|
| 448 |
+
# max_new_tokens: int = 40960,
|
| 449 |
+
# ):
|
| 450 |
+
# curr_prompt = self.init_prompt(func_schemas, question)
|
| 451 |
+
# all_tool_calls= []
|
| 452 |
+
|
| 453 |
+
# for i in range(32):
|
| 454 |
+
# prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 455 |
+
# max_tokens_left = max_new_tokens - len(prompt_tokens) - 100
|
| 456 |
+
# # for oss model served via vllm
|
| 457 |
+
# # response = requests.post(
|
| 458 |
+
# # f'{model_url}/v1/chat/completions',
|
| 459 |
+
# # json={
|
| 460 |
+
# # "text": curr_prompt,
|
| 461 |
+
# # # "reasoning": "medium"
|
| 462 |
+
# # },
|
| 463 |
+
# # ).json()
|
| 464 |
+
# # for sglang served models hf models
|
| 465 |
+
# response = requests.post(
|
| 466 |
+
# f'{model_url}/generate',
|
| 467 |
+
# json={
|
| 468 |
+
# "text": curr_prompt,
|
| 469 |
+
# "sampling_params": {
|
| 470 |
+
# "temperature": temperature,
|
| 471 |
+
# "max_new_tokens": max_tokens_left,
|
| 472 |
+
# "repetition_penalty": 1.05
|
| 473 |
+
# },
|
| 474 |
+
|
| 475 |
+
# }
|
| 476 |
+
# ).json()
|
| 477 |
+
# if "error" in response.keys():
|
| 478 |
+
# print("resp",response)
|
| 479 |
+
# curr_prompt = self.cat_assistant_response(curr_prompt, response['text'])
|
| 480 |
+
|
| 481 |
+
# tool_calls: List[str] = self.extract_tool_calls(response['text'])
|
| 482 |
+
# all_tool_calls += tool_calls
|
| 483 |
+
|
| 484 |
+
# if len(tool_calls) == 0:
|
| 485 |
+
# break
|
| 486 |
+
|
| 487 |
+
# else:
|
| 488 |
+
# # print(f"Step-{i+1}")
|
| 489 |
+
# results: List[str] = self.execute_tool_calls(env, tool_calls)
|
| 490 |
+
# curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
|
| 491 |
+
|
| 492 |
+
# return curr_prompt, all_tool_calls
|
| 493 |
+
from typing import List, Dict, Any, Tuple
|
| 494 |
+
import requests
|
| 495 |
+
|
| 496 |
+
@retry(max=5, sleep=1, fallback={"score": 0})
|
| 497 |
+
def run(
|
| 498 |
+
self,
|
| 499 |
+
env: str,
|
| 500 |
+
func_schemas: str,
|
| 501 |
+
question: str,
|
| 502 |
+
tokenizer,
|
| 503 |
+
model_url: str = "http://0.0.0.0:1214",
|
| 504 |
+
temperature: float = 0.0,
|
| 505 |
+
max_new_tokens: int = 40960,
|
| 506 |
+
) -> Tuple[str, List[str], List[Dict[str, str]]]:
|
| 507 |
+
"""
|
| 508 |
+
Returns:
|
| 509 |
+
curr_prompt: the final prompt buffer (with assistant/tool traces you maintain internally)
|
| 510 |
+
all_tool_calls: flat list of all tool call strings extracted across steps
|
| 511 |
+
chat: a lightweight chat transcript list[{"role": "...", "content": "..."}]
|
| 512 |
+
• 'user' items = the original question + aggregated tool responses
|
| 513 |
+
• 'assistant' items = model responses (and a compact line-list of tool calls)
|
| 514 |
+
"""
|
| 515 |
+
# Build runtime prompt and initialize accumulators
|
| 516 |
+
curr_prompt = self.init_prompt(func_schemas, question)
|
| 517 |
+
all_tool_calls: List[str] = []
|
| 518 |
+
chat: List[Dict[str, str]] = []
|
| 519 |
+
|
| 520 |
+
# Seed transcript with JUST the question (no system prompt)
|
| 521 |
+
chat.append({"role": "user", "content": question})
|
| 522 |
+
|
| 523 |
+
for i in range(32):
|
| 524 |
+
# Budget tokens for this step
|
| 525 |
+
prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 526 |
+
max_tokens_left = max(1, max_new_tokens - len(prompt_tokens) - 100)
|
| 527 |
+
|
| 528 |
+
# ---- Model call (sglang/vLLM-style JSON) ----
|
| 529 |
+
# If you switch to /v1/chat/completions, adjust accordingly.
|
| 530 |
+
response = requests.post(
|
| 531 |
+
f"{model_url}/generate",
|
| 532 |
+
json={
|
| 533 |
+
"text": curr_prompt,
|
| 534 |
+
"sampling_params": {
|
| 535 |
+
"temperature": temperature,
|
| 536 |
+
"max_new_tokens": max_tokens_left,
|
| 537 |
+
"repetition_penalty": 1.05,
|
| 538 |
+
},
|
| 539 |
+
},
|
| 540 |
+
timeout=300,
|
| 541 |
+
).json()
|
| 542 |
+
|
| 543 |
+
if isinstance(response, dict) and "error" in response:
|
| 544 |
+
# Log the error as assistant text for visibility and break
|
| 545 |
+
err_msg = f"[model_error] {response.get('error')}"
|
| 546 |
+
chat.append({"role": "assistant", "content": err_msg})
|
| 547 |
+
break
|
| 548 |
+
|
| 549 |
+
assistant_text = response.get("text", "")
|
| 550 |
+
# Append assistant's raw text to chat
|
| 551 |
+
chat.append({"role": "assistant", "content": assistant_text})
|
| 552 |
+
|
| 553 |
+
# Update your running prompt with assistant text
|
| 554 |
+
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text)
|
| 555 |
+
|
| 556 |
+
# Extract tool calls from the assistant text
|
| 557 |
+
tool_calls: List[str] = self.extract_tool_calls(assistant_text)
|
| 558 |
+
if tool_calls:
|
| 559 |
+
all_tool_calls.extend(tool_calls)
|
| 560 |
+
|
| 561 |
+
# Log tool calls as an assistant message (newline-joined)
|
| 562 |
+
chat.append({"role": "assistant", "content": "\n".join(tool_calls)})
|
| 563 |
+
|
| 564 |
+
# Execute tools and collect results
|
| 565 |
+
results: List[str] = self.execute_tool_calls(env, tool_calls)
|
| 566 |
+
|
| 567 |
+
# Feed tool results back into prompt
|
| 568 |
+
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
|
| 569 |
+
|
| 570 |
+
# Aggregate tool responses into a single user message
|
| 571 |
+
tool_res_blocks = []
|
| 572 |
+
for idx, (call, res) in enumerate(zip(tool_calls, results), 1):
|
| 573 |
+
tool_res_blocks.append(f"[Tool {idx}] Result:\n{res}")
|
| 574 |
+
chat.append({"role": "user", "content": "\n\n".join(tool_res_blocks)})
|
| 575 |
+
|
| 576 |
+
else:
|
| 577 |
+
# No tool calls → model produced a final answer; stop.
|
| 578 |
+
break
|
| 579 |
+
|
| 580 |
+
# Return the original outputs plus the chat-style transcript
|
| 581 |
+
return curr_prompt, all_tool_calls, chat
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
@retry(max=5, sleep=1, fallback={"score": 0})
|
| 588 |
+
def run_deepseek(
|
| 589 |
+
self,
|
| 590 |
+
env: str,
|
| 591 |
+
func_schemas: str,
|
| 592 |
+
question: str,
|
| 593 |
+
model_name: str,
|
| 594 |
+
temperature: float = 0.0,
|
| 595 |
+
top_p: float = 0.95,
|
| 596 |
+
max_tokens: int = 32768,
|
| 597 |
+
):
|
| 598 |
+
# print("AA"* 100)
|
| 599 |
+
"""
|
| 600 |
+
Chat-based ReCall loop for DeepSeek-R1 on Together.
|
| 601 |
+
"""
|
| 602 |
+
sys_content = self.sys_prompt_websailor_deepseek.format(func_schemas=func_schemas)
|
| 603 |
+
# sys_content = self.init_prompt(func_schemas, question)
|
| 604 |
+
|
| 605 |
+
messages = [
|
| 606 |
+
{"role": "system", "content": sys_content},
|
| 607 |
+
{"role": "user", "content": question},
|
| 608 |
+
]
|
| 609 |
+
|
| 610 |
+
# client = Together(api_key="")
|
| 611 |
+
client = Together(api_key="bcc761f7a821a80c9c5166171ebb36756cd16d505cec226c3b2259b846364000")
|
| 612 |
+
all_tool_calls = []
|
| 613 |
+
for turn in range(32): # up to 10 reasoning turns
|
| 614 |
+
resp = client.chat.completions.create(
|
| 615 |
+
model=model_name,
|
| 616 |
+
# model="Qwen/Qwen3-235B-A22B-fp8-tput",
|
| 617 |
+
messages=messages,
|
| 618 |
+
temperature=temperature,
|
| 619 |
+
top_p=top_p,
|
| 620 |
+
max_tokens=39000,
|
| 621 |
+
stop=["<|end▁of▁sentence|>", "<|im_end|>"]
|
| 622 |
+
)
|
| 623 |
+
# print(resp)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
assistant_text = resp.choices[0].message.content
|
| 627 |
+
# print(assistant_text)
|
| 628 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
| 629 |
+
# print(f"assistant_output: {assistant_text}")
|
| 630 |
+
|
| 631 |
+
# ⛑ Safe tool call extraction with diagnostic
|
| 632 |
+
# try:
|
| 633 |
+
# print("Extracting tool calls")
|
| 634 |
+
tool_calls = self.extract_tool_calls_deepseek(assistant_text)
|
| 635 |
+
print(tool_calls)
|
| 636 |
+
all_tool_calls += tool_calls
|
| 637 |
+
# except Exception as e:
|
| 638 |
+
# print(f"Extraction failed with exception {e}")
|
| 639 |
+
# err_msg = f"<tool_response>Tool call extraction failed on turn {turn+1}: {str(e)}</tool_response>"
|
| 640 |
+
# messages.append({"role": "user", "content": err_msg})
|
| 641 |
+
# continue # continue to next turn instead of breaking
|
| 642 |
+
if "<answer>" in assistant_text:
|
| 643 |
+
break
|
| 644 |
+
|
| 645 |
+
if len(tool_calls) != 0:
|
| 646 |
+
results = self.execute_tool_calls(env, tool_calls)
|
| 647 |
+
tool_resp_block = "".join(
|
| 648 |
+
f"<tool_response>{c}\n{r}\n</tool_response>\n"
|
| 649 |
+
for c, r in zip(tool_calls, results)
|
| 650 |
+
)
|
| 651 |
+
messages.append({"role": "user", "content": tool_resp_block})
|
| 652 |
+
# print(f"Tool Response {tool_resp_block}")
|
| 653 |
+
else:
|
| 654 |
+
print("no answer or tool call")
|
| 655 |
+
break
|
| 656 |
+
|
| 657 |
+
trajectory = "\n".join(
|
| 658 |
+
f"<{m['role']}>\n{m['content']}" for m in messages
|
| 659 |
+
if m["role"] != "system"
|
| 660 |
+
)
|
| 661 |
+
return trajectory, all_tool_calls
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# ────────────────────────────────────────────────────────────────
|
| 665 |
+
# HF-endpoint version of “retrieve → inject → tool loop”
|
| 666 |
+
# ────────────────────────────────────────────────────────────────
|
| 667 |
+
@retry(max=5, sleep=1, fallback=None)
|
| 668 |
+
def run_with_prompt_injection(
|
| 669 |
+
self,
|
| 670 |
+
env: str,
|
| 671 |
+
func_schemas: str,
|
| 672 |
+
question: str,
|
| 673 |
+
model_url: str = "http://0.0.0.0:1214",
|
| 674 |
+
temperature: float = 0.0,
|
| 675 |
+
max_new_tokens: int = 512,
|
| 676 |
+
top_n: int = 5,
|
| 677 |
+
):
|
| 678 |
+
"""
|
| 679 |
+
0) call pubmed_search(question, top_n) once via the sandbox
|
| 680 |
+
1) inject those snippets into the very first user message
|
| 681 |
+
2) continue with the normal multi-turn ReCall loop against *model_url*
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
# 0️⃣ do a single retrieval tool call
|
| 685 |
+
retrieve_call = json.dumps({
|
| 686 |
+
"name": "pubmed_search",
|
| 687 |
+
"arguments": {"query": question, "top_n": top_n}
|
| 688 |
+
})
|
| 689 |
+
retrieval_raw = self.execute_tool_calls(env, [retrieve_call])[0]
|
| 690 |
+
try:
|
| 691 |
+
snippets_block = retrieval_raw.split("result:", 1)[-1].strip()
|
| 692 |
+
except Exception:
|
| 693 |
+
snippets_block = ""
|
| 694 |
+
|
| 695 |
+
# 1️⃣ build initial prompt with injected snippets
|
| 696 |
+
user_msg = (
|
| 697 |
+
f"Question: {question}\n\n"
|
| 698 |
+
"Here are some relevant PubMed snippets:\n"
|
| 699 |
+
f"{snippets_block}"
|
| 700 |
+
) if snippets_block else f"Question: {question}"
|
| 701 |
+
|
| 702 |
+
sys_prompt = self.init_prompt(func_schemas, question)
|
| 703 |
+
system_prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>"
|
| 704 |
+
user_prompt = f"<|im_start|>user\n{user_msg}<|im_end|>"
|
| 705 |
+
assistant_pref= f"<|im_start|>assistant\n<think>"
|
| 706 |
+
curr_prompt = system_prompt + "\n" + user_prompt + "\n" + assistant_pref
|
| 707 |
+
|
| 708 |
+
# 2️⃣ normal ReCall loop hitting the HF inference endpoint
|
| 709 |
+
for _ in range(10):
|
| 710 |
+
resp = requests.post(
|
| 711 |
+
f"{model_url}/generate",
|
| 712 |
+
json={
|
| 713 |
+
"text": curr_prompt,
|
| 714 |
+
"sampling_params": {
|
| 715 |
+
"temperature": temperature,
|
| 716 |
+
"max_new_tokens": max_new_tokens,
|
| 717 |
+
}
|
| 718 |
+
},
|
| 719 |
+
timeout=120,
|
| 720 |
+
).json()
|
| 721 |
+
if "error" in resp.keys():
|
| 722 |
+
print("resp",response)
|
| 723 |
+
assistant_txt = resp["text"]
|
| 724 |
+
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_txt)
|
| 725 |
+
|
| 726 |
+
tool_calls = self.extract_tool_calls(assistant_txt)
|
| 727 |
+
if len(tool_calls) != 0:
|
| 728 |
+
# break # model produced an answer → done
|
| 729 |
+
|
| 730 |
+
results = self.execute_tool_calls(env, tool_calls)
|
| 731 |
+
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
|
| 732 |
+
|
| 733 |
+
else:
|
| 734 |
+
continue
|
| 735 |
+
return curr_prompt
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
@retry(max=5, sleep=1, fallback={"score": 0})
|
| 740 |
+
def run_budget(
|
| 741 |
+
self,
|
| 742 |
+
env: str,
|
| 743 |
+
func_schemas: str,
|
| 744 |
+
question: str,
|
| 745 |
+
model_url: str = "http://0.0.0.0:1214",
|
| 746 |
+
temperature: float = 0.0,
|
| 747 |
+
max_new_tokens: int = 2048,
|
| 748 |
+
) -> str:
|
| 749 |
+
"""
|
| 750 |
+
Execute an agentic dialogue with external tools while *pruning* previous
|
| 751 |
+
<tool_response> blocks to prevent context-length explosion.
|
| 752 |
+
"""
|
| 753 |
+
curr_prompt = self.init_prompt(func_schemas, question)
|
| 754 |
+
|
| 755 |
+
for _ in range(16): # hard loop-limit
|
| 756 |
+
# ── 1. Call the model
|
| 757 |
+
rsp = requests.post(
|
| 758 |
+
f"{model_url}/generate",
|
| 759 |
+
json={
|
| 760 |
+
"text": curr_prompt,
|
| 761 |
+
"sampling_params": {
|
| 762 |
+
"temperature": temperature,
|
| 763 |
+
"max_new_tokens": max_new_tokens,
|
| 764 |
+
"stop": ["<|im_end|>", "</think>", "</think>\n" "</think>\n\n"],
|
| 765 |
+
},
|
| 766 |
+
|
| 767 |
+
},
|
| 768 |
+
timeout=120,
|
| 769 |
+
).json()
|
| 770 |
+
generated = rsp["text"] # what you have now
|
| 771 |
+
matched = rsp["meta_info"]["finish_reason"].get("matched")
|
| 772 |
+
|
| 773 |
+
# ⇢ append the tag back only if it was removed
|
| 774 |
+
if matched and not generated.endswith(matched):
|
| 775 |
+
generated += matched
|
| 776 |
+
|
| 777 |
+
# Fail fast on server error
|
| 778 |
+
if "error" in rsp:
|
| 779 |
+
raise RuntimeError(rsp["error"])
|
| 780 |
+
|
| 781 |
+
assistant_text: str = rsp["text"]
|
| 782 |
+
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text)
|
| 783 |
+
|
| 784 |
+
# ── 2. Check for final answer ────────────────────────────────────
|
| 785 |
+
if "<answer>" in assistant_text:
|
| 786 |
+
break
|
| 787 |
+
|
| 788 |
+
# ── 3. Extract & execute tool calls ──────────────────────────────
|
| 789 |
+
tool_calls: List[str] = self.extract_tool_calls(assistant_text)
|
| 790 |
+
if not tool_calls: # continue reasoning without calling a tool
|
| 791 |
+
continue
|
| 792 |
+
|
| 793 |
+
results: List[str] = self.execute_tool_calls(env, tool_calls)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
# ── 4. BEFORE appending new tool output, drop all old ones ───────
|
| 797 |
+
curr_prompt =self. _strip_old_tool_responses(curr_prompt)
|
| 798 |
+
|
| 799 |
+
# ── 5. Append *only* the fresh tool_response block ───────────────
|
| 800 |
+
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results)
|
| 801 |
+
|
| 802 |
+
return curr_prompt
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def _strip_old_tool_responses_msgs(self, messages: list[dict]) -> list[dict]:
|
| 808 |
+
"""
|
| 809 |
+
Return a copy of `messages` with every *user* message that starts with
|
| 810 |
+
<tool_response> removed. Keeps assistant turns untouched.
|
| 811 |
+
"""
|
| 812 |
+
return [
|
| 813 |
+
m for m in messages
|
| 814 |
+
if not (m["role"] == "user" and m["content"].lstrip().startswith("<tool_response>"))
|
| 815 |
+
]
|
| 816 |
+
# ────────── budget version ──────────
|
| 817 |
+
@retry(max=5, sleep=1, fallback={"score": 0})
|
| 818 |
+
def run_deepseek_budget(
|
| 819 |
+
self,
|
| 820 |
+
env: str,
|
| 821 |
+
func_schemas: str,
|
| 822 |
+
question: str,
|
| 823 |
+
api_key: str,
|
| 824 |
+
model_name: str,
|
| 825 |
+
temperature: float = 0.0,
|
| 826 |
+
top_p: float = 0.95,
|
| 827 |
+
max_tokens: int = 32768,
|
| 828 |
+
max_turns: int = 10,
|
| 829 |
+
):
|
| 830 |
+
"""
|
| 831 |
+
Chat-based ReCall loop for DeepSeek-R1 **with context-budget pruning**.
|
| 832 |
+
Keeps only the *latest* <tool_response> block to avoid prompt bloat.
|
| 833 |
+
"""
|
| 834 |
+
sys_content = self.system_prompt_budget.format(func_schemas=func_schemas)
|
| 835 |
+
|
| 836 |
+
messages = [
|
| 837 |
+
{"role": "system", "content": sys_content},
|
| 838 |
+
{"role": "user", "content": question},
|
| 839 |
+
]
|
| 840 |
+
|
| 841 |
+
client = Together(api_key=api_key)
|
| 842 |
+
|
| 843 |
+
for turn in range(max_turns):
|
| 844 |
+
# ── 1. model call ───────────────────────────────────────────────
|
| 845 |
+
resp = client.chat.completions.create(
|
| 846 |
+
model=model_name,
|
| 847 |
+
messages=messages,
|
| 848 |
+
temperature=temperature,
|
| 849 |
+
top_p=top_p,
|
| 850 |
+
max_tokens=max_tokens,
|
| 851 |
+
stop=["</tool_call>", "<|end▁of▁sentence|>"],
|
| 852 |
+
)
|
| 853 |
+
assistant_text = resp.choices[0].message.content
|
| 854 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
| 855 |
+
|
| 856 |
+
print(f"**assistant** \n {assistant_text}")
|
| 857 |
+
|
| 858 |
+
# ── 2. finished? ────────────────────────────────────────────────
|
| 859 |
+
if "<answer>" in assistant_text:
|
| 860 |
+
break
|
| 861 |
+
|
| 862 |
+
# ── 3. parse tool calls ────────────────────────────────────────
|
| 863 |
+
tool_calls = self.extract_tool_calls(assistant_text)
|
| 864 |
+
print(f"**tool_calls** \n {tool_calls}")
|
| 865 |
+
if not tool_calls:
|
| 866 |
+
continue # keep reasoning without tools
|
| 867 |
+
|
| 868 |
+
# ── 4. execute tools ───────────────────────────────────────────
|
| 869 |
+
results = self.execute_tool_calls(env, tool_calls)
|
| 870 |
+
print(f"**tool_response** \n {results}")
|
| 871 |
+
|
| 872 |
+
# ── 5. prune & append fresh tool_response ──────────────────────
|
| 873 |
+
messages = self._strip_old_tool_responses_msgs(messages)
|
| 874 |
+
|
| 875 |
+
tool_resp_block = "".join(
|
| 876 |
+
f"<tool_response>{c}\n{r}\n</tool_response>\n"
|
| 877 |
+
for c, r in zip(tool_calls, results)
|
| 878 |
+
)
|
| 879 |
+
messages.append({"role": "user", "content": tool_resp_block})
|
| 880 |
+
|
| 881 |
+
# ── 6. flatten & return trajectory (sans system for readability) ───
|
| 882 |
+
trajectory = "\n".join(
|
| 883 |
+
f"<{m['role']}>\n{m['content']}" for m in messages if m["role"] != "system"
|
| 884 |
+
)
|
| 885 |
+
return trajectory
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
@retry(max=5, sleep=1, fallback=None)
|
| 889 |
+
def run_deepseek_with_prompt_injection(
|
| 890 |
+
self,
|
| 891 |
+
env: str,
|
| 892 |
+
func_schemas: str,
|
| 893 |
+
question: str,
|
| 894 |
+
api_key: str,
|
| 895 |
+
model_name: str,
|
| 896 |
+
temperature: float = 0.0,
|
| 897 |
+
top_p: float = 0.95,
|
| 898 |
+
max_tokens: int = 32768,
|
| 899 |
+
):
|
| 900 |
+
"""
|
| 901 |
+
1) Call pubmed_search(question, top_n=5) as a tool to get snippets.
|
| 902 |
+
2) Inject them into the first user message.
|
| 903 |
+
3) Proceed with the usual DeepSeek-R1 tool‐based rollout.
|
| 904 |
+
"""
|
| 905 |
+
|
| 906 |
+
# ── Step 0: prepare the single‐tool call for retrieval ───────────────
|
| 907 |
+
retrieve_call = json.dumps({
|
| 908 |
+
"name": "pubmed_search",
|
| 909 |
+
"arguments": {
|
| 910 |
+
"query": question,
|
| 911 |
+
"top_n": 5
|
| 912 |
+
}
|
| 913 |
+
})
|
| 914 |
+
|
| 915 |
+
# Execute it once via your helper
|
| 916 |
+
# note: `env` must include whatever import / client‐setup
|
| 917 |
+
# your sandbox needs to run pubmed_search(...)
|
| 918 |
+
raw_retrieval_results = self.execute_tool_calls(env, [retrieve_call])[0]
|
| 919 |
+
# print("AAAAA"*100)
|
| 920 |
+
try:
|
| 921 |
+
snippets = raw_retrieval_results[9:] #"remove result: str"
|
| 922 |
+
# print(snippets)
|
| 923 |
+
except:
|
| 924 |
+
snippets = ""
|
| 925 |
+
# print(f"[ReCall] Retriever call failed to parse JSON, got:\n{raw_retrieval_results!r}")
|
| 926 |
+
|
| 927 |
+
# ── Step 1: build the injected user prompt ────────────────────────────
|
| 928 |
+
if snippets:
|
| 929 |
+
|
| 930 |
+
user_content = (
|
| 931 |
+
f"Question: {question}\n\n"
|
| 932 |
+
"Here are some relevant PubMed snippets:\n"
|
| 933 |
+
f"{snippets}"
|
| 934 |
+
)
|
| 935 |
+
else:
|
| 936 |
+
user_content = f"Question: {question}"
|
| 937 |
+
|
| 938 |
+
# ── Step 2: start the chat history ────────────────────────────────────
|
| 939 |
+
sys_content = self.system_prompt_forcing_tool_call
|
| 940 |
+
messages = [
|
| 941 |
+
{"role": "system", "content": sys_content},
|
| 942 |
+
{"role": "user", "content": user_content},
|
| 943 |
+
]
|
| 944 |
+
client = Together(api_key=api_key)
|
| 945 |
+
|
| 946 |
+
# ── Step 3: your normal ReCall tool‐calling loop ─────────────────────
|
| 947 |
+
for turn in range(10):
|
| 948 |
+
resp = client.chat.completions.create(
|
| 949 |
+
model = model_name,
|
| 950 |
+
messages = messages,
|
| 951 |
+
temperature = temperature,
|
| 952 |
+
top_p = top_p,
|
| 953 |
+
max_tokens = max_tokens,
|
| 954 |
+
stop = ["</tool_call>", "<|end▁of▁sentence|>"]
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
assistant_text = resp.choices[0].message.content
|
| 958 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
| 959 |
+
|
| 960 |
+
tool_calls = self.extract_tool_calls(assistant_text)
|
| 961 |
+
if not tool_calls:
|
| 962 |
+
break
|
| 963 |
+
|
| 964 |
+
# Execute all of the tool calls in one go
|
| 965 |
+
results = self.execute_tool_calls(env, tool_calls)
|
| 966 |
+
# and append them back in the required <tool_response> format
|
| 967 |
+
tool_resp_block = "".join(
|
| 968 |
+
f"<tool_response>{call}\n{out}\n</tool_response>\n"
|
| 969 |
+
for call, out in zip(tool_calls, results)
|
| 970 |
+
)
|
| 971 |
+
messages.append({"role": "user", "content": tool_resp_block})
|
| 972 |
+
|
| 973 |
+
# ── Step 4: flatten to a single trajectory ────────────────────────────
|
| 974 |
+
trajectory = "\n".join(
|
| 975 |
+
f"<{m['role']}>\n{m['content']}"
|
| 976 |
+
for m in messages
|
| 977 |
+
if m["role"] != "system"
|
| 978 |
+
)
|
| 979 |
+
return trajectory
|
| 980 |
+
|
inference/simpledeepsearch.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""o1_searcher_inference.py — Serper‑based Search‑o1 re‑implementation
|
| 3 |
+
with *original* in‑house summarisation workflow, step‑replacement logic and
|
| 4 |
+
bug‑fixes for duplicate queries / ValueError.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os, re, json, time, string, pathlib
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Dict, Optional, Tuple
|
| 11 |
+
import requests, trafilatura
|
| 12 |
+
|
| 13 |
+
# -----------------------------------------------------------------------------
|
| 14 |
+
# Optional NLTK sentence tokenizer (fallback to regex) -------------------------
|
| 15 |
+
try:
|
| 16 |
+
from nltk.tokenize import sent_tokenize # type: ignore
|
| 17 |
+
except Exception: # ImportError *or* missing punkt data
|
| 18 |
+
def sent_tokenize(x: str):
|
| 19 |
+
return re.split(r"(?<=[.!?]) +", x)
|
| 20 |
+
|
| 21 |
+
# -----------------------------------------------------------------------------
|
| 22 |
+
# Special tags & constants -----------------------------------------------------
|
| 23 |
+
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
| 24 |
+
END_SEARCH_QUERY = "<|end_search_query|>"
|
| 25 |
+
BEGIN_DOCUMENT_QUERY = "<|begin_of_document|>"
|
| 26 |
+
END_DOCUMENT_QUERY = "<|end_of_document|>"
|
| 27 |
+
THINK_OPEN, THINK_CLOSE = "<think>", "</think>"
|
| 28 |
+
EOS_TOKEN = "<|im_end|>"
|
| 29 |
+
ANSWER_OPEN, ANSWER_CLOSE = "<answer>", "</answer>"
|
| 30 |
+
STOP_STRINGS = [END_SEARCH_QUERY, ANSWER_CLOSE, EOS_TOKEN, "<|endoftext|>"]
|
| 31 |
+
ALLOWED_DATASETS = {"musique", "frames", "simpleqa", "browsercomp"}
|
| 32 |
+
# tokenizer =
|
| 33 |
+
TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
|
| 34 |
+
|
| 35 |
+
# ───────────────────────── tokenizer ────────────────────────────────────────
|
| 36 |
+
try:
|
| 37 |
+
from transformers import AutoTokenizer
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
import sys
|
| 41 |
+
sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
|
| 42 |
+
|
| 43 |
+
# -----------------------------------------------------------------------------
|
| 44 |
+
# Helper functions -------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
def remove_punc(t: str) -> str:
|
| 47 |
+
return t.translate(str.maketrans("", "", string.punctuation))
|
| 48 |
+
|
| 49 |
+
# legacy aliases for older checkpoints ---------------------------------------
|
| 50 |
+
_nopunc = remove_punc
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def f1(a: set, b: set) -> float:
|
| 54 |
+
inter = len(a & b)
|
| 55 |
+
return 0.0 if inter == 0 else 2 * inter / (len(a) + len(b))
|
| 56 |
+
|
| 57 |
+
# legacy alias
|
| 58 |
+
_f1 = f1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def extract_snippet_ctx(text: str, snippet: str, win: int = 2500) -> str:
|
| 62 |
+
"""Return *window*‑sized context around the sentence most similar to snippet."""
|
| 63 |
+
text = text[:50_000]
|
| 64 |
+
sn_set = set(remove_punc(snippet.lower()).split())
|
| 65 |
+
best, best_score = None, 0.20
|
| 66 |
+
for sent in sent_tokenize(text):
|
| 67 |
+
score = f1(sn_set, set(remove_punc(sent.lower()).split()))
|
| 68 |
+
if score > best_score:
|
| 69 |
+
best, best_score = sent, score
|
| 70 |
+
if best:
|
| 71 |
+
pos = text.find(best)
|
| 72 |
+
return text[max(0, pos - win): pos + len(best) + win]
|
| 73 |
+
return text[: 2 * win]
|
| 74 |
+
|
| 75 |
+
# -----------------------------------------------------------------------------
|
| 76 |
+
# Config dataclass -------------------------------------------------------------
|
| 77 |
+
@dataclass
|
| 78 |
+
class SDSCfg:
|
| 79 |
+
serper_api_key: str = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
|
| 80 |
+
gl: str = "us"; hl: str = "en"
|
| 81 |
+
top_k: int = 10; max_doc_len: int = 3000
|
| 82 |
+
max_search: int = 10; max_turn: int = 15
|
| 83 |
+
use_jina: bool = True
|
| 84 |
+
jina_tpl: str = "https://r.jina.ai/http://{}"
|
| 85 |
+
# generation params
|
| 86 |
+
temperature: float = 0.7; top_p: float = 0.8; top_k_sampling: int = 20
|
| 87 |
+
rep_pen: float = 1.05; thinker_max_tokens: int = 40960
|
| 88 |
+
|
| 89 |
+
# -----------------------------------------------------------------------------
|
| 90 |
+
# Serper search + page fetch ---------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def serper_search(q: str, num: int, key: str, gl="us", hl="en") -> List[Dict]:
|
| 93 |
+
hdr = {"X-API-KEY": key, "Content-Type": "application/json"}
|
| 94 |
+
body = {"q": q, "num": num, "gl": gl, "hl": hl}
|
| 95 |
+
r = requests.post("https://google.serper.dev/search", json=body, headers=hdr, timeout=20)
|
| 96 |
+
r.raise_for_status(); return r.json().get("organic", [])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def fetch_page(url: str, cfg: O1Cfg, snippet: str = "") -> str:
|
| 100 |
+
try:
|
| 101 |
+
txt = ""
|
| 102 |
+
if cfg.use_jina:
|
| 103 |
+
r = requests.get(cfg.jina_tpl.format(url), timeout=15)
|
| 104 |
+
if r.ok and len(r.text.strip()) > 100:
|
| 105 |
+
txt = r.text.strip()
|
| 106 |
+
if txt == "":
|
| 107 |
+
r = requests.get(url, timeout=15); r.raise_for_status()
|
| 108 |
+
txt = trafilatura.extract(r.text, output_format="txt") or ""
|
| 109 |
+
if snippet:
|
| 110 |
+
txt = extract_snippet_ctx(txt, snippet, cfg.max_doc_len)
|
| 111 |
+
|
| 112 |
+
return txt
|
| 113 |
+
except Exception:
|
| 114 |
+
return ""
|
| 115 |
+
|
| 116 |
+
# -----------------------------------------------------------------------------
|
| 117 |
+
# replace_recent_steps --------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
def replace_recent_steps(origin: str, patch: str) -> str:
|
| 120 |
+
"""Apply *patch* (containing numbered `Step N:` lines) to *origin*."""
|
| 121 |
+
step_re = re.compile(r"Step\s+(\d+):\s*")
|
| 122 |
+
|
| 123 |
+
def parse(block: str) -> Dict[int, str]:
|
| 124 |
+
cur, buf, out = None, [], {}
|
| 125 |
+
for line in block.splitlines():
|
| 126 |
+
m = step_re.match(line)
|
| 127 |
+
if m:
|
| 128 |
+
if cur is not None:
|
| 129 |
+
out[cur] = "\n".join(buf).strip()
|
| 130 |
+
cur, buf = int(m.group(1)), [line[m.end():].strip()]
|
| 131 |
+
elif cur is not None:
|
| 132 |
+
buf.append(line)
|
| 133 |
+
if cur is not None:
|
| 134 |
+
out[cur] = "\n".join(buf).strip()
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
base = parse(origin); mod = parse(patch)
|
| 138 |
+
for k, v in mod.items():
|
| 139 |
+
if "DELETE THIS STEP" in v:
|
| 140 |
+
base.pop(k, None)
|
| 141 |
+
else:
|
| 142 |
+
base[k] = v
|
| 143 |
+
return "\n\n".join(base[k] for k in sorted(base))
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------------------------
|
| 146 |
+
# Prompts ----------------------------------------------------------------------
|
| 147 |
+
# from prompts import get_webpage_to_reasonchain_instruction # keep original helper
|
| 148 |
+
|
| 149 |
+
# -----------------------------------------------------------------------------
|
| 150 |
+
# Main agent -------------------------------------------------------------------
|
| 151 |
+
class SDSearcher:
|
| 152 |
+
# STOP_TOKENS = [
|
| 153 |
+
# "<|im_end|>",
|
| 154 |
+
# "<|endoftext|>",
|
| 155 |
+
# "<|end_of_query|>",
|
| 156 |
+
# " <|end_of_query|>",
|
| 157 |
+
# "<|end_of_query|>\n",
|
| 158 |
+
# "<|end_of_query|>\n\n",
|
| 159 |
+
# " <|end_of_query|>\n",
|
| 160 |
+
# " <|end_of_query|>\n\n",
|
| 161 |
+
# ]
|
| 162 |
+
get_webpage_to_reasonchain_instruction = """**Task Instruction:**
|
| 163 |
+
|
| 164 |
+
You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
|
| 165 |
+
|
| 166 |
+
**Guidelines:**
|
| 167 |
+
|
| 168 |
+
1. **Analyze the Searched Web Pages:**
|
| 169 |
+
- Carefully review the content of each searched web page.
|
| 170 |
+
- Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
|
| 171 |
+
|
| 172 |
+
2. **Extract Relevant Information:**
|
| 173 |
+
- Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
|
| 174 |
+
- Ensure that the extracted information is accurate and relevant.
|
| 175 |
+
|
| 176 |
+
3. **Output Format:**
|
| 177 |
+
- **If the web pages provide helpful information for current search query:** Present the information beginning with **Final Information** as shown below.
|
| 178 |
+
**Final Information**
|
| 179 |
+
|
| 180 |
+
[Helpful information]
|
| 181 |
+
|
| 182 |
+
- **If the web pages do not provide any helpful information for current search query:** Output the following text.
|
| 183 |
+
|
| 184 |
+
**Final Information**
|
| 185 |
+
|
| 186 |
+
No helpful information found.
|
| 187 |
+
|
| 188 |
+
**Inputs:**
|
| 189 |
+
- **Previous Reasoning Steps:**
|
| 190 |
+
{prev_reasoning}
|
| 191 |
+
|
| 192 |
+
- **Current Search Query:**
|
| 193 |
+
{search_query}
|
| 194 |
+
|
| 195 |
+
- **Searched Web Pages:**
|
| 196 |
+
{document}
|
| 197 |
+
|
| 198 |
+
Now you should analyze each web page and find helpful information based on the current search query {search_query} and previous reasoning steps.
|
| 199 |
+
Return the Helpful information in the <information></information> tags
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
sys_prompt_multiqa = (
|
| 203 |
+
"You are a reasoning assistant with the ability to perform web searches to help "
|
| 204 |
+
"you answer the user's question accurately. You have special tools:\n\n"
|
| 205 |
+
"- To perform a search: write <|begin_search_query|> your query here <|end_search_query|>.\n"
|
| 206 |
+
"Then, the system will search and analyze relevant web pages, then provide you with helpful information in the format <|begin_search_result|> ...search results... <|end_search_result|>.\n\n"
|
| 207 |
+
f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to 16.\n\n"
|
| 208 |
+
"Once you have all the information you need, continue your reasoning.\n\n"
|
| 209 |
+
"Example:\n"
|
| 210 |
+
"Question: \"Alice David is the voice of Lara Croft in a video game developed by which company?\"\n"
|
| 211 |
+
"Assistant thinking steps:\n"
|
| 212 |
+
"- I need to find out who voices Lara Croft in the video game.\n"
|
| 213 |
+
"- Then, I need to determine which company developed that video game.\n\n"
|
| 214 |
+
"Assistant:\n"
|
| 215 |
+
"<|begin_search_query|>Alice David Lara Croft voice<|end_search_query|>\n\n"
|
| 216 |
+
"(System returns processed information from relevant web pages)\n\n"
|
| 217 |
+
"Assistant thinks: The search results indicate that Alice David is the voice of Lara Croft in a specific video game. Now, I need to find out which company developed that game.\n\n"
|
| 218 |
+
"Assistant:\n"
|
| 219 |
+
"<|begin_search_query|>video game developed by Alice David Lara Croft<|end_search_query|>\n\n"
|
| 220 |
+
"(System returns processed information from relevant web pages)\n\n"
|
| 221 |
+
"Assistant continues reasoning with the new information...\n\n"
|
| 222 |
+
"Remember:\n"
|
| 223 |
+
"- Use <|begin_search_query|> to request a web search and end with <|end_search_query|>.\n"
|
| 224 |
+
"- When done searching, continue your reasoning.\n\n",
|
| 225 |
+
"Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def __init__(self, cfg: O1Cfg, thinker_url: str):
|
| 229 |
+
if not cfg.serper_api_key:
|
| 230 |
+
raise ValueError("SERPER_API_KEY required")
|
| 231 |
+
self.cfg, self.model_url = cfg, thinker_url.rstrip("/")
|
| 232 |
+
self.search_cache: Dict[str, List[Dict]] = {}
|
| 233 |
+
self.page_cache: Dict[Tuple[str, str], str] = {}
|
| 234 |
+
|
| 235 |
+
# --- low‑level generation call ------------------------------------------
|
| 236 |
+
def _generate(self, prompt: str) -> str:
|
| 237 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 238 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 239 |
+
resp = requests.post(
|
| 240 |
+
f"{self.model_url}/generate",
|
| 241 |
+
json={
|
| 242 |
+
"text": prompt,
|
| 243 |
+
"sampling_params": {
|
| 244 |
+
"temperature": self.cfg.temperature,
|
| 245 |
+
"top_p": self.cfg.top_p,
|
| 246 |
+
"max_new_tokens": max_tokens_left,
|
| 247 |
+
"repetition_penalty": self.cfg.rep_pen,
|
| 248 |
+
"stop": STOP_STRINGS,
|
| 249 |
+
},
|
| 250 |
+
|
| 251 |
+
},
|
| 252 |
+
timeout=60,
|
| 253 |
+
).json()
|
| 254 |
+
generated = resp["text"] # what you have now
|
| 255 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 256 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 257 |
+
|
| 258 |
+
# ⇢ append the tag back only if it was removed
|
| 259 |
+
if reason == "stop" and matched in STOP_STRINGS:
|
| 260 |
+
if not "<|end_of_query|>" in generated:
|
| 261 |
+
generated += matched
|
| 262 |
+
if reason == "stop" and matched == 151645:
|
| 263 |
+
if not generated.endswith("<|im_end|>"):
|
| 264 |
+
generated += "<|im_end|>"
|
| 265 |
+
if reason == "stop" and matched == 151643:
|
| 266 |
+
if not generated.endswith("<|endoftext|>"):
|
| 267 |
+
generated += "<|endoftext|>"
|
| 268 |
+
return generated
|
| 269 |
+
|
| 270 |
+
def _generate_summary(self, prompt: str) -> str:
|
| 271 |
+
summary_url = "http://0.0.0.0:1243"
|
| 272 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 273 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 274 |
+
resp = requests.post(
|
| 275 |
+
f"{summary_url}/generate",
|
| 276 |
+
json={
|
| 277 |
+
"text": prompt,
|
| 278 |
+
"sampling_params": {
|
| 279 |
+
"temperature": self.cfg.temperature,
|
| 280 |
+
"max_new_tokens": 8192,#max_tokens_left,
|
| 281 |
+
"stop": STOP_STRINGS,
|
| 282 |
+
},
|
| 283 |
+
|
| 284 |
+
},
|
| 285 |
+
timeout=60,
|
| 286 |
+
).json()
|
| 287 |
+
generated = resp["text"] # what you have now
|
| 288 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 289 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 290 |
+
# ##print("-"*100)
|
| 291 |
+
# ##print(resp)
|
| 292 |
+
# ##print(matched)
|
| 293 |
+
# ##print("-"*100)
|
| 294 |
+
# ⇢ append the tag back only if it was removed
|
| 295 |
+
if reason == "stop" and matched in STOP_STRINGS:
|
| 296 |
+
if not "<|end_of_query|>" in generated:
|
| 297 |
+
generated += matched + EOS_TOKEN
|
| 298 |
+
if reason == "stop" and matched == 151645:
|
| 299 |
+
if not generated.endswith("<|im_end|>"):
|
| 300 |
+
generated += "<|im_end|>"
|
| 301 |
+
if reason == "stop" and matched == 151643:
|
| 302 |
+
if not generated.endswith("<|endoftext|>"):
|
| 303 |
+
generated += "<|endoftext|>"
|
| 304 |
+
return generated
|
| 305 |
+
# --- public entry -------------------------------------------------------
|
| 306 |
+
def run(self, question: str):
|
| 307 |
+
prompt = (
|
| 308 |
+
f"<|im_start|>system\n{self.sys_prompt_multiqa}<|im_end|>\n"
|
| 309 |
+
f"<|im_start|>user\n{question}<|im_end|>\n"
|
| 310 |
+
f"<|im_start|>assistant\n{THINK_OPEN}"
|
| 311 |
+
)
|
| 312 |
+
full_trace = prompt # <-- Track full trace
|
| 313 |
+
queries: List[str] = []
|
| 314 |
+
seen_queries: set[str] = set()
|
| 315 |
+
|
| 316 |
+
for i in range(self.cfg.max_turn):
|
| 317 |
+
chunk = self._generate(prompt)
|
| 318 |
+
prompt += chunk
|
| 319 |
+
|
| 320 |
+
if ANSWER_CLOSE in chunk:
|
| 321 |
+
break
|
| 322 |
+
|
| 323 |
+
##print(f"step-{i}")
|
| 324 |
+
##print(chunk)
|
| 325 |
+
|
| 326 |
+
query = self._extract_query(chunk)
|
| 327 |
+
##print(query)
|
| 328 |
+
if not query or len(queries) >= self.cfg.max_search:
|
| 329 |
+
break
|
| 330 |
+
if query in seen_queries:
|
| 331 |
+
continue
|
| 332 |
+
queries.append(query)
|
| 333 |
+
seen_queries.add(query)
|
| 334 |
+
|
| 335 |
+
doc = self._retrieve_doc(query)
|
| 336 |
+
prev_reasoning = self._extract_reasoning(prompt)
|
| 337 |
+
summary = "\n<|im_start|>user" + self._summarise(prev_reasoning, query, doc) + EOS_TOKEN + "\n<|im_start|>assistant" + THINK_OPEN
|
| 338 |
+
##print("summary")
|
| 339 |
+
##print(summary)
|
| 340 |
+
prompt += summary # <-- Log summary to trace
|
| 341 |
+
|
| 342 |
+
# new_reasoning = replace_recent_steps(prev_reasoning, summary)
|
| 343 |
+
|
| 344 |
+
# if prev_reasoning:
|
| 345 |
+
# prompt = prompt.rsplit(prev_reasoning, 1)[0] + new_reasoning + THINK_CLOSE + THINK_OPEN
|
| 346 |
+
# else:
|
| 347 |
+
# prompt += new_reasoning + THINK_CLOSE + THINK_OPEN
|
| 348 |
+
|
| 349 |
+
# full_trace += + THINK_CLOSE + THINK_OPEN + "\n" # <-- Log reasoning to trace
|
| 350 |
+
else:
|
| 351 |
+
final = f"{ANSWER_OPEN}I don't know.{ANSWER_CLOSE}"
|
| 352 |
+
prompt += final
|
| 353 |
+
# full_trace += final
|
| 354 |
+
|
| 355 |
+
return prompt, queries
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# ---------------------------------------------------------------------
|
| 359 |
+
# helpers --------------------------------------------------------------
|
| 360 |
+
def _extract_query(self, txt: str) -> Optional[str]:
|
| 361 |
+
if BEGIN_SEARCH_QUERY not in txt or END_SEARCH_QUERY not in txt:
|
| 362 |
+
return None
|
| 363 |
+
frag = txt.split(BEGIN_SEARCH_QUERY)[-1].split(END_SEARCH_QUERY)[0]
|
| 364 |
+
# strip quotes / ellipsis / tabs
|
| 365 |
+
return re.sub(r"[\"'…\t]", " ", frag.split("<|")[0]).strip()
|
| 366 |
+
|
| 367 |
+
def _retrieve_doc(self, query: str) -> str:
|
| 368 |
+
if query not in self.search_cache:
|
| 369 |
+
self.search_cache[query] = serper_search(query, self.cfg.top_k, self.cfg.serper_api_key,
|
| 370 |
+
gl=self.cfg.gl, hl=self.cfg.hl)
|
| 371 |
+
for hit in self.search_cache[query]:
|
| 372 |
+
# ##print("hit")
|
| 373 |
+
# ##print(hit)
|
| 374 |
+
url, sn = hit.get("link", ""), hit.get("snippet", "")
|
| 375 |
+
if not url:
|
| 376 |
+
continue
|
| 377 |
+
key = (url, sn)
|
| 378 |
+
if key not in self.page_cache:
|
| 379 |
+
self.page_cache[key] = fetch_page(url, self.cfg, sn)
|
| 380 |
+
if self.page_cache[key]:
|
| 381 |
+
return self.page_cache[key]
|
| 382 |
+
return ""
|
| 383 |
+
|
| 384 |
+
def _summarise(self, prev: str, query: str, doc: str) -> str:
|
| 385 |
+
rid_prompt = self.get_webpage_to_reasonchain_instruction.format(prev_reasoning = prev, search_query = query, document = doc)
|
| 386 |
+
chat = f"<|im_start|>user\\n{rid_prompt}\\n<|im_end|>\\n<|im_start|>assistant\\n"
|
| 387 |
+
resp = self._generate_summary(chat)
|
| 388 |
+
# ##print("summarization out \n", resp)
|
| 389 |
+
return BEGIN_DOCUMENT_QUERY + self._extract_summary(resp) + END_DOCUMENT_QUERY
|
| 390 |
+
# ##print("summary")
|
| 391 |
+
# ##print(resp)
|
| 392 |
+
# match = re.search(r"Final Information\*\*\s*\n(.+?)<\|im_end\|>", resp)
|
| 393 |
+
# if match:
|
| 394 |
+
# final_info = match.group(1).strip()
|
| 395 |
+
# ##print(final_info)
|
| 396 |
+
# return final_info
|
| 397 |
+
|
| 398 |
+
def _extract_summary(self, prompt: str) -> str:
|
| 399 |
+
if "<information>" in prompt:
|
| 400 |
+
summary = prompt.split("<information>")[-1].split("</information>")[0] if THINK_OPEN in prompt else ""
|
| 401 |
+
return summary
|
| 402 |
+
else:
|
| 403 |
+
match = re.search(r"\*\*Final Information\*\*\s*\n(.+?)<\|im_end\|>", prompt)
|
| 404 |
+
if match:
|
| 405 |
+
final_info = match.group(1).strip()
|
| 406 |
+
return final_info
|
| 407 |
+
return prompt
|
| 408 |
+
|
| 409 |
+
def _extract_reasoning(self, prompt: str) -> str:
|
| 410 |
+
return prompt.split(THINK_OPEN)[-1].split(THINK_CLOSE)[0] if THINK_OPEN in prompt else ""
|
| 411 |
+
|
| 412 |
+
# -----------------------------------------------------------------------------
|
| 413 |
+
# CLI -------------------------------------------------------------------------
|
| 414 |
+
# if __name__ == "__main__":
|
| 415 |
+
# # import argparse, json
|
| 416 |
+
# # parser = argparse.ArgumentParser()
|
| 417 |
+
# # parser.add_argument("question"); parser.add_argument("--dataset", required=True, choices=sorted(ALLOWED_DATASETS)); parser.add_argument("--model-url", required
|
inference/zerosearch.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# zero_search_inference.py
|
| 2 |
+
"""End‑to‑end inference loop that emulates the ZeroSearch prompting style.
|
| 3 |
+
|
| 4 |
+
The policy model ("thinker") must:
|
| 5 |
+
• reason inside <think> … </think>
|
| 6 |
+
• place a query inside <search> … </search> whenever it needs external knowledge
|
| 7 |
+
• return the final short answer inside <answer> … </answer>
|
| 8 |
+
|
| 9 |
+
The wrapper intercepts each <search> request, fulfils it with either:
|
| 10 |
+
|
| 11 |
+
(a) a **simulated search engine** (another small LLM fine‑tuned as ZeroSearch
|
| 12 |
+
retriever) ‑‑ default; or
|
| 13 |
+
(b) a real search backend (e.g. Serper.dev, Bing) if `engine="real"`.
|
| 14 |
+
|
| 15 |
+
It then injects results between <information> … </information> and hands control
|
| 16 |
+
back to the policy model. The loop repeats until </answer> is produced or a
|
| 17 |
+
maximum number of retrieval rounds is reached.
|
| 18 |
+
|
| 19 |
+
The goal is to mirror the ergonomics of the user’s existing `ReCall` class so
|
| 20 |
+
that outer orchestration code can drop this in with minimal friction.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
import time
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from typing import List, Optional
|
| 30 |
+
|
| 31 |
+
import requests
|
| 32 |
+
from openai import OpenAI
|
| 33 |
+
|
| 34 |
+
__all__ = ["ZeroSearchInference", "ZeroSearchConfig"]
|
| 35 |
+
|
| 36 |
+
TOKENIZER_DIR = "/home/fractal_admin/shreyas/models/Qwen3-4B"
|
| 37 |
+
|
| 38 |
+
# ───────────────────────── tokenizer ────────────────────────────────────────
|
| 39 |
+
try:
|
| 40 |
+
from transformers import AutoTokenizer
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, trust_remote_code=True)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
import sys
|
| 44 |
+
sys.exit(f"❌ Could not load Qwen3 tokenizer: {e}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Utility: retry decorator ---------------------------------------------------
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def retry(max_attempts: int = 4, sleep: int = 1, fallback=None):
|
| 52 |
+
def decorator(func):
|
| 53 |
+
def wrapper(*args, **kwargs):
|
| 54 |
+
for i in range(max_attempts):
|
| 55 |
+
try:
|
| 56 |
+
return func(*args, **kwargs)
|
| 57 |
+
except Exception as exc:
|
| 58 |
+
#print(f"[retry] {func.__name__}: attempt {i + 1}/{max_attempts} failed – {exc}")
|
| 59 |
+
if i == max_attempts - 1:
|
| 60 |
+
return fallback
|
| 61 |
+
time.sleep(sleep)
|
| 62 |
+
return wrapper
|
| 63 |
+
return decorator
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Configuration dataclass ----------------------------------------------------
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class ZeroSearchConfig:
|
| 72 |
+
# thinker LLM endpoint
|
| 73 |
+
thinker_url: str = "http://0.0.0.0:1214"
|
| 74 |
+
thinker_temperature: float = 0.7
|
| 75 |
+
thinker_max_tokens: int = 40960
|
| 76 |
+
|
| 77 |
+
# retrieval engine mode: "sim" or "real"
|
| 78 |
+
engine: str = "real" # simulated search (LLM) by default
|
| 79 |
+
|
| 80 |
+
# simulated search model (only used if engine == "sim")
|
| 81 |
+
retriever_model: str = "gpt-4o-mini"
|
| 82 |
+
retriever_top_k: int = 5
|
| 83 |
+
|
| 84 |
+
# real search backend (engine == "real")
|
| 85 |
+
serper_api_key: Optional[str] = "7bfe51ead1a1766b656c1355b292d1d29c15c114"
|
| 86 |
+
serper_url: str = "https://google.serper.dev/search"
|
| 87 |
+
serper_top_k: int = 5
|
| 88 |
+
|
| 89 |
+
# Loop limits
|
| 90 |
+
max_rounds: int = 16
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Main wrapper ---------------------------------------------------------------
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
class ZeroSearchInference:
|
| 98 |
+
SEARCH_OPEN = "<search>"
|
| 99 |
+
SEARCH_CLOSE = "</search>"
|
| 100 |
+
INFO_OPEN = "<information>"
|
| 101 |
+
INFO_CLOSE = "</information>"
|
| 102 |
+
|
| 103 |
+
ANSWER_CLOSE = "</answer>"
|
| 104 |
+
THINK_OPEN = "<think>"
|
| 105 |
+
THINK_CLOSE = "</think>"
|
| 106 |
+
|
| 107 |
+
STOP_TOKENS = ["<|im_end|>", "<|endoftext|>", "</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]#, "</think>", "</think>\n", " </think>\n", "</think>\n\n", " </think>\n\n"]
|
| 108 |
+
|
| 109 |
+
def __init__(self, cfg: ZeroSearchConfig):
|
| 110 |
+
self.cfg = cfg
|
| 111 |
+
# ------------------------------------------------------------------
|
| 112 |
+
# Public driver -----------------------------------------------------
|
| 113 |
+
# ------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
def run(self, user_question: str) -> str:
|
| 116 |
+
tool_calls = []
|
| 117 |
+
prompt = self._build_initial_prompt(user_question)
|
| 118 |
+
for round_idx in range(self.cfg.max_rounds):
|
| 119 |
+
generated = self._call_thinker(prompt)
|
| 120 |
+
#print("-"*100)
|
| 121 |
+
#print(f"Round: {round_idx}")
|
| 122 |
+
#print(generated)
|
| 123 |
+
prompt += generated
|
| 124 |
+
|
| 125 |
+
if self.ANSWER_CLOSE in generated:
|
| 126 |
+
#print(f"[ZeroSearch] Done in {round_idx + 1} rounds")
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
query = self._extract_query(generated)
|
| 130 |
+
|
| 131 |
+
if not query:
|
| 132 |
+
#print("[ZeroSearch] Model failed to emit <search>; aborting")
|
| 133 |
+
break
|
| 134 |
+
tool_calls.append(query)
|
| 135 |
+
info_block = self._retrieve_and_format(query)
|
| 136 |
+
#print(f"retrived docs: \n{info_block}")
|
| 137 |
+
#print("-"*100)
|
| 138 |
+
prompt += info_block + self.THINK_OPEN # next turn
|
| 139 |
+
|
| 140 |
+
else: # exceeded rounds
|
| 141 |
+
prompt += "<answer>I don't know.</answer><|im_end|>"
|
| 142 |
+
return prompt, tool_calls
|
| 143 |
+
|
| 144 |
+
# ------------------------------------------------------------------
|
| 145 |
+
# Prompt construction helpers --------------------------------------
|
| 146 |
+
# ------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
def _build_initial_prompt(self, question: str) -> str:
|
| 149 |
+
user_msg = f"""Answer the given question. \
|
| 150 |
+
You must conduct reasoning inside <think> and </think> first every time you get new information. \
|
| 151 |
+
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. \
|
| 152 |
+
You can search as many times as your want. \
|
| 153 |
+
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}\n"""
|
| 154 |
+
return f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{self.THINK_OPEN}"
|
| 155 |
+
|
| 156 |
+
# ------------------------------------------------------------------
|
| 157 |
+
# Thinker model call ------------------------------------------------
|
| 158 |
+
# ------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
@retry(fallback="")
|
| 161 |
+
def _call_thinker(self, prompt: str) -> str:
|
| 162 |
+
prompt_tokens = tokenizer(prompt, return_tensors=None, add_special_tokens=False)["input_ids"]
|
| 163 |
+
max_tokens_left = self.cfg.thinker_max_tokens - len(prompt_tokens) - 100
|
| 164 |
+
resp = requests.post(
|
| 165 |
+
f"{self.cfg.thinker_url}/generate",
|
| 166 |
+
json={
|
| 167 |
+
"text": prompt,
|
| 168 |
+
"sampling_params": {
|
| 169 |
+
"temperature": self.cfg.thinker_temperature,
|
| 170 |
+
"max_new_tokens": max_tokens_left,
|
| 171 |
+
"stop": self.STOP_TOKENS,
|
| 172 |
+
},
|
| 173 |
+
|
| 174 |
+
},
|
| 175 |
+
timeout=60,
|
| 176 |
+
).json()
|
| 177 |
+
generated = resp["text"] # what you have now
|
| 178 |
+
matched = resp["meta_info"]["finish_reason"].get("matched")
|
| 179 |
+
reason = resp["meta_info"]["finish_reason"].get("type")
|
| 180 |
+
# ⇢ append the tag back only if it was removed
|
| 181 |
+
if reason == "stop" and matched in self.STOP_TOKENS:
|
| 182 |
+
if not generated.endswith(matched):
|
| 183 |
+
generated += matched
|
| 184 |
+
if reason == "stop" and matched == 151645:
|
| 185 |
+
if not generated.endswith("<|im_end|>"):
|
| 186 |
+
generated += "<|im_end|>"
|
| 187 |
+
return generated
|
| 188 |
+
|
| 189 |
+
# ------------------------------------------------------------------
|
| 190 |
+
# Query extraction --------------------------------------------------
|
| 191 |
+
# ------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
def _extract_query(self, gen_text: str) -> Optional[str]:
|
| 194 |
+
if self.SEARCH_OPEN not in gen_text or self.SEARCH_CLOSE not in gen_text:
|
| 195 |
+
return None
|
| 196 |
+
query = gen_text.split(self.SEARCH_OPEN)[-1].split(self.SEARCH_CLOSE)[0].strip()
|
| 197 |
+
return query or None
|
| 198 |
+
|
| 199 |
+
# ------------------------------------------------------------------
|
| 200 |
+
# Retrieval path ----------------------------------------------------
|
| 201 |
+
# ------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
def _retrieve_and_format(self, query: str) -> str:
|
| 204 |
+
if self.cfg.engine == "real":
|
| 205 |
+
docs = self._real_search(query)
|
| 206 |
+
#print("DOCS")
|
| 207 |
+
#print(docs)
|
| 208 |
+
else:
|
| 209 |
+
docs = self._simulated_search(query)
|
| 210 |
+
return f"{self.INFO_OPEN}\n{docs}\n{self.INFO_CLOSE}\n\n"
|
| 211 |
+
|
| 212 |
+
# --- simulated search with LLM ------------------------------------
|
| 213 |
+
|
| 214 |
+
@retry(fallback="No information available")
|
| 215 |
+
def _simulated_search(self, query: str) -> str:
|
| 216 |
+
messages = [
|
| 217 |
+
{
|
| 218 |
+
"role": "user",
|
| 219 |
+
"content": (
|
| 220 |
+
"You are a search engine. Return up to "
|
| 221 |
+
f"{self.cfg.retriever_top_k} short documents (titles + snippets) "
|
| 222 |
+
"most relevant to the query, each on a new line.\n\n"
|
| 223 |
+
f"Query: {query}"
|
| 224 |
+
),
|
| 225 |
+
}
|
| 226 |
+
]
|
| 227 |
+
resp = self.openai.chat.completions.create(
|
| 228 |
+
model=self.cfg.retriever_model,
|
| 229 |
+
messages=messages,
|
| 230 |
+
max_tokens=256,
|
| 231 |
+
)
|
| 232 |
+
return resp.choices[0].message.content.strip()
|
| 233 |
+
|
| 234 |
+
# --- real web search via Serper ----------------------------------
|
| 235 |
+
|
| 236 |
+
@retry(fallback="No information available")
|
| 237 |
+
def _real_search(self, query: str) -> str:
|
| 238 |
+
if not self.cfg.serper_api_key:
|
| 239 |
+
raise ValueError("serper_api_key must be set for real search mode")
|
| 240 |
+
headers = {"X-API-KEY": self.cfg.serper_api_key, "Content-Type": "application/json"}
|
| 241 |
+
payload = {"q": query, "num": self.cfg.serper_top_k}
|
| 242 |
+
resp = requests.post(self.cfg.serper_url, json=payload, headers=headers, timeout=20)
|
| 243 |
+
resp.raise_for_status()
|
| 244 |
+
data = resp.json().get("organic", [])[: self.cfg.serper_top_k]
|
| 245 |
+
lines = []
|
| 246 |
+
for i, item in enumerate(data, 1):
|
| 247 |
+
snippet = f"Title: {item['title']}, \nSnippet{item['snippet']}"
|
| 248 |
+
lines.append(f"Doc {i}: {snippet}")
|
| 249 |
+
return "\n".join(lines) or "No information available"
|