KevinHuSh
commited on
Commit
·
bf46bd5
1
Parent(s):
c054af2
fix english query bug (#840)
Browse files### What problem does this PR solve?
#834
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/db/services/dialog_service.py +1 -1
- rag/llm/chat_model.py +44 -0
- rag/llm/rpc_server.py +29 -1
- rag/nlp/query.py +6 -5
api/db/services/dialog_service.py
CHANGED
|
@@ -118,7 +118,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 118 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
| 119 |
dialog.similarity_threshold,
|
| 120 |
dialog.vector_similarity_weight,
|
| 121 |
-
doc_ids=kwargs
|
| 122 |
top=1024, aggs=False)
|
| 123 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 124 |
chat_logger.info(
|
|
|
|
| 118 |
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
| 119 |
dialog.similarity_threshold,
|
| 120 |
dialog.vector_similarity_weight,
|
| 121 |
+
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
| 122 |
top=1024, aggs=False)
|
| 123 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 124 |
chat_logger.info(
|
rag/llm/chat_model.py
CHANGED
|
@@ -20,6 +20,7 @@ from openai import OpenAI
|
|
| 20 |
import openai
|
| 21 |
from ollama import Client
|
| 22 |
from rag.nlp import is_english
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class Base(ABC):
|
|
@@ -255,3 +256,46 @@ class OllamaChat(Base):
|
|
| 255 |
except Exception as e:
|
| 256 |
yield ans + "\n**ERROR**: " + str(e)
|
| 257 |
yield 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import openai
|
| 21 |
from ollama import Client
|
| 22 |
from rag.nlp import is_english
|
| 23 |
+
from rag.utils import num_tokens_from_string
|
| 24 |
|
| 25 |
|
| 26 |
class Base(ABC):
|
|
|
|
| 256 |
except Exception as e:
|
| 257 |
yield ans + "\n**ERROR**: " + str(e)
|
| 258 |
yield 0
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class LocalLLM(Base):
|
| 262 |
+
class RPCProxy:
|
| 263 |
+
def __init__(self, host, port):
|
| 264 |
+
self.host = host
|
| 265 |
+
self.port = int(port)
|
| 266 |
+
self.__conn()
|
| 267 |
+
|
| 268 |
+
def __conn(self):
|
| 269 |
+
from multiprocessing.connection import Client
|
| 270 |
+
self._connection = Client(
|
| 271 |
+
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
| 272 |
+
|
| 273 |
+
def __getattr__(self, name):
|
| 274 |
+
import pickle
|
| 275 |
+
|
| 276 |
+
def do_rpc(*args, **kwargs):
|
| 277 |
+
for _ in range(3):
|
| 278 |
+
try:
|
| 279 |
+
self._connection.send(
|
| 280 |
+
pickle.dumps((name, args, kwargs)))
|
| 281 |
+
return pickle.loads(self._connection.recv())
|
| 282 |
+
except Exception as e:
|
| 283 |
+
self.__conn()
|
| 284 |
+
raise Exception("RPC connection lost!")
|
| 285 |
+
|
| 286 |
+
return do_rpc
|
| 287 |
+
|
| 288 |
+
def __init__(self, key, model_name="glm-3-turbo"):
|
| 289 |
+
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
| 290 |
+
|
| 291 |
+
def chat(self, system, history, gen_conf):
|
| 292 |
+
if system:
|
| 293 |
+
history.insert(0, {"role": "system", "content": system})
|
| 294 |
+
try:
|
| 295 |
+
ans = self.client.chat(
|
| 296 |
+
history,
|
| 297 |
+
gen_conf
|
| 298 |
+
)
|
| 299 |
+
return ans, num_tokens_from_string(ans)
|
| 300 |
+
except Exception as e:
|
| 301 |
+
return "**ERROR**: " + str(e), 0
|
rag/llm/rpc_server.py
CHANGED
|
@@ -2,9 +2,10 @@ import argparse
|
|
| 2 |
import pickle
|
| 3 |
import random
|
| 4 |
import time
|
|
|
|
| 5 |
from multiprocessing.connection import Listener
|
| 6 |
from threading import Thread
|
| 7 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
|
| 9 |
|
| 10 |
def torch_gc():
|
|
@@ -95,6 +96,32 @@ def chat(messages, gen_conf):
|
|
| 95 |
return str(e)
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def Model():
|
| 99 |
global models
|
| 100 |
random.seed(time.time())
|
|
@@ -113,6 +140,7 @@ if __name__ == "__main__":
|
|
| 113 |
|
| 114 |
handler = RPCHandler()
|
| 115 |
handler.register_function(chat)
|
|
|
|
| 116 |
|
| 117 |
models = []
|
| 118 |
for _ in range(1):
|
|
|
|
| 2 |
import pickle
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
+
from copy import deepcopy
|
| 6 |
from multiprocessing.connection import Listener
|
| 7 |
from threading import Thread
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 9 |
|
| 10 |
|
| 11 |
def torch_gc():
|
|
|
|
| 96 |
return str(e)
|
| 97 |
|
| 98 |
|
| 99 |
+
def chat_streamly(messages, gen_conf):
|
| 100 |
+
global tokenizer
|
| 101 |
+
model = Model()
|
| 102 |
+
try:
|
| 103 |
+
torch_gc()
|
| 104 |
+
conf = deepcopy(gen_conf)
|
| 105 |
+
print(messages, conf)
|
| 106 |
+
text = tokenizer.apply_chat_template(
|
| 107 |
+
messages,
|
| 108 |
+
tokenize=False,
|
| 109 |
+
add_generation_prompt=True
|
| 110 |
+
)
|
| 111 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 112 |
+
streamer = TextStreamer(tokenizer)
|
| 113 |
+
conf["inputs"] = model_inputs.input_ids
|
| 114 |
+
conf["streamer"] = streamer
|
| 115 |
+
conf["max_new_tokens"] = conf["max_tokens"]
|
| 116 |
+
del conf["max_tokens"]
|
| 117 |
+
thread = Thread(target=model.generate, kwargs=conf)
|
| 118 |
+
thread.start()
|
| 119 |
+
for _, new_text in enumerate(streamer):
|
| 120 |
+
yield new_text
|
| 121 |
+
except Exception as e:
|
| 122 |
+
yield "**ERROR**: " + str(e)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def Model():
|
| 126 |
global models
|
| 127 |
random.seed(time.time())
|
|
|
|
| 140 |
|
| 141 |
handler = RPCHandler()
|
| 142 |
handler.register_function(chat)
|
| 143 |
+
handler.register_function(chat_streamly)
|
| 144 |
|
| 145 |
models = []
|
| 146 |
for _ in range(1):
|
rag/nlp/query.py
CHANGED
|
@@ -36,7 +36,7 @@ class EsQueryer:
|
|
| 36 |
patts = [
|
| 37 |
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
| 38 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
| 39 |
-
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
|
| 40 |
]
|
| 41 |
for r, p in patts:
|
| 42 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
|
@@ -44,7 +44,7 @@ class EsQueryer:
|
|
| 44 |
|
| 45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
| 46 |
txt = re.sub(
|
| 47 |
-
r"[ \r\n\t
|
| 48 |
" ",
|
| 49 |
rag_tokenizer.tradi2simp(
|
| 50 |
rag_tokenizer.strQ2B(
|
|
@@ -53,9 +53,10 @@ class EsQueryer:
|
|
| 53 |
|
| 54 |
if not self.isChinese(txt):
|
| 55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
| 59 |
if not q:
|
| 60 |
q.append(txt)
|
| 61 |
return Q("bool",
|
|
|
|
| 36 |
patts = [
|
| 37 |
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
| 38 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
| 39 |
+
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
| 40 |
]
|
| 41 |
for r, p in patts:
|
| 42 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
|
|
|
| 44 |
|
| 45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
| 46 |
txt = re.sub(
|
| 47 |
+
r"[ \r\n\t,,。??/`!!&\^%%]+",
|
| 48 |
" ",
|
| 49 |
rag_tokenizer.tradi2simp(
|
| 50 |
rag_tokenizer.strQ2B(
|
|
|
|
| 53 |
|
| 54 |
if not self.isChinese(txt):
|
| 55 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
| 56 |
+
tks_w = self.tw.weights(tks)
|
| 57 |
+
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
|
| 58 |
+
for i in range(1, len(tks_w)):
|
| 59 |
+
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
| 60 |
if not q:
|
| 61 |
q.append(txt)
|
| 62 |
return Q("bool",
|