KevinHuSh commited on
Commit
6ad2626
·
1 Parent(s): 7483940

fix jina adding issure and term weight refinement (#974)

Browse files

### What problem does this PR solve?

#724 #162

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

api/apps/llm_app.py CHANGED
@@ -39,17 +39,18 @@ def factories():
39
  def set_api_key():
40
  req = request.json
41
  # test if api key works
42
- chat_passed = False
43
  factory = req["llm_factory"]
44
  msg = ""
45
  for llm in LLMService.query(fid=factory):
46
- if llm.model_type == LLMType.EMBEDDING.value:
47
  mdl = EmbeddingModel[factory](
48
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
49
  try:
50
  arr, tc = mdl.encode(["Test if the api key is available"])
51
  if len(arr[0]) == 0 or tc == 0:
52
  raise Exception("Fail")
 
53
  except Exception as e:
54
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
55
  elif not chat_passed and llm.model_type == LLMType.CHAT.value:
@@ -60,20 +61,21 @@ def set_api_key():
60
  "temperature": 0.9})
61
  if not tc:
62
  raise Exception(m)
63
- chat_passed = True
64
  except Exception as e:
65
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
66
  e)
67
- elif llm.model_type == LLMType.RERANK:
 
68
  mdl = RerankModel[factory](
69
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
70
  try:
71
- m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
72
- if len(arr[0]) == 0 or tc == 0:
73
  raise Exception("Fail")
74
  except Exception as e:
75
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
76
  e)
 
77
 
78
  if msg:
79
  return get_data_error_result(retmsg=msg)
 
39
  def set_api_key():
40
  req = request.json
41
  # test if api key works
42
+ chat_passed, embd_passed, rerank_passed = False, False, False
43
  factory = req["llm_factory"]
44
  msg = ""
45
  for llm in LLMService.query(fid=factory):
46
+ if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
47
  mdl = EmbeddingModel[factory](
48
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
49
  try:
50
  arr, tc = mdl.encode(["Test if the api key is available"])
51
  if len(arr[0]) == 0 or tc == 0:
52
  raise Exception("Fail")
53
+ embd_passed = True
54
  except Exception as e:
55
  msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
56
  elif not chat_passed and llm.model_type == LLMType.CHAT.value:
 
61
  "temperature": 0.9})
62
  if not tc:
63
  raise Exception(m)
 
64
  except Exception as e:
65
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
66
  e)
67
+ chat_passed = True
68
+ elif not rerank_passed and llm.model_type == LLMType.RERANK:
69
  mdl = RerankModel[factory](
70
  req["api_key"], llm.llm_name, base_url=req.get("base_url"))
71
  try:
72
+ arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
73
+ if len(arr) == 0 or tc == 0:
74
  raise Exception("Fail")
75
  except Exception as e:
76
  msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
77
  e)
78
+ rerank_passed = True
79
 
80
  if msg:
81
  return get_data_error_result(retmsg=msg)
api/db/services/llm_service.py CHANGED
@@ -147,7 +147,6 @@ class TenantLLMService(CommonService):
147
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
148
  .execute()
149
  except Exception as e:
150
- print(e)
151
  pass
152
  return num
153
 
 
147
  .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
148
  .execute()
149
  except Exception as e:
 
150
  pass
151
  return num
152
 
rag/llm/__init__.py CHANGED
@@ -28,6 +28,7 @@ EmbeddingModel = {
28
  "FastEmbed": FastEmbed,
29
  "Youdao": YoudaoEmbed,
30
  "BaiChuan": BaiChuanEmbed,
 
31
  "BAAI": DefaultEmbedding
32
  }
33
 
 
28
  "FastEmbed": FastEmbed,
29
  "Youdao": YoudaoEmbed,
30
  "BaiChuan": BaiChuanEmbed,
31
+ "Jina": JinaEmbed,
32
  "BAAI": DefaultEmbedding
33
  }
34
 
rag/llm/embedding_model.py CHANGED
@@ -291,7 +291,7 @@ class JinaEmbed(Base):
291
  "input": texts,
292
  'encoding_type': 'float'
293
  }
294
- res = requests.post(self.base_url, headers=self.headers, json=data)
295
  return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
296
 
297
  def encode_queries(self, text):
 
291
  "input": texts,
292
  'encoding_type': 'float'
293
  }
294
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
295
  return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
296
 
297
  def encode_queries(self, text):
rag/llm/rerank_model.py CHANGED
@@ -91,7 +91,7 @@ class JinaRerank(Base):
91
  "documents": texts,
92
  "top_n": len(texts)
93
  }
94
- res = requests.post(self.base_url, headers=self.headers, json=data)
95
  return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
96
 
97
 
 
91
  "documents": texts,
92
  "top_n": len(texts)
93
  }
94
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
95
  return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
96
 
97
 
rag/nlp/query.py CHANGED
@@ -44,7 +44,7 @@ class EsQueryer:
44
 
45
  def question(self, txt, tbl="qa", min_match="60%"):
46
  txt = re.sub(
47
- r"[ \r\n\t,,。??/`!!&\^%%]+",
48
  " ",
49
  rag_tokenizer.tradi2simp(
50
  rag_tokenizer.strQ2B(
 
44
 
45
  def question(self, txt, tbl="qa", min_match="60%"):
46
  txt = re.sub(
47
+ r"[ :\r\n\t,,。??/`!!&\^%%]+",
48
  " ",
49
  rag_tokenizer.tradi2simp(
50
  rag_tokenizer.strQ2B(
rag/nlp/term_weight.py CHANGED
@@ -104,7 +104,7 @@ class Dealer:
104
  while i < len(tks):
105
  j = i
106
  if i == 0 and oneTerm(tks[i]) and len(
107
- tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
108
  res.append(" ".join(tks[0:2]))
109
  i = 2
110
  continue
 
104
  while i < len(tks):
105
  j = i
106
  if i == 0 and oneTerm(tks[i]) and len(
107
+ tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
108
  res.append(" ".join(tks[0:2]))
109
  i = 2
110
  continue