File size: 4,638 Bytes
6054f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b691127
6054f54
 
b691127
6054f54
 
 
 
 
b691127
 
 
 
 
 
 
6054f54
b691127
 
 
6054f54
b691127
6054f54
b691127
 
 
 
 
 
 
 
 
 
 
 
 
 
6054f54
b691127
 
 
 
 
6054f54
b691127
6054f54
 
 
 
b691127
6054f54
b691127
 
 
 
 
 
 
6054f54
 
b691127
 
 
 
 
 
 
 
6054f54
 
b691127
 
 
 
 
 
 
 
6054f54
 
 
 
 
 
b691127
6054f54
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import json
from copy import deepcopy
from typing import Dict

import pandas as pd
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr

from rag.nlp.search import Dealer


class KGSearch(Dealer):
    def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
        def merge_into_first(sres, title="") -> Dict[str, str]:
            if not sres:
                return {}
            content_with_weight = ""
            df, texts = [],[]
            for d in sres.values():
                try:
                    df.append(json.loads(d["content_with_weight"]))
                except Exception:
                    texts.append(d["content_with_weight"])
            if df:
                content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
            else:
                content_with_weight = title + "\n" + "\n".join(texts)
            first_id = ""
            first_source = {}
            for k, v in sres.items():
                first_id = id
                first_source = deepcopy(v)
                break
            first_source["content_with_weight"] = content_with_weight
            first_id = next(iter(sres))
            return {first_id: first_source}

        qst = req.get("question", "")
        matchText, keywords = self.qryr.question(qst, min_match=0.05)
        condition = self.get_filters(req)

        ## Entity retrieval
        condition.update({"knowledge_graph_kwd": ["entity"]})
        assert emb_mdl, "No embedding model selected"
        matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
        q_vec = matchDense.embedding_data
        src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
                                 "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
                                 "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
                                 "weight_int", "weight_flt", "rank_int"
                                 ])

        fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})

        ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
        ent_res_fields = self.dataStore.getFields(ent_res, src)
        entities = [d["name_kwd"] for d in ent_res_fields.values()]
        ent_ids = self.dataStore.getChunkIds(ent_res)
        ent_content = merge_into_first(ent_res_fields, "-Entities-")
        if ent_content:
            ent_ids = list(ent_content.keys())

        ## Community retrieval
        condition = self.get_filters(req)
        condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
        comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
        comm_res_fields = self.dataStore.getFields(comm_res, src)
        comm_ids = self.dataStore.getChunkIds(comm_res)
        comm_content = merge_into_first(comm_res_fields, "-Community Report-")
        if comm_content:
            comm_ids = list(comm_content.keys())

        ## Text content retrieval
        condition = self.get_filters(req)
        condition.update({"knowledge_graph_kwd": ["text"]})
        txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
        txt_res_fields = self.dataStore.getFields(txt_res, src)
        txt_ids = self.dataStore.getChunkIds(txt_res)
        txt_content = merge_into_first(txt_res_fields, "-Original Content-")
        if txt_content:
            txt_ids = list(txt_content.keys())

        return self.SearchResult(
            total=len(ent_ids) + len(comm_ids) + len(txt_ids),
            ids=[*ent_ids, *comm_ids, *txt_ids],
            query_vector=q_vec,
            highlight=None,
            field={**ent_content, **comm_content, **txt_content},
            keywords=[]
        )