KevinHuSh
commited on
Commit
·
34b2ab3
1
Parent(s):
484e5ab
Test APIs and fix bugs (#41)
Browse files- api/apps/chunk_app.py +1 -1
- api/apps/conversation_app.py +5 -3
- api/apps/document_app.py +7 -3
- api/apps/llm_app.py +1 -1
- api/db/db_models.py +2 -2
- api/db/services/llm_service.py +1 -1
- api/utils/file_utils.py +1 -1
- rag/llm/chat_model.py +14 -6
- rag/llm/cv_model.py +3 -1
- rag/nlp/search.py +6 -3
- rag/svr/parse_user_docs.py +5 -5
api/apps/chunk_app.py
CHANGED
|
@@ -214,7 +214,7 @@ def retrieval_test():
|
|
| 214 |
question = req["question"]
|
| 215 |
kb_id = req["kb_id"]
|
| 216 |
doc_ids = req.get("doc_ids", [])
|
| 217 |
-
similarity_threshold = float(req.get("similarity_threshold", 0.
|
| 218 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
| 219 |
top = int(req.get("top", 1024))
|
| 220 |
try:
|
|
|
|
| 214 |
question = req["question"]
|
| 215 |
kb_id = req["kb_id"]
|
| 216 |
doc_ids = req.get("doc_ids", [])
|
| 217 |
+
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
| 218 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
| 219 |
top = int(req.get("top", 1024))
|
| 220 |
try:
|
api/apps/conversation_app.py
CHANGED
|
@@ -170,7 +170,7 @@ def chat(dialog, messages, **kwargs):
|
|
| 170 |
if p["key"] not in kwargs:
|
| 171 |
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
| 172 |
|
| 173 |
-
model_config = TenantLLMService.get_api_key(dialog.tenant_id,
|
| 174 |
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
|
| 175 |
|
| 176 |
question = messages[-1]["content"]
|
|
@@ -186,10 +186,10 @@ def chat(dialog, messages, **kwargs):
|
|
| 186 |
kwargs["knowledge"] = "\n".join(knowledges)
|
| 187 |
gen_conf = dialog.llm_setting[dialog.llm_setting_type]
|
| 188 |
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
| 189 |
-
used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
| 190 |
if "max_tokens" in gen_conf:
|
| 191 |
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
| 192 |
-
mdl = ChatModel[model_config.llm_factory](model_config
|
| 193 |
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
| 194 |
|
| 195 |
answer = retrievaler.insert_citations(answer,
|
|
@@ -198,4 +198,6 @@ def chat(dialog, messages, **kwargs):
|
|
| 198 |
embd_mdl,
|
| 199 |
tkweight=1-dialog.vector_similarity_weight,
|
| 200 |
vtweight=dialog.vector_similarity_weight)
|
|
|
|
|
|
|
| 201 |
return {"answer": answer, "retrieval": kbinfos}
|
|
|
|
| 170 |
if p["key"] not in kwargs:
|
| 171 |
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
| 172 |
|
| 173 |
+
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
|
| 174 |
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
|
| 175 |
|
| 176 |
question = messages[-1]["content"]
|
|
|
|
| 186 |
kwargs["knowledge"] = "\n".join(knowledges)
|
| 187 |
gen_conf = dialog.llm_setting[dialog.llm_setting_type]
|
| 188 |
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
| 189 |
+
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
| 190 |
if "max_tokens" in gen_conf:
|
| 191 |
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
| 192 |
+
mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
|
| 193 |
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
| 194 |
|
| 195 |
answer = retrievaler.insert_citations(answer,
|
|
|
|
| 198 |
embd_mdl,
|
| 199 |
tkweight=1-dialog.vector_similarity_weight,
|
| 200 |
vtweight=dialog.vector_similarity_weight)
|
| 201 |
+
for c in kbinfos["chunks"]:
|
| 202 |
+
if c.get("vector"):del c["vector"]
|
| 203 |
return {"answer": answer, "retrieval": kbinfos}
|
api/apps/document_app.py
CHANGED
|
@@ -11,7 +11,8 @@
|
|
| 11 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License
|
|
|
|
| 15 |
#
|
| 16 |
import base64
|
| 17 |
import pathlib
|
|
@@ -65,7 +66,7 @@ def upload():
|
|
| 65 |
while MINIO.obj_exist(kb_id, location):
|
| 66 |
location += "_"
|
| 67 |
blob = request.files['file'].read()
|
| 68 |
-
MINIO.put(kb_id,
|
| 69 |
doc = DocumentService.insert({
|
| 70 |
"id": get_uuid(),
|
| 71 |
"kb_id": kb.id,
|
|
@@ -188,7 +189,10 @@ def rm():
|
|
| 188 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 189 |
if not e:
|
| 190 |
return get_data_error_result(retmsg="Document not found!")
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
| 194 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
|
|
|
| 11 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
#
|
| 16 |
#
|
| 17 |
import base64
|
| 18 |
import pathlib
|
|
|
|
| 66 |
while MINIO.obj_exist(kb_id, location):
|
| 67 |
location += "_"
|
| 68 |
blob = request.files['file'].read()
|
| 69 |
+
MINIO.put(kb_id, location, blob)
|
| 70 |
doc = DocumentService.insert({
|
| 71 |
"id": get_uuid(),
|
| 72 |
"kb_id": kb.id,
|
|
|
|
| 189 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 190 |
if not e:
|
| 191 |
return get_data_error_result(retmsg="Document not found!")
|
| 192 |
+
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
| 193 |
+
if not tenant_id:
|
| 194 |
+
return get_data_error_result(retmsg="Tenant not found!")
|
| 195 |
+
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
| 196 |
|
| 197 |
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
| 198 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
api/apps/llm_app.py
CHANGED
|
@@ -75,7 +75,7 @@ def list():
|
|
| 75 |
llms = LLMService.get_all()
|
| 76 |
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
| 77 |
for m in llms:
|
| 78 |
-
m["available"] = m
|
| 79 |
|
| 80 |
res = {}
|
| 81 |
for m in llms:
|
|
|
|
| 75 |
llms = LLMService.get_all()
|
| 76 |
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
| 77 |
for m in llms:
|
| 78 |
+
m["available"] = m["llm_name"] in mdlnms
|
| 79 |
|
| 80 |
res = {}
|
| 81 |
for m in llms:
|
api/db/db_models.py
CHANGED
|
@@ -469,7 +469,7 @@ class Knowledgebase(DataBaseModel):
|
|
| 469 |
doc_num = IntegerField(default=0)
|
| 470 |
token_num = IntegerField(default=0)
|
| 471 |
chunk_num = IntegerField(default=0)
|
| 472 |
-
similarity_threshold = FloatField(default=0.
|
| 473 |
vector_similarity_weight = FloatField(default=0.3)
|
| 474 |
|
| 475 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
|
@@ -521,7 +521,7 @@ class Dialog(DataBaseModel):
|
|
| 521 |
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
| 522 |
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
| 523 |
|
| 524 |
-
similarity_threshold = FloatField(default=0.
|
| 525 |
vector_similarity_weight = FloatField(default=0.3)
|
| 526 |
top_n = IntegerField(default=6)
|
| 527 |
|
|
|
|
| 469 |
doc_num = IntegerField(default=0)
|
| 470 |
token_num = IntegerField(default=0)
|
| 471 |
chunk_num = IntegerField(default=0)
|
| 472 |
+
similarity_threshold = FloatField(default=0.2)
|
| 473 |
vector_similarity_weight = FloatField(default=0.3)
|
| 474 |
|
| 475 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
|
|
|
| 521 |
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
| 522 |
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
| 523 |
|
| 524 |
+
similarity_threshold = FloatField(default=0.2)
|
| 525 |
vector_similarity_weight = FloatField(default=0.3)
|
| 526 |
top_n = IntegerField(default=6)
|
| 527 |
|
api/db/services/llm_service.py
CHANGED
|
@@ -63,7 +63,7 @@ class TenantLLMService(CommonService):
|
|
| 63 |
|
| 64 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 65 |
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
|
| 66 |
-
model_config = model_config
|
| 67 |
if llm_type == LLMType.EMBEDDING.value:
|
| 68 |
if model_config["llm_factory"] not in EmbeddingModel: return
|
| 69 |
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
|
|
|
| 63 |
|
| 64 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 65 |
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
|
| 66 |
+
model_config = model_config.to_dict()
|
| 67 |
if llm_type == LLMType.EMBEDDING.value:
|
| 68 |
if model_config["llm_factory"] not in EmbeddingModel: return
|
| 69 |
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
api/utils/file_utils.py
CHANGED
|
@@ -143,7 +143,7 @@ def filename_type(filename):
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
-
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
|
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
+
if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
rag/llm/chat_model.py
CHANGED
|
@@ -19,31 +19,39 @@ import os
|
|
| 19 |
|
| 20 |
|
| 21 |
class Base(ABC):
|
|
|
|
|
|
|
|
|
|
| 22 |
def chat(self, system, history, gen_conf):
|
| 23 |
raise NotImplementedError("Please implement encode method!")
|
| 24 |
|
| 25 |
|
| 26 |
class GptTurbo(Base):
|
| 27 |
-
def __init__(self):
|
| 28 |
-
self.client = OpenAI(api_key=
|
|
|
|
| 29 |
|
| 30 |
def chat(self, system, history, gen_conf):
|
| 31 |
history.insert(0, {"role": "system", "content": system})
|
| 32 |
res = self.client.chat.completions.create(
|
| 33 |
-
model=
|
| 34 |
messages=history,
|
| 35 |
**gen_conf)
|
| 36 |
return res.choices[0].message.content.strip()
|
| 37 |
|
| 38 |
|
|
|
|
| 39 |
class QWenChat(Base):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def chat(self, system, history, gen_conf):
|
| 41 |
from http import HTTPStatus
|
| 42 |
-
from dashscope import Generation
|
| 43 |
-
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
| 44 |
history.insert(0, {"role": "system", "content": system})
|
| 45 |
response = Generation.call(
|
| 46 |
-
|
| 47 |
messages=history,
|
| 48 |
result_format='message'
|
| 49 |
)
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class Base(ABC):
|
| 22 |
+
def __init__(self, key, model_name):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
def chat(self, system, history, gen_conf):
|
| 26 |
raise NotImplementedError("Please implement encode method!")
|
| 27 |
|
| 28 |
|
| 29 |
class GptTurbo(Base):
|
| 30 |
+
def __init__(self, key, model_name="gpt-3.5-turbo"):
|
| 31 |
+
self.client = OpenAI(api_key=key)
|
| 32 |
+
self.model_name = model_name
|
| 33 |
|
| 34 |
def chat(self, system, history, gen_conf):
|
| 35 |
history.insert(0, {"role": "system", "content": system})
|
| 36 |
res = self.client.chat.completions.create(
|
| 37 |
+
model=self.model_name,
|
| 38 |
messages=history,
|
| 39 |
**gen_conf)
|
| 40 |
return res.choices[0].message.content.strip()
|
| 41 |
|
| 42 |
|
| 43 |
+
from dashscope import Generation
|
| 44 |
class QWenChat(Base):
|
| 45 |
+
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
|
| 46 |
+
import dashscope
|
| 47 |
+
dashscope.api_key = key
|
| 48 |
+
self.model_name = model_name
|
| 49 |
+
|
| 50 |
def chat(self, system, history, gen_conf):
|
| 51 |
from http import HTTPStatus
|
|
|
|
|
|
|
| 52 |
history.insert(0, {"role": "system", "content": system})
|
| 53 |
response = Generation.call(
|
| 54 |
+
self.model_name,
|
| 55 |
messages=history,
|
| 56 |
result_format='message'
|
| 57 |
)
|
rag/llm/cv_model.py
CHANGED
|
@@ -28,6 +28,8 @@ class Base(ABC):
|
|
| 28 |
raise NotImplementedError("Please implement encode method!")
|
| 29 |
|
| 30 |
def image2base64(self, image):
|
|
|
|
|
|
|
| 31 |
if isinstance(image, BytesIO):
|
| 32 |
return base64.b64encode(image.getvalue()).decode("utf-8")
|
| 33 |
buffered = BytesIO()
|
|
@@ -59,7 +61,7 @@ class Base(ABC):
|
|
| 59 |
|
| 60 |
class GptV4(Base):
|
| 61 |
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
| 62 |
-
self.client = OpenAI(key)
|
| 63 |
self.model_name = model_name
|
| 64 |
|
| 65 |
def describe(self, image, max_tokens=300):
|
|
|
|
| 28 |
raise NotImplementedError("Please implement encode method!")
|
| 29 |
|
| 30 |
def image2base64(self, image):
|
| 31 |
+
if isinstance(image, bytes):
|
| 32 |
+
return base64.b64encode(image).decode("utf-8")
|
| 33 |
if isinstance(image, BytesIO):
|
| 34 |
return base64.b64encode(image.getvalue()).decode("utf-8")
|
| 35 |
buffered = BytesIO()
|
|
|
|
| 61 |
|
| 62 |
class GptV4(Base):
|
| 63 |
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
| 64 |
+
self.client = OpenAI(api_key = key)
|
| 65 |
self.model_name = model_name
|
| 66 |
|
| 67 |
def describe(self, image, max_tokens=300):
|
rag/nlp/search.py
CHANGED
|
@@ -187,9 +187,10 @@ class Dealer:
|
|
| 187 |
if len(t) < 5: continue
|
| 188 |
idx.append(i)
|
| 189 |
pieces_.append(t)
|
|
|
|
| 190 |
if not pieces_: return answer
|
| 191 |
|
| 192 |
-
ans_v = embd_mdl.encode(pieces_)
|
| 193 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
| 194 |
len(ans_v[0]), len(chunk_v[0]))
|
| 195 |
|
|
@@ -219,7 +220,7 @@ class Dealer:
|
|
| 219 |
Dealer.trans2floats(
|
| 220 |
sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
|
| 221 |
if not ins_embd:
|
| 222 |
-
return []
|
| 223 |
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
|
| 224 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
| 225 |
ins_embd,
|
|
@@ -235,6 +236,8 @@ class Dealer:
|
|
| 235 |
|
| 236 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 237 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
|
|
|
|
|
|
| 238 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
|
| 239 |
"question": question, "vector": True,
|
| 240 |
"similarity": similarity_threshold}
|
|
@@ -243,7 +246,7 @@ class Dealer:
|
|
| 243 |
sim, tsim, vsim = self.rerank(
|
| 244 |
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
| 245 |
idx = np.argsort(sim * -1)
|
| 246 |
-
|
| 247 |
dim = len(sres.query_vector)
|
| 248 |
start_idx = (page - 1) * page_size
|
| 249 |
for i in idx:
|
|
|
|
| 187 |
if len(t) < 5: continue
|
| 188 |
idx.append(i)
|
| 189 |
pieces_.append(t)
|
| 190 |
+
es_logger.info("{} => {}".format(answer, pieces_))
|
| 191 |
if not pieces_: return answer
|
| 192 |
|
| 193 |
+
ans_v, c = embd_mdl.encode(pieces_)
|
| 194 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
| 195 |
len(ans_v[0]), len(chunk_v[0]))
|
| 196 |
|
|
|
|
| 220 |
Dealer.trans2floats(
|
| 221 |
sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
|
| 222 |
if not ins_embd:
|
| 223 |
+
return [], [], []
|
| 224 |
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
|
| 225 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
| 226 |
ins_embd,
|
|
|
|
| 236 |
|
| 237 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 238 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
| 239 |
+
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
| 240 |
+
if not question: return ranks
|
| 241 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
|
| 242 |
"question": question, "vector": True,
|
| 243 |
"similarity": similarity_threshold}
|
|
|
|
| 246 |
sim, tsim, vsim = self.rerank(
|
| 247 |
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
|
| 248 |
idx = np.argsort(sim * -1)
|
| 249 |
+
|
| 250 |
dim = len(sres.query_vector)
|
| 251 |
start_idx = (page - 1) * page_size
|
| 252 |
for i in idx:
|
rag/svr/parse_user_docs.py
CHANGED
|
@@ -78,6 +78,7 @@ def chuck_doc(name, binary, cvmdl=None):
|
|
| 78 |
field = TextChunker.Fields()
|
| 79 |
field.text_chunks = [(txt, binary)]
|
| 80 |
field.table_chunks = []
|
|
|
|
| 81 |
|
| 82 |
return TextChunker()(binary)
|
| 83 |
|
|
@@ -161,9 +162,9 @@ def build(row, cvmdl):
|
|
| 161 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
| 162 |
output_buffer = BytesIO()
|
| 163 |
docs = []
|
| 164 |
-
md5 = hashlib.md5()
|
| 165 |
for txt, img in obj.text_chunks:
|
| 166 |
d = copy.deepcopy(doc)
|
|
|
|
| 167 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
| 168 |
d["_id"] = md5.hexdigest()
|
| 169 |
d["content_ltks"] = huqie.qie(txt)
|
|
@@ -186,6 +187,7 @@ def build(row, cvmdl):
|
|
| 186 |
for i, txt in enumerate(arr):
|
| 187 |
d = copy.deepcopy(doc)
|
| 188 |
d["content_ltks"] = huqie.qie(txt)
|
|
|
|
| 189 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
| 190 |
d["_id"] = md5.hexdigest()
|
| 191 |
if not img:
|
|
@@ -226,9 +228,6 @@ def embedding(docs, mdl):
|
|
| 226 |
|
| 227 |
|
| 228 |
def main(comm, mod):
|
| 229 |
-
global model
|
| 230 |
-
from rag.llm import HuEmbedding
|
| 231 |
-
model = HuEmbedding()
|
| 232 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
| 233 |
tm = findMaxTm(tm_fnm)
|
| 234 |
rows = collect(comm, mod, tm)
|
|
@@ -260,13 +259,14 @@ def main(comm, mod):
|
|
| 260 |
set_progress(r["id"], random.randint(70, 95) / 100.,
|
| 261 |
"Finished embedding! Start to build index!")
|
| 262 |
init_kb(r)
|
|
|
|
| 263 |
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
| 264 |
if es_r:
|
| 265 |
set_progress(r["id"], -1, "Index failure!")
|
| 266 |
cron_logger.error(str(es_r))
|
| 267 |
else:
|
| 268 |
set_progress(r["id"], 1., "Done!")
|
| 269 |
-
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count,
|
| 270 |
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
| 271 |
|
| 272 |
tmf.write(str(r["update_time"]) + "\n")
|
|
|
|
| 78 |
field = TextChunker.Fields()
|
| 79 |
field.text_chunks = [(txt, binary)]
|
| 80 |
field.table_chunks = []
|
| 81 |
+
return field
|
| 82 |
|
| 83 |
return TextChunker()(binary)
|
| 84 |
|
|
|
|
| 162 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
| 163 |
output_buffer = BytesIO()
|
| 164 |
docs = []
|
|
|
|
| 165 |
for txt, img in obj.text_chunks:
|
| 166 |
d = copy.deepcopy(doc)
|
| 167 |
+
md5 = hashlib.md5()
|
| 168 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
| 169 |
d["_id"] = md5.hexdigest()
|
| 170 |
d["content_ltks"] = huqie.qie(txt)
|
|
|
|
| 187 |
for i, txt in enumerate(arr):
|
| 188 |
d = copy.deepcopy(doc)
|
| 189 |
d["content_ltks"] = huqie.qie(txt)
|
| 190 |
+
md5 = hashlib.md5()
|
| 191 |
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
| 192 |
d["_id"] = md5.hexdigest()
|
| 193 |
if not img:
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
def main(comm, mod):
|
|
|
|
|
|
|
|
|
|
| 231 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
| 232 |
tm = findMaxTm(tm_fnm)
|
| 233 |
rows = collect(comm, mod, tm)
|
|
|
|
| 259 |
set_progress(r["id"], random.randint(70, 95) / 100.,
|
| 260 |
"Finished embedding! Start to build index!")
|
| 261 |
init_kb(r)
|
| 262 |
+
chunk_count = len(set([c["_id"] for c in cks]))
|
| 263 |
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
| 264 |
if es_r:
|
| 265 |
set_progress(r["id"], -1, "Index failure!")
|
| 266 |
cron_logger.error(str(es_r))
|
| 267 |
else:
|
| 268 |
set_progress(r["id"], 1., "Done!")
|
| 269 |
+
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm)
|
| 270 |
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
| 271 |
|
| 272 |
tmf.write(str(r["update_time"]) + "\n")
|