lidp commited on
Commit
e1017ef
·
1 Parent(s): 172caf6

Add benchmark ndcg@10 (#2326)

Browse files

### What problem does this PR solve?

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Files changed (3) hide show
  1. rag/benchmark.py +94 -0
  2. requirements.txt +1 -0
  3. requirements_arm.txt +1 -0
rag/benchmark.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ import argparse
18
+ from collections import defaultdict
19
+ from api.db import FileType, TaskStatus, ParserType, LLMType
20
+ from api.db.services.llm_service import LLMBundle
21
+ from api.db.services.knowledgebase_service import KnowledgebaseService
22
+ from api.settings import retrievaler
23
+ from api.utils import get_uuid
24
+ from rag.nlp import tokenize, search
25
+ from rag.utils.es_conn import ELASTICSEARCH
26
+ from ranx import evaluate
27
+
28
+
29
+ class benchmark_ndcg10:
30
+ def __init__(self, kb_id):
31
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
32
+ self.similarity_threshold = kb.similarity_threshold
33
+ self.vector_similarity_weight = kb.vector_similarity_weight
34
+ self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
35
+
36
+ def _get_benchmarks(self, query, count=16):
37
+ req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
38
+ sres = retrievaler.search(req, search.index_name("benchmark"), self.embd_mdl)
39
+ return sres
40
+
41
+ def _get_retrieval(self, qrels):
42
+ run = defaultdict(dict)
43
+ query_list = list(qrels.keys())
44
+ for query in query_list:
45
+ sres = self._get_benchmarks(query)
46
+ sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
47
+ self.vector_similarity_weight)
48
+ for index, id in enumerate(sres.ids):
49
+ run[query][id] = sim[index]
50
+ return run
51
+
52
+ def embedding(self, docs, batch_size=16):
53
+ vects = []
54
+ cnts = [d["content_with_weight"] for d in docs]
55
+ for i in range(0, len(cnts), batch_size):
56
+ vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
57
+ vects.extend(vts.tolist())
58
+ assert len(docs) == len(vects)
59
+ for i, d in enumerate(docs):
60
+ v = vects[i]
61
+ d["q_%d_vec" % len(v)] = v
62
+ return docs
63
+
64
+ def __call__(self, file_path):
65
+ qrels = defaultdict(dict)
66
+
67
+ docs = []
68
+ with open(file_path) as f:
69
+ for line in f:
70
+ query, text, rel = line.strip('\n').split()
71
+ d = {
72
+ "id": get_uuid()
73
+ }
74
+ tokenize(d, text)
75
+ docs.append(d)
76
+ if len(docs) >= 32:
77
+ ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
78
+ docs = []
79
+ qrels[query][d["id"]] = float(rel)
80
+ docs = self.embedding(docs)
81
+ ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
82
+
83
+ run = self._get_retrieval(qrels)
84
+ return evaluate(qrels, run, "ndcg@10")
85
+
86
+
87
+ if __name__ == '__main__':
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument('-f', '--filepath', default='', help="file path", action='store', required=True)
90
+ parser.add_argument('-k', '--kb_id', default='', help="kb_id", action='store', required=True)
91
+ args = parser.parse_args()
92
+
93
+ ex = benchmark_ndcg10(args.kb_id)
94
+ print(ex(args.filepath))
requirements.txt CHANGED
@@ -70,6 +70,7 @@ python_dateutil==2.8.2
70
  python_pptx==0.6.23
71
  pywencai==0.12.2
72
  qianfan==0.4.6
 
73
  readability_lxml==0.8.1
74
  redis==5.0.3
75
  Requests==2.32.2
 
70
  python_pptx==0.6.23
71
  pywencai==0.12.2
72
  qianfan==0.4.6
73
+ ranx==0.3.20
74
  readability_lxml==0.8.1
75
  redis==5.0.3
76
  Requests==2.32.2
requirements_arm.txt CHANGED
@@ -171,3 +171,4 @@ vertexai==1.64.0
171
  yfinance==0.2.43
172
  pywencai==0.12.2
173
  akshare==1.14.72
 
 
171
  yfinance==0.2.43
172
  pywencai==0.12.2
173
  akshare==1.14.72
174
+ ranx==0.3.20