KevinHuSh commited on
Commit
bf46bd5
·
1 Parent(s): c054af2

fix english query bug (#840)

Browse files

### What problem does this PR solve?

#834

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

api/db/services/dialog_service.py CHANGED
@@ -118,7 +118,7 @@ def chat(dialog, messages, stream=True, **kwargs):
118
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
119
  dialog.similarity_threshold,
120
  dialog.vector_similarity_weight,
121
- doc_ids=kwargs.get("doc_ids", "").split(","),
122
  top=1024, aggs=False)
123
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
124
  chat_logger.info(
 
118
  kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
119
  dialog.similarity_threshold,
120
  dialog.vector_similarity_weight,
121
+ doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
122
  top=1024, aggs=False)
123
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
124
  chat_logger.info(
rag/llm/chat_model.py CHANGED
@@ -20,6 +20,7 @@ from openai import OpenAI
20
  import openai
21
  from ollama import Client
22
  from rag.nlp import is_english
 
23
 
24
 
25
  class Base(ABC):
@@ -255,3 +256,46 @@ class OllamaChat(Base):
255
  except Exception as e:
256
  yield ans + "\n**ERROR**: " + str(e)
257
  yield 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import openai
21
  from ollama import Client
22
  from rag.nlp import is_english
23
+ from rag.utils import num_tokens_from_string
24
 
25
 
26
  class Base(ABC):
 
256
  except Exception as e:
257
  yield ans + "\n**ERROR**: " + str(e)
258
  yield 0
259
+
260
+
261
+ class LocalLLM(Base):
262
+ class RPCProxy:
263
+ def __init__(self, host, port):
264
+ self.host = host
265
+ self.port = int(port)
266
+ self.__conn()
267
+
268
+ def __conn(self):
269
+ from multiprocessing.connection import Client
270
+ self._connection = Client(
271
+ (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
272
+
273
+ def __getattr__(self, name):
274
+ import pickle
275
+
276
+ def do_rpc(*args, **kwargs):
277
+ for _ in range(3):
278
+ try:
279
+ self._connection.send(
280
+ pickle.dumps((name, args, kwargs)))
281
+ return pickle.loads(self._connection.recv())
282
+ except Exception as e:
283
+ self.__conn()
284
+ raise Exception("RPC connection lost!")
285
+
286
+ return do_rpc
287
+
288
+ def __init__(self, key, model_name="glm-3-turbo"):
289
+ self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
290
+
291
+ def chat(self, system, history, gen_conf):
292
+ if system:
293
+ history.insert(0, {"role": "system", "content": system})
294
+ try:
295
+ ans = self.client.chat(
296
+ history,
297
+ gen_conf
298
+ )
299
+ return ans, num_tokens_from_string(ans)
300
+ except Exception as e:
301
+ return "**ERROR**: " + str(e), 0
rag/llm/rpc_server.py CHANGED
@@ -2,9 +2,10 @@ import argparse
2
  import pickle
3
  import random
4
  import time
 
5
  from multiprocessing.connection import Listener
6
  from threading import Thread
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
10
  def torch_gc():
@@ -95,6 +96,32 @@ def chat(messages, gen_conf):
95
  return str(e)
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def Model():
99
  global models
100
  random.seed(time.time())
@@ -113,6 +140,7 @@ if __name__ == "__main__":
113
 
114
  handler = RPCHandler()
115
  handler.register_function(chat)
 
116
 
117
  models = []
118
  for _ in range(1):
 
2
  import pickle
3
  import random
4
  import time
5
+ from copy import deepcopy
6
  from multiprocessing.connection import Listener
7
  from threading import Thread
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
9
 
10
 
11
  def torch_gc():
 
96
  return str(e)
97
 
98
 
99
+ def chat_streamly(messages, gen_conf):
100
+ global tokenizer
101
+ model = Model()
102
+ try:
103
+ torch_gc()
104
+ conf = deepcopy(gen_conf)
105
+ print(messages, conf)
106
+ text = tokenizer.apply_chat_template(
107
+ messages,
108
+ tokenize=False,
109
+ add_generation_prompt=True
110
+ )
111
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
112
+ streamer = TextStreamer(tokenizer)
113
+ conf["inputs"] = model_inputs.input_ids
114
+ conf["streamer"] = streamer
115
+ conf["max_new_tokens"] = conf["max_tokens"]
116
+ del conf["max_tokens"]
117
+ thread = Thread(target=model.generate, kwargs=conf)
118
+ thread.start()
119
+ for _, new_text in enumerate(streamer):
120
+ yield new_text
121
+ except Exception as e:
122
+ yield "**ERROR**: " + str(e)
123
+
124
+
125
  def Model():
126
  global models
127
  random.seed(time.time())
 
140
 
141
  handler = RPCHandler()
142
  handler.register_function(chat)
143
+ handler.register_function(chat_streamly)
144
 
145
  models = []
146
  for _ in range(1):
rag/nlp/query.py CHANGED
@@ -36,7 +36,7 @@ class EsQueryer:
36
  patts = [
37
  (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
38
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
39
- (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
40
  ]
41
  for r, p in patts:
42
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
@@ -44,7 +44,7 @@ class EsQueryer:
44
 
45
  def question(self, txt, tbl="qa", min_match="60%"):
46
  txt = re.sub(
47
- r"[ \r\n\t,,。??/`!!&]+",
48
  " ",
49
  rag_tokenizer.tradi2simp(
50
  rag_tokenizer.strQ2B(
@@ -53,9 +53,10 @@ class EsQueryer:
53
 
54
  if not self.isChinese(txt):
55
  tks = rag_tokenizer.tokenize(txt).split(" ")
56
- q = copy.deepcopy(tks)
57
- for i in range(1, len(tks)):
58
- q.append("\"%s %s\"^2" % (tks[i - 1], tks[i]))
 
59
  if not q:
60
  q.append(txt)
61
  return Q("bool",
 
36
  patts = [
37
  (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
38
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
39
+ (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
40
  ]
41
  for r, p in patts:
42
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
 
44
 
45
  def question(self, txt, tbl="qa", min_match="60%"):
46
  txt = re.sub(
47
+ r"[ \r\n\t,,。??/`!!&\^%%]+",
48
  " ",
49
  rag_tokenizer.tradi2simp(
50
  rag_tokenizer.strQ2B(
 
53
 
54
  if not self.isChinese(txt):
55
  tks = rag_tokenizer.tokenize(txt).split(" ")
56
+ tks_w = self.tw.weights(tks)
57
+ q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
58
+ for i in range(1, len(tks_w)):
59
+ q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
60
  if not q:
61
  q.append(txt)
62
  return Q("bool",