liuhua
liuhua
commited on
Commit
·
678763e
1
Parent(s):
5000eb5
Fix: renrank_model and pdf_parser bugs | Update: session API (#2601)
Browse files### What problem does this PR solve?
Fix: renrank_model and pdf_parser bugs | Update: session API
#2575
#2559
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
---------
Co-authored-by: liuhua <[email protected]>
api/apps/sdk/session.py
CHANGED
@@ -87,9 +87,9 @@ def completion(tenant_id):
|
|
87 |
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
88 |
# {"role": "user", "content": "上海有吗?"}
|
89 |
# ]}
|
90 |
-
if "
|
91 |
-
return get_data_error_result(retmsg="
|
92 |
-
conv = ConversationService.query(id=req["
|
93 |
if not conv:
|
94 |
return get_data_error_result(retmsg="Session does not exist")
|
95 |
conv = conv[0]
|
@@ -108,7 +108,7 @@ def completion(tenant_id):
|
|
108 |
msg.append(m)
|
109 |
message_id = msg[-1].get("id")
|
110 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
111 |
-
del req["
|
112 |
|
113 |
if not conv.reference:
|
114 |
conv.reference = []
|
@@ -168,6 +168,9 @@ def get(tenant_id):
|
|
168 |
return get_data_error_result(retmsg="Session does not exist")
|
169 |
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
170 |
return get_data_error_result(retmsg="You do not own the session")
|
|
|
|
|
|
|
171 |
conv = conv[0].to_dict()
|
172 |
conv['messages'] = conv.pop("message")
|
173 |
conv["assistant_id"] = conv.pop("dialog_id")
|
@@ -207,7 +210,7 @@ def list(tenant_id):
|
|
207 |
assistant_id = request.args["assistant_id"]
|
208 |
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
|
209 |
return get_json_result(
|
210 |
-
data=False, retmsg=f'
|
211 |
retcode=RetCode.OPERATING_ERROR)
|
212 |
convs = ConversationService.query(
|
213 |
dialog_id=assistant_id,
|
|
|
87 |
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
88 |
# {"role": "user", "content": "上海有吗?"}
|
89 |
# ]}
|
90 |
+
if "session_id" not in req:
|
91 |
+
return get_data_error_result(retmsg="session_id is required")
|
92 |
+
conv = ConversationService.query(id=req["session_id"])
|
93 |
if not conv:
|
94 |
return get_data_error_result(retmsg="Session does not exist")
|
95 |
conv = conv[0]
|
|
|
108 |
msg.append(m)
|
109 |
message_id = msg[-1].get("id")
|
110 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
111 |
+
del req["session_id"]
|
112 |
|
113 |
if not conv.reference:
|
114 |
conv.reference = []
|
|
|
168 |
return get_data_error_result(retmsg="Session does not exist")
|
169 |
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
170 |
return get_data_error_result(retmsg="You do not own the session")
|
171 |
+
if "assistant_id" in req:
|
172 |
+
if req["assistant_id"] != conv[0].dialog_id:
|
173 |
+
return get_data_error_result(retmsg="The session doesn't belong to the assistant")
|
174 |
conv = conv[0].to_dict()
|
175 |
conv['messages'] = conv.pop("message")
|
176 |
conv["assistant_id"] = conv.pop("dialog_id")
|
|
|
210 |
assistant_id = request.args["assistant_id"]
|
211 |
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
|
212 |
return get_json_result(
|
213 |
+
data=False, retmsg=f"You don't own the assistant.",
|
214 |
retcode=RetCode.OPERATING_ERROR)
|
215 |
convs = ConversationService.query(
|
216 |
dialog_id=assistant_id,
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -488,7 +488,7 @@ class RAGFlowPdfParser:
|
|
488 |
i += 1
|
489 |
continue
|
490 |
|
491 |
-
if not down["text"].strip():
|
492 |
i += 1
|
493 |
continue
|
494 |
|
|
|
488 |
i += 1
|
489 |
continue
|
490 |
|
491 |
+
if not down["text"].strip() or not up["text"].strip():
|
492 |
i += 1
|
493 |
continue
|
494 |
|
rag/llm/rerank_model.py
CHANGED
@@ -26,9 +26,11 @@ from api.utils.file_utils import get_home_cache_dir
|
|
26 |
from rag.utils import num_tokens_from_string, truncate
|
27 |
import json
|
28 |
|
|
|
29 |
def sigmoid(x):
|
30 |
return 1 / (1 + np.exp(-x))
|
31 |
|
|
|
32 |
class Base(ABC):
|
33 |
def __init__(self, key, model_name):
|
34 |
pass
|
@@ -59,16 +61,19 @@ class DefaultRerank(Base):
|
|
59 |
with DefaultRerank._model_lock:
|
60 |
if not DefaultRerank._model:
|
61 |
try:
|
62 |
-
DefaultRerank._model = FlagReranker(
|
|
|
|
|
63 |
except Exception as e:
|
64 |
-
model_dir = snapshot_download(repo_id=
|
65 |
-
local_dir=os.path.join(get_home_cache_dir(),
|
|
|
66 |
local_dir_use_symlinks=False)
|
67 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
68 |
self._model = DefaultRerank._model
|
69 |
|
70 |
def similarity(self, query: str, texts: list):
|
71 |
-
pairs = [(query,truncate(t, 2048)) for t in texts]
|
72 |
token_count = 0
|
73 |
for _, t in pairs:
|
74 |
token_count += num_tokens_from_string(t)
|
@@ -77,8 +82,10 @@ class DefaultRerank(Base):
|
|
77 |
for i in range(0, len(pairs), batch_size):
|
78 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
79 |
scores = sigmoid(np.array(scores)).tolist()
|
80 |
-
if isinstance(scores, float):
|
81 |
-
|
|
|
|
|
82 |
return np.array(res), token_count
|
83 |
|
84 |
|
@@ -101,7 +108,10 @@ class JinaRerank(Base):
|
|
101 |
"top_n": len(texts)
|
102 |
}
|
103 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
104 |
-
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
class YoudaoRerank(DefaultRerank):
|
@@ -124,7 +134,7 @@ class YoudaoRerank(DefaultRerank):
|
|
124 |
"maidalun1020", "InfiniFlow"))
|
125 |
|
126 |
self._model = YoudaoRerank._model
|
127 |
-
|
128 |
def similarity(self, query: str, texts: list):
|
129 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
130 |
token_count = 0
|
@@ -135,8 +145,10 @@ class YoudaoRerank(DefaultRerank):
|
|
135 |
for i in range(0, len(pairs), batch_size):
|
136 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
|
137 |
scores = sigmoid(np.array(scores)).tolist()
|
138 |
-
if isinstance(scores, float):
|
139 |
-
|
|
|
|
|
140 |
return np.array(res), token_count
|
141 |
|
142 |
|
@@ -162,7 +174,10 @@ class XInferenceRerank(Base):
|
|
162 |
"documents": texts
|
163 |
}
|
164 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
165 |
-
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
class LocalAIRerank(Base):
|
@@ -175,7 +190,7 @@ class LocalAIRerank(Base):
|
|
175 |
|
176 |
class NvidiaRerank(Base):
|
177 |
def __init__(
|
178 |
-
|
179 |
):
|
180 |
if not base_url:
|
181 |
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
@@ -208,9 +223,10 @@ class NvidiaRerank(Base):
|
|
208 |
"top_n": len(texts),
|
209 |
}
|
210 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
211 |
-
rank = np.
|
212 |
-
|
213 |
-
|
|
|
214 |
|
215 |
|
216 |
class LmStudioRerank(Base):
|
@@ -247,9 +263,10 @@ class CoHereRerank(Base):
|
|
247 |
top_n=len(texts),
|
248 |
return_documents=False,
|
249 |
)
|
250 |
-
rank = np.
|
251 |
-
|
252 |
-
|
|
|
253 |
|
254 |
|
255 |
class TogetherAIRerank(Base):
|
@@ -262,7 +279,7 @@ class TogetherAIRerank(Base):
|
|
262 |
|
263 |
class SILICONFLOWRerank(Base):
|
264 |
def __init__(
|
265 |
-
|
266 |
):
|
267 |
if not base_url:
|
268 |
base_url = "https://api.siliconflow.cn/v1/rerank"
|
@@ -287,10 +304,11 @@ class SILICONFLOWRerank(Base):
|
|
287 |
response = requests.post(
|
288 |
self.base_url, json=payload, headers=self.headers
|
289 |
).json()
|
290 |
-
rank = np.
|
291 |
-
|
|
|
292 |
return (
|
293 |
-
rank
|
294 |
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
295 |
)
|
296 |
|
@@ -312,9 +330,10 @@ class BaiduYiyanRerank(Base):
|
|
312 |
documents=texts,
|
313 |
top_n=len(texts),
|
314 |
).body
|
315 |
-
rank = np.
|
316 |
-
|
317 |
-
|
|
|
318 |
|
319 |
|
320 |
class VoyageRerank(Base):
|
@@ -328,6 +347,7 @@ class VoyageRerank(Base):
|
|
328 |
res = self.client.rerank(
|
329 |
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
330 |
)
|
331 |
-
rank = np.
|
332 |
-
|
333 |
-
|
|
|
|
26 |
from rag.utils import num_tokens_from_string, truncate
|
27 |
import json
|
28 |
|
29 |
+
|
30 |
def sigmoid(x):
|
31 |
return 1 / (1 + np.exp(-x))
|
32 |
|
33 |
+
|
34 |
class Base(ABC):
|
35 |
def __init__(self, key, model_name):
|
36 |
pass
|
|
|
61 |
with DefaultRerank._model_lock:
|
62 |
if not DefaultRerank._model:
|
63 |
try:
|
64 |
+
DefaultRerank._model = FlagReranker(
|
65 |
+
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
66 |
+
use_fp16=torch.cuda.is_available())
|
67 |
except Exception as e:
|
68 |
+
model_dir = snapshot_download(repo_id=model_name,
|
69 |
+
local_dir=os.path.join(get_home_cache_dir(),
|
70 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
71 |
local_dir_use_symlinks=False)
|
72 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
73 |
self._model = DefaultRerank._model
|
74 |
|
75 |
def similarity(self, query: str, texts: list):
|
76 |
+
pairs = [(query, truncate(t, 2048)) for t in texts]
|
77 |
token_count = 0
|
78 |
for _, t in pairs:
|
79 |
token_count += num_tokens_from_string(t)
|
|
|
82 |
for i in range(0, len(pairs), batch_size):
|
83 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
84 |
scores = sigmoid(np.array(scores)).tolist()
|
85 |
+
if isinstance(scores, float):
|
86 |
+
res.append(scores)
|
87 |
+
else:
|
88 |
+
res.extend(scores)
|
89 |
return np.array(res), token_count
|
90 |
|
91 |
|
|
|
108 |
"top_n": len(texts)
|
109 |
}
|
110 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
111 |
+
rank = np.zeros(len(texts), dtype=float)
|
112 |
+
for d in res["results"]:
|
113 |
+
rank[d["index"]] = d["relevance_score"]
|
114 |
+
return rank, res["usage"]["total_tokens"]
|
115 |
|
116 |
|
117 |
class YoudaoRerank(DefaultRerank):
|
|
|
134 |
"maidalun1020", "InfiniFlow"))
|
135 |
|
136 |
self._model = YoudaoRerank._model
|
137 |
+
|
138 |
def similarity(self, query: str, texts: list):
|
139 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
140 |
token_count = 0
|
|
|
145 |
for i in range(0, len(pairs), batch_size):
|
146 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
|
147 |
scores = sigmoid(np.array(scores)).tolist()
|
148 |
+
if isinstance(scores, float):
|
149 |
+
res.append(scores)
|
150 |
+
else:
|
151 |
+
res.extend(scores)
|
152 |
return np.array(res), token_count
|
153 |
|
154 |
|
|
|
174 |
"documents": texts
|
175 |
}
|
176 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
177 |
+
rank = np.zeros(len(texts), dtype=float)
|
178 |
+
for d in res["results"]:
|
179 |
+
rank[d["index"]] = d["relevance_score"]
|
180 |
+
return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
|
181 |
|
182 |
|
183 |
class LocalAIRerank(Base):
|
|
|
190 |
|
191 |
class NvidiaRerank(Base):
|
192 |
def __init__(
|
193 |
+
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
194 |
):
|
195 |
if not base_url:
|
196 |
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
|
223 |
"top_n": len(texts),
|
224 |
}
|
225 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
226 |
+
rank = np.zeros(len(texts), dtype=float)
|
227 |
+
for d in res["rankings"]:
|
228 |
+
rank[d["index"]] = d["logit"]
|
229 |
+
return rank, token_count
|
230 |
|
231 |
|
232 |
class LmStudioRerank(Base):
|
|
|
263 |
top_n=len(texts),
|
264 |
return_documents=False,
|
265 |
)
|
266 |
+
rank = np.zeros(len(texts), dtype=float)
|
267 |
+
for d in res.results:
|
268 |
+
rank[d.index] = d.relevance_score
|
269 |
+
return rank, token_count
|
270 |
|
271 |
|
272 |
class TogetherAIRerank(Base):
|
|
|
279 |
|
280 |
class SILICONFLOWRerank(Base):
|
281 |
def __init__(
|
282 |
+
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
|
283 |
):
|
284 |
if not base_url:
|
285 |
base_url = "https://api.siliconflow.cn/v1/rerank"
|
|
|
304 |
response = requests.post(
|
305 |
self.base_url, json=payload, headers=self.headers
|
306 |
).json()
|
307 |
+
rank = np.zeros(len(texts), dtype=float)
|
308 |
+
for d in response["results"]:
|
309 |
+
rank[d["index"]] = d["relevance_score"]
|
310 |
return (
|
311 |
+
rank,
|
312 |
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
313 |
)
|
314 |
|
|
|
330 |
documents=texts,
|
331 |
top_n=len(texts),
|
332 |
).body
|
333 |
+
rank = np.zeros(len(texts), dtype=float)
|
334 |
+
for d in res["results"]:
|
335 |
+
rank[d["index"]] = d["relevance_score"]
|
336 |
+
return rank, res["usage"]["total_tokens"]
|
337 |
|
338 |
|
339 |
class VoyageRerank(Base):
|
|
|
347 |
res = self.client.rerank(
|
348 |
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
349 |
)
|
350 |
+
rank = np.zeros(len(texts), dtype=float)
|
351 |
+
for r in res.results:
|
352 |
+
rank[r.index] = r.relevance_score
|
353 |
+
return rank, res.total_tokens
|
sdk/python/ragflow/modules/assistant.py
CHANGED
@@ -76,7 +76,7 @@ class Assistant(Base):
|
|
76 |
raise Exception(res["retmsg"])
|
77 |
|
78 |
def get_session(self, id) -> Session:
|
79 |
-
res = self.get("/session/get", {"id": id})
|
80 |
res = res.json()
|
81 |
if res.get("retmsg") == "success":
|
82 |
return Session(self.rag, res["data"])
|
|
|
76 |
raise Exception(res["retmsg"])
|
77 |
|
78 |
def get_session(self, id) -> Session:
|
79 |
+
res = self.get("/session/get", {"id": id,"assistant_id":self.id})
|
80 |
res = res.json()
|
81 |
if res.get("retmsg") == "success":
|
82 |
return Session(self.rag, res["data"])
|
sdk/python/ragflow/modules/session.py
CHANGED
@@ -16,9 +16,12 @@ class Session(Base):
|
|
16 |
if "reference" in message:
|
17 |
message.pop("reference")
|
18 |
res = self.post("/session/completion",
|
19 |
-
{"
|
20 |
for line in res.iter_lines():
|
21 |
line = line.decode("utf-8")
|
|
|
|
|
|
|
22 |
if line.startswith("data:"):
|
23 |
json_data = json.loads(line[5:])
|
24 |
if json_data["data"] != True:
|
@@ -69,6 +72,7 @@ class Message(Base):
|
|
69 |
self.reference = None
|
70 |
self.role = "assistant"
|
71 |
self.prompt = None
|
|
|
72 |
super().__init__(rag, res_dict)
|
73 |
|
74 |
|
@@ -76,10 +80,10 @@ class Chunk(Base):
|
|
76 |
def __init__(self, rag, res_dict):
|
77 |
self.id = None
|
78 |
self.content = None
|
79 |
-
self.document_id =
|
80 |
-
self.document_name =
|
81 |
-
self.knowledgebase_id =
|
82 |
-
self.image_id =
|
83 |
self.similarity = None
|
84 |
self.vector_similarity = None
|
85 |
self.term_similarity = None
|
|
|
16 |
if "reference" in message:
|
17 |
message.pop("reference")
|
18 |
res = self.post("/session/completion",
|
19 |
+
{"session_id": self.id, "question": question, "stream": True}, stream=stream)
|
20 |
for line in res.iter_lines():
|
21 |
line = line.decode("utf-8")
|
22 |
+
if line.startswith("{"):
|
23 |
+
json_data = json.loads(line)
|
24 |
+
raise Exception(json_data["retmsg"])
|
25 |
if line.startswith("data:"):
|
26 |
json_data = json.loads(line[5:])
|
27 |
if json_data["data"] != True:
|
|
|
72 |
self.reference = None
|
73 |
self.role = "assistant"
|
74 |
self.prompt = None
|
75 |
+
self.id = None
|
76 |
super().__init__(rag, res_dict)
|
77 |
|
78 |
|
|
|
80 |
def __init__(self, rag, res_dict):
|
81 |
self.id = None
|
82 |
self.content = None
|
83 |
+
self.document_id = ""
|
84 |
+
self.document_name = ""
|
85 |
+
self.knowledgebase_id = ""
|
86 |
+
self.image_id = ""
|
87 |
self.similarity = None
|
88 |
self.vector_similarity = None
|
89 |
self.term_similarity = None
|
sdk/python/test/t_session.py
CHANGED
@@ -19,7 +19,7 @@ class TestSession:
|
|
19 |
question = "What is AI"
|
20 |
for ans in session.chat(question, stream=True):
|
21 |
pass
|
22 |
-
assert ans.content
|
23 |
|
24 |
def test_delete_session_with_success(self):
|
25 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
|
|
19 |
question = "What is AI"
|
20 |
for ans in session.chat(question, stream=True):
|
21 |
pass
|
22 |
+
assert not ans.content.startswith("**ERROR**"), "Please check this error."
|
23 |
|
24 |
def test_delete_session_with_success(self):
|
25 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|