kwang2049 commited on
Commit
dd26848
·
1 Parent(s): 6d0a17d
Files changed (4) hide show
  1. .gitignore +132 -0
  2. app.py +47 -0
  3. bm25.py +293 -0
  4. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ pip-wheel-metadata/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # pipenv
89
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
91
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
92
+ # install all needed dependencies.
93
+ #Pipfile.lock
94
+
95
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96
+ __pypackages__/
97
+
98
+ # Celery stuff
99
+ celerybeat-schedule
100
+ celerybeat.pid
101
+
102
+ # SageMath parsed files
103
+ *.sage.py
104
+
105
+ # Environments
106
+ .env
107
+ .venv
108
+ env/
109
+ venv/
110
+ ENV/
111
+ env.bak/
112
+ venv.bak/
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ flagged
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Dict, List, Optional, TypedDict
3
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
4
+ from bm25 import BM25Index, BM25Retriever
5
+
6
+ sciq = load_sciq()
7
+ bm25_index = BM25Index.build_from_documents(
8
+ documents=iter(sciq.corpus),
9
+ ndocs=12160,
10
+ show_progress_bar=True,
11
+ )
12
+ bm25_index.save("output/bm25_sciq_index")
13
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index")
14
+
15
+
16
+ class Hit(TypedDict):
17
+ cid: str
18
+ score: float
19
+ text: str
20
+
21
+
22
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
23
+ return_type = List[Hit]
24
+
25
+ ## YOUR_CODE_STARTS_HERE
26
+ cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus}
27
+
28
+
29
+ def search(query: str) -> List[Hit]:
30
+ ranking: Dict[str, float] = bm25_retriever.retrieve(query)
31
+ # Sort the ranking by score in descending order
32
+ sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True)
33
+ hits = []
34
+ for cid, score in sorted_ranking:
35
+ hits.append(Hit(cid=cid, score=score, text=cid2doc[cid]))
36
+ return hits
37
+
38
+
39
+ demo = gr.Interface(
40
+ fn=search,
41
+ inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
42
+ outputs="text",
43
+ title="BM25 Retriever Search",
44
+ description="Search using a BM25 retriever and return ranked documents with scores.",
45
+ )
46
+ ## YOUR_CODE_ENDS_HERE
47
+ demo.launch()
bm25.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from abc import abstractmethod
4
+ import pickle
5
+ import os
6
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
7
+ from nlp4web_codebase.ir.data_loaders.dm import Document
8
+ from nlp4web_codebase.ir.models import BaseRetriever
9
+ from collections import Counter
10
+ import tqdm
11
+ import re
12
+ import math
13
+ import tqdm
14
+ import nltk
15
+
16
+ nltk.download("stopwords", quiet=True)
17
+ from nltk.corpus import stopwords as nltk_stopwords
18
+
19
+ LANGUAGE = "english"
20
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
21
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
22
+
23
+
24
+ def word_splitting(text: str) -> List[str]:
25
+ return word_splitter(text.lower())
26
+
27
+
28
+ def lemmatization(words: List[str]) -> List[str]:
29
+ return words # We ignore lemmatization here for simplicity
30
+
31
+
32
+ def simple_tokenize(text: str) -> List[str]:
33
+ words = word_splitting(text)
34
+ tokenized = list(filter(lambda w: w not in stopwords, words))
35
+ tokenized = lemmatization(tokenized)
36
+ return tokenized
37
+
38
+
39
+ T = TypeVar("T", bound="InvertedIndex")
40
+
41
+
42
+ @dataclass
43
+ class PostingList:
44
+ term: str # The term
45
+ docid_postings: List[
46
+ int
47
+ ] # docid_postings[i] means the docid (int) of the i-th associated posting
48
+ tweight_postings: List[
49
+ float
50
+ ] # tweight_postings[i] means the term weight (float) of the i-th associated posting
51
+
52
+
53
+ @dataclass
54
+ class InvertedIndex:
55
+ posting_lists: List[PostingList] # docid -> posting_list
56
+ vocab: Dict[str, int]
57
+ cid2docid: Dict[str, int] # collection_id -> docid
58
+ collection_ids: List[str] # docid -> collection_id
59
+ doc_texts: Optional[List[str]] = None # docid -> document text
60
+
61
+ def save(self, output_dir: str) -> None:
62
+ os.makedirs(output_dir, exist_ok=True)
63
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
64
+ pickle.dump(self, f)
65
+
66
+ @classmethod
67
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
68
+ index = cls(
69
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
70
+ )
71
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
72
+ index = pickle.load(f)
73
+ return index
74
+
75
+
76
+ # The output of the counting function:
77
+ @dataclass
78
+ class Counting:
79
+ posting_lists: List[PostingList]
80
+ vocab: Dict[str, int]
81
+ cid2docid: Dict[str, int]
82
+ collection_ids: List[str]
83
+ dfs: List[int] # tid -> df
84
+ dls: List[int] # docid -> doc length
85
+ avgdl: float
86
+ nterms: int
87
+ doc_texts: Optional[List[str]] = None
88
+
89
+
90
+ def run_counting(
91
+ documents: Iterable[Document],
92
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
93
+ store_raw: bool = True, # store the document text in doc_texts
94
+ ndocs: Optional[int] = None,
95
+ show_progress_bar: bool = True,
96
+ ) -> Counting:
97
+ """Counting TFs, DFs, doc_lengths, etc."""
98
+ posting_lists: List[PostingList] = []
99
+ vocab: Dict[str, int] = {}
100
+ cid2docid: Dict[str, int] = {}
101
+ collection_ids: List[str] = []
102
+ dfs: List[int] = [] # tid -> df
103
+ dls: List[int] = [] # docid -> doc length
104
+ nterms: int = 0
105
+ doc_texts: Optional[List[str]] = []
106
+ for doc in tqdm.tqdm(
107
+ documents,
108
+ desc="Counting",
109
+ total=ndocs,
110
+ disable=not show_progress_bar,
111
+ ):
112
+ if doc.collection_id in cid2docid:
113
+ continue
114
+ collection_ids.append(doc.collection_id)
115
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
116
+ toks = tokenize_fn(doc.text)
117
+ tok2tf = Counter(toks)
118
+ dls.append(sum(tok2tf.values()))
119
+ for tok, tf in tok2tf.items():
120
+ nterms += tf
121
+ tid = vocab.get(tok, None)
122
+ if tid is None:
123
+ posting_lists.append(
124
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
125
+ )
126
+ tid = vocab.setdefault(tok, len(vocab))
127
+ posting_lists[tid].docid_postings.append(docid)
128
+ posting_lists[tid].tweight_postings.append(tf)
129
+ if tid < len(dfs):
130
+ dfs[tid] += 1
131
+ else:
132
+ dfs.append(0)
133
+ if store_raw:
134
+ doc_texts.append(doc.text)
135
+ else:
136
+ doc_texts = None
137
+ return Counting(
138
+ posting_lists=posting_lists,
139
+ vocab=vocab,
140
+ cid2docid=cid2docid,
141
+ collection_ids=collection_ids,
142
+ dfs=dfs,
143
+ dls=dls,
144
+ avgdl=sum(dls) / len(dls),
145
+ nterms=nterms,
146
+ doc_texts=doc_texts,
147
+ )
148
+
149
+
150
+ @dataclass
151
+ class BM25Index(InvertedIndex):
152
+
153
+ @staticmethod
154
+ def tokenize(text: str) -> List[str]:
155
+ return simple_tokenize(text)
156
+
157
+ @staticmethod
158
+ def cache_term_weights(
159
+ posting_lists: List[PostingList],
160
+ total_docs: int,
161
+ avgdl: float,
162
+ dfs: List[int],
163
+ dls: List[int],
164
+ k1: float,
165
+ b: float,
166
+ ) -> None:
167
+ """Compute term weights and caching"""
168
+
169
+ N = total_docs
170
+ for tid, posting_list in enumerate(
171
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
172
+ ):
173
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
174
+ for i in range(len(posting_list.docid_postings)):
175
+ docid = posting_list.docid_postings[i]
176
+ tf = posting_list.tweight_postings[i]
177
+ dl = dls[docid]
178
+ regularized_tf = BM25Index.calc_regularized_tf(
179
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
180
+ )
181
+ posting_list.tweight_postings[i] = regularized_tf * idf
182
+
183
+ @staticmethod
184
+ def calc_regularized_tf(
185
+ tf: int, dl: float, avgdl: float, k1: float, b: float
186
+ ) -> float:
187
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
188
+
189
+ @staticmethod
190
+ def calc_idf(df: int, N: int):
191
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
192
+
193
+ @classmethod
194
+ def build_from_documents(
195
+ cls: Type[BM25Index],
196
+ documents: Iterable[Document],
197
+ store_raw: bool = True,
198
+ output_dir: Optional[str] = None,
199
+ ndocs: Optional[int] = None,
200
+ show_progress_bar: bool = True,
201
+ k1: float = 0.9,
202
+ b: float = 0.4,
203
+ ) -> BM25Index:
204
+ # Counting TFs, DFs, doc_lengths, etc.:
205
+ counting = run_counting(
206
+ documents=documents,
207
+ tokenize_fn=BM25Index.tokenize,
208
+ store_raw=store_raw,
209
+ ndocs=ndocs,
210
+ show_progress_bar=show_progress_bar,
211
+ )
212
+
213
+ # Compute term weights and caching:
214
+ posting_lists = counting.posting_lists
215
+ total_docs = len(counting.cid2docid)
216
+ BM25Index.cache_term_weights(
217
+ posting_lists=posting_lists,
218
+ total_docs=total_docs,
219
+ avgdl=counting.avgdl,
220
+ dfs=counting.dfs,
221
+ dls=counting.dls,
222
+ k1=k1,
223
+ b=b,
224
+ )
225
+
226
+ # Assembly and save:
227
+ index = BM25Index(
228
+ posting_lists=posting_lists,
229
+ vocab=counting.vocab,
230
+ cid2docid=counting.cid2docid,
231
+ collection_ids=counting.collection_ids,
232
+ doc_texts=counting.doc_texts,
233
+ )
234
+ return index
235
+
236
+
237
+ class BaseInvertedIndexRetriever(BaseRetriever):
238
+
239
+ @property
240
+ @abstractmethod
241
+ def index_class(self) -> Type[InvertedIndex]:
242
+ pass
243
+
244
+ def __init__(self, index_dir: str) -> None:
245
+ self.index = self.index_class.from_saved(index_dir)
246
+
247
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
248
+ toks = self.index.tokenize(query)
249
+ target_docid = self.index.cid2docid[cid]
250
+ term_weights = {}
251
+ for tok in toks:
252
+ if tok not in self.index.vocab:
253
+ continue
254
+ tid = self.index.vocab[tok]
255
+ posting_list = self.index.posting_lists[tid]
256
+ for docid, tweight in zip(
257
+ posting_list.docid_postings, posting_list.tweight_postings
258
+ ):
259
+ if docid == target_docid:
260
+ term_weights[tok] = tweight
261
+ break
262
+ return term_weights
263
+
264
+ def score(self, query: str, cid: str) -> float:
265
+ return sum(self.get_term_weights(query=query, cid=cid).values())
266
+
267
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
268
+ toks = self.index.tokenize(query)
269
+ docid2score: Dict[int, float] = {}
270
+ for tok in toks:
271
+ if tok not in self.index.vocab:
272
+ continue
273
+ tid = self.index.vocab[tok]
274
+ posting_list = self.index.posting_lists[tid]
275
+ for docid, tweight in zip(
276
+ posting_list.docid_postings, posting_list.tweight_postings
277
+ ):
278
+ docid2score.setdefault(docid, 0)
279
+ docid2score[docid] += tweight
280
+ docid2score = dict(
281
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
282
+ )
283
+ return {
284
+ self.index.collection_ids[docid]: score
285
+ for docid, score in docid2score.items()
286
+ }
287
+
288
+
289
+ class BM25Retriever(BaseInvertedIndexRetriever):
290
+
291
+ @property
292
+ def index_class(self) -> Type[BM25Index]:
293
+ return BM25Index
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ nlp4web-codebase @ git+https://github.com/kwang2049/nlp4web-codebase.git@00627e75881a5bb33a695c125d9b0c4016e735c1