Kevin Hu commited on
Commit
b71b66c
·
1 Parent(s): 6b2a674

add function: upload and parse (#1889)

Browse files

### What problem does this PR solve?

#1880
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/conversation_app.py CHANGED
@@ -118,6 +118,8 @@ def completion():
118
  if m["role"] == "assistant" and not msg:
119
  continue
120
  msg.append({"role": m["role"], "content": m["content"]})
 
 
121
  try:
122
  e, conv = ConversationService.get_by_id(req["conversation_id"])
123
  if not e:
 
118
  if m["role"] == "assistant" and not msg:
119
  continue
120
  msg.append({"role": m["role"], "content": m["content"]})
121
+ if "doc_ids" in m:
122
+ msg[-1]["doc_ids"] = m["doc_ids"]
123
  try:
124
  e, conv = ConversationService.get_by_id(req["conversation_id"])
125
  if not e:
api/apps/document_app.py CHANGED
@@ -13,10 +13,16 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License
15
  #
16
-
 
 
17
  import os
18
  import pathlib
19
  import re
 
 
 
 
20
 
21
  import flask
22
  from elasticsearch_dsl import Q
@@ -24,22 +30,26 @@ from flask import request
24
  from flask_login import login_required, current_user
25
 
26
  from api.db.db_models import Task, File
 
27
  from api.db.services.file2document_service import File2DocumentService
28
  from api.db.services.file_service import FileService
 
29
  from api.db.services.task_service import TaskService, queue_tasks
 
 
 
30
  from rag.nlp import search
31
  from rag.utils.es_conn import ELASTICSEARCH
32
  from api.db.services import duplicate_name
33
  from api.db.services.knowledgebase_service import KnowledgebaseService
34
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
35
  from api.utils import get_uuid
36
- from api.db import FileType, TaskStatus, ParserType, FileSource
37
  from api.db.services.document_service import DocumentService
38
- from api.settings import RetCode
39
  from api.utils.api_utils import get_json_result
40
  from rag.utils.minio_conn import MINIO
41
- from api.utils.file_utils import filename_type, thumbnail
42
- from api.utils.web_utils import html2pdf, is_valid_url
43
  from api.utils.web_utils import html2pdf, is_valid_url
44
 
45
 
@@ -65,55 +75,7 @@ def upload():
65
  if not e:
66
  raise LookupError("Can't find this knowledgebase!")
67
 
68
- root_folder = FileService.get_root_folder(current_user.id)
69
- pf_id = root_folder["id"]
70
- FileService.init_knowledgebase_docs(pf_id, current_user.id)
71
- kb_root_folder = FileService.get_kb_folder(current_user.id)
72
- kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
73
-
74
- err = []
75
- for file in file_objs:
76
- try:
77
- MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
78
- if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
79
- raise RuntimeError("Exceed the maximum file number of a free user!")
80
-
81
- filename = duplicate_name(
82
- DocumentService.query,
83
- name=file.filename,
84
- kb_id=kb.id)
85
- filetype = filename_type(filename)
86
- if filetype == FileType.OTHER.value:
87
- raise RuntimeError("This type of file has not been supported yet!")
88
-
89
- location = filename
90
- while MINIO.obj_exist(kb_id, location):
91
- location += "_"
92
- blob = file.read()
93
- MINIO.put(kb_id, location, blob)
94
- doc = {
95
- "id": get_uuid(),
96
- "kb_id": kb.id,
97
- "parser_id": kb.parser_id,
98
- "parser_config": kb.parser_config,
99
- "created_by": current_user.id,
100
- "type": filetype,
101
- "name": filename,
102
- "location": location,
103
- "size": len(blob),
104
- "thumbnail": thumbnail(filename, blob)
105
- }
106
- if doc["type"] == FileType.VISUAL:
107
- doc["parser_id"] = ParserType.PICTURE.value
108
- if doc["type"] == FileType.AURAL:
109
- doc["parser_id"] = ParserType.AUDIO.value
110
- if re.search(r"\.(ppt|pptx|pages)$", filename):
111
- doc["parser_id"] = ParserType.PRESENTATION.value
112
- DocumentService.insert(doc)
113
-
114
- FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
115
- except Exception as e:
116
- err.append(file.filename + ": " + str(e))
117
  if err:
118
  return get_json_result(
119
  data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
@@ -149,7 +111,7 @@ def web_crawl():
149
  try:
150
  filename = duplicate_name(
151
  DocumentService.query,
152
- name=name+".pdf",
153
  kb_id=kb.id)
154
  filetype = filename_type(filename)
155
  if filetype == FileType.OTHER.value:
@@ -414,7 +376,7 @@ def get(doc_id):
414
  if not e:
415
  return get_data_error_result(retmsg="Document not found!")
416
 
417
- b,n = File2DocumentService.get_minio_address(doc_id=doc_id)
418
  response = flask.make_response(MINIO.get(b, n))
419
 
420
  ext = re.search(r"\.([^.]+)$", doc.name)
@@ -484,3 +446,133 @@ def get_image(image_id):
484
  return response
485
  except Exception as e:
486
  return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License
15
  #
16
+ import datetime
17
+ import hashlib
18
+ import json
19
  import os
20
  import pathlib
21
  import re
22
+ import traceback
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from copy import deepcopy
25
+ from io import BytesIO
26
 
27
  import flask
28
  from elasticsearch_dsl import Q
 
30
  from flask_login import login_required, current_user
31
 
32
  from api.db.db_models import Task, File
33
+ from api.db.services.dialog_service import DialogService, ConversationService
34
  from api.db.services.file2document_service import File2DocumentService
35
  from api.db.services.file_service import FileService
36
+ from api.db.services.llm_service import LLMBundle
37
  from api.db.services.task_service import TaskService, queue_tasks
38
+ from api.db.services.user_service import TenantService
39
+ from graphrag.mind_map_extractor import MindMapExtractor
40
+ from rag.app import naive
41
  from rag.nlp import search
42
  from rag.utils.es_conn import ELASTICSEARCH
43
  from api.db.services import duplicate_name
44
  from api.db.services.knowledgebase_service import KnowledgebaseService
45
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
46
  from api.utils import get_uuid
47
+ from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
48
  from api.db.services.document_service import DocumentService
49
+ from api.settings import RetCode, stat_logger
50
  from api.utils.api_utils import get_json_result
51
  from rag.utils.minio_conn import MINIO
52
+ from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
 
53
  from api.utils.web_utils import html2pdf, is_valid_url
54
 
55
 
 
75
  if not e:
76
  raise LookupError("Can't find this knowledgebase!")
77
 
78
+ err, _ = FileService.upload_document(kb, file_objs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if err:
80
  return get_json_result(
81
  data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
 
111
  try:
112
  filename = duplicate_name(
113
  DocumentService.query,
114
+ name=name + ".pdf",
115
  kb_id=kb.id)
116
  filetype = filename_type(filename)
117
  if filetype == FileType.OTHER.value:
 
376
  if not e:
377
  return get_data_error_result(retmsg="Document not found!")
378
 
379
+ b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
380
  response = flask.make_response(MINIO.get(b, n))
381
 
382
  ext = re.search(r"\.([^.]+)$", doc.name)
 
446
  return response
447
  except Exception as e:
448
  return server_error_response(e)
449
+
450
+
451
+ @manager.route('/upload_and_parse', methods=['POST'])
452
+ @login_required
453
+ @validate_request("conversation_id")
454
+ def upload_and_parse():
455
+ req = request.json
456
+ if 'file' not in request.files:
457
+ return get_json_result(
458
+ data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
459
+
460
+ file_objs = request.files.getlist('file')
461
+ for file_obj in file_objs:
462
+ if file_obj.filename == '':
463
+ return get_json_result(
464
+ data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
465
+
466
+ e, conv = ConversationService.get_by_id(req["conversation_id"])
467
+ if not e:
468
+ return get_data_error_result(retmsg="Conversation not found!")
469
+ e, dia = DialogService.get_by_id(conv.dialog_id)
470
+ kb_id = dia.kb_ids[0]
471
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
472
+ if not e:
473
+ raise LookupError("Can't find this knowledgebase!")
474
+
475
+ idxnm = search.index_name(kb.tenant_id)
476
+ if not ELASTICSEARCH.indexExist(idxnm):
477
+ ELASTICSEARCH.createIdx(idxnm, json.load(
478
+ open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
479
+
480
+ embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
481
+
482
+ err, files = FileService.upload_document(kb, file_objs)
483
+ if err:
484
+ return get_json_result(
485
+ data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
486
+
487
+ def dummy(prog=None, msg=""):
488
+ pass
489
+
490
+ parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?。;!?", "layout_recognize": False}
491
+ exe = ThreadPoolExecutor(max_workers=12)
492
+ threads = []
493
+ for d, blob in files:
494
+ kwargs = {
495
+ "callback": dummy,
496
+ "parser_config": parser_config,
497
+ "from_page": 0,
498
+ "to_page": 100000
499
+ }
500
+ threads.append(exe.submit(naive.chunk, d["name"], blob, **kwargs))
501
+
502
+ for (docinfo,_), th in zip(files, threads):
503
+ docs = []
504
+ doc = {
505
+ "doc_id": docinfo["id"],
506
+ "kb_id": [kb.id]
507
+ }
508
+ for ck in th.result():
509
+ d = deepcopy(doc)
510
+ d.update(ck)
511
+ md5 = hashlib.md5()
512
+ md5.update((ck["content_with_weight"] +
513
+ str(d["doc_id"])).encode("utf-8"))
514
+ d["_id"] = md5.hexdigest()
515
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
516
+ d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
517
+ if not d.get("image"):
518
+ docs.append(d)
519
+ continue
520
+
521
+ output_buffer = BytesIO()
522
+ if isinstance(d["image"], bytes):
523
+ output_buffer = BytesIO(d["image"])
524
+ else:
525
+ d["image"].save(output_buffer, format='JPEG')
526
+
527
+ MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
528
+ d["img_id"] = "{}-{}".format(kb.id, d["_id"])
529
+ del d["image"]
530
+ docs.append(d)
531
+
532
+ parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
533
+ docids = [d["id"] for d, _ in files]
534
+ chunk_counts = {id: 0 for id in docids}
535
+ token_counts = {id: 0 for id in docids}
536
+ es_bulk_size = 64
537
+
538
+ def embedding(doc_id, cnts, batch_size=16):
539
+ nonlocal embd_mdl, chunk_counts, token_counts
540
+ vects = []
541
+ for i in range(0, len(cnts), batch_size):
542
+ vts, c = embd_mdl.encode(cnts[i: i + batch_size])
543
+ vects.extend(vts.tolist())
544
+ chunk_counts[doc_id] += len(cnts[i:i + batch_size])
545
+ token_counts[doc_id] += c
546
+ return vects
547
+
548
+ _, tenant = TenantService.get_by_id(kb.tenant_id)
549
+ llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
550
+ for doc_id in docids:
551
+ cks = [c for c in docs if c["doc_id"] == doc_id]
552
+
553
+ if parser_ids[doc_id] != ParserType.PICTURE.value:
554
+ mindmap = MindMapExtractor(llm_bdl)
555
+ try:
556
+ mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
557
+ if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
558
+ cks.append({
559
+ "doc_id": doc_id,
560
+ "kb_id": [kb.id],
561
+ "content_with_weight": mind_map,
562
+ "knowledge_graph_kwd": "mind_map"
563
+ })
564
+ except Exception as e:
565
+ stat_logger.error("Mind map generation error:", traceback.format_exc())
566
+
567
+ vects = embedding(doc_id, cks)
568
+ assert len(cks) == len(vects)
569
+ for i, d in enumerate(cks):
570
+ v = vects[i]
571
+ d["q_%d_vec" % len(v)] = v
572
+ for b in range(0, len(cks), es_bulk_size):
573
+ ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
574
+
575
+ DocumentService.increment_chunk_num(
576
+ doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
577
+
578
+ return get_json_result(data=[d["id"] for d in files])
api/db/services/dialog_service.py CHANGED
@@ -104,7 +104,11 @@ def chat(dialog, messages, stream=True, **kwargs):
104
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
105
  retr = retrievaler if not is_kg else kg_retrievaler
106
 
107
- questions = [m["content"] for m in messages if m["role"] == "user"]
 
 
 
 
108
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
109
  if llm_id2llm_type(dialog.llm_id) == "image2text":
110
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
@@ -144,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
144
  kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
145
  dialog.similarity_threshold,
146
  dialog.vector_similarity_weight,
147
- doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
148
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
149
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
150
  #self-rag
@@ -153,7 +157,7 @@ def chat(dialog, messages, stream=True, **kwargs):
153
  kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
154
  dialog.similarity_threshold,
155
  dialog.vector_similarity_weight,
156
- doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
157
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
158
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
159
 
 
104
  is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
105
  retr = retrievaler if not is_kg else kg_retrievaler
106
 
107
+ questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
108
+ attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
109
+ if "doc_ids" in messages[-1]:
110
+ attachments = messages[-1]["doc_ids"]
111
+
112
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
113
  if llm_id2llm_type(dialog.llm_id) == "image2text":
114
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
 
148
  kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
149
  dialog.similarity_threshold,
150
  dialog.vector_similarity_weight,
151
+ doc_ids=attachments,
152
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
153
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
154
  #self-rag
 
157
  kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
158
  dialog.similarity_threshold,
159
  dialog.vector_similarity_weight,
160
+ doc_ids=attachments,
161
  top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
162
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
163
 
api/db/services/document_service.py CHANGED
@@ -26,7 +26,7 @@ from rag.utils.es_conn import ELASTICSEARCH
26
  from rag.utils.minio_conn import MINIO
27
  from rag.nlp import search
28
 
29
- from api.db import FileType, TaskStatus
30
  from api.db.db_models import DB, Knowledgebase, Tenant, Task
31
  from api.db.db_models import Document
32
  from api.db.services.common_service import CommonService
 
26
  from rag.utils.minio_conn import MINIO
27
  from rag.nlp import search
28
 
29
+ from api.db import FileType, TaskStatus, ParserType
30
  from api.db.db_models import DB, Knowledgebase, Tenant, Task
31
  from api.db.db_models import Document
32
  from api.db.services.common_service import CommonService
api/db/services/file_service.py CHANGED
@@ -13,16 +13,21 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from flask_login import current_user
17
  from peewee import fn
18
 
19
- from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource
20
  from api.db.db_models import DB, File2Document, Knowledgebase
21
  from api.db.db_models import File, Document
 
22
  from api.db.services.common_service import CommonService
23
  from api.db.services.document_service import DocumentService
24
  from api.db.services.file2document_service import File2DocumentService
25
  from api.utils import get_uuid
 
 
26
 
27
 
28
  class FileService(CommonService):
@@ -318,4 +323,60 @@ class FileService(CommonService):
318
  cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
319
  except Exception as e:
320
  print(e)
321
- raise RuntimeError("Database error (File move)!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import re
17
+
18
  from flask_login import current_user
19
  from peewee import fn
20
 
21
+ from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType
22
  from api.db.db_models import DB, File2Document, Knowledgebase
23
  from api.db.db_models import File, Document
24
+ from api.db.services import duplicate_name
25
  from api.db.services.common_service import CommonService
26
  from api.db.services.document_service import DocumentService
27
  from api.db.services.file2document_service import File2DocumentService
28
  from api.utils import get_uuid
29
+ from api.utils.file_utils import filename_type, thumbnail
30
+ from rag.utils.minio_conn import MINIO
31
 
32
 
33
  class FileService(CommonService):
 
323
  cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
324
  except Exception as e:
325
  print(e)
326
+ raise RuntimeError("Database error (File move)!")
327
+
328
+ @classmethod
329
+ @DB.connection_context()
330
+ def upload_document(self, kb, file_objs):
331
+ root_folder = self.get_root_folder(current_user.id)
332
+ pf_id = root_folder["id"]
333
+ self.init_knowledgebase_docs(pf_id, current_user.id)
334
+ kb_root_folder = self.get_kb_folder(current_user.id)
335
+ kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
336
+
337
+ err, files = [], []
338
+ for file in file_objs:
339
+ try:
340
+ MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
341
+ if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
342
+ raise RuntimeError("Exceed the maximum file number of a free user!")
343
+
344
+ filename = duplicate_name(
345
+ DocumentService.query,
346
+ name=file.filename,
347
+ kb_id=kb.id)
348
+ filetype = filename_type(filename)
349
+ if filetype == FileType.OTHER.value:
350
+ raise RuntimeError("This type of file has not been supported yet!")
351
+
352
+ location = filename
353
+ while MINIO.obj_exist(kb.id, location):
354
+ location += "_"
355
+ blob = file.read()
356
+ MINIO.put(kb.id, location, blob)
357
+ doc = {
358
+ "id": get_uuid(),
359
+ "kb_id": kb.id,
360
+ "parser_id": kb.parser_id,
361
+ "parser_config": kb.parser_config,
362
+ "created_by": current_user.id,
363
+ "type": filetype,
364
+ "name": filename,
365
+ "location": location,
366
+ "size": len(blob),
367
+ "thumbnail": thumbnail(filename, blob)
368
+ }
369
+ if doc["type"] == FileType.VISUAL:
370
+ doc["parser_id"] = ParserType.PICTURE.value
371
+ if doc["type"] == FileType.AURAL:
372
+ doc["parser_id"] = ParserType.AUDIO.value
373
+ if re.search(r"\.(ppt|pptx|pages)$", filename):
374
+ doc["parser_id"] = ParserType.PRESENTATION.value
375
+ DocumentService.insert(doc)
376
+
377
+ FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
378
+ files.append((doc, blob))
379
+ except Exception as e:
380
+ err.append(file.filename + ": " + str(e))
381
+
382
+ return err, files
graphrag/index.py CHANGED
@@ -30,24 +30,6 @@ from rag.nlp import rag_tokenizer
30
  from rag.utils import num_tokens_from_string
31
 
32
 
33
- def be_children(obj: dict, keyset:set):
34
- if isinstance(obj, str):
35
- obj = [obj]
36
- if isinstance(obj, list):
37
- for i in obj: keyset.add(i)
38
- return [{"id": re.sub(r"\*+", "", i), "children":[]} for i in obj]
39
- arr = []
40
- for k,v in obj.items():
41
- k = re.sub(r"\*+", "", k)
42
- if not k or k in keyset:continue
43
- keyset.add(k)
44
- arr.append({
45
- "id": k,
46
- "children": be_children(v, keyset)
47
- })
48
- return arr
49
-
50
-
51
  def graph_merge(g1, g2):
52
  g = g2.copy()
53
  for n, attr in g1.nodes(data=True):
@@ -153,16 +135,10 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
153
  mg = mindmap(_chunks).output
154
  if not len(mg.keys()): return chunks
155
 
156
- if len(mg.keys()) > 1:
157
- keyset = set([re.sub(r"\*+", "", k) for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
158
- md_map = {"id": "root", "children": [{"id": re.sub(r"\*+", "", k), "children": be_children(v, keyset)} for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]}
159
- else:
160
- k = re.sub(r"\*+", "", list(mg.keys())[0])
161
- md_map = {"id": k, "children": be_children(list(mg.items())[0][1], set([k]))}
162
- print(json.dumps(md_map, ensure_ascii=False, indent=2))
163
  chunks.append(
164
  {
165
- "content_with_weight": json.dumps(md_map, ensure_ascii=False, indent=2),
166
  "knowledge_graph_kwd": "mind_map"
167
  })
168
 
 
30
  from rag.utils import num_tokens_from_string
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def graph_merge(g1, g2):
34
  g = g2.copy()
35
  for n, attr in g1.nodes(data=True):
 
135
  mg = mindmap(_chunks).output
136
  if not len(mg.keys()): return chunks
137
 
138
+ print(json.dumps(mg, ensure_ascii=False, indent=2))
 
 
 
 
 
 
139
  chunks.append(
140
  {
141
+ "content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
142
  "knowledge_graph_kwd": "mind_map"
143
  })
144
 
graphrag/mind_map_extractor.py CHANGED
@@ -57,6 +57,26 @@ class MindMapExtractor:
57
  self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
58
  self._on_error = on_error or (lambda _e, _s, _d: None)
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def __call__(
61
  self, sections: list[str], prompt_variables: dict[str, Any] | None = None
62
  ) -> MindMapResult:
@@ -86,13 +106,23 @@ class MindMapExtractor:
86
  res.append(_.result())
87
 
88
  merge_json = reduce(self._merge, res)
89
- merge_json = self._list_to_kv(merge_json)
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
  logging.exception("error mind graph")
92
  self._on_error(
93
  e,
94
  traceback.format_exc(), None
95
  )
 
96
 
97
  return MindMapResult(output=merge_json)
98
 
 
57
  self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
58
  self._on_error = on_error or (lambda _e, _s, _d: None)
59
 
60
+ def _key(self, k):
61
+ return re.sub(r"\*+", "", k)
62
+
63
+ def _be_children(self, obj: dict, keyset: set):
64
+ if isinstance(obj, str):
65
+ obj = [obj]
66
+ if isinstance(obj, list):
67
+ for i in obj: keyset.add(i)
68
+ return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj]
69
+ arr = []
70
+ for k, v in obj.items():
71
+ k = self._key(k)
72
+ if not k or k in keyset: continue
73
+ keyset.add(k)
74
+ arr.append({
75
+ "id": k,
76
+ "children": self._be_children(v, keyset)
77
+ })
78
+ return arr
79
+
80
  def __call__(
81
  self, sections: list[str], prompt_variables: dict[str, Any] | None = None
82
  ) -> MindMapResult:
 
106
  res.append(_.result())
107
 
108
  merge_json = reduce(self._merge, res)
109
+ if len(merge_json.keys()) > 1:
110
+ keyset = set(
111
+ [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
112
+ merge_json = {"id": "root",
113
+ "children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in
114
+ merge_json.items() if isinstance(v, dict) and self._key(k)]}
115
+ else:
116
+ k = self._key(list(self._be_children.keys())[0])
117
+ merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))}
118
+
119
  except Exception as e:
120
  logging.exception("error mind graph")
121
  self._on_error(
122
  e,
123
  traceback.format_exc(), None
124
  )
125
+ merge_json = {"error": str(e)}
126
 
127
  return MindMapResult(output=merge_json)
128
 
graphrag/mind_map_prompt.py CHANGED
@@ -23,6 +23,7 @@ MIND_MAP_EXTRACTION_PROMPT = """
23
  4. Add a shot content summary of the bottom level section.
24
 
25
  - Output requirement:
 
26
  - Always try to maximize the number of sub-sections.
27
  - In language of 'Text'
28
  - MUST IN FORMAT OF MARKDOWN
 
23
  4. Add a shot content summary of the bottom level section.
24
 
25
  - Output requirement:
26
+ - Generate at least 4 levels.
27
  - Always try to maximize the number of sub-sections.
28
  - In language of 'Text'
29
  - MUST IN FORMAT OF MARKDOWN