KevinHuSh
commited on
Commit
·
e32ef75
1
Parent(s):
34b2ab3
Test chat API and refine ppt chunker (#42)
Browse files- api/apps/conversation_app.py +4 -8
- api/db/db_models.py +1 -0
- api/db/services/llm_service.py +96 -17
- api/utils/file_utils.py +2 -2
- rag/llm/chat_model.py +3 -3
- rag/llm/cv_model.py +3 -3
- rag/llm/embedding_model.py +29 -11
- rag/nlp/huchunk.py +42 -19
- rag/nlp/search.py +36 -20
- rag/svr/parse_user_docs.py +10 -8
api/apps/conversation_app.py
CHANGED
|
@@ -17,7 +17,7 @@ from flask import request
|
|
| 17 |
from flask_login import login_required
|
| 18 |
from api.db.services.dialog_service import DialogService, ConversationService
|
| 19 |
from api.db import LLMType
|
| 20 |
-
from api.db.services.llm_service import LLMService, TenantLLMService
|
| 21 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 22 |
from api.utils import get_uuid
|
| 23 |
from api.utils.api_utils import get_json_result
|
|
@@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs):
|
|
| 170 |
if p["key"] not in kwargs:
|
| 171 |
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
| 172 |
|
| 173 |
-
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
|
| 174 |
-
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
|
| 175 |
-
|
| 176 |
question = messages[-1]["content"]
|
| 177 |
-
embd_mdl =
|
| 178 |
-
|
| 179 |
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
|
| 180 |
dialog.vector_similarity_weight, top=1024, aggs=False)
|
| 181 |
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
|
|
@@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs):
|
|
| 189 |
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
| 190 |
if "max_tokens" in gen_conf:
|
| 191 |
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
| 192 |
-
|
| 193 |
-
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
| 194 |
|
| 195 |
answer = retrievaler.insert_citations(answer,
|
| 196 |
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
|
|
|
| 17 |
from flask_login import login_required
|
| 18 |
from api.db.services.dialog_service import DialogService, ConversationService
|
| 19 |
from api.db import LLMType
|
| 20 |
+
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
| 21 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 22 |
from api.utils import get_uuid
|
| 23 |
from api.utils.api_utils import get_json_result
|
|
|
|
| 170 |
if p["key"] not in kwargs:
|
| 171 |
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
| 172 |
|
|
|
|
|
|
|
|
|
|
| 173 |
question = messages[-1]["content"]
|
| 174 |
+
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
| 175 |
+
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 176 |
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
|
| 177 |
dialog.vector_similarity_weight, top=1024, aggs=False)
|
| 178 |
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
|
|
|
|
| 186 |
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
| 187 |
if "max_tokens" in gen_conf:
|
| 188 |
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
| 189 |
+
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
|
|
|
| 190 |
|
| 191 |
answer = retrievaler.insert_citations(answer,
|
| 192 |
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
api/db/db_models.py
CHANGED
|
@@ -524,6 +524,7 @@ class Dialog(DataBaseModel):
|
|
| 524 |
similarity_threshold = FloatField(default=0.2)
|
| 525 |
vector_similarity_weight = FloatField(default=0.3)
|
| 526 |
top_n = IntegerField(default=6)
|
|
|
|
| 527 |
|
| 528 |
kb_ids = JSONField(null=False, default=[])
|
| 529 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
|
|
|
| 524 |
similarity_threshold = FloatField(default=0.2)
|
| 525 |
vector_similarity_weight = FloatField(default=0.3)
|
| 526 |
top_n = IntegerField(default=6)
|
| 527 |
+
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
|
| 528 |
|
| 529 |
kb_ids = JSONField(null=False, default=[])
|
| 530 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
api/db/services/llm_service.py
CHANGED
|
@@ -14,12 +14,12 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from api.db.services.user_service import TenantService
|
| 17 |
-
from
|
|
|
|
| 18 |
from api.db import LLMType
|
| 19 |
from api.db.db_models import DB, UserTenant
|
| 20 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
| 21 |
from api.db.services.common_service import CommonService
|
| 22 |
-
from api.db import StatusEnum
|
| 23 |
|
| 24 |
|
| 25 |
class LLMFactoriesService(CommonService):
|
|
@@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
|
|
| 37 |
@DB.connection_context()
|
| 38 |
def get_api_key(cls, tenant_id, model_name):
|
| 39 |
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
| 40 |
-
if not objs:
|
|
|
|
| 41 |
return objs[0]
|
| 42 |
|
| 43 |
@classmethod
|
| 44 |
@DB.connection_context()
|
| 45 |
def get_my_llms(cls, tenant_id):
|
| 46 |
-
fields = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
| 48 |
cls.model.tenant_id == tenant_id).dicts()
|
| 49 |
|
|
@@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
|
|
| 51 |
|
| 52 |
@classmethod
|
| 53 |
@DB.connection_context()
|
| 54 |
-
def model_instance(cls, tenant_id, llm_type):
|
| 55 |
-
e,tenant = TenantService.get_by_id(tenant_id)
|
| 56 |
-
if not e:
|
|
|
|
| 57 |
|
| 58 |
-
if llm_type == LLMType.EMBEDDING.value:
|
| 59 |
-
|
| 60 |
-
elif llm_type == LLMType.
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 65 |
-
if not model_config:
|
|
|
|
| 66 |
model_config = model_config.to_dict()
|
| 67 |
if llm_type == LLMType.EMBEDDING.value:
|
| 68 |
-
if model_config["llm_factory"] not in EmbeddingModel:
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
if llm_type == LLMType.IMAGE2TEXT.value:
|
| 72 |
-
if model_config["llm_factory"] not in CvModel:
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
from api.db.services.user_service import TenantService
|
| 17 |
+
from api.settings import database_logger
|
| 18 |
+
from rag.llm import EmbeddingModel, CvModel, ChatModel
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.db_models import DB, UserTenant
|
| 21 |
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
| 22 |
from api.db.services.common_service import CommonService
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class LLMFactoriesService(CommonService):
|
|
|
|
| 37 |
@DB.connection_context()
|
| 38 |
def get_api_key(cls, tenant_id, model_name):
|
| 39 |
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
|
| 40 |
+
if not objs:
|
| 41 |
+
return
|
| 42 |
return objs[0]
|
| 43 |
|
| 44 |
@classmethod
|
| 45 |
@DB.connection_context()
|
| 46 |
def get_my_llms(cls, tenant_id):
|
| 47 |
+
fields = [
|
| 48 |
+
cls.model.llm_factory,
|
| 49 |
+
LLMFactories.logo,
|
| 50 |
+
LLMFactories.tags,
|
| 51 |
+
cls.model.model_type,
|
| 52 |
+
cls.model.llm_name]
|
| 53 |
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
| 54 |
cls.model.tenant_id == tenant_id).dicts()
|
| 55 |
|
|
|
|
| 57 |
|
| 58 |
@classmethod
|
| 59 |
@DB.connection_context()
|
| 60 |
+
def model_instance(cls, tenant_id, llm_type, llm_name=None):
|
| 61 |
+
e, tenant = TenantService.get_by_id(tenant_id)
|
| 62 |
+
if not e:
|
| 63 |
+
raise LookupError("Tenant not found")
|
| 64 |
|
| 65 |
+
if llm_type == LLMType.EMBEDDING.value:
|
| 66 |
+
mdlnm = tenant.embd_id
|
| 67 |
+
elif llm_type == LLMType.SPEECH2TEXT.value:
|
| 68 |
+
mdlnm = tenant.asr_id
|
| 69 |
+
elif llm_type == LLMType.IMAGE2TEXT.value:
|
| 70 |
+
mdlnm = tenant.img2txt_id
|
| 71 |
+
elif llm_type == LLMType.CHAT.value:
|
| 72 |
+
mdlnm = tenant.llm_id if not llm_name else llm_name
|
| 73 |
+
else:
|
| 74 |
+
assert False, "LLM type error"
|
| 75 |
|
| 76 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 77 |
+
if not model_config:
|
| 78 |
+
raise LookupError("Model({}) not found".format(mdlnm))
|
| 79 |
model_config = model_config.to_dict()
|
| 80 |
if llm_type == LLMType.EMBEDDING.value:
|
| 81 |
+
if model_config["llm_factory"] not in EmbeddingModel:
|
| 82 |
+
return
|
| 83 |
+
return EmbeddingModel[model_config["llm_factory"]](
|
| 84 |
+
model_config["api_key"], model_config["llm_name"])
|
| 85 |
|
| 86 |
if llm_type == LLMType.IMAGE2TEXT.value:
|
| 87 |
+
if model_config["llm_factory"] not in CvModel:
|
| 88 |
+
return
|
| 89 |
+
return CvModel[model_config["llm_factory"]](
|
| 90 |
+
model_config["api_key"], model_config["llm_name"])
|
| 91 |
+
|
| 92 |
+
if llm_type == LLMType.CHAT.value:
|
| 93 |
+
if model_config["llm_factory"] not in ChatModel:
|
| 94 |
+
return
|
| 95 |
+
return ChatModel[model_config["llm_factory"]](
|
| 96 |
+
model_config["api_key"], model_config["llm_name"])
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
@DB.connection_context()
|
| 100 |
+
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
| 101 |
+
e, tenant = TenantService.get_by_id(tenant_id)
|
| 102 |
+
if not e:
|
| 103 |
+
raise LookupError("Tenant not found")
|
| 104 |
+
|
| 105 |
+
if llm_type == LLMType.EMBEDDING.value:
|
| 106 |
+
mdlnm = tenant.embd_id
|
| 107 |
+
elif llm_type == LLMType.SPEECH2TEXT.value:
|
| 108 |
+
mdlnm = tenant.asr_id
|
| 109 |
+
elif llm_type == LLMType.IMAGE2TEXT.value:
|
| 110 |
+
mdlnm = tenant.img2txt_id
|
| 111 |
+
elif llm_type == LLMType.CHAT.value:
|
| 112 |
+
mdlnm = tenant.llm_id if not llm_name else llm_name
|
| 113 |
+
else:
|
| 114 |
+
assert False, "LLM type error"
|
| 115 |
+
|
| 116 |
+
num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
|
| 117 |
+
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
|
| 118 |
+
.execute()
|
| 119 |
+
return num
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class LLMBundle(object):
|
| 123 |
+
def __init__(self, tenant_id, llm_type, llm_name=None):
|
| 124 |
+
self.tenant_id = tenant_id
|
| 125 |
+
self.llm_type = llm_type
|
| 126 |
+
self.llm_name = llm_name
|
| 127 |
+
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
|
| 128 |
+
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
| 129 |
+
|
| 130 |
+
def encode(self, texts: list, batch_size=32):
|
| 131 |
+
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
| 132 |
+
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
| 133 |
+
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
| 134 |
+
return emd, used_tokens
|
| 135 |
+
|
| 136 |
+
def encode_queries(self, query: str):
|
| 137 |
+
emd, used_tokens = self.mdl.encode_queries(query)
|
| 138 |
+
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
| 139 |
+
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
|
| 140 |
+
return emd, used_tokens
|
| 141 |
+
|
| 142 |
+
def describe(self, image, max_tokens=300):
|
| 143 |
+
txt, used_tokens = self.mdl.describe(image, max_tokens)
|
| 144 |
+
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
| 145 |
+
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
|
| 146 |
+
return txt
|
| 147 |
+
|
| 148 |
+
def chat(self, system, history, gen_conf):
|
| 149 |
+
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
| 150 |
+
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
| 151 |
+
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
|
| 152 |
+
return txt
|
api/utils/file_utils.py
CHANGED
|
@@ -143,11 +143,11 @@ def filename_type(filename):
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
-
if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
| 150 |
return FileType.AURAL.value
|
| 151 |
|
| 152 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
| 153 |
-
return FileType.VISUAL
|
|
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
+
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
| 150 |
return FileType.AURAL.value
|
| 151 |
|
| 152 |
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
| 153 |
+
return FileType.VISUAL
|
rag/llm/chat_model.py
CHANGED
|
@@ -37,7 +37,7 @@ class GptTurbo(Base):
|
|
| 37 |
model=self.model_name,
|
| 38 |
messages=history,
|
| 39 |
**gen_conf)
|
| 40 |
-
return res.choices[0].message.content.strip()
|
| 41 |
|
| 42 |
|
| 43 |
from dashscope import Generation
|
|
@@ -56,5 +56,5 @@ class QWenChat(Base):
|
|
| 56 |
result_format='message'
|
| 57 |
)
|
| 58 |
if response.status_code == HTTPStatus.OK:
|
| 59 |
-
return response.output.choices[0]['message']['content']
|
| 60 |
-
return response.message
|
|
|
|
| 37 |
model=self.model_name,
|
| 38 |
messages=history,
|
| 39 |
**gen_conf)
|
| 40 |
+
return res.choices[0].message.content.strip(), res.usage.completion_tokens
|
| 41 |
|
| 42 |
|
| 43 |
from dashscope import Generation
|
|
|
|
| 56 |
result_format='message'
|
| 57 |
)
|
| 58 |
if response.status_code == HTTPStatus.OK:
|
| 59 |
+
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
| 60 |
+
return response.message, 0
|
rag/llm/cv_model.py
CHANGED
|
@@ -72,7 +72,7 @@ class GptV4(Base):
|
|
| 72 |
messages=self.prompt(b64),
|
| 73 |
max_tokens=max_tokens,
|
| 74 |
)
|
| 75 |
-
return res.choices[0].message.content.strip()
|
| 76 |
|
| 77 |
|
| 78 |
class QWenCV(Base):
|
|
@@ -87,5 +87,5 @@ class QWenCV(Base):
|
|
| 87 |
response = MultiModalConversation.call(model=self.model_name,
|
| 88 |
messages=self.prompt(self.image2base64(image)))
|
| 89 |
if response.status_code == HTTPStatus.OK:
|
| 90 |
-
return response.output.choices[0]['message']['content']
|
| 91 |
-
return response.message
|
|
|
|
| 72 |
messages=self.prompt(b64),
|
| 73 |
max_tokens=max_tokens,
|
| 74 |
)
|
| 75 |
+
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
| 76 |
|
| 77 |
|
| 78 |
class QWenCV(Base):
|
|
|
|
| 87 |
response = MultiModalConversation.call(model=self.model_name,
|
| 88 |
messages=self.prompt(self.image2base64(image)))
|
| 89 |
if response.status_code == HTTPStatus.OK:
|
| 90 |
+
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
| 91 |
+
return response.message, 0
|
rag/llm/embedding_model.py
CHANGED
|
@@ -36,6 +36,9 @@ class Base(ABC):
|
|
| 36 |
def encode(self, texts: list, batch_size=32):
|
| 37 |
raise NotImplementedError("Please implement encode method!")
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
class HuEmbedding(Base):
|
| 41 |
def __init__(self, key="", model_name=""):
|
|
@@ -68,15 +71,18 @@ class HuEmbedding(Base):
|
|
| 68 |
|
| 69 |
class OpenAIEmbed(Base):
|
| 70 |
def __init__(self, key, model_name="text-embedding-ada-002"):
|
| 71 |
-
self.client = OpenAI(key)
|
| 72 |
self.model_name = model_name
|
| 73 |
|
| 74 |
def encode(self, texts: list, batch_size=32):
|
| 75 |
-
token_count = 0
|
| 76 |
-
for t in texts: token_count += num_tokens_from_string(t)
|
| 77 |
res = self.client.embeddings.create(input=texts,
|
| 78 |
model=self.model_name)
|
| 79 |
-
return [d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
class QWenEmbed(Base):
|
|
@@ -84,16 +90,28 @@ class QWenEmbed(Base):
|
|
| 84 |
dashscope.api_key = key
|
| 85 |
self.model_name = model_name
|
| 86 |
|
| 87 |
-
def encode(self, texts: list, batch_size=
|
| 88 |
import dashscope
|
| 89 |
res = []
|
| 90 |
token_count = 0
|
| 91 |
-
for txt in texts
|
|
|
|
| 92 |
resp = dashscope.TextEmbedding.call(
|
| 93 |
model=self.model_name,
|
| 94 |
-
input=
|
| 95 |
-
text_type=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
token_count += resp["usage"]["total_tokens"]
|
| 99 |
-
return res, token_count
|
|
|
|
| 36 |
def encode(self, texts: list, batch_size=32):
|
| 37 |
raise NotImplementedError("Please implement encode method!")
|
| 38 |
|
| 39 |
+
def encode_queries(self, text: str):
|
| 40 |
+
raise NotImplementedError("Please implement encode method!")
|
| 41 |
+
|
| 42 |
|
| 43 |
class HuEmbedding(Base):
|
| 44 |
def __init__(self, key="", model_name=""):
|
|
|
|
| 71 |
|
| 72 |
class OpenAIEmbed(Base):
|
| 73 |
def __init__(self, key, model_name="text-embedding-ada-002"):
|
| 74 |
+
self.client = OpenAI(api_key=key)
|
| 75 |
self.model_name = model_name
|
| 76 |
|
| 77 |
def encode(self, texts: list, batch_size=32):
|
|
|
|
|
|
|
| 78 |
res = self.client.embeddings.create(input=texts,
|
| 79 |
model=self.model_name)
|
| 80 |
+
return np.array([d.embedding for d in res.data]), res.usage.total_tokens
|
| 81 |
+
|
| 82 |
+
def encode_queries(self, text):
|
| 83 |
+
res = self.client.embeddings.create(input=[text],
|
| 84 |
+
model=self.model_name)
|
| 85 |
+
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 86 |
|
| 87 |
|
| 88 |
class QWenEmbed(Base):
|
|
|
|
| 90 |
dashscope.api_key = key
|
| 91 |
self.model_name = model_name
|
| 92 |
|
| 93 |
+
def encode(self, texts: list, batch_size=10):
|
| 94 |
import dashscope
|
| 95 |
res = []
|
| 96 |
token_count = 0
|
| 97 |
+
texts = [txt[:2048] for txt in texts]
|
| 98 |
+
for i in range(0, len(texts), batch_size):
|
| 99 |
resp = dashscope.TextEmbedding.call(
|
| 100 |
model=self.model_name,
|
| 101 |
+
input=texts[i:i+batch_size],
|
| 102 |
+
text_type="document"
|
| 103 |
+
)
|
| 104 |
+
embds = [[]] * len(resp["output"]["embeddings"])
|
| 105 |
+
for e in resp["output"]["embeddings"]:
|
| 106 |
+
embds[e["text_index"]] = e["embedding"]
|
| 107 |
+
res.extend(embds)
|
| 108 |
+
token_count += resp["usage"]["input_tokens"]
|
| 109 |
+
return np.array(res), token_count
|
| 110 |
+
|
| 111 |
+
def encode_queries(self, text):
|
| 112 |
+
resp = dashscope.TextEmbedding.call(
|
| 113 |
+
model=self.model_name,
|
| 114 |
+
input=text[:2048],
|
| 115 |
+
text_type="query"
|
| 116 |
)
|
| 117 |
+
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
|
|
|
|
|
|
rag/nlp/huchunk.py
CHANGED
|
@@ -11,6 +11,11 @@ from io import BytesIO
|
|
| 11 |
|
| 12 |
class HuChunker:
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def __init__(self):
|
| 15 |
self.MAX_LVL = 12
|
| 16 |
self.proj_patt = [
|
|
@@ -228,11 +233,6 @@ class HuChunker:
|
|
| 228 |
|
| 229 |
class PdfChunker(HuChunker):
|
| 230 |
|
| 231 |
-
@dataclass
|
| 232 |
-
class Fields:
|
| 233 |
-
text_chunks: List = None
|
| 234 |
-
table_chunks: List = None
|
| 235 |
-
|
| 236 |
def __init__(self, pdf_parser):
|
| 237 |
self.pdf = pdf_parser
|
| 238 |
super().__init__()
|
|
@@ -293,11 +293,6 @@ class PdfChunker(HuChunker):
|
|
| 293 |
|
| 294 |
class DocxChunker(HuChunker):
|
| 295 |
|
| 296 |
-
@dataclass
|
| 297 |
-
class Fields:
|
| 298 |
-
text_chunks: List = None
|
| 299 |
-
table_chunks: List = None
|
| 300 |
-
|
| 301 |
def __init__(self, doc_parser):
|
| 302 |
self.doc = doc_parser
|
| 303 |
super().__init__()
|
|
@@ -344,11 +339,6 @@ class DocxChunker(HuChunker):
|
|
| 344 |
|
| 345 |
class ExcelChunker(HuChunker):
|
| 346 |
|
| 347 |
-
@dataclass
|
| 348 |
-
class Fields:
|
| 349 |
-
text_chunks: List = None
|
| 350 |
-
table_chunks: List = None
|
| 351 |
-
|
| 352 |
def __init__(self, excel_parser):
|
| 353 |
self.excel = excel_parser
|
| 354 |
super().__init__()
|
|
@@ -370,18 +360,51 @@ class PptChunker(HuChunker):
|
|
| 370 |
def __init__(self):
|
| 371 |
super().__init__()
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
def __call__(self, fnm):
|
| 374 |
from pptx import Presentation
|
| 375 |
ppt = Presentation(fnm) if isinstance(
|
| 376 |
fnm, str) else Presentation(
|
| 377 |
BytesIO(fnm))
|
| 378 |
-
|
| 379 |
-
flds.text_chunks = []
|
| 380 |
for slide in ppt.slides:
|
|
|
|
| 381 |
for shape in slide.shapes:
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
flds.table_chunks = []
|
|
|
|
| 385 |
return flds
|
| 386 |
|
| 387 |
|
|
|
|
| 11 |
|
| 12 |
class HuChunker:
|
| 13 |
|
| 14 |
+
@dataclass
|
| 15 |
+
class Fields:
|
| 16 |
+
text_chunks: List = None
|
| 17 |
+
table_chunks: List = None
|
| 18 |
+
|
| 19 |
def __init__(self):
|
| 20 |
self.MAX_LVL = 12
|
| 21 |
self.proj_patt = [
|
|
|
|
| 233 |
|
| 234 |
class PdfChunker(HuChunker):
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
def __init__(self, pdf_parser):
|
| 237 |
self.pdf = pdf_parser
|
| 238 |
super().__init__()
|
|
|
|
| 293 |
|
| 294 |
class DocxChunker(HuChunker):
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
def __init__(self, doc_parser):
|
| 297 |
self.doc = doc_parser
|
| 298 |
super().__init__()
|
|
|
|
| 339 |
|
| 340 |
class ExcelChunker(HuChunker):
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
def __init__(self, excel_parser):
|
| 343 |
self.excel = excel_parser
|
| 344 |
super().__init__()
|
|
|
|
| 360 |
def __init__(self):
|
| 361 |
super().__init__()
|
| 362 |
|
| 363 |
+
def __extract(self, shape):
|
| 364 |
+
if shape.shape_type == 19:
|
| 365 |
+
tb = shape.table
|
| 366 |
+
rows = []
|
| 367 |
+
for i in range(1, len(tb.rows)):
|
| 368 |
+
rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
|
| 369 |
+
return "\n".join(rows)
|
| 370 |
+
|
| 371 |
+
if shape.has_text_frame:
|
| 372 |
+
return shape.text_frame.text
|
| 373 |
+
|
| 374 |
+
if shape.shape_type == 6:
|
| 375 |
+
texts = []
|
| 376 |
+
for p in shape.shapes:
|
| 377 |
+
t = self.__extract(p)
|
| 378 |
+
if t: texts.append(t)
|
| 379 |
+
return "\n".join(texts)
|
| 380 |
+
|
| 381 |
def __call__(self, fnm):
|
| 382 |
from pptx import Presentation
|
| 383 |
ppt = Presentation(fnm) if isinstance(
|
| 384 |
fnm, str) else Presentation(
|
| 385 |
BytesIO(fnm))
|
| 386 |
+
txts = []
|
|
|
|
| 387 |
for slide in ppt.slides:
|
| 388 |
+
texts = []
|
| 389 |
for shape in slide.shapes:
|
| 390 |
+
txt = self.__extract(shape)
|
| 391 |
+
if txt: texts.append(txt)
|
| 392 |
+
txts.append("\n".join(texts))
|
| 393 |
+
|
| 394 |
+
import aspose.slides as slides
|
| 395 |
+
import aspose.pydrawing as drawing
|
| 396 |
+
imgs = []
|
| 397 |
+
with slides.Presentation(BytesIO(fnm)) as presentation:
|
| 398 |
+
for slide in presentation.slides:
|
| 399 |
+
buffered = BytesIO()
|
| 400 |
+
slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg)
|
| 401 |
+
imgs.append(buffered.getvalue())
|
| 402 |
+
assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
| 403 |
+
|
| 404 |
+
flds = self.Fields()
|
| 405 |
+
flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
|
| 406 |
flds.table_chunks = []
|
| 407 |
+
|
| 408 |
return flds
|
| 409 |
|
| 410 |
|
rag/nlp/search.py
CHANGED
|
@@ -58,7 +58,8 @@ class Dealer:
|
|
| 58 |
if req["available_int"] == 0:
|
| 59 |
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
| 60 |
else:
|
| 61 |
-
bqry.filter.append(
|
|
|
|
| 62 |
bqry.boost = 0.05
|
| 63 |
|
| 64 |
s = Search()
|
|
@@ -87,9 +88,12 @@ class Dealer:
|
|
| 87 |
q_vec = []
|
| 88 |
if req.get("vector"):
|
| 89 |
assert emb_mdl, "No embedding model selected"
|
| 90 |
-
s["knn"] = self._vector(
|
|
|
|
|
|
|
| 91 |
s["knn"]["filter"] = bqry.to_dict()
|
| 92 |
-
if "highlight" in s:
|
|
|
|
| 93 |
q_vec = s["knn"]["query_vector"]
|
| 94 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
| 95 |
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
|
@@ -175,7 +179,8 @@ class Dealer:
|
|
| 175 |
def trans2floats(txt):
|
| 176 |
return [float(t) for t in txt.split("\t")]
|
| 177 |
|
| 178 |
-
def insert_citations(self, answer, chunks, chunk_v,
|
|
|
|
| 179 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
| 180 |
for i in range(1, len(pieces)):
|
| 181 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
|
@@ -184,47 +189,57 @@ class Dealer:
|
|
| 184 |
idx = []
|
| 185 |
pieces_ = []
|
| 186 |
for i, t in enumerate(pieces):
|
| 187 |
-
if len(t) < 5:
|
|
|
|
| 188 |
idx.append(i)
|
| 189 |
pieces_.append(t)
|
| 190 |
es_logger.info("{} => {}".format(answer, pieces_))
|
| 191 |
-
if not pieces_:
|
|
|
|
| 192 |
|
| 193 |
-
ans_v,
|
| 194 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
| 195 |
len(ans_v[0]), len(chunk_v[0]))
|
| 196 |
|
| 197 |
chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
|
| 198 |
cites = {}
|
| 199 |
-
for i,a in enumerate(pieces_):
|
| 200 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
| 201 |
chunk_v,
|
| 202 |
-
huqie.qie(
|
|
|
|
| 203 |
chunks_tks,
|
| 204 |
tkweight, vtweight)
|
| 205 |
mx = np.max(sim) * 0.99
|
| 206 |
-
if mx < 0.55:
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
|
| 209 |
res = ""
|
| 210 |
-
for i,p in enumerate(pieces):
|
| 211 |
res += p
|
| 212 |
-
if i not in idx:
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
|
| 216 |
return res
|
| 217 |
|
| 218 |
-
def rerank(self, sres, query, tkweight=0.3,
|
|
|
|
| 219 |
ins_embd = [
|
| 220 |
Dealer.trans2floats(
|
| 221 |
-
sres.field[i]
|
| 222 |
if not ins_embd:
|
| 223 |
return [], [], []
|
| 224 |
-
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
|
|
|
|
| 225 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
| 226 |
ins_embd,
|
| 227 |
-
huqie.qie(
|
|
|
|
| 228 |
ins_tw, tkweight, vtweight)
|
| 229 |
return sim, tksim, vtsim
|
| 230 |
|
|
@@ -237,7 +252,8 @@ class Dealer:
|
|
| 237 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 238 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
| 239 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
| 240 |
-
if not question:
|
|
|
|
| 241 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
|
| 242 |
"question": question, "vector": True,
|
| 243 |
"similarity": similarity_threshold}
|
|
|
|
| 58 |
if req["available_int"] == 0:
|
| 59 |
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
| 60 |
else:
|
| 61 |
+
bqry.filter.append(
|
| 62 |
+
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
| 63 |
bqry.boost = 0.05
|
| 64 |
|
| 65 |
s = Search()
|
|
|
|
| 88 |
q_vec = []
|
| 89 |
if req.get("vector"):
|
| 90 |
assert emb_mdl, "No embedding model selected"
|
| 91 |
+
s["knn"] = self._vector(
|
| 92 |
+
qst, emb_mdl, req.get(
|
| 93 |
+
"similarity", 0.4), ps)
|
| 94 |
s["knn"]["filter"] = bqry.to_dict()
|
| 95 |
+
if "highlight" in s:
|
| 96 |
+
del s["highlight"]
|
| 97 |
q_vec = s["knn"]["query_vector"]
|
| 98 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
| 99 |
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
|
|
|
| 179 |
def trans2floats(txt):
|
| 180 |
return [float(t) for t in txt.split("\t")]
|
| 181 |
|
| 182 |
+
def insert_citations(self, answer, chunks, chunk_v,
|
| 183 |
+
embd_mdl, tkweight=0.3, vtweight=0.7):
|
| 184 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
| 185 |
for i in range(1, len(pieces)):
|
| 186 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
|
|
|
| 189 |
idx = []
|
| 190 |
pieces_ = []
|
| 191 |
for i, t in enumerate(pieces):
|
| 192 |
+
if len(t) < 5:
|
| 193 |
+
continue
|
| 194 |
idx.append(i)
|
| 195 |
pieces_.append(t)
|
| 196 |
es_logger.info("{} => {}".format(answer, pieces_))
|
| 197 |
+
if not pieces_:
|
| 198 |
+
return answer
|
| 199 |
|
| 200 |
+
ans_v, _ = embd_mdl.encode(pieces_)
|
| 201 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
| 202 |
len(ans_v[0]), len(chunk_v[0]))
|
| 203 |
|
| 204 |
chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
|
| 205 |
cites = {}
|
| 206 |
+
for i, a in enumerate(pieces_):
|
| 207 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
| 208 |
chunk_v,
|
| 209 |
+
huqie.qie(
|
| 210 |
+
pieces_[i]).split(" "),
|
| 211 |
chunks_tks,
|
| 212 |
tkweight, vtweight)
|
| 213 |
mx = np.max(sim) * 0.99
|
| 214 |
+
if mx < 0.55:
|
| 215 |
+
continue
|
| 216 |
+
cites[idx[i]] = list(
|
| 217 |
+
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
|
| 218 |
|
| 219 |
res = ""
|
| 220 |
+
for i, p in enumerate(pieces):
|
| 221 |
res += p
|
| 222 |
+
if i not in idx:
|
| 223 |
+
continue
|
| 224 |
+
if i not in cites:
|
| 225 |
+
continue
|
| 226 |
+
res += "##%s$$" % "$".join(cites[i])
|
| 227 |
|
| 228 |
return res
|
| 229 |
|
| 230 |
+
def rerank(self, sres, query, tkweight=0.3,
|
| 231 |
+
vtweight=0.7, cfield="content_ltks"):
|
| 232 |
ins_embd = [
|
| 233 |
Dealer.trans2floats(
|
| 234 |
+
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
|
| 235 |
if not ins_embd:
|
| 236 |
return [], [], []
|
| 237 |
+
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
|
| 238 |
+
for i in sres.ids]
|
| 239 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
| 240 |
ins_embd,
|
| 241 |
+
huqie.qie(
|
| 242 |
+
query).split(" "),
|
| 243 |
ins_tw, tkweight, vtweight)
|
| 244 |
return sim, tksim, vtsim
|
| 245 |
|
|
|
|
| 252 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 253 |
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
|
| 254 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
| 255 |
+
if not question:
|
| 256 |
+
return ranks
|
| 257 |
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
|
| 258 |
"question": question, "vector": True,
|
| 259 |
"similarity": similarity_threshold}
|
rag/svr/parse_user_docs.py
CHANGED
|
@@ -49,7 +49,7 @@ from rag.nlp.huchunk import (
|
|
| 49 |
)
|
| 50 |
from api.db import LLMType
|
| 51 |
from api.db.services.document_service import DocumentService
|
| 52 |
-
from api.db.services.llm_service import TenantLLMService
|
| 53 |
from api.settings import database_logger
|
| 54 |
from api.utils import get_format_time
|
| 55 |
from api.utils.file_utils import get_project_base_directory
|
|
@@ -62,7 +62,7 @@ EXC = ExcelChunker(ExcelParser())
|
|
| 62 |
PPT = PptChunker()
|
| 63 |
|
| 64 |
|
| 65 |
-
def chuck_doc(name, binary, cvmdl=None):
|
| 66 |
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
| 67 |
if suff.find("pdf") >= 0:
|
| 68 |
return PDF(binary)
|
|
@@ -127,7 +127,7 @@ def build(row, cvmdl):
|
|
| 127 |
100., "Finished preparing! Start to slice file!", True)
|
| 128 |
try:
|
| 129 |
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
| 130 |
-
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
|
| 131 |
except Exception as e:
|
| 132 |
if re.search("(No such file|not found)", str(e)):
|
| 133 |
set_progress(
|
|
@@ -236,12 +236,14 @@ def main(comm, mod):
|
|
| 236 |
|
| 237 |
tmf = open(tm_fnm, "a+")
|
| 238 |
for _, r in rows.iterrows():
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
continue
|
| 244 |
-
|
| 245 |
st_tm = timer()
|
| 246 |
cks = build(r, cv_mdl)
|
| 247 |
if not cks:
|
|
|
|
| 49 |
)
|
| 50 |
from api.db import LLMType
|
| 51 |
from api.db.services.document_service import DocumentService
|
| 52 |
+
from api.db.services.llm_service import TenantLLMService, LLMBundle
|
| 53 |
from api.settings import database_logger
|
| 54 |
from api.utils import get_format_time
|
| 55 |
from api.utils.file_utils import get_project_base_directory
|
|
|
|
| 62 |
PPT = PptChunker()
|
| 63 |
|
| 64 |
|
| 65 |
+
def chuck_doc(name, binary, tenant_id, cvmdl=None):
|
| 66 |
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
| 67 |
if suff.find("pdf") >= 0:
|
| 68 |
return PDF(binary)
|
|
|
|
| 127 |
100., "Finished preparing! Start to slice file!", True)
|
| 128 |
try:
|
| 129 |
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
| 130 |
+
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl)
|
| 131 |
except Exception as e:
|
| 132 |
if re.search("(No such file|not found)", str(e)):
|
| 133 |
set_progress(
|
|
|
|
| 236 |
|
| 237 |
tmf = open(tm_fnm, "a+")
|
| 238 |
for _, r in rows.iterrows():
|
| 239 |
+
try:
|
| 240 |
+
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
|
| 241 |
+
cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
|
| 242 |
+
#TODO: sequence2text model
|
| 243 |
+
except Exception as e:
|
| 244 |
+
set_progress(r["id"], -1, str(e))
|
| 245 |
continue
|
| 246 |
+
|
| 247 |
st_tm = timer()
|
| 248 |
cks = build(r, cv_mdl)
|
| 249 |
if not cks:
|