liuhua liuhua commited on
Commit
678763e
·
1 Parent(s): 5000eb5

Fix: renrank_model and pdf_parser bugs | Update: session API (#2601)

Browse files

### What problem does this PR solve?

Fix: renrank_model and pdf_parser bugs | Update: session API
#2575
#2559
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring

---------

Co-authored-by: liuhua <[email protected]>

api/apps/sdk/session.py CHANGED
@@ -87,9 +87,9 @@ def completion(tenant_id):
87
  # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
88
  # {"role": "user", "content": "上海有吗?"}
89
  # ]}
90
- if "id" not in req:
91
- return get_data_error_result(retmsg="id is required")
92
- conv = ConversationService.query(id=req["id"])
93
  if not conv:
94
  return get_data_error_result(retmsg="Session does not exist")
95
  conv = conv[0]
@@ -108,7 +108,7 @@ def completion(tenant_id):
108
  msg.append(m)
109
  message_id = msg[-1].get("id")
110
  e, dia = DialogService.get_by_id(conv.dialog_id)
111
- del req["id"]
112
 
113
  if not conv.reference:
114
  conv.reference = []
@@ -168,6 +168,9 @@ def get(tenant_id):
168
  return get_data_error_result(retmsg="Session does not exist")
169
  if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
170
  return get_data_error_result(retmsg="You do not own the session")
 
 
 
171
  conv = conv[0].to_dict()
172
  conv['messages'] = conv.pop("message")
173
  conv["assistant_id"] = conv.pop("dialog_id")
@@ -207,7 +210,7 @@ def list(tenant_id):
207
  assistant_id = request.args["assistant_id"]
208
  if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
209
  return get_json_result(
210
- data=False, retmsg=f'Only owner of the assistant is authorized for this operation.',
211
  retcode=RetCode.OPERATING_ERROR)
212
  convs = ConversationService.query(
213
  dialog_id=assistant_id,
 
87
  # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
88
  # {"role": "user", "content": "上海有吗?"}
89
  # ]}
90
+ if "session_id" not in req:
91
+ return get_data_error_result(retmsg="session_id is required")
92
+ conv = ConversationService.query(id=req["session_id"])
93
  if not conv:
94
  return get_data_error_result(retmsg="Session does not exist")
95
  conv = conv[0]
 
108
  msg.append(m)
109
  message_id = msg[-1].get("id")
110
  e, dia = DialogService.get_by_id(conv.dialog_id)
111
+ del req["session_id"]
112
 
113
  if not conv.reference:
114
  conv.reference = []
 
168
  return get_data_error_result(retmsg="Session does not exist")
169
  if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
170
  return get_data_error_result(retmsg="You do not own the session")
171
+ if "assistant_id" in req:
172
+ if req["assistant_id"] != conv[0].dialog_id:
173
+ return get_data_error_result(retmsg="The session doesn't belong to the assistant")
174
  conv = conv[0].to_dict()
175
  conv['messages'] = conv.pop("message")
176
  conv["assistant_id"] = conv.pop("dialog_id")
 
210
  assistant_id = request.args["assistant_id"]
211
  if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
212
  return get_json_result(
213
+ data=False, retmsg=f"You don't own the assistant.",
214
  retcode=RetCode.OPERATING_ERROR)
215
  convs = ConversationService.query(
216
  dialog_id=assistant_id,
deepdoc/parser/pdf_parser.py CHANGED
@@ -488,7 +488,7 @@ class RAGFlowPdfParser:
488
  i += 1
489
  continue
490
 
491
- if not down["text"].strip():
492
  i += 1
493
  continue
494
 
 
488
  i += 1
489
  continue
490
 
491
+ if not down["text"].strip() or not up["text"].strip():
492
  i += 1
493
  continue
494
 
rag/llm/rerank_model.py CHANGED
@@ -26,9 +26,11 @@ from api.utils.file_utils import get_home_cache_dir
26
  from rag.utils import num_tokens_from_string, truncate
27
  import json
28
 
 
29
  def sigmoid(x):
30
  return 1 / (1 + np.exp(-x))
31
 
 
32
  class Base(ABC):
33
  def __init__(self, key, model_name):
34
  pass
@@ -59,16 +61,19 @@ class DefaultRerank(Base):
59
  with DefaultRerank._model_lock:
60
  if not DefaultRerank._model:
61
  try:
62
- DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
 
 
63
  except Exception as e:
64
- model_dir = snapshot_download(repo_id= model_name,
65
- local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
 
66
  local_dir_use_symlinks=False)
67
  DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
68
  self._model = DefaultRerank._model
69
 
70
  def similarity(self, query: str, texts: list):
71
- pairs = [(query,truncate(t, 2048)) for t in texts]
72
  token_count = 0
73
  for _, t in pairs:
74
  token_count += num_tokens_from_string(t)
@@ -77,8 +82,10 @@ class DefaultRerank(Base):
77
  for i in range(0, len(pairs), batch_size):
78
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
79
  scores = sigmoid(np.array(scores)).tolist()
80
- if isinstance(scores, float): res.append(scores)
81
- else: res.extend(scores)
 
 
82
  return np.array(res), token_count
83
 
84
 
@@ -101,7 +108,10 @@ class JinaRerank(Base):
101
  "top_n": len(texts)
102
  }
103
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
104
- return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
 
 
 
105
 
106
 
107
  class YoudaoRerank(DefaultRerank):
@@ -124,7 +134,7 @@ class YoudaoRerank(DefaultRerank):
124
  "maidalun1020", "InfiniFlow"))
