Commit
路
d919631
1
Parent(s):
4d2f593
Change default error message to English (#3838)
Browse files### What problem does this PR solve?
As title
### Type of change
- [x] Refactoring
---------
Signed-off-by: Jin Hai <[email protected]>
- rag/llm/chat_model.py +55 -27
- rag/nlp/__init__.py +8 -0
rag/llm/chat_model.py
CHANGED
@@ -22,7 +22,7 @@ from abc import ABC
|
|
22 |
from openai import OpenAI
|
23 |
import openai
|
24 |
from ollama import Client
|
25 |
-
from rag.nlp import
|
26 |
from rag.utils import num_tokens_from_string
|
27 |
from groq import Groq
|
28 |
import os
|
@@ -30,6 +30,8 @@ import json
|
|
30 |
import requests
|
31 |
import asyncio
|
32 |
|
|
|
|
|
33 |
|
34 |
class Base(ABC):
|
35 |
def __init__(self, key, model_name, base_url):
|
@@ -47,8 +49,10 @@ class Base(ABC):
|
|
47 |
**gen_conf)
|
48 |
ans = response.choices[0].message.content.strip()
|
49 |
if response.choices[0].finish_reason == "length":
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
return ans, response.usage.total_tokens
|
53 |
except openai.APIError as e:
|
54 |
return "**ERROR**: " + str(e), 0
|
@@ -80,8 +84,10 @@ class Base(ABC):
|
|
80 |
else: total_tokens = resp.usage.total_tokens
|
81 |
|
82 |
if resp.choices[0].finish_reason == "length":
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
yield ans
|
86 |
|
87 |
except openai.APIError as e:
|
@@ -167,8 +173,10 @@ class BaiChuanChat(Base):
|
|
167 |
**self._format_params(gen_conf))
|
168 |
ans = response.choices[0].message.content.strip()
|
169 |
if response.choices[0].finish_reason == "length":
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
return ans, response.usage.total_tokens
|
173 |
except openai.APIError as e:
|
174 |
return "**ERROR**: " + str(e), 0
|
@@ -207,8 +215,10 @@ class BaiChuanChat(Base):
|
|
207 |
else resp.usage["total_tokens"]
|
208 |
)
|
209 |
if resp.choices[0].finish_reason == "length":
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
yield ans
|
213 |
|
214 |
except Exception as e:
|
@@ -242,8 +252,10 @@ class QWenChat(Base):
|
|
242 |
ans += response.output.choices[0]['message']['content']
|
243 |
tk_count += response.usage.total_tokens
|
244 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
245 |
-
|
246 |
-
|
|
|
|
|
247 |
return ans, tk_count
|
248 |
|
249 |
return "**ERROR**: " + response.message, tk_count
|
@@ -276,8 +288,10 @@ class QWenChat(Base):
|
|
276 |
ans = resp.output.choices[0]['message']['content']
|
277 |
tk_count = resp.usage.total_tokens
|
278 |
if resp.output.choices[0].get("finish_reason", "") == "length":
|
279 |
-
|
280 |
-
|
|
|
|
|
281 |
yield ans
|
282 |
else:
|
283 |
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
|
@@ -308,8 +322,10 @@ class ZhipuChat(Base):
|
|
308 |
)
|
309 |
ans = response.choices[0].message.content.strip()
|
310 |
if response.choices[0].finish_reason == "length":
|
311 |
-
|
312 |
-
|
|
|
|
|
313 |
return ans, response.usage.total_tokens
|
314 |
except Exception as e:
|
315 |
return "**ERROR**: " + str(e), 0
|
@@ -333,8 +349,10 @@ class ZhipuChat(Base):
|
|
333 |
delta = resp.choices[0].delta.content
|
334 |
ans += delta
|
335 |
if resp.choices[0].finish_reason == "length":
|
336 |
-
|
337 |
-
|
|
|
|
|
338 |
tk_count = resp.usage.total_tokens
|
339 |
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
340 |
yield ans
|
@@ -525,8 +543,10 @@ class MiniMaxChat(Base):
|
|
525 |
response = response.json()
|
526 |
ans = response["choices"][0]["message"]["content"].strip()
|
527 |
if response["choices"][0]["finish_reason"] == "length":
|
528 |
-
|
529 |
-
|
|
|
|
|
530 |
return ans, response["usage"]["total_tokens"]
|
531 |
except Exception as e:
|
532 |
return "**ERROR**: " + str(e), 0
|
@@ -594,8 +614,10 @@ class MistralChat(Base):
|
|
594 |
**gen_conf)
|
595 |
ans = response.choices[0].message.content
|
596 |
if response.choices[0].finish_reason == "length":
|
597 |
-
|
598 |
-
|
|
|
|
|
599 |
return ans, response.usage.total_tokens
|
600 |
except openai.APIError as e:
|
601 |
return "**ERROR**: " + str(e), 0
|
@@ -618,8 +640,10 @@ class MistralChat(Base):
|
|
618 |
ans += resp.choices[0].delta.content
|
619 |
total_tokens += 1
|
620 |
if resp.choices[0].finish_reason == "length":
|
621 |
-
|
622 |
-
|
|
|
|
|
623 |
yield ans
|
624 |
|
625 |
except openai.APIError as e:
|
@@ -811,8 +835,10 @@ class GroqChat:
|
|
811 |
)
|
812 |
ans = response.choices[0].message.content
|
813 |
if response.choices[0].finish_reason == "length":
|
814 |
-
|
815 |
-
|
|
|
|
|
816 |
return ans, response.usage.total_tokens
|
817 |
except Exception as e:
|
818 |
return ans + "\n**ERROR**: " + str(e), 0
|
@@ -838,8 +864,10 @@ class GroqChat:
|
|
838 |
ans += resp.choices[0].delta.content
|
839 |
total_tokens += 1
|
840 |
if resp.choices[0].finish_reason == "length":
|
841 |
-
|
842 |
-
|
|
|
|
|
843 |
yield ans
|
844 |
|
845 |
except Exception as e:
|
|
|
22 |
from openai import OpenAI
|
23 |
import openai
|
24 |
from ollama import Client
|
25 |
+
from rag.nlp import is_chinese
|
26 |
from rag.utils import num_tokens_from_string
|
27 |
from groq import Groq
|
28 |
import os
|
|
|
30 |
import requests
|
31 |
import asyncio
|
32 |
|
33 |
+
LENGTH_NOTIFICATION_CN = "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
|
34 |
+
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
|
35 |
|
36 |
class Base(ABC):
|
37 |
def __init__(self, key, model_name, base_url):
|
|
|
49 |
**gen_conf)
|
50 |
ans = response.choices[0].message.content.strip()
|
51 |
if response.choices[0].finish_reason == "length":
|
52 |
+
if is_chinese(ans):
|
53 |
+
ans += LENGTH_NOTIFICATION_CN
|
54 |
+
else:
|
55 |
+
ans += LENGTH_NOTIFICATION_EN
|
56 |
return ans, response.usage.total_tokens
|
57 |
except openai.APIError as e:
|
58 |
return "**ERROR**: " + str(e), 0
|
|
|
84 |
else: total_tokens = resp.usage.total_tokens
|
85 |
|
86 |
if resp.choices[0].finish_reason == "length":
|
87 |
+
if is_chinese(ans):
|
88 |
+
ans += LENGTH_NOTIFICATION_CN
|
89 |
+
else:
|
90 |
+
ans += LENGTH_NOTIFICATION_EN
|
91 |
yield ans
|
92 |
|
93 |
except openai.APIError as e:
|
|
|
173 |
**self._format_params(gen_conf))
|
174 |
ans = response.choices[0].message.content.strip()
|
175 |
if response.choices[0].finish_reason == "length":
|
176 |
+
if is_chinese([ans]):
|
177 |
+
ans += LENGTH_NOTIFICATION_CN
|
178 |
+
else:
|
179 |
+
ans += LENGTH_NOTIFICATION_EN
|
180 |
return ans, response.usage.total_tokens
|
181 |
except openai.APIError as e:
|
182 |
return "**ERROR**: " + str(e), 0
|
|
|
215 |
else resp.usage["total_tokens"]
|
216 |
)
|
217 |
if resp.choices[0].finish_reason == "length":
|
218 |
+
if is_chinese([ans]):
|
219 |
+
ans += LENGTH_NOTIFICATION_CN
|
220 |
+
else:
|
221 |
+
ans += LENGTH_NOTIFICATION_EN
|
222 |
yield ans
|
223 |
|
224 |
except Exception as e:
|
|
|
252 |
ans += response.output.choices[0]['message']['content']
|
253 |
tk_count += response.usage.total_tokens
|
254 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
255 |
+
if is_chinese([ans]):
|
256 |
+
ans += LENGTH_NOTIFICATION_CN
|
257 |
+
else:
|
258 |
+
ans += LENGTH_NOTIFICATION_EN
|
259 |
return ans, tk_count
|
260 |
|
261 |
return "**ERROR**: " + response.message, tk_count
|
|
|
288 |
ans = resp.output.choices[0]['message']['content']
|
289 |
tk_count = resp.usage.total_tokens
|
290 |
if resp.output.choices[0].get("finish_reason", "") == "length":
|
291 |
+
if is_chinese(ans):
|
292 |
+
ans += LENGTH_NOTIFICATION_CN
|
293 |
+
else:
|
294 |
+
ans += LENGTH_NOTIFICATION_EN
|
295 |
yield ans
|
296 |
else:
|
297 |
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
|
322 |
)
|
323 |
ans = response.choices[0].message.content.strip()
|
324 |
if response.choices[0].finish_reason == "length":
|
325 |
+
if is_chinese(ans):
|
326 |
+
ans += LENGTH_NOTIFICATION_CN
|
327 |
+
else:
|
328 |
+
ans += LENGTH_NOTIFICATION_EN
|
329 |
return ans, response.usage.total_tokens
|
330 |
except Exception as e:
|
331 |
return "**ERROR**: " + str(e), 0
|
|
|
349 |
delta = resp.choices[0].delta.content
|
350 |
ans += delta
|
351 |
if resp.choices[0].finish_reason == "length":
|
352 |
+
if is_chinese(ans):
|
353 |
+
ans += LENGTH_NOTIFICATION_CN
|
354 |
+
else:
|
355 |
+
ans += LENGTH_NOTIFICATION_EN
|
356 |
tk_count = resp.usage.total_tokens
|
357 |
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
358 |
yield ans
|
|
|
543 |
response = response.json()
|
544 |
ans = response["choices"][0]["message"]["content"].strip()
|
545 |
if response["choices"][0]["finish_reason"] == "length":
|
546 |
+
if is_chinese(ans):
|
547 |
+
ans += LENGTH_NOTIFICATION_CN
|
548 |
+
else:
|
549 |
+
ans += LENGTH_NOTIFICATION_EN
|
550 |
return ans, response["usage"]["total_tokens"]
|
551 |
except Exception as e:
|
552 |
return "**ERROR**: " + str(e), 0
|
|
|
614 |
**gen_conf)
|
615 |
ans = response.choices[0].message.content
|
616 |
if response.choices[0].finish_reason == "length":
|
617 |
+
if is_chinese(ans):
|
618 |
+
ans += LENGTH_NOTIFICATION_CN
|
619 |
+
else:
|
620 |
+
ans += LENGTH_NOTIFICATION_EN
|
621 |
return ans, response.usage.total_tokens
|
622 |
except openai.APIError as e:
|
623 |
return "**ERROR**: " + str(e), 0
|
|
|
640 |
ans += resp.choices[0].delta.content
|
641 |
total_tokens += 1
|
642 |
if resp.choices[0].finish_reason == "length":
|
643 |
+
if is_chinese(ans):
|
644 |
+
ans += LENGTH_NOTIFICATION_CN
|
645 |
+
else:
|
646 |
+
ans += LENGTH_NOTIFICATION_EN
|
647 |
yield ans
|
648 |
|
649 |
except openai.APIError as e:
|
|
|
835 |
)
|
836 |
ans = response.choices[0].message.content
|
837 |
if response.choices[0].finish_reason == "length":
|
838 |
+
if is_chinese(ans):
|
839 |
+
ans += LENGTH_NOTIFICATION_CN
|
840 |
+
else:
|
841 |
+
ans += LENGTH_NOTIFICATION_EN
|
842 |
return ans, response.usage.total_tokens
|
843 |
except Exception as e:
|
844 |
return ans + "\n**ERROR**: " + str(e), 0
|
|
|
864 |
ans += resp.choices[0].delta.content
|
865 |
total_tokens += 1
|
866 |
if resp.choices[0].finish_reason == "length":
|
867 |
+
if is_chinese(ans):
|
868 |
+
ans += LENGTH_NOTIFICATION_CN
|
869 |
+
else:
|
870 |
+
ans += LENGTH_NOTIFICATION_EN
|
871 |
yield ans
|
872 |
|
873 |
except Exception as e:
|
rag/nlp/__init__.py
CHANGED
@@ -230,6 +230,14 @@ def is_english(texts):
|
|
230 |
return True
|
231 |
return False
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
def tokenize(d, t, eng):
|
235 |
d["content_with_weight"] = t
|
|
|
230 |
return True
|
231 |
return False
|
232 |
|
233 |
+
def is_chinese(text):
|
234 |
+
chinese = 0
|
235 |
+
for ch in text:
|
236 |
+
if '\u4e00' <= ch <= '\u9fff':
|
237 |
+
chinese += 1
|
238 |
+
if chinese / len(text) > 0.2:
|
239 |
+
return True
|
240 |
+
return False
|
241 |
|
242 |
def tokenize(d, t, eng):
|
243 |
d["content_with_weight"] = t
|