KevinHuSh
commited on
Commit
·
bf46bd5
1
Parent(s):
c054af2
fix english query bug (#840)
Browse files### What problem does this PR solve?
#834
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/db/services/dialog_service.py +1 -1
- rag/llm/chat_model.py +44 -0
- rag/llm/rpc_server.py +29 -1
- rag/nlp/query.py +6 -5
api/db/services/dialog_service.py
CHANGED
@@ -118,7 +118,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
118 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
119 |
dialog.similarity_threshold,
|
120 |
dialog.vector_similarity_weight,
|
121 |
-
doc_ids=kwargs
|
122 |
top=1024, aggs=False)
|
123 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
124 |
chat_logger.info(
|
|
|
118 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
119 |
dialog.similarity_threshold,
|
120 |
dialog.vector_similarity_weight,
|
121 |
+
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
122 |
top=1024, aggs=False)
|
123 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
124 |
chat_logger.info(
|
rag/llm/chat_model.py
CHANGED
@@ -20,6 +20,7 @@ from openai import OpenAI
|
|
20 |
import openai
|
21 |
from ollama import Client
|
22 |
from rag.nlp import is_english
|
|
|
23 |
|
24 |
|
25 |
class Base(ABC):
|
@@ -255,3 +256,46 @@ class OllamaChat(Base):
|
|
255 |
except Exception as e:
|
256 |
yield ans + "\n**ERROR**: " + str(e)
|
257 |
yield 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
import openai
|
21 |
from ollama import Client
|
22 |
from rag.nlp import is_english
|
23 |
+
from rag.utils import num_tokens_from_string
|
24 |
|
25 |
|
26 |
class Base(ABC):
|
|
|
256 |
except Exception as e:
|
257 |
yield ans + "\n**ERROR**: " + str(e)
|
258 |
yield 0
|
259 |
+
|
260 |
+
|
261 |
+
class LocalLLM(Base):
|
262 |
+
class RPCProxy:
|
263 |
+
def __init__(self, host, port):
|
264 |
+
self.host = host
|
265 |
+
self.port = int(port)
|
266 |
+
self.__conn()
|
267 |
+
|
268 |
+
def __conn(self):
|
269 |
+
from multiprocessing.connection import Client
|
270 |
+
self._connection = Client(
|
271 |
+
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
272 |
+
|
273 |
+
def __getattr__(self, name):
|
274 |
+
import pickle
|
275 |
+
|
276 |
+
def do_rpc(*args, **kwargs):
|
277 |
+
for _ in range(3):
|
278 |
+
try:
|
279 |
+
self._connection.send(
|
280 |
+
pickle.dumps((name, args, kwargs)))
|
281 |
+
return pickle.loads(self._connection.recv())
|
282 |
+
except Exception as e:
|
283 |
+
self.__conn()
|
284 |
+
raise Exception("RPC connection lost!")
|
285 |
+
|
286 |
+
return do_rpc
|
287 |
+
|
288 |
+
def __init__(self, key, model_name="glm-3-turbo"):
|
289 |
+
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
290 |
+
|
291 |
+
def chat(self, system, history, gen_conf):
|
292 |
+
if system:
|
293 |
+
history.insert(0, {"role": "system", "content": system})
|
294 |
+
try:
|
295 |
+
ans = self.client.chat(
|
296 |
+
history,
|
297 |
+
gen_conf
|
298 |
+
)
|
299 |
+
return ans, num_tokens_from_string(ans)
|
300 |
+
except Exception as e:
|
301 |
+
return "**ERROR**: " + str(e), 0
|
rag/llm/rpc_server.py
CHANGED
@@ -2,9 +2,10 @@ import argparse
|
|
2 |
import pickle
|
3 |
import random
|
4 |
import time
|
|
|
5 |
from multiprocessing.connection import Listener
|
6 |
from threading import Thread
|
7 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
|
9 |
|
10 |
def torch_gc():
|
@@ -95,6 +96,32 @@ def chat(messages, gen_conf):
|
|
95 |
return str(e)
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def Model():
|
99 |
global models
|
100 |
random.seed(time.time())
|
@@ -113,6 +140,7 @@ if __name__ == "__main__":
|
|
113 |
|
114 |
handler = RPCHandler()
|
115 |
handler.register_function(chat)
|
|
|
116 |
|
117 |
models = []
|
118 |
for _ in range(1):
|
|
|
2 |
import pickle
|
3 |
import random
|
4 |
import time
|
5 |
+
from copy import deepcopy
|
6 |
from multiprocessing.connection import Listener
|
7 |
from threading import Thread
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
9 |
|
10 |
|
11 |
def torch_gc():
|
|
|
96 |
return str(e)
|
97 |
|
98 |
|
99 |
+
def chat_streamly(messages, gen_conf):
|
100 |
+
global tokenizer
|
101 |
+
model = Model()
|
102 |
+
try:
|
103 |
+
torch_gc()
|
104 |
+
conf = deepcopy(gen_conf)
|
105 |
+
print(messages, conf)
|
106 |
+
text = tokenizer.apply_chat_template(
|
107 |
+
messages,
|
108 |
+
tokenize=False,
|
109 |
+
add_generation_prompt=True
|
110 |
+
)
|
111 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
112 |
+
streamer = TextStreamer(tokenizer)
|
113 |
+
conf["inputs"] = model_inputs.input_ids
|
114 |
+
conf["streamer"] = streamer
|
115 |
+
conf["max_new_tokens"] = conf["max_tokens"]
|
116 |
+
del conf["max_tokens"]
|
117 |
+
thread = Thread(target=model.generate, kwargs=conf)
|
118 |
+
thread.start()
|
119 |
+
for _, new_text in enumerate(streamer):
|
120 |
+
yield new_text
|
121 |
+
except Exception as e:
|
122 |
+
yield "**ERROR**: " + str(e)
|
123 |
+
|
124 |
+
|
125 |
def Model():
|
126 |
global models
|
127 |
random.seed(time.time())
|
|
|
140 |
|
141 |
handler = RPCHandler()
|
142 |
handler.register_function(chat)
|
143 |
+
handler.register_function(chat_streamly)
|
144 |
|
145 |
models = []
|
146 |
for _ in range(1):
|
rag/nlp/query.py
CHANGED
@@ -36,7 +36,7 @@ class EsQueryer:
|
|
36 |
patts = [
|
37 |
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
38 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
39 |
-
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
|
40 |
]
|
41 |
for r, p in patts:
|
42 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
@@ -44,7 +44,7 @@ class EsQueryer:
|
|
44 |
|
45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
46 |
txt = re.sub(
|
47 |
-
r"[ \r\n\t
|
48 |
" ",
|
49 |
rag_tokenizer.tradi2simp(
|
50 |
rag_tokenizer.strQ2B(
|
@@ -53,9 +53,10 @@ class EsQueryer:
|
|
53 |
|
54 |
if not self.isChinese(txt):
|
55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
if not q:
|
60 |
q.append(txt)
|
61 |
return Q("bool",
|
|
|
36 |
patts = [
|
37 |
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
38 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
39 |
+
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
40 |
]
|
41 |
for r, p in patts:
|
42 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
|
|
44 |
|
45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
46 |
txt = re.sub(
|
47 |
+
r"[ \r\n\t,,。??/`!!&\^%%]+",
|
48 |
" ",
|
49 |
rag_tokenizer.tradi2simp(
|
50 |
rag_tokenizer.strQ2B(
|
|
|
53 |
|
54 |
if not self.isChinese(txt):
|
55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
56 |
+
tks_w = self.tw.weights(tks)
|
57 |
+
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
|
58 |
+
for i in range(1, len(tks_w)):
|
59 |
+
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
60 |
if not q:
|
61 |
q.append(txt)
|
62 |
return Q("bool",
|