125
 
126
  self._model = YoudaoRerank._model
127
-
128
  def similarity(self, query: str, texts: list):
129
  pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
130
  token_count = 0
@@ -135,8 +145,10 @@ class YoudaoRerank(DefaultRerank):
135
  for i in range(0, len(pairs), batch_size):
136
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
137
  scores = sigmoid(np.array(scores)).tolist()
138
- if isinstance(scores, float): res.append(scores)
139
- else: res.extend(scores)
 
 
140
  return np.array(res), token_count
141
 
142
 
@@ -162,7 +174,10 @@ class XInferenceRerank(Base):
162
  "documents": texts
163
  }
164
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
165
- return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
 
 
 
166
 
167
 
168
  class LocalAIRerank(Base):
@@ -175,7 +190,7 @@ class LocalAIRerank(Base):
175
 
176
  class NvidiaRerank(Base):
177
  def __init__(
178
- self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
179
  ):
180
  if not base_url:
181
  base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
@@ -208,9 +223,10 @@ class NvidiaRerank(Base):
208
  "top_n": len(texts),
209
  }
210
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
211
- rank = np.array([d["logit"] for d in res["rankings"]])
212
- indexs = [d["index"] for d in res["rankings"]]
213
- return rank[indexs], token_count
 
214
 
215
 
216
  class LmStudioRerank(Base):
@@ -247,9 +263,10 @@ class CoHereRerank(Base):
247
  top_n=len(texts),
248
  return_documents=False,
249
  )
250
- rank = np.array([d.relevance_score for d in res.results])
251
- indexs = [d.index for d in res.results]
252
- return rank[indexs], token_count
 
253
 
254
 
255
  class TogetherAIRerank(Base):
@@ -262,7 +279,7 @@ class TogetherAIRerank(Base):
262
 
263
  class SILICONFLOWRerank(Base):
264
  def __init__(
265
- self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
266
  ):
267
  if not base_url:
268
  base_url = "https://api.siliconflow.cn/v1/rerank"
@@ -287,10 +304,11 @@ class SILICONFLOWRerank(Base):
287
  response = requests.post(
288
  self.base_url, json=payload, headers=self.headers
289
  ).json()
290
- rank = np.array([d["relevance_score"] for d in response["results"]])
291
- indexs = [d["index"] for d in response["results"]]
 
292
  return (
293
- rank[indexs],
294
  response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
295
  )
296
 
@@ -312,9 +330,10 @@ class BaiduYiyanRerank(Base):
312
  documents=texts,
313
  top_n=len(texts),
314
  ).body
315
- rank = np.array([d["relevance_score"] for d in res["results"]])
316
- indexs = [d["index"] for d in res["results"]]
317
- return rank[indexs], res["usage"]["total_tokens"]
 
318
 
319
 
320
  class VoyageRerank(Base):
@@ -328,6 +347,7 @@ class VoyageRerank(Base):
328
  res = self.client.rerank(
329
  query=query, documents=texts, model=self.model_name, top_k=len(texts)
330
  )
331
- rank = np.array([r.relevance_score for r in res.results])
332
- indexs = [r.index for r in res.results]
333
- return rank[indexs], res.total_tokens
 
 
26
  from rag.utils import num_tokens_from_string, truncate
27
  import json
28
 
29
+
30
  def sigmoid(x):
31
  return 1 / (1 + np.exp(-x))
32
 
33
+
34
  class Base(ABC):
35
  def __init__(self, key, model_name):
36
  pass
 
61
  with DefaultRerank._model_lock:
62
  if not DefaultRerank._model:
63
  try:
64
+ DefaultRerank._model = FlagReranker(
65
+ os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
66
+ use_fp16=torch.cuda.is_available())
67
  except Exception as e:
68
+ model_dir = snapshot_download(repo_id=model_name,
69
+ local_dir=os.path.join(get_home_cache_dir(),
70
+ re.sub(r"^[a-zA-Z]+/", "", model_name)),
71
  local_dir_use_symlinks=False)
72
  DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
73
  self._model = DefaultRerank._model
74
 
75
  def similarity(self, query: str, texts: list):
76
+ pairs = [(query, truncate(t, 2048)) for t in texts]
77
  token_count = 0
78
  for _, t in pairs:
79
  token_count += num_tokens_from_string(t)
 
82
  for i in range(0, len(pairs), batch_size):
83
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
84
  scores = sigmoid(np.array(scores)).tolist()
85
+ if isinstance(scores, float):
86
+ res.append(scores)
87
+ else:
88
+ res.extend(scores)
89
  return np.array(res), token_count
90
 
91
 
 
108
  "top_n": len(texts)
109
  }
110
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
111
+ rank = np.zeros(len(texts), dtype=float)
112
+ for d in res["results"]:
113
+ rank[d["index"]] = d["relevance_score"]
114
+ return rank, res["usage"]["total_tokens"]
115
 
116
 
117
  class YoudaoRerank(DefaultRerank):
 
134
  "maidalun1020", "InfiniFlow"))
135
 
136
  self._model = YoudaoRerank._model
137
+
138
  def similarity(self, query: str, texts: list):
139
  pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
140
  token_count = 0
 
145
  for i in range(0, len(pairs), batch_size):
146
  scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
147
  scores = sigmoid(np.array(scores)).tolist()
148
+ if isinstance(scores, float):
149
+ res.append(scores)
150
+ else:
151
+ res.extend(scores)
152
  return np.array(res), token_count
153
 
154
 
 
174
  "documents": texts
175
  }
176
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
177
+ rank = np.zeros(len(texts), dtype=float)
178
+ for d in res["results"]:
179
+ rank[d["index"]] = d["relevance_score"]
180
+ return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
181
 
182
 
183
  class LocalAIRerank(Base):
 
190
 
191
  class NvidiaRerank(Base):
192
  def __init__(
193
+ self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
194
  ):
195
  if not base_url:
196
  base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
 
223
  "top_n": len(texts),
224
  }
225
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
226
+ rank = np.zeros(len(texts), dtype=float)
227
+ for d in res["rankings"]:
228
+ rank[d["index"]] = d["logit"]
229
+ return rank, token_count
230
 
231
 
232
  class LmStudioRerank(Base):
 
263
  top_n=len(texts),
264
  return_documents=False,
265
  )
266
+ rank = np.zeros(len(texts), dtype=float)
267
+ for d in res.results:
268
+ rank[d.index] = d.relevance_score
269
+ return rank, token_count
270
 
271
 
272
  class TogetherAIRerank(Base):
 
279
 
280
  class SILICONFLOWRerank(Base):
281
  def __init__(
282
+ self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
283
  ):
284
  if not base_url:
285
  base_url = "https://api.siliconflow.cn/v1/rerank"
 
304
  response = requests.post(
305
  self.base_url, json=payload, headers=self.headers
306
  ).json()
307
+ rank = np.zeros(len(texts), dtype=float)
308
+ for d in response["results"]:
309
+ rank[d["index"]] = d["relevance_score"]
310
  return (
311
+ rank,
312
  response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
313
  )
314
 
 
330
  documents=texts,
331
  top_n=len(texts),
332
  ).body
333
+ rank = np.zeros(len(texts), dtype=float)
334
+ for d in res["results"]:
335
+ rank[d["index"]] = d["relevance_score"]
336
+ return rank, res["usage"]["total_tokens"]
337
 
338
 
339
  class VoyageRerank(Base):
 
347
  res = self.client.rerank(
348
  query=query, documents=texts, model=self.model_name, top_k=len(texts)
349
  )
350
+ rank = np.zeros(len(texts), dtype=float)
351
+ for r in res.results:
352
+ rank[r.index] = r.relevance_score
353
+ return rank, res.total_tokens
sdk/python/ragflow/modules/assistant.py CHANGED
@@ -76,7 +76,7 @@ class Assistant(Base):
76
  raise Exception(res["retmsg"])
77
 
78
  def get_session(self, id) -> Session:
79
- res = self.get("/session/get", {"id": id})
80
  res = res.json()
81
  if res.get("retmsg") == "success":
82
  return Session(self.rag, res["data"])
 
76
  raise Exception(res["retmsg"])
77
 
78
  def get_session(self, id) -> Session:
79
+ res = self.get("/session/get", {"id": id,"assistant_id":self.id})
80
  res = res.json()
81
  if res.get("retmsg") == "success":
82
  return Session(self.rag, res["data"])
sdk/python/ragflow/modules/session.py CHANGED
@@ -16,9 +16,12 @@ class Session(Base):
16
  if "reference" in message:
17
  message.pop("reference")
18
  res = self.post("/session/completion",
19
- {"id": self.id, "question": question, "stream": stream}, stream=True)
20
  for line in res.iter_lines():
21
  line = line.decode("utf-8")
 
 
 
22
  if line.startswith("data:"):
23
  json_data = json.loads(line[5:])
24
  if json_data["data"] != True:
@@ -69,6 +72,7 @@ class Message(Base):
69
  self.reference = None
70
  self.role = "assistant"
71
  self.prompt = None
 
72
  super().__init__(rag, res_dict)
73
 
74
 
@@ -76,10 +80,10 @@ class Chunk(Base):
76
  def __init__(self, rag, res_dict):
77
  self.id = None
78
  self.content = None
79
- self.document_id = None
80
- self.document_name = None
81
- self.knowledgebase_id = None
82
- self.image_id = None
83
  self.similarity = None
84
  self.vector_similarity = None
85
  self.term_similarity = None
 
16
  if "reference" in message:
17
  message.pop("reference")
18
  res = self.post("/session/completion",
19
+ {"session_id": self.id, "question": question, "stream": True}, stream=stream)
20
  for line in res.iter_lines():
21
  line = line.decode("utf-8")
22
+ if line.startswith("{"):
23
+ json_data = json.loads(line)
24
+ raise Exception(json_data["retmsg"])
25
  if line.startswith("data:"):
26
  json_data = json.loads(line[5:])
27
  if json_data["data"] != True:
 
72
  self.reference = None
73
  self.role = "assistant"
74
  self.prompt = None
75
+ self.id = None
76
  super().__init__(rag, res_dict)
77
 
78
 
 
80
  def __init__(self, rag, res_dict):
81
  self.id = None
82
  self.content = None
83
+ self.document_id = ""
84
+ self.document_name = ""
85
+ self.knowledgebase_id = ""
86
+ self.image_id = ""
87
  self.similarity = None
88
  self.vector_similarity = None
89
  self.term_similarity = None
sdk/python/test/t_session.py CHANGED
@@ -19,7 +19,7 @@ class TestSession:
19
  question = "What is AI"
20
  for ans in session.chat(question, stream=True):
21
  pass
22
- assert ans.content!="\n**ERROR**", "Please check this error."
23
 
24
  def test_delete_session_with_success(self):
25
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
 
19
  question = "What is AI"
20
  for ans in session.chat(question, stream=True):
21
  pass
22
+ assert not ans.content.startswith("**ERROR**"), "Please check this error."
23
 
24
  def test_delete_session_with_success(self):
25
  rag = RAGFlow(API_KEY, HOST_ADDRESS)