KevinHuSh
commited on
Commit
·
6ad2626
1
Parent(s):
7483940
fix jina adding issure and term weight refinement (#974)
Browse files### What problem does this PR solve?
#724 #162
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/llm_app.py +8 -6
- api/db/services/llm_service.py +0 -1
- rag/llm/__init__.py +1 -0
- rag/llm/embedding_model.py +1 -1
- rag/llm/rerank_model.py +1 -1
- rag/nlp/query.py +1 -1
- rag/nlp/term_weight.py +1 -1
api/apps/llm_app.py
CHANGED
@@ -39,17 +39,18 @@ def factories():
|
|
39 |
def set_api_key():
|
40 |
req = request.json
|
41 |
# test if api key works
|
42 |
-
chat_passed = False
|
43 |
factory = req["llm_factory"]
|
44 |
msg = ""
|
45 |
for llm in LLMService.query(fid=factory):
|
46 |
-
if llm.model_type == LLMType.EMBEDDING.value:
|
47 |
mdl = EmbeddingModel[factory](
|
48 |
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
49 |
try:
|
50 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
51 |
if len(arr[0]) == 0 or tc == 0:
|
52 |
raise Exception("Fail")
|
|
|
53 |
except Exception as e:
|
54 |
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
55 |
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
@@ -60,20 +61,21 @@ def set_api_key():
|
|
60 |
"temperature": 0.9})
|
61 |
if not tc:
|
62 |
raise Exception(m)
|
63 |
-
chat_passed = True
|
64 |
except Exception as e:
|
65 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
66 |
e)
|
67 |
-
|
|
|
68 |
mdl = RerankModel[factory](
|
69 |
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
70 |
try:
|
71 |
-
|
72 |
-
if len(arr
|
73 |
raise Exception("Fail")
|
74 |
except Exception as e:
|
75 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
76 |
e)
|
|
|
77 |
|
78 |
if msg:
|
79 |
return get_data_error_result(retmsg=msg)
|
|
|
39 |
def set_api_key():
|
40 |
req = request.json
|
41 |
# test if api key works
|
42 |
+
chat_passed, embd_passed, rerank_passed = False, False, False
|
43 |
factory = req["llm_factory"]
|
44 |
msg = ""
|
45 |
for llm in LLMService.query(fid=factory):
|
46 |
+
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
47 |
mdl = EmbeddingModel[factory](
|
48 |
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
49 |
try:
|
50 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
51 |
if len(arr[0]) == 0 or tc == 0:
|
52 |
raise Exception("Fail")
|
53 |
+
embd_passed = True
|
54 |
except Exception as e:
|
55 |
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
56 |
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
|
|
61 |
"temperature": 0.9})
|
62 |
if not tc:
|
63 |
raise Exception(m)
|
|
|
64 |
except Exception as e:
|
65 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
66 |
e)
|
67 |
+
chat_passed = True
|
68 |
+
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
69 |
mdl = RerankModel[factory](
|
70 |
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
71 |
try:
|
72 |
+
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
73 |
+
if len(arr) == 0 or tc == 0:
|
74 |
raise Exception("Fail")
|
75 |
except Exception as e:
|
76 |
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
|
77 |
e)
|
78 |
+
rerank_passed = True
|
79 |
|
80 |
if msg:
|
81 |
return get_data_error_result(retmsg=msg)
|
api/db/services/llm_service.py
CHANGED
@@ -147,7 +147,6 @@ class TenantLLMService(CommonService):
|
|
147 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
148 |
.execute()
|
149 |
except Exception as e:
|
150 |
-
print(e)
|
151 |
pass
|
152 |
return num
|
153 |
|
|
|
147 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
148 |
.execute()
|
149 |
except Exception as e:
|
|
|
150 |
pass
|
151 |
return num
|
152 |
|
rag/llm/__init__.py
CHANGED
@@ -28,6 +28,7 @@ EmbeddingModel = {
|
|
28 |
"FastEmbed": FastEmbed,
|
29 |
"Youdao": YoudaoEmbed,
|
30 |
"BaiChuan": BaiChuanEmbed,
|
|
|
31 |
"BAAI": DefaultEmbedding
|
32 |
}
|
33 |
|
|
|
28 |
"FastEmbed": FastEmbed,
|
29 |
"Youdao": YoudaoEmbed,
|
30 |
"BaiChuan": BaiChuanEmbed,
|
31 |
+
"Jina": JinaEmbed,
|
32 |
"BAAI": DefaultEmbedding
|
33 |
}
|
34 |
|
rag/llm/embedding_model.py
CHANGED
@@ -291,7 +291,7 @@ class JinaEmbed(Base):
|
|
291 |
"input": texts,
|
292 |
'encoding_type': 'float'
|
293 |
}
|
294 |
-
res = requests.post(self.base_url, headers=self.headers, json=data)
|
295 |
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
296 |
|
297 |
def encode_queries(self, text):
|
|
|
291 |
"input": texts,
|
292 |
'encoding_type': 'float'
|
293 |
}
|
294 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
295 |
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
296 |
|
297 |
def encode_queries(self, text):
|
rag/llm/rerank_model.py
CHANGED
@@ -91,7 +91,7 @@ class JinaRerank(Base):
|
|
91 |
"documents": texts,
|
92 |
"top_n": len(texts)
|
93 |
}
|
94 |
-
res = requests.post(self.base_url, headers=self.headers, json=data)
|
95 |
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
|
96 |
|
97 |
|
|
|
91 |
"documents": texts,
|
92 |
"top_n": len(texts)
|
93 |
}
|
94 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
95 |
return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
|
96 |
|
97 |
|
rag/nlp/query.py
CHANGED
@@ -44,7 +44,7 @@ class EsQueryer:
|
|
44 |
|
45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
46 |
txt = re.sub(
|
47 |
-
r"[
|
48 |
" ",
|
49 |
rag_tokenizer.tradi2simp(
|
50 |
rag_tokenizer.strQ2B(
|
|
|
44 |
|
45 |
def question(self, txt, tbl="qa", min_match="60%"):
|
46 |
txt = re.sub(
|
47 |
+
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
48 |
" ",
|
49 |
rag_tokenizer.tradi2simp(
|
50 |
rag_tokenizer.strQ2B(
|
rag/nlp/term_weight.py
CHANGED
@@ -104,7 +104,7 @@ class Dealer:
|
|
104 |
while i < len(tks):
|
105 |
j = i
|
106 |
if i == 0 and oneTerm(tks[i]) and len(
|
107 |
-
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
|
108 |
res.append(" ".join(tks[0:2]))
|
109 |
i = 2
|
110 |
continue
|
|
|
104 |
while i < len(tks):
|
105 |
j = i
|
106 |
if i == 0 and oneTerm(tks[i]) and len(
|
107 |
+
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
108 |
res.append(" ".join(tks[0:2]))
|
109 |
i = 2
|
110 |
continue
|