KevinHuSh
commited on
Commit
·
a49657b
1
Parent(s):
13080d4
add self-rag (#1070)
Browse files### What problem does this PR solve?
#1069
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +24 -20
- api/apps/canvas_app.py +112 -0
- api/apps/conversation_app.py +3 -2
- api/apps/dialog_app.py +2 -2
- api/db/services/canvas_service.py +26 -0
- api/db/services/dialog_service.py +59 -10
- deepdoc/parser/pdf_parser.py +2 -0
- rag/llm/rerank_model.py +0 -1
- rag/nlp/query.py +6 -3
api/apps/api_app.py
CHANGED
|
@@ -198,15 +198,18 @@ def completion():
|
|
| 198 |
else: conv.reference[-1] = ans["reference"]
|
| 199 |
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
def stream():
|
| 202 |
nonlocal dia, msg, req, conv
|
| 203 |
try:
|
| 204 |
for ans in chat(dia, msg, True, **req):
|
| 205 |
fillin_conv(ans)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
chunk_i.pop('docnm_kwd')
|
| 209 |
-
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 210 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 211 |
except Exception as e:
|
| 212 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
|
@@ -554,23 +557,24 @@ def completion_faq():
|
|
| 554 |
"content": ""
|
| 555 |
}
|
| 556 |
]
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
fillin_conv(ans)
|
| 561 |
-
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 562 |
-
|
| 563 |
-
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
| 564 |
-
for chunk_idx in chunk_idxs[:1]:
|
| 565 |
-
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
| 566 |
-
try:
|
| 567 |
-
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
| 568 |
-
response = MINIO.get(bkt, nm)
|
| 569 |
-
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
| 570 |
-
data.append(data_type_picture)
|
| 571 |
-
except Exception as e:
|
| 572 |
-
return server_error_response(e)
|
| 573 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
response = {"code": 200, "msg": "success", "data": data}
|
| 576 |
return response
|
|
|
|
| 198 |
else: conv.reference[-1] = ans["reference"]
|
| 199 |
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
| 200 |
|
| 201 |
+
def rename_field(ans):
|
| 202 |
+
for chunk_i in ans['reference'].get('chunks', []):
|
| 203 |
+
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
| 204 |
+
chunk_i.pop('docnm_kwd')
|
| 205 |
+
|
| 206 |
def stream():
|
| 207 |
nonlocal dia, msg, req, conv
|
| 208 |
try:
|
| 209 |
for ans in chat(dia, msg, True, **req):
|
| 210 |
fillin_conv(ans)
|
| 211 |
+
rename_field(rename_field)
|
| 212 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
|
|
|
|
|
|
| 213 |
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 214 |
except Exception as e:
|
| 215 |
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
|
|
|
| 557 |
"content": ""
|
| 558 |
}
|
| 559 |
]
|
| 560 |
+
ans = ""
|
| 561 |
+
for a in chat(dia, msg, stream=False, **req):
|
| 562 |
+
ans = a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
break
|
| 564 |
+
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
| 565 |
+
fillin_conv(ans)
|
| 566 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 567 |
+
|
| 568 |
+
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
| 569 |
+
for chunk_idx in chunk_idxs[:1]:
|
| 570 |
+
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
| 571 |
+
try:
|
| 572 |
+
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
| 573 |
+
response = MINIO.get(bkt, nm)
|
| 574 |
+
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
| 575 |
+
data.append(data_type_picture)
|
| 576 |
+
except Exception as e:
|
| 577 |
+
return server_error_response(e)
|
| 578 |
|
| 579 |
response = {"code": 200, "msg": "success", "data": data}
|
| 580 |
return response
|
api/apps/canvas_app.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
+
from flask import request
|
| 19 |
+
from flask_login import login_required, current_user
|
| 20 |
+
|
| 21 |
+
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
| 22 |
+
from api.utils import get_uuid
|
| 23 |
+
from api.utils.api_utils import get_json_result, server_error_response, validate_request
|
| 24 |
+
from graph.canvas import Canvas
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@manager.route('/templates', methods=['GET'])
|
| 28 |
+
@login_required
|
| 29 |
+
def templates():
|
| 30 |
+
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@manager.route('/list', methods=['GET'])
|
| 34 |
+
@login_required
|
| 35 |
+
def canvas_list():
|
| 36 |
+
|
| 37 |
+
return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@manager.route('/rm', methods=['POST'])
|
| 41 |
+
@validate_request("canvas_ids")
|
| 42 |
+
@login_required
|
| 43 |
+
def rm():
|
| 44 |
+
for i in request.json["canvas_ids"]:
|
| 45 |
+
UserCanvasService.delete_by_id(i)
|
| 46 |
+
return get_json_result(data=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@manager.route('/set', methods=['POST'])
|
| 50 |
+
@validate_request("dsl", "title")
|
| 51 |
+
@login_required
|
| 52 |
+
def save():
|
| 53 |
+
req = request.json
|
| 54 |
+
req["user_id"] = current_user.id
|
| 55 |
+
if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
| 56 |
+
try:
|
| 57 |
+
Canvas(req["dsl"])
|
| 58 |
+
except Exception as e:
|
| 59 |
+
return server_error_response(e)
|
| 60 |
+
|
| 61 |
+
req["dsl"] = json.loads(req["dsl"])
|
| 62 |
+
if "id" not in req:
|
| 63 |
+
req["id"] = get_uuid()
|
| 64 |
+
if not UserCanvasService.save(**req):
|
| 65 |
+
return server_error_response("Fail to save canvas.")
|
| 66 |
+
else:
|
| 67 |
+
UserCanvasService.update_by_id(req["id"], req)
|
| 68 |
+
|
| 69 |
+
return get_json_result(data=req)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@manager.route('/get/<canvas_id>', methods=['GET'])
|
| 73 |
+
@login_required
|
| 74 |
+
def get(canvas_id):
|
| 75 |
+
e, c = UserCanvasService.get_by_id(canvas_id)
|
| 76 |
+
if not e:
|
| 77 |
+
return server_error_response("canvas not found.")
|
| 78 |
+
return get_json_result(data=c.to_dict())
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@manager.route('/run', methods=['POST'])
|
| 82 |
+
@validate_request("id", "dsl")
|
| 83 |
+
@login_required
|
| 84 |
+
def run():
|
| 85 |
+
req = request.json
|
| 86 |
+
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
| 87 |
+
try:
|
| 88 |
+
canvas = Canvas(req["dsl"], current_user.id)
|
| 89 |
+
ans = canvas.run()
|
| 90 |
+
req["dsl"] = json.loads(str(canvas))
|
| 91 |
+
UserCanvasService.update_by_id(req["id"], dsl=req["dsl"])
|
| 92 |
+
return get_json_result(data=req["dsl"])
|
| 93 |
+
except Exception as e:
|
| 94 |
+
return server_error_response(e)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@manager.route('/reset', methods=['POST'])
|
| 98 |
+
@validate_request("canvas_id")
|
| 99 |
+
@login_required
|
| 100 |
+
def reset():
|
| 101 |
+
req = request.json
|
| 102 |
+
try:
|
| 103 |
+
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
|
| 104 |
+
canvas = Canvas(req["dsl"], current_user.id)
|
| 105 |
+
canvas.reset()
|
| 106 |
+
req["dsl"] = json.loads(str(canvas))
|
| 107 |
+
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
|
| 108 |
+
return get_json_result(data=req["dsl"])
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return server_error_response(e)
|
| 111 |
+
|
| 112 |
+
|
api/apps/conversation_app.py
CHANGED
|
@@ -13,7 +13,8 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
-
from
|
|
|
|
| 17 |
from flask_login import login_required
|
| 18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
| 19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
@@ -121,7 +122,7 @@ def completion():
|
|
| 121 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
| 122 |
if not e:
|
| 123 |
return get_data_error_result(retmsg="Conversation not found!")
|
| 124 |
-
conv.message.append(msg[-1])
|
| 125 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 126 |
if not e:
|
| 127 |
return get_data_error_result(retmsg="Dialog not found!")
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
from copy import deepcopy
|
| 17 |
+
from flask import request, Response
|
| 18 |
from flask_login import login_required
|
| 19 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
| 20 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
|
|
| 122 |
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
| 123 |
if not e:
|
| 124 |
return get_data_error_result(retmsg="Conversation not found!")
|
| 125 |
+
conv.message.append(deepcopy(msg[-1]))
|
| 126 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 127 |
if not e:
|
| 128 |
return get_data_error_result(retmsg="Dialog not found!")
|
api/apps/dialog_app.py
CHANGED
|
@@ -31,8 +31,8 @@ def set_dialog():
|
|
| 31 |
req = request.json
|
| 32 |
dialog_id = req.get("dialog_id")
|
| 33 |
name = req.get("name", "New Dialog")
|
| 34 |
-
icon = req.get("icon", "")
|
| 35 |
description = req.get("description", "A helpful Dialog")
|
|
|
|
| 36 |
top_n = req.get("top_n", 6)
|
| 37 |
top_k = req.get("top_k", 1024)
|
| 38 |
rerank_id = req.get("rerank_id", "")
|
|
@@ -92,7 +92,7 @@ def set_dialog():
|
|
| 92 |
"rerank_id": rerank_id,
|
| 93 |
"similarity_threshold": similarity_threshold,
|
| 94 |
"vector_similarity_weight": vector_similarity_weight,
|
| 95 |
-
"icon": icon
|
| 96 |
}
|
| 97 |
if not DialogService.save(**dia):
|
| 98 |
return get_data_error_result(retmsg="Fail to new a dialog!")
|
|
|
|
| 31 |
req = request.json
|
| 32 |
dialog_id = req.get("dialog_id")
|
| 33 |
name = req.get("name", "New Dialog")
|
|
|
|
| 34 |
description = req.get("description", "A helpful Dialog")
|
| 35 |
+
icon = req.get("icon", "")
|
| 36 |
top_n = req.get("top_n", 6)
|
| 37 |
top_k = req.get("top_k", 1024)
|
| 38 |
rerank_id = req.get("rerank_id", "")
|
|
|
|
| 92 |
"rerank_id": rerank_id,
|
| 93 |
"similarity_threshold": similarity_threshold,
|
| 94 |
"vector_similarity_weight": vector_similarity_weight,
|
| 95 |
+
"icon": icon
|
| 96 |
}
|
| 97 |
if not DialogService.save(**dia):
|
| 98 |
return get_data_error_result(retmsg="Fail to new a dialog!")
|
api/db/services/canvas_service.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
import peewee
|
| 18 |
+
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
| 19 |
+
from api.db.services.common_service import CommonService
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CanvasTemplateService(CommonService):
|
| 23 |
+
model = CanvasTemplate
|
| 24 |
+
|
| 25 |
+
class UserCanvasService(CommonService):
|
| 26 |
+
model = UserCanvas
|
api/db/services/dialog_service.py
CHANGED
|
@@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
| 23 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
| 24 |
from api.settings import chat_logger, retrievaler
|
| 25 |
from rag.app.resume import forbidden_select_fields4resume
|
|
|
|
| 26 |
from rag.nlp.search import index_name
|
| 27 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
| 28 |
|
|
@@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 80 |
if not llm:
|
| 81 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 82 |
max_tokens = 1024
|
| 83 |
-
else:
|
|
|
|
| 84 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 85 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 86 |
if len(embd_nms) != 1:
|
|
@@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 124 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
| 125 |
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
| 126 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
chat_logger.info(
|
| 128 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 129 |
|
|
@@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 136 |
|
| 137 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
| 138 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
| 139 |
-
|
| 140 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
| 141 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
| 142 |
|
|
@@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 150 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 151 |
answer, idx = retrievaler.insert_citations(answer,
|
| 152 |
[ck["content_ltks"]
|
| 153 |
-
|
| 154 |
[ck["vector"]
|
| 155 |
-
|
| 156 |
embd_mdl,
|
| 157 |
tkweight=1 - dialog.vector_similarity_weight,
|
| 158 |
vtweight=dialog.vector_similarity_weight)
|
|
@@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 166 |
for c in refs["chunks"]:
|
| 167 |
if c.get("vector"):
|
| 168 |
del c["vector"]
|
| 169 |
-
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
| 170 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 171 |
return {"answer": answer, "reference": refs}
|
| 172 |
|
|
@@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 204 |
def get_table():
|
| 205 |
nonlocal sys_prompt, user_promt, question, tried_times
|
| 206 |
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
| 207 |
-
|
| 208 |
print(user_promt, sql)
|
| 209 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
| 210 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
|
@@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 273 |
|
| 274 |
# compose markdown table
|
| 275 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
| 276 |
-
|
|
|
|
| 277 |
|
| 278 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
| 279 |
-
|
| 280 |
|
| 281 |
rows = ["|" +
|
| 282 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
| 283 |
"|" for r in tbl["rows"]]
|
| 284 |
if quota:
|
| 285 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
| 286 |
-
else:
|
|
|
|
| 287 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
| 288 |
|
| 289 |
if not docid_idx or not docnm_idx:
|
|
@@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 303 |
return {
|
| 304 |
"answer": "\n".join([clmns, line, rows]),
|
| 305 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
| 306 |
-
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
|
|
|
| 307 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
| 24 |
from api.settings import chat_logger, retrievaler
|
| 25 |
from rag.app.resume import forbidden_select_fields4resume
|
| 26 |
+
from rag.nlp.rag_tokenizer import is_chinese
|
| 27 |
from rag.nlp.search import index_name
|
| 28 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
| 29 |
|
|
|
|
| 81 |
if not llm:
|
| 82 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 83 |
max_tokens = 1024
|
| 84 |
+
else:
|
| 85 |
+
max_tokens = llm[0].max_tokens
|
| 86 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 87 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 88 |
if len(embd_nms) != 1:
|
|
|
|
| 126 |
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
| 127 |
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
| 128 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 129 |
+
#self-rag
|
| 130 |
+
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
| 131 |
+
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
| 132 |
+
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
| 133 |
+
dialog.similarity_threshold,
|
| 134 |
+
dialog.vector_similarity_weight,
|
| 135 |
+
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
| 136 |
+
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
| 137 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 138 |
+
|
| 139 |
chat_logger.info(
|
| 140 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 141 |
|
|
|
|
| 148 |
|
| 149 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
| 150 |
msg.extend([{"role": m["role"], "content": m["content"]}
|
| 151 |
+
for m in messages if m["role"] != "system"])
|
| 152 |
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
| 153 |
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
| 154 |
|
|
|
|
| 162 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 163 |
answer, idx = retrievaler.insert_citations(answer,
|
| 164 |
[ck["content_ltks"]
|
| 165 |
+
for ck in kbinfos["chunks"]],
|
| 166 |
[ck["vector"]
|
| 167 |
+
for ck in kbinfos["chunks"]],
|
| 168 |
embd_mdl,
|
| 169 |
tkweight=1 - dialog.vector_similarity_weight,
|
| 170 |
vtweight=dialog.vector_similarity_weight)
|
|
|
|
| 178 |
for c in refs["chunks"]:
|
| 179 |
if c.get("vector"):
|
| 180 |
del c["vector"]
|
| 181 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 182 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 183 |
return {"answer": answer, "reference": refs}
|
| 184 |
|
|
|
|
| 216 |
def get_table():
|
| 217 |
nonlocal sys_prompt, user_promt, question, tried_times
|
| 218 |
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
| 219 |
+
"temperature": 0.06})
|
| 220 |
print(user_promt, sql)
|
| 221 |
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
| 222 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
|
|
|
| 285 |
|
| 286 |
# compose markdown table
|
| 287 |
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
| 288 |
+
tbl["columns"][i]["name"])) for i in
|
| 289 |
+
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
| 290 |
|
| 291 |
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
| 292 |
+
("|------|" if docid_idx and docid_idx else "")
|
| 293 |
|
| 294 |
rows = ["|" +
|
| 295 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
| 296 |
"|" for r in tbl["rows"]]
|
| 297 |
if quota:
|
| 298 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
| 299 |
+
else:
|
| 300 |
+
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
| 301 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
| 302 |
|
| 303 |
if not docid_idx or not docnm_idx:
|
|
|
|
| 317 |
return {
|
| 318 |
"answer": "\n".join([clmns, line, rows]),
|
| 319 |
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
| 320 |
+
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
| 321 |
+
doc_aggs.items()]}
|
| 322 |
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def relevant(tenant_id, llm_id, question, contents: list):
|
| 326 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
| 327 |
+
prompt = """
|
| 328 |
+
You are a grader assessing relevance of a retrieved document to a user question.
|
| 329 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
| 330 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
| 331 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
| 332 |
+
No other words needed except 'yes' or 'no'.
|
| 333 |
+
"""
|
| 334 |
+
if not contents:return False
|
| 335 |
+
contents = "Documents: \n" + " - ".join(contents)
|
| 336 |
+
contents = f"Question: {question}\n" + contents
|
| 337 |
+
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
| 338 |
+
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
| 339 |
+
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
| 340 |
+
if ans.lower().find("yes") >= 0: return True
|
| 341 |
+
return False
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def rewrite(tenant_id, llm_id, question):
|
| 345 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
| 346 |
+
prompt = """
|
| 347 |
+
You are an expert at query expansion to generate a paraphrasing of a question.
|
| 348 |
+
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
| 349 |
+
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
|
| 350 |
+
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
|
| 351 |
+
changing the way of expression, translating the original question into another language (English/Chinese), etc.
|
| 352 |
+
And return 5 versions of question and one is from translation.
|
| 353 |
+
Just list the question. No other words are needed.
|
| 354 |
+
"""
|
| 355 |
+
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
| 356 |
+
return ans
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -1021,6 +1021,8 @@ class RAGFlowPdfParser:
|
|
| 1021 |
|
| 1022 |
self.page_cum_height = np.cumsum(self.page_cum_height)
|
| 1023 |
assert len(self.page_cum_height) == len(self.page_images) + 1
|
|
|
|
|
|
|
| 1024 |
|
| 1025 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
| 1026 |
self.__images__(fnm, zoomin)
|
|
|
|
| 1021 |
|
| 1022 |
self.page_cum_height = np.cumsum(self.page_cum_height)
|
| 1023 |
assert len(self.page_cum_height) == len(self.page_images) + 1
|
| 1024 |
+
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
|
| 1025 |
+
page_to, callback)
|
| 1026 |
|
| 1027 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
| 1028 |
self.__images__(fnm, zoomin)
|
rag/llm/rerank_model.py
CHANGED
|
@@ -129,4 +129,3 @@ class YoudaoRerank(DefaultRerank):
|
|
| 129 |
return np.array(res), token_count
|
| 130 |
|
| 131 |
|
| 132 |
-
|
|
|
|
| 129 |
return np.array(res), token_count
|
| 130 |
|
| 131 |
|
|
|
rag/nlp/query.py
CHANGED
|
@@ -48,7 +48,7 @@ class EsQueryer:
|
|
| 48 |
@staticmethod
|
| 49 |
def rmWWW(txt):
|
| 50 |
patts = [
|
| 51 |
-
(r"是*(
|
| 52 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
| 53 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
| 54 |
]
|
|
@@ -68,7 +68,9 @@ class EsQueryer:
|
|
| 68 |
if not self.isChinese(txt):
|
| 69 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
| 70 |
tks_w = self.tw.weights(tks)
|
| 71 |
-
tks_w = [(re.sub(r"[ \\\"']
|
|
|
|
|
|
|
| 72 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
| 73 |
for i in range(1, len(tks_w)):
|
| 74 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
|
@@ -118,7 +120,8 @@ class EsQueryer:
|
|
| 118 |
if sm:
|
| 119 |
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
| 120 |
" ".join(sm), " ".join(sm))
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
| 124 |
|
|
|
|
| 48 |
@staticmethod
|
| 49 |
def rmWWW(txt):
|
| 50 |
patts = [
|
| 51 |
+
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
| 52 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
| 53 |
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
| 54 |
]
|
|
|
|
| 68 |
if not self.isChinese(txt):
|
| 69 |
tks = rag_tokenizer.tokenize(txt).split(" ")
|
| 70 |
tks_w = self.tw.weights(tks)
|
| 71 |
+
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
| 72 |
+
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
| 73 |
+
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
| 74 |
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
| 75 |
for i in range(1, len(tks_w)):
|
| 76 |
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
|
|
|
| 120 |
if sm:
|
| 121 |
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
| 122 |
" ".join(sm), " ".join(sm))
|
| 123 |
+
if tk.strip():
|
| 124 |
+
tms.append((tk, w))
|
| 125 |
|
| 126 |
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
| 127 |
